mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
minor fixes
This commit is contained in:
parent
1a591faf79
commit
7cf9747b55
3 changed files with 48 additions and 2 deletions
|
@ -12,6 +12,8 @@ from rich.text import Text
|
||||||
def sharding_info(arr, name=None):
|
def sharding_info(arr, name=None):
|
||||||
if isinstance(arr, np.ndarray):
|
if isinstance(arr, np.ndarray):
|
||||||
return print_sharding_info(arr, None, name)
|
return print_sharding_info(arr, None, name)
|
||||||
|
if not isinstance(arr, Array):
|
||||||
|
raise ValueError(f"is not a jax array, got {type(arr)}")
|
||||||
|
|
||||||
def _info(sharding):
|
def _info(sharding):
|
||||||
print_sharding_info(arr, sharding, name)
|
print_sharding_info(arr, sharding, name)
|
||||||
|
|
|
@ -79,7 +79,7 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
|
||||||
|
|
||||||
dims = list(range(len(shape)))
|
dims = list(range(len(shape)))
|
||||||
if isinstance(sharding, PmapSharding):
|
if isinstance(sharding, PmapSharding):
|
||||||
console.print("[red bold] Output for PmapSharding")
|
console.print("[red bold]Output for PmapSharding might be incorrect")
|
||||||
if len(shape) > 2:
|
if len(shape) > 2:
|
||||||
raise NotImplementedError("can only visualize PmapSharding with shapes with less than 3 dimensions")
|
raise NotImplementedError("can only visualize PmapSharding with shapes with less than 3 dimensions")
|
||||||
if len(shape) > 2 and not isinstance(sharding, PmapSharding):
|
if len(shape) > 2 and not isinstance(sharding, PmapSharding):
|
||||||
|
|
|
@ -213,6 +213,42 @@ def test_in_jit(capsys):
|
||||||
""".lstrip()
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pmap(capsys):
|
||||||
|
arr = jax.numpy.zeros(shape=(8, 8*3), dtype=jax.numpy.complex64)
|
||||||
|
arr = jax.pmap(lambda x: x ** 2)(arr)
|
||||||
|
sharding_info(arr)
|
||||||
|
sharding_vis(arr)
|
||||||
|
|
||||||
|
assert capsys.readouterr().out == """
|
||||||
|
╭────────────────────────────────────────────────────────────────────────╮
|
||||||
|
│ shape: (8, 24) │
|
||||||
|
│ dtype: complex64 │
|
||||||
|
│ size: 1.5 KiB │
|
||||||
|
│ PmapSharding(sharding_spec=ShardingSpec((Unstacked(8), NoSharding()), │
|
||||||
|
│ (ShardedAxis(axis=0),)), device_ids=[0, 1, 2, 3, 4, 5, 6, 7], │
|
||||||
|
│ device_platform=CPU, device_shape=(8,)) │
|
||||||
|
╰────────────────────────────────────────────────────────────────────────╯
|
||||||
|
Output for PmapSharding might be incorrect
|
||||||
|
┌─────────────────────────────────────────────────────────────────────────┐
|
||||||
|
│ CPU 0 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ CPU 1 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ CPU 2 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ CPU 3 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ CPU 4 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ CPU 5 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ CPU 6 │
|
||||||
|
├─────────────────────────────────────────────────────────────────────────┤
|
||||||
|
│ CPU 7 │
|
||||||
|
└─────────────────────────────────────────────────────────────────────────┘
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
def test_numpy(capsys):
|
def test_numpy(capsys):
|
||||||
arr = np.zeros(shape=(10, 10, 10))
|
arr = np.zeros(shape=(10, 10, 10))
|
||||||
sharding_info(arr)
|
sharding_info(arr)
|
||||||
|
@ -275,5 +311,13 @@ def test_3d_sharded(capsys):
|
||||||
""".lstrip()
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_array(capsys):
|
||||||
|
arr = [1, 2, 3]
|
||||||
|
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
||||||
|
sharding_info(arr)
|
||||||
|
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
||||||
|
sharding_vis(arr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_3d_sharded(None)
|
test_pmap(None)
|
||||||
|
|
Loading…
Reference in a new issue