1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

minor fixes

This commit is contained in:
Lukas Winkler 2023-09-04 16:33:38 +02:00
parent 1a591faf79
commit 7cf9747b55
3 changed files with 48 additions and 2 deletions

View file

@ -12,6 +12,8 @@ from rich.text import Text
def sharding_info(arr, name=None):
if isinstance(arr, np.ndarray):
return print_sharding_info(arr, None, name)
if not isinstance(arr, Array):
raise ValueError(f"is not a jax array, got {type(arr)}")
def _info(sharding):
print_sharding_info(arr, sharding, name)

View file

@ -79,7 +79,7 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
dims = list(range(len(shape)))
if isinstance(sharding, PmapSharding):
console.print("[red bold] Output for PmapSharding")
console.print("[red bold]Output for PmapSharding might be incorrect")
if len(shape) > 2:
raise NotImplementedError("can only visualize PmapSharding with shapes with less than 3 dimensions")
if len(shape) > 2 and not isinstance(sharding, PmapSharding):

View file

@ -213,6 +213,42 @@ def test_in_jit(capsys):
""".lstrip()
def test_pmap(capsys):
arr = jax.numpy.zeros(shape=(8, 8*3), dtype=jax.numpy.complex64)
arr = jax.pmap(lambda x: x ** 2)(arr)
sharding_info(arr)
sharding_vis(arr)
assert capsys.readouterr().out == """
shape: (8, 24)
dtype: complex64
size: 1.5 KiB
PmapSharding(sharding_spec=ShardingSpec((Unstacked(8), NoSharding()),
(ShardedAxis(axis=0),)), device_ids=[0, 1, 2, 3, 4, 5, 6, 7],
device_platform=CPU, device_shape=(8,))
Output for PmapSharding might be incorrect
CPU 0
CPU 1
CPU 2
CPU 3
CPU 4
CPU 5
CPU 6
CPU 7
""".lstrip()
def test_numpy(capsys):
arr = np.zeros(shape=(10, 10, 10))
sharding_info(arr)
@ -275,5 +311,13 @@ def test_3d_sharded(capsys):
""".lstrip()
def test_non_array(capsys):
arr = [1, 2, 3]
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
sharding_info(arr)
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
sharding_vis(arr)
if __name__ == '__main__':
test_3d_sharded(None)
test_pmap(None)