1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

one more test

This commit is contained in:
Lukas Winkler 2023-09-04 16:47:53 +02:00
parent f75389aa0a
commit eaf968a9c6

View file

@ -316,7 +316,6 @@ def test_indirectly_sharded(capsys):
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
"""
arr = jax.numpy.zeros(shape=(16, 16, 16))
arr = jax.device_put(arr)
def func(x):
y = jax.numpy.zeros(shape=(16, 16, 16))
@ -337,6 +336,30 @@ def test_indirectly_sharded(capsys):
""".lstrip()
def test_with_sharding_constraint(capsys):
"""
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
"""
arr = jax.numpy.zeros(shape=(16, 16, 16))
def func(x):
return jax.lax.with_sharding_constraint(x, simple_sharding)
func = jax.jit(func)
arr = func(arr)
sharding_info(arr)
assert capsys.readouterr().out == """
shape: (16, 16, 16)
dtype: float32
size: 16.0 KiB
GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7})
axis 1 is sharded: CPU 0 contains 0:2 (of 16)
""".lstrip()
def test_non_array(capsys):
arr = [1, 2, 3]
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):