mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
one more test
This commit is contained in:
parent
f75389aa0a
commit
eaf968a9c6
1 changed files with 24 additions and 1 deletions
|
@ -316,7 +316,6 @@ def test_indirectly_sharded(capsys):
|
||||||
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
||||||
"""
|
"""
|
||||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||||
arr = jax.device_put(arr)
|
|
||||||
|
|
||||||
def func(x):
|
def func(x):
|
||||||
y = jax.numpy.zeros(shape=(16, 16, 16))
|
y = jax.numpy.zeros(shape=(16, 16, 16))
|
||||||
|
@ -337,6 +336,30 @@ def test_indirectly_sharded(capsys):
|
||||||
""".lstrip()
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_sharding_constraint(capsys):
|
||||||
|
"""
|
||||||
|
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
||||||
|
"""
|
||||||
|
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
return jax.lax.with_sharding_constraint(x, simple_sharding)
|
||||||
|
|
||||||
|
func = jax.jit(func)
|
||||||
|
arr = func(arr)
|
||||||
|
sharding_info(arr)
|
||||||
|
|
||||||
|
assert capsys.readouterr().out == """
|
||||||
|
╭─────────────────────────────────────────────────╮
|
||||||
|
│ shape: (16, 16, 16) │
|
||||||
|
│ dtype: float32 │
|
||||||
|
│ size: 16.0 KiB │
|
||||||
|
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||||
|
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||||
|
╰─────────────────────────────────────────────────╯
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
def test_non_array(capsys):
|
def test_non_array(capsys):
|
||||||
arr = [1, 2, 3]
|
arr = [1, 2, 3]
|
||||||
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
||||||
|
|
Loading…
Reference in a new issue