From 4381a7998ea69d55a8067d8ccb4908ae9a3104dc Mon Sep 17 00:00:00 2001 From: Lukas Winkler Date: Mon, 29 Jul 2024 15:29:03 +0200 Subject: [PATCH] change output format slightly and add shard_map example --- jax_array_info/sharding_info.py | 6 +- tests/jaxtest.py | 200 +++++++++++++++++++------------- 2 files changed, 125 insertions(+), 81 deletions(-) diff --git a/jax_array_info/sharding_info.py b/jax_array_info/sharding_info.py index 35d6842..b8cf7f1 100644 --- a/jax_array_info/sharding_info.py +++ b/jax_array_info/sharding_info.py @@ -63,7 +63,11 @@ def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console): for i, sl in enumerate(slcs): if sl.start is None: continue - console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (of {shape[i]})") + local_size = sl.stop - sl.start + global_size = shape[i] + num_shards = global_size // local_size + console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (1/{num_shards})") + console.print(f" Total size: {global_size}") def print_sharding_info(arr: Array, sharding: Sharding, name=None): diff --git a/tests/jaxtest.py b/tests/jaxtest.py index 945427e..31dd91a 100644 --- a/tests/jaxtest.py +++ b/tests/jaxtest.py @@ -1,10 +1,12 @@ import os +from functools import partial import jax.numpy import numpy as np import pytest from jax._src.sharding_impls import PositionalSharding from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax_array_info import sharding_info, sharding_vis, print_array_stats @@ -25,11 +27,11 @@ mesh_3d = Mesh(devices_3d, axis_names=('a', 'b', 'c')) def test_simple(capsys): arr = jax.numpy.array([1, 2, 3]) - sharding_info(arr) + sharding_info(arr, "arr") sharding_vis(arr) assert capsys.readouterr().out == """ -╭──────────────╮ +╭──── arr ─────╮ │ shape: (3,) │ │ dtype: int32 │ │ size: 12.0 B │ @@ -73,13 +75,14 @@ def test_device_put_sharded(capsys): sharding_info(arr) sharding_vis(arr) assert capsys.readouterr().out == """ -╭───────────────────────────────────────────────╮ -│ shape: (32, 32, 32) │ -│ dtype: complex64 │ -│ size: 256.0 KiB │ -│ NamedSharding: P(None, 'gpus') │ -│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │ -╰───────────────────────────────────────────────╯ +╭─────────────────────────────────────────────╮ +│ shape: (32, 32, 32) │ +│ dtype: complex64 │ +│ size: 256.0 KiB │ +│ NamedSharding: P(None, 'gpus') │ +│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │ +│ Total size: 32 │ +╰─────────────────────────────────────────────╯ ───────────── showing dims [0, 1] from original shape (32, 32, 32) ───────────── ┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ @@ -102,13 +105,14 @@ def test_operator_sharded(capsys): sharding_info(arr) sharding_vis(arr) assert capsys.readouterr().out == """ -╭───────────────────────────────────────────────╮ -│ shape: (32, 32, 32) │ -│ dtype: complex64 │ -│ size: 256.0 KiB │ -│ NamedSharding: P(None, 'gpus') │ -│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │ -╰───────────────────────────────────────────────╯ +╭─────────────────────────────────────────────╮ +│ shape: (32, 32, 32) │ +│ dtype: complex64 │ +│ size: 256.0 KiB │ +│ NamedSharding: P(None, 'gpus') │ +│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │ +│ Total size: 32 │ +╰─────────────────────────────────────────────╯ ───────────── showing dims [0, 1] from original shape (32, 32, 32) ───────────── ┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ @@ -135,13 +139,14 @@ def test_jit_out_sharding_sharded(capsys): sharding_info(arr) sharding_vis(arr) assert capsys.readouterr().out == """ -╭───────────────────────────────────────────────╮ -│ shape: (32, 32, 32) │ -│ dtype: complex64 │ -│ size: 256.0 KiB │ -│ NamedSharding: P(None, 'gpus') │ -│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │ -╰───────────────────────────────────────────────╯ +╭─────────────────────────────────────────────╮ +│ shape: (32, 32, 32) │ +│ dtype: complex64 │ +│ size: 256.0 KiB │ +│ NamedSharding: P(None, 'gpus') │ +│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │ +│ Total size: 32 │ +╰─────────────────────────────────────────────╯ ───────────── showing dims [0, 1] from original shape (32, 32, 32) ───────────── ┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ @@ -169,7 +174,8 @@ def test_positional_sharded(capsys): │ size: 256.0 B │ │ PositionalSharding: │ │ [{CPU 0} {CPU 1} {CPU 2} {CPU 3} {CPU 4} {CPU 5} {CPU 6} {CPU 7}] │ -│ axis 0 is sharded: CPU 0 contains 0:4 (of 32) │ +│ axis 0 is sharded: CPU 0 contains 0:4 (1/8) │ +│ Total size: 32 │ ╰───────────────────────────────────────────────────────────────────╯ ┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │ @@ -189,22 +195,23 @@ def test_in_jit(capsys): func = jax.jit(func) func(arr) assert capsys.readouterr().out == """ -╭───────────────────────────────────────────────╮ -│ shape: (32, 32, 32) │ -│ dtype: complex64 │ -│ size: 256.0 KiB │ -│ called in jit │ -│ PositionalSharding: │ -│ [[[{CPU 0}] │ -│ [{CPU 1}] │ -│ [{CPU 2}] │ -│ [{CPU 3}] │ -│ [{CPU 4}] │ -│ [{CPU 5}] │ -│ [{CPU 6}] │ -│ [{CPU 7}]]] │ -│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │ -╰───────────────────────────────────────────────╯ +╭─────────────────────────────────────────────╮ +│ shape: (32, 32, 32) │ +│ dtype: complex64 │ +│ size: 256.0 KiB │ +│ called in jit │ +│ PositionalSharding: │ +│ [[[{CPU 0}] │ +│ [{CPU 1}] │ +│ [{CPU 2}] │ +│ [{CPU 3}] │ +│ [{CPU 4}] │ +│ [{CPU 5}] │ +│ [{CPU 6}] │ +│ [{CPU 7}]]] │ +│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │ +│ Total size: 32 │ +╰─────────────────────────────────────────────╯ ───────────── showing dims [0, 1] from original shape (32, 32, 32) ───────────── ┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ @@ -277,14 +284,16 @@ def test_2d_sharded(capsys): sharding_info(arr) sharding_vis(arr) assert capsys.readouterr().out == """ -╭────────────────────────────────────────────────╮ -│ shape: (32, 32, 32) │ -│ dtype: complex64 │ -│ size: 256.0 KiB │ -│ NamedSharding: P(None, 'a', 'b') │ -│ axis 1 is sharded: CPU 0 contains 0:8 (of 32) │ -│ axis 2 is sharded: CPU 0 contains 0:16 (of 32) │ -╰────────────────────────────────────────────────╯ +╭──────────────────────────────────────────────╮ +│ shape: (32, 32, 32) │ +│ dtype: complex64 │ +│ size: 256.0 KiB │ +│ NamedSharding: P(None, 'a', 'b') │ +│ axis 1 is sharded: CPU 0 contains 0:8 (1/4) │ +│ Total size: 32 │ +│ axis 2 is sharded: CPU 0 contains 0:16 (1/2) │ +│ Total size: 32 │ +╰──────────────────────────────────────────────╯ ───────────── showing dims [1, 2] from original shape (32, 32, 32) ───────────── ┌───────┬───────┐ │ CPU 0 │ CPU 1 │ @@ -306,15 +315,44 @@ def test_3d_sharded(capsys): match=r"can only visualize up to 2 sharded dimension. \[0, 1, 2\] are sharded."): sharding_vis(arr) assert capsys.readouterr().out == """ -╭────────────────────────────────────────────────╮ -│ shape: (32, 32, 32) │ -│ dtype: complex64 │ -│ size: 256.0 KiB │ -│ NamedSharding: P('a', 'b', 'c') │ -│ axis 0 is sharded: CPU 0 contains 0:16 (of 32) │ -│ axis 1 is sharded: CPU 0 contains 0:16 (of 32) │ -│ axis 2 is sharded: CPU 0 contains 0:16 (of 32) │ -╰────────────────────────────────────────────────╯ +╭──────────────────────────────────────────────╮ +│ shape: (32, 32, 32) │ +│ dtype: complex64 │ +│ size: 256.0 KiB │ +│ NamedSharding: P('a', 'b', 'c') │ +│ axis 0 is sharded: CPU 0 contains 0:16 (1/2) │ +│ Total size: 32 │ +│ axis 1 is sharded: CPU 0 contains 0:16 (1/2) │ +│ Total size: 32 │ +│ axis 2 is sharded: CPU 0 contains 0:16 (1/2) │ +│ Total size: 32 │ +╰──────────────────────────────────────────────╯ +""".lstrip() + + +def test_shard_map(capsys): + """ + https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html + """ + arr = jax.numpy.zeros(shape=(16, 16)) + + @partial(shard_map, mesh=mesh, in_specs=P(None, 'gpus'), out_specs=P(None, 'gpus')) + def test(a): + # sharding_info(a,"input") # doesn't seem to work inside a shard_map + return a ** 2 + + out = test(arr) + + sharding_info(out) + assert capsys.readouterr().out == """ +╭─────────────────────────────────────────────╮ +│ shape: (16, 16) │ +│ dtype: float32 │ +│ size: 1.0 KiB │ +│ NamedSharding: P(None, 'gpus') │ +│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │ +│ Total size: 16 │ +╰─────────────────────────────────────────────╯ """.lstrip() @@ -332,22 +370,23 @@ def test_indirectly_sharded(capsys): func = jax.jit(func, out_shardings=simple_sharding) arr = func(arr) assert capsys.readouterr().out == """ -╭───────────────────────────────────────────────╮ -│ shape: (16, 16, 16) │ -│ dtype: float32 │ -│ size: 16.0 KiB │ -│ called in jit │ -│ PositionalSharding: │ -│ [[[{CPU 0}] │ -│ [{CPU 1}] │ -│ [{CPU 2}] │ -│ [{CPU 3}] │ -│ [{CPU 4}] │ -│ [{CPU 5}] │ -│ [{CPU 6}] │ -│ [{CPU 7}]]] │ -│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │ -╰───────────────────────────────────────────────╯ +╭─────────────────────────────────────────────╮ +│ shape: (16, 16, 16) │ +│ dtype: float32 │ +│ size: 16.0 KiB │ +│ called in jit │ +│ PositionalSharding: │ +│ [[[{CPU 0}] │ +│ [{CPU 1}] │ +│ [{CPU 2}] │ +│ [{CPU 3}] │ +│ [{CPU 4}] │ +│ [{CPU 5}] │ +│ [{CPU 6}] │ +│ [{CPU 7}]]] │ +│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │ +│ Total size: 16 │ +╰─────────────────────────────────────────────╯ """.lstrip() @@ -362,13 +401,14 @@ def test_with_sharding_constraint(capsys): sharding_info(arr) assert capsys.readouterr().out == """ -╭───────────────────────────────────────────────╮ -│ shape: (16, 16, 16) │ -│ dtype: float32 │ -│ size: 16.0 KiB │ -│ GSPMDSharding({devices=[1,8,1]<=[8]}) │ -│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │ -╰───────────────────────────────────────────────╯ +╭─────────────────────────────────────────────╮ +│ shape: (16, 16, 16) │ +│ dtype: float32 │ +│ size: 16.0 KiB │ +│ GSPMDSharding({devices=[1,8,1]<=[8]}) │ +│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │ +│ Total size: 16 │ +╰─────────────────────────────────────────────╯ """.lstrip()