1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

change output format slightly and add shard_map example

This commit is contained in:
Lukas Winkler 2024-07-29 15:29:03 +02:00
parent 9ac3ea6715
commit 4381a7998e
Signed by: lukas
GPG key ID: 54DE4D798D244853
2 changed files with 125 additions and 81 deletions

View file

@ -63,7 +63,11 @@ def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
for i, sl in enumerate(slcs): for i, sl in enumerate(slcs):
if sl.start is None: if sl.start is None:
continue continue
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (of {shape[i]})") local_size = sl.stop - sl.start
global_size = shape[i]
num_shards = global_size // local_size
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (1/{num_shards})")
console.print(f" Total size: {global_size}")
def print_sharding_info(arr: Array, sharding: Sharding, name=None): def print_sharding_info(arr: Array, sharding: Sharding, name=None):

View file

@ -1,10 +1,12 @@
import os import os
from functools import partial
import jax.numpy import jax.numpy
import numpy as np import numpy as np
import pytest import pytest
from jax._src.sharding_impls import PositionalSharding from jax._src.sharding_impls import PositionalSharding
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax_array_info import sharding_info, sharding_vis, print_array_stats from jax_array_info import sharding_info, sharding_vis, print_array_stats
@ -25,11 +27,11 @@ mesh_3d = Mesh(devices_3d, axis_names=('a', 'b', 'c'))
def test_simple(capsys): def test_simple(capsys):
arr = jax.numpy.array([1, 2, 3]) arr = jax.numpy.array([1, 2, 3])
sharding_info(arr) sharding_info(arr, "arr")
sharding_vis(arr) sharding_vis(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
arr
shape: (3,) shape: (3,)
dtype: int32 dtype: int32
size: 12.0 B size: 12.0 B
@ -73,13 +75,14 @@ def test_device_put_sharded(capsys):
sharding_info(arr) sharding_info(arr)
sharding_vis(arr) sharding_vis(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (32, 32, 32) shape: (32, 32, 32)
dtype: complex64 dtype: complex64
size: 256.0 KiB size: 256.0 KiB
NamedSharding: P(None, 'gpus') NamedSharding: P(None, 'gpus')
axis 1 is sharded: CPU 0 contains 0:4 (of 32) axis 1 is sharded: CPU 0 contains 0:4 (1/8)
Total size: 32
showing dims [0, 1] from original shape (32, 32, 32) showing dims [0, 1] from original shape (32, 32, 32)
@ -102,13 +105,14 @@ def test_operator_sharded(capsys):
sharding_info(arr) sharding_info(arr)
sharding_vis(arr) sharding_vis(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (32, 32, 32) shape: (32, 32, 32)
dtype: complex64 dtype: complex64
size: 256.0 KiB size: 256.0 KiB
NamedSharding: P(None, 'gpus') NamedSharding: P(None, 'gpus')
axis 1 is sharded: CPU 0 contains 0:4 (of 32) axis 1 is sharded: CPU 0 contains 0:4 (1/8)
Total size: 32
showing dims [0, 1] from original shape (32, 32, 32) showing dims [0, 1] from original shape (32, 32, 32)
@ -135,13 +139,14 @@ def test_jit_out_sharding_sharded(capsys):
sharding_info(arr) sharding_info(arr)
sharding_vis(arr) sharding_vis(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (32, 32, 32) shape: (32, 32, 32)
dtype: complex64 dtype: complex64
size: 256.0 KiB size: 256.0 KiB
NamedSharding: P(None, 'gpus') NamedSharding: P(None, 'gpus')
axis 1 is sharded: CPU 0 contains 0:4 (of 32) axis 1 is sharded: CPU 0 contains 0:4 (1/8)
Total size: 32
showing dims [0, 1] from original shape (32, 32, 32) showing dims [0, 1] from original shape (32, 32, 32)
@ -169,7 +174,8 @@ def test_positional_sharded(capsys):
size: 256.0 B size: 256.0 B
PositionalSharding: PositionalSharding:
[{CPU 0} {CPU 1} {CPU 2} {CPU 3} {CPU 4} {CPU 5} {CPU 6} {CPU 7}] [{CPU 0} {CPU 1} {CPU 2} {CPU 3} {CPU 4} {CPU 5} {CPU 6} {CPU 7}]
axis 0 is sharded: CPU 0 contains 0:4 (of 32) axis 0 is sharded: CPU 0 contains 0:4 (1/8)
Total size: 32
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7 CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
@ -189,22 +195,23 @@ def test_in_jit(capsys):
func = jax.jit(func) func = jax.jit(func)
func(arr) func(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (32, 32, 32) shape: (32, 32, 32)
dtype: complex64 dtype: complex64
size: 256.0 KiB size: 256.0 KiB
called in jit called in jit
PositionalSharding: PositionalSharding:
[[[{CPU 0}] [[[{CPU 0}]
[{CPU 1}] [{CPU 1}]
[{CPU 2}] [{CPU 2}]
[{CPU 3}] [{CPU 3}]
[{CPU 4}] [{CPU 4}]
[{CPU 5}] [{CPU 5}]
[{CPU 6}] [{CPU 6}]
[{CPU 7}]]] [{CPU 7}]]]
axis 1 is sharded: CPU 0 contains 0:4 (of 32) axis 1 is sharded: CPU 0 contains 0:4 (1/8)
Total size: 32
showing dims [0, 1] from original shape (32, 32, 32) showing dims [0, 1] from original shape (32, 32, 32)
@ -277,14 +284,16 @@ def test_2d_sharded(capsys):
sharding_info(arr) sharding_info(arr)
sharding_vis(arr) sharding_vis(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (32, 32, 32) shape: (32, 32, 32)
dtype: complex64 dtype: complex64
size: 256.0 KiB size: 256.0 KiB
NamedSharding: P(None, 'a', 'b') NamedSharding: P(None, 'a', 'b')
axis 1 is sharded: CPU 0 contains 0:8 (of 32) axis 1 is sharded: CPU 0 contains 0:8 (1/4)
axis 2 is sharded: CPU 0 contains 0:16 (of 32) Total size: 32
axis 2 is sharded: CPU 0 contains 0:16 (1/2)
Total size: 32
showing dims [1, 2] from original shape (32, 32, 32) showing dims [1, 2] from original shape (32, 32, 32)
CPU 0 CPU 1 CPU 0 CPU 1
@ -306,15 +315,44 @@ def test_3d_sharded(capsys):
match=r"can only visualize up to 2 sharded dimension. \[0, 1, 2\] are sharded."): match=r"can only visualize up to 2 sharded dimension. \[0, 1, 2\] are sharded."):
sharding_vis(arr) sharding_vis(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (32, 32, 32) shape: (32, 32, 32)
dtype: complex64 dtype: complex64
size: 256.0 KiB size: 256.0 KiB
NamedSharding: P('a', 'b', 'c') NamedSharding: P('a', 'b', 'c')
axis 0 is sharded: CPU 0 contains 0:16 (of 32) axis 0 is sharded: CPU 0 contains 0:16 (1/2)
axis 1 is sharded: CPU 0 contains 0:16 (of 32) Total size: 32
axis 2 is sharded: CPU 0 contains 0:16 (of 32) axis 1 is sharded: CPU 0 contains 0:16 (1/2)
Total size: 32
axis 2 is sharded: CPU 0 contains 0:16 (1/2)
Total size: 32
""".lstrip()
def test_shard_map(capsys):
"""
https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html
"""
arr = jax.numpy.zeros(shape=(16, 16))
@partial(shard_map, mesh=mesh, in_specs=P(None, 'gpus'), out_specs=P(None, 'gpus'))
def test(a):
# sharding_info(a,"input") # doesn't seem to work inside a shard_map
return a ** 2
out = test(arr)
sharding_info(out)
assert capsys.readouterr().out == """
shape: (16, 16)
dtype: float32
size: 1.0 KiB
NamedSharding: P(None, 'gpus')
axis 1 is sharded: CPU 0 contains 0:2 (1/8)
Total size: 16
""".lstrip() """.lstrip()
@ -332,22 +370,23 @@ def test_indirectly_sharded(capsys):
func = jax.jit(func, out_shardings=simple_sharding) func = jax.jit(func, out_shardings=simple_sharding)
arr = func(arr) arr = func(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (16, 16, 16) shape: (16, 16, 16)
dtype: float32 dtype: float32
size: 16.0 KiB size: 16.0 KiB
called in jit called in jit
PositionalSharding: PositionalSharding:
[[[{CPU 0}] [[[{CPU 0}]
[{CPU 1}] [{CPU 1}]
[{CPU 2}] [{CPU 2}]
[{CPU 3}] [{CPU 3}]
[{CPU 4}] [{CPU 4}]
[{CPU 5}] [{CPU 5}]
[{CPU 6}] [{CPU 6}]
[{CPU 7}]]] [{CPU 7}]]]
axis 1 is sharded: CPU 0 contains 0:2 (of 16) axis 1 is sharded: CPU 0 contains 0:2 (1/8)
Total size: 16
""".lstrip() """.lstrip()
@ -362,13 +401,14 @@ def test_with_sharding_constraint(capsys):
sharding_info(arr) sharding_info(arr)
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
shape: (16, 16, 16) shape: (16, 16, 16)
dtype: float32 dtype: float32
size: 16.0 KiB size: 16.0 KiB
GSPMDSharding({devices=[1,8,1]<=[8]}) GSPMDSharding({devices=[1,8,1]<=[8]})
axis 1 is sharded: CPU 0 contains 0:2 (of 16) axis 1 is sharded: CPU 0 contains 0:2 (1/8)
Total size: 16
""".lstrip() """.lstrip()