1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

one more jit example

This commit is contained in:
Lukas Winkler 2023-09-04 16:42:32 +02:00
parent 7cf9747b55
commit f75389aa0a

View file

@ -311,6 +311,32 @@ def test_3d_sharded(capsys):
""".lstrip()
def test_indirectly_sharded(capsys):
"""
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
"""
arr = jax.numpy.zeros(shape=(16, 16, 16))
arr = jax.device_put(arr)
def func(x):
y = jax.numpy.zeros(shape=(16, 16, 16))
sharding_info(y)
return x * y
func = jax.jit(func, out_shardings=simple_sharding)
arr = func(arr)
assert capsys.readouterr().out == """
shape: (16, 16, 16)
dtype: float32
size: 16.0 KiB
called in jit
GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7})
axis 1 is sharded: CPU 0 contains 0:2 (of 16)
""".lstrip()
def test_non_array(capsys):
arr = [1, 2, 3]
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
@ -320,4 +346,4 @@ def test_non_array(capsys):
if __name__ == '__main__':
test_pmap(None)
test_indirectly_sharded(None)