mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
one more test
This commit is contained in:
parent
f75389aa0a
commit
eaf968a9c6
1 changed files with 24 additions and 1 deletions
|
@ -316,7 +316,6 @@ def test_indirectly_sharded(capsys):
|
|||
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
||||
"""
|
||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
arr = jax.device_put(arr)
|
||||
|
||||
def func(x):
|
||||
y = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
|
@ -337,6 +336,30 @@ def test_indirectly_sharded(capsys):
|
|||
""".lstrip()
|
||||
|
||||
|
||||
def test_with_sharding_constraint(capsys):
|
||||
"""
|
||||
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
||||
"""
|
||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
|
||||
def func(x):
|
||||
return jax.lax.with_sharding_constraint(x, simple_sharding)
|
||||
|
||||
func = jax.jit(func)
|
||||
arr = func(arr)
|
||||
sharding_info(arr)
|
||||
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
def test_non_array(capsys):
|
||||
arr = [1, 2, 3]
|
||||
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
||||
|
|
Loading…
Reference in a new issue