From f75389aa0a4495992ad1544fa85b32800a87cd88 Mon Sep 17 00:00:00 2001 From: Lukas Winkler Date: Mon, 4 Sep 2023 16:42:32 +0200 Subject: [PATCH] one more jit example --- tests/jaxtest.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/jaxtest.py b/tests/jaxtest.py index 890beb3..0026137 100644 --- a/tests/jaxtest.py +++ b/tests/jaxtest.py @@ -214,7 +214,7 @@ def test_in_jit(capsys): def test_pmap(capsys): - arr = jax.numpy.zeros(shape=(8, 8*3), dtype=jax.numpy.complex64) + arr = jax.numpy.zeros(shape=(8, 8 * 3), dtype=jax.numpy.complex64) arr = jax.pmap(lambda x: x ** 2)(arr) sharding_info(arr) sharding_vis(arr) @@ -311,6 +311,32 @@ def test_3d_sharded(capsys): """.lstrip() +def test_indirectly_sharded(capsys): + """ + y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function + """ + arr = jax.numpy.zeros(shape=(16, 16, 16)) + arr = jax.device_put(arr) + + def func(x): + y = jax.numpy.zeros(shape=(16, 16, 16)) + sharding_info(y) + return x * y + + func = jax.jit(func, out_shardings=simple_sharding) + arr = func(arr) + assert capsys.readouterr().out == """ +╭─────────────────────────────────────────────────╮ +│ shape: (16, 16, 16) │ +│ dtype: float32 │ +│ size: 16.0 KiB │ +│ called in jit │ +│ GSPMDSharding({devices=[1,8,1]0,1,2,3,4,5,6,7}) │ +│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │ +╰─────────────────────────────────────────────────╯ +""".lstrip() + + def test_non_array(capsys): arr = [1, 2, 3] with pytest.raises(ValueError, match="is not a jax array, got "): @@ -320,4 +346,4 @@ def test_non_array(capsys): if __name__ == '__main__': - test_pmap(None) + test_indirectly_sharded(None)