mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
adapt tests for results with jax 0.4.19
This commit is contained in:
parent
eaf968a9c6
commit
fd64100565
1 changed files with 59 additions and 43 deletions
102
tests/jaxtest.py
102
tests/jaxtest.py
|
@ -30,12 +30,12 @@ def test_simple(capsys):
|
|||
sharding_vis(arr)
|
||||
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────╮
|
||||
│ shape: (3,) │
|
||||
│ dtype: int32 │
|
||||
│ size: 12.0 B │
|
||||
│ GSPMDSharding({replicated}) │
|
||||
╰─────────────────────────────╯
|
||||
╭──────────────╮
|
||||
│ shape: (3,) │
|
||||
│ dtype: int32 │
|
||||
│ size: 12.0 B │
|
||||
│ not sharded │
|
||||
╰──────────────╯
|
||||
┌───────┐
|
||||
│ CPU 0 │
|
||||
└───────┘
|
||||
|
@ -103,13 +103,13 @@ def test_operator_sharded(capsys):
|
|||
sharding_info(arr)
|
||||
sharding_vis(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ │ │ │ │ │ │ │ │
|
||||
|
@ -136,13 +136,13 @@ def test_jit_out_sharding_sharded(capsys):
|
|||
sharding_info(arr)
|
||||
sharding_vis(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ │ │ │ │ │ │ │ │
|
||||
|
@ -190,14 +190,22 @@ def test_in_jit(capsys):
|
|||
func = jax.jit(func)
|
||||
func(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ called in jit │
|
||||
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ called in jit │
|
||||
│ PositionalSharding: │
|
||||
│ [[[{CPU 0}] │
|
||||
│ [{CPU 1}] │
|
||||
│ [{CPU 2}] │
|
||||
│ [{CPU 3}] │
|
||||
│ [{CPU 4}] │
|
||||
│ [{CPU 5}] │
|
||||
│ [{CPU 6}] │
|
||||
│ [{CPU 7}]]] │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ │ │ │ │ │ │ │ │
|
||||
|
@ -325,14 +333,22 @@ def test_indirectly_sharded(capsys):
|
|||
func = jax.jit(func, out_shardings=simple_sharding)
|
||||
arr = func(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ called in jit │
|
||||
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ called in jit │
|
||||
│ PositionalSharding: │
|
||||
│ [[[{CPU 0}] │
|
||||
│ [{CPU 1}] │
|
||||
│ [{CPU 2}] │
|
||||
│ [{CPU 3}] │
|
||||
│ [{CPU 4}] │
|
||||
│ [{CPU 5}] │
|
||||
│ [{CPU 6}] │
|
||||
│ [{CPU 7}]]] │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
|
@ -350,13 +366,13 @@ def test_with_sharding_constraint(capsys):
|
|||
sharding_info(arr)
|
||||
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ GSPMDSharding({devices=[1,8,1]<=[8]}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue