mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
one more jit example
This commit is contained in:
parent
7cf9747b55
commit
f75389aa0a
1 changed files with 28 additions and 2 deletions
|
@ -214,7 +214,7 @@ def test_in_jit(capsys):
|
|||
|
||||
|
||||
def test_pmap(capsys):
|
||||
arr = jax.numpy.zeros(shape=(8, 8*3), dtype=jax.numpy.complex64)
|
||||
arr = jax.numpy.zeros(shape=(8, 8 * 3), dtype=jax.numpy.complex64)
|
||||
arr = jax.pmap(lambda x: x ** 2)(arr)
|
||||
sharding_info(arr)
|
||||
sharding_vis(arr)
|
||||
|
@ -311,6 +311,32 @@ def test_3d_sharded(capsys):
|
|||
""".lstrip()
|
||||
|
||||
|
||||
def test_indirectly_sharded(capsys):
|
||||
"""
|
||||
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
||||
"""
|
||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
arr = jax.device_put(arr)
|
||||
|
||||
def func(x):
|
||||
y = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
sharding_info(y)
|
||||
return x * y
|
||||
|
||||
func = jax.jit(func, out_shardings=simple_sharding)
|
||||
arr = func(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ called in jit │
|
||||
│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰─────────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
def test_non_array(capsys):
|
||||
arr = [1, 2, 3]
|
||||
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
||||
|
@ -320,4 +346,4 @@ def test_non_array(capsys):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_pmap(None)
|
||||
test_indirectly_sharded(None)
|
||||
|
|
Loading…
Reference in a new issue