1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

adapt tests for results with jax 0.4.19

This commit is contained in:
Lukas Winkler 2023-10-25 13:37:36 +02:00
parent eaf968a9c6
commit fd64100565

View file

@ -30,12 +30,12 @@ def test_simple(capsys):
sharding_vis(arr)
assert capsys.readouterr().out == """
shape: (3,)
dtype: int32
size: 12.0 B
GSPMDSharding({replicated})
shape: (3,)
dtype: int32
size: 12.0 B
not sharded
CPU 0
@ -103,13 +103,13 @@ def test_operator_sharded(capsys):
sharding_info(arr)
sharding_vis(arr)
assert capsys.readouterr().out == """
shape: (32, 32, 32)
dtype: complex64
size: 256.0 KiB
GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7})
axis 1 is sharded: CPU 0 contains 0:4 (of 32)
shape: (32, 32, 32)
dtype: complex64
size: 256.0 KiB
NamedSharding: P(None, 'gpus')
axis 1 is sharded: CPU 0 contains 0:4 (of 32)
showing dims [0, 1] from original shape (32, 32, 32)
@ -136,13 +136,13 @@ def test_jit_out_sharding_sharded(capsys):
sharding_info(arr)
sharding_vis(arr)
assert capsys.readouterr().out == """
shape: (32, 32, 32)
dtype: complex64
size: 256.0 KiB
GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7})
axis 1 is sharded: CPU 0 contains 0:4 (of 32)
shape: (32, 32, 32)
dtype: complex64
size: 256.0 KiB
NamedSharding: P(None, 'gpus')
axis 1 is sharded: CPU 0 contains 0:4 (of 32)
showing dims [0, 1] from original shape (32, 32, 32)
@ -190,14 +190,22 @@ def test_in_jit(capsys):
func = jax.jit(func)
func(arr)
assert capsys.readouterr().out == """
shape: (32, 32, 32)
dtype: complex64
size: 256.0 KiB
called in jit
GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7})
axis 1 is sharded: CPU 0 contains 0:4 (of 32)
shape: (32, 32, 32)
dtype: complex64
size: 256.0 KiB
called in jit
PositionalSharding:
[[[{CPU 0}]
[{CPU 1}]
[{CPU 2}]
[{CPU 3}]
[{CPU 4}]
[{CPU 5}]
[{CPU 6}]
[{CPU 7}]]]
axis 1 is sharded: CPU 0 contains 0:4 (of 32)
showing dims [0, 1] from original shape (32, 32, 32)
@ -325,14 +333,22 @@ def test_indirectly_sharded(capsys):
func = jax.jit(func, out_shardings=simple_sharding)
arr = func(arr)
assert capsys.readouterr().out == """
shape: (16, 16, 16)
dtype: float32
size: 16.0 KiB
called in jit
GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7})
axis 1 is sharded: CPU 0 contains 0:2 (of 16)
shape: (16, 16, 16)
dtype: float32
size: 16.0 KiB
called in jit
PositionalSharding:
[[[{CPU 0}]
[{CPU 1}]
[{CPU 2}]
[{CPU 3}]
[{CPU 4}]
[{CPU 5}]
[{CPU 6}]
[{CPU 7}]]]
axis 1 is sharded: CPU 0 contains 0:2 (of 16)
""".lstrip()
@ -350,13 +366,13 @@ def test_with_sharding_constraint(capsys):
sharding_info(arr)
assert capsys.readouterr().out == """
shape: (16, 16, 16)
dtype: float32
size: 16.0 KiB
GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7})
axis 1 is sharded: CPU 0 contains 0:2 (of 16)
shape: (16, 16, 16)
dtype: float32
size: 16.0 KiB
GSPMDSharding({devices=[1,8,1]<=[8]})
axis 1 is sharded: CPU 0 contains 0:2 (of 16)
""".lstrip()