mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
change output format slightly and add shard_map example
This commit is contained in:
parent
9ac3ea6715
commit
4381a7998e
2 changed files with 125 additions and 81 deletions
|
@ -63,7 +63,11 @@ def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
||||||
for i, sl in enumerate(slcs):
|
for i, sl in enumerate(slcs):
|
||||||
if sl.start is None:
|
if sl.start is None:
|
||||||
continue
|
continue
|
||||||
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (of {shape[i]})")
|
local_size = sl.stop - sl.start
|
||||||
|
global_size = shape[i]
|
||||||
|
num_shards = global_size // local_size
|
||||||
|
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (1/{num_shards})")
|
||||||
|
console.print(f" Total size: {global_size}")
|
||||||
|
|
||||||
|
|
||||||
def print_sharding_info(arr: Array, sharding: Sharding, name=None):
|
def print_sharding_info(arr: Array, sharding: Sharding, name=None):
|
||||||
|
|
200
tests/jaxtest.py
200
tests/jaxtest.py
|
@ -1,10 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import jax.numpy
|
import jax.numpy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from jax._src.sharding_impls import PositionalSharding
|
from jax._src.sharding_impls import PositionalSharding
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||||
from jax_array_info import sharding_info, sharding_vis, print_array_stats
|
from jax_array_info import sharding_info, sharding_vis, print_array_stats
|
||||||
|
|
||||||
|
@ -25,11 +27,11 @@ mesh_3d = Mesh(devices_3d, axis_names=('a', 'b', 'c'))
|
||||||
|
|
||||||
def test_simple(capsys):
|
def test_simple(capsys):
|
||||||
arr = jax.numpy.array([1, 2, 3])
|
arr = jax.numpy.array([1, 2, 3])
|
||||||
sharding_info(arr)
|
sharding_info(arr, "arr")
|
||||||
sharding_vis(arr)
|
sharding_vis(arr)
|
||||||
|
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭──────────────╮
|
╭──── arr ─────╮
|
||||||
│ shape: (3,) │
|
│ shape: (3,) │
|
||||||
│ dtype: int32 │
|
│ dtype: int32 │
|
||||||
│ size: 12.0 B │
|
│ size: 12.0 B │
|
||||||
|
@ -73,13 +75,14 @@ def test_device_put_sharded(capsys):
|
||||||
sharding_info(arr)
|
sharding_info(arr)
|
||||||
sharding_vis(arr)
|
sharding_vis(arr)
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭───────────────────────────────────────────────╮
|
╭─────────────────────────────────────────────╮
|
||||||
│ shape: (32, 32, 32) │
|
│ shape: (32, 32, 32) │
|
||||||
│ dtype: complex64 │
|
│ dtype: complex64 │
|
||||||
│ size: 256.0 KiB │
|
│ size: 256.0 KiB │
|
||||||
│ NamedSharding: P(None, 'gpus') │
|
│ NamedSharding: P(None, 'gpus') │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||||
╰───────────────────────────────────────────────╯
|
│ Total size: 32 │
|
||||||
|
╰─────────────────────────────────────────────╯
|
||||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||||
│ │ │ │ │ │ │ │ │
|
│ │ │ │ │ │ │ │ │
|
||||||
|
@ -102,13 +105,14 @@ def test_operator_sharded(capsys):
|
||||||
sharding_info(arr)
|
sharding_info(arr)
|
||||||
sharding_vis(arr)
|
sharding_vis(arr)
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭───────────────────────────────────────────────╮
|
╭─────────────────────────────────────────────╮
|
||||||
│ shape: (32, 32, 32) │
|
│ shape: (32, 32, 32) │
|
||||||
│ dtype: complex64 │
|
│ dtype: complex64 │
|
||||||
│ size: 256.0 KiB │
|
│ size: 256.0 KiB │
|
||||||
│ NamedSharding: P(None, 'gpus') │
|
│ NamedSharding: P(None, 'gpus') │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||||
╰───────────────────────────────────────────────╯
|
│ Total size: 32 │
|
||||||
|
╰─────────────────────────────────────────────╯
|
||||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||||
│ │ │ │ │ │ │ │ │
|
│ │ │ │ │ │ │ │ │
|
||||||
|
@ -135,13 +139,14 @@ def test_jit_out_sharding_sharded(capsys):
|
||||||
sharding_info(arr)
|
sharding_info(arr)
|
||||||
sharding_vis(arr)
|
sharding_vis(arr)
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭───────────────────────────────────────────────╮
|
╭─────────────────────────────────────────────╮
|
||||||
│ shape: (32, 32, 32) │
|
│ shape: (32, 32, 32) │
|
||||||
│ dtype: complex64 │
|
│ dtype: complex64 │
|
||||||
│ size: 256.0 KiB │
|
│ size: 256.0 KiB │
|
||||||
│ NamedSharding: P(None, 'gpus') │
|
│ NamedSharding: P(None, 'gpus') │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||||
╰───────────────────────────────────────────────╯
|
│ Total size: 32 │
|
||||||
|
╰─────────────────────────────────────────────╯
|
||||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||||
│ │ │ │ │ │ │ │ │
|
│ │ │ │ │ │ │ │ │
|
||||||
|
@ -169,7 +174,8 @@ def test_positional_sharded(capsys):
|
||||||
│ size: 256.0 B │
|
│ size: 256.0 B │
|
||||||
│ PositionalSharding: │
|
│ PositionalSharding: │
|
||||||
│ [{CPU 0} {CPU 1} {CPU 2} {CPU 3} {CPU 4} {CPU 5} {CPU 6} {CPU 7}] │
|
│ [{CPU 0} {CPU 1} {CPU 2} {CPU 3} {CPU 4} {CPU 5} {CPU 6} {CPU 7}] │
|
||||||
│ axis 0 is sharded: CPU 0 contains 0:4 (of 32) │
|
│ axis 0 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||||
|
│ Total size: 32 │
|
||||||
╰───────────────────────────────────────────────────────────────────╯
|
╰───────────────────────────────────────────────────────────────────╯
|
||||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||||
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
|
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
|
||||||
|
@ -189,22 +195,23 @@ def test_in_jit(capsys):
|
||||||
func = jax.jit(func)
|
func = jax.jit(func)
|
||||||
func(arr)
|
func(arr)
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭───────────────────────────────────────────────╮
|
╭─────────────────────────────────────────────╮
|
||||||
│ shape: (32, 32, 32) │
|
│ shape: (32, 32, 32) │
|
||||||
│ dtype: complex64 │
|
│ dtype: complex64 │
|
||||||
│ size: 256.0 KiB │
|
│ size: 256.0 KiB │
|
||||||
│ called in jit │
|
│ called in jit │
|
||||||
│ PositionalSharding: │
|
│ PositionalSharding: │
|
||||||
│ [[[{CPU 0}] │
|
│ [[[{CPU 0}] │
|
||||||
│ [{CPU 1}] │
|
│ [{CPU 1}] │
|
||||||
│ [{CPU 2}] │
|
│ [{CPU 2}] │
|
||||||
│ [{CPU 3}] │
|
│ [{CPU 3}] │
|
||||||
│ [{CPU 4}] │
|
│ [{CPU 4}] │
|
||||||
│ [{CPU 5}] │
|
│ [{CPU 5}] │
|
||||||
│ [{CPU 6}] │
|
│ [{CPU 6}] │
|
||||||
│ [{CPU 7}]]] │
|
│ [{CPU 7}]]] │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||||
╰───────────────────────────────────────────────╯
|
│ Total size: 32 │
|
||||||
|
╰─────────────────────────────────────────────╯
|
||||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||||
│ │ │ │ │ │ │ │ │
|
│ │ │ │ │ │ │ │ │
|
||||||
|
@ -277,14 +284,16 @@ def test_2d_sharded(capsys):
|
||||||
sharding_info(arr)
|
sharding_info(arr)
|
||||||
sharding_vis(arr)
|
sharding_vis(arr)
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭────────────────────────────────────────────────╮
|
╭──────────────────────────────────────────────╮
|
||||||
│ shape: (32, 32, 32) │
|
│ shape: (32, 32, 32) │
|
||||||
│ dtype: complex64 │
|
│ dtype: complex64 │
|
||||||
│ size: 256.0 KiB │
|
│ size: 256.0 KiB │
|
||||||
│ NamedSharding: P(None, 'a', 'b') │
|
│ NamedSharding: P(None, 'a', 'b') │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:8 (of 32) │
|
│ axis 1 is sharded: CPU 0 contains 0:8 (1/4) │
|
||||||
│ axis 2 is sharded: CPU 0 contains 0:16 (of 32) │
|
│ Total size: 32 │
|
||||||
╰────────────────────────────────────────────────╯
|
│ axis 2 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||||
|
│ Total size: 32 │
|
||||||
|
╰──────────────────────────────────────────────╯
|
||||||
───────────── showing dims [1, 2] from original shape (32, 32, 32) ─────────────
|
───────────── showing dims [1, 2] from original shape (32, 32, 32) ─────────────
|
||||||
┌───────┬───────┐
|
┌───────┬───────┐
|
||||||
│ CPU 0 │ CPU 1 │
|
│ CPU 0 │ CPU 1 │
|
||||||
|
@ -306,15 +315,44 @@ def test_3d_sharded(capsys):
|
||||||
match=r"can only visualize up to 2 sharded dimension. \[0, 1, 2\] are sharded."):
|
match=r"can only visualize up to 2 sharded dimension. \[0, 1, 2\] are sharded."):
|
||||||
sharding_vis(arr)
|
sharding_vis(arr)
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭────────────────────────────────────────────────╮
|
╭──────────────────────────────────────────────╮
|
||||||
│ shape: (32, 32, 32) │
|
│ shape: (32, 32, 32) │
|
||||||
│ dtype: complex64 │
|
│ dtype: complex64 │
|
||||||
│ size: 256.0 KiB │
|
│ size: 256.0 KiB │
|
||||||
│ NamedSharding: P('a', 'b', 'c') │
|
│ NamedSharding: P('a', 'b', 'c') │
|
||||||
│ axis 0 is sharded: CPU 0 contains 0:16 (of 32) │
|
│ axis 0 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:16 (of 32) │
|
│ Total size: 32 │
|
||||||
│ axis 2 is sharded: CPU 0 contains 0:16 (of 32) │
|
│ axis 1 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||||
╰────────────────────────────────────────────────╯
|
│ Total size: 32 │
|
||||||
|
│ axis 2 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||||
|
│ Total size: 32 │
|
||||||
|
╰──────────────────────────────────────────────╯
|
||||||
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_map(capsys):
|
||||||
|
"""
|
||||||
|
https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html
|
||||||
|
"""
|
||||||
|
arr = jax.numpy.zeros(shape=(16, 16))
|
||||||
|
|
||||||
|
@partial(shard_map, mesh=mesh, in_specs=P(None, 'gpus'), out_specs=P(None, 'gpus'))
|
||||||
|
def test(a):
|
||||||
|
# sharding_info(a,"input") # doesn't seem to work inside a shard_map
|
||||||
|
return a ** 2
|
||||||
|
|
||||||
|
out = test(arr)
|
||||||
|
|
||||||
|
sharding_info(out)
|
||||||
|
assert capsys.readouterr().out == """
|
||||||
|
╭─────────────────────────────────────────────╮
|
||||||
|
│ shape: (16, 16) │
|
||||||
|
│ dtype: float32 │
|
||||||
|
│ size: 1.0 KiB │
|
||||||
|
│ NamedSharding: P(None, 'gpus') │
|
||||||
|
│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │
|
||||||
|
│ Total size: 16 │
|
||||||
|
╰─────────────────────────────────────────────╯
|
||||||
""".lstrip()
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
@ -332,22 +370,23 @@ def test_indirectly_sharded(capsys):
|
||||||
func = jax.jit(func, out_shardings=simple_sharding)
|
func = jax.jit(func, out_shardings=simple_sharding)
|
||||||
arr = func(arr)
|
arr = func(arr)
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭───────────────────────────────────────────────╮
|
╭─────────────────────────────────────────────╮
|
||||||
│ shape: (16, 16, 16) │
|
│ shape: (16, 16, 16) │
|
||||||
│ dtype: float32 │
|
│ dtype: float32 │
|
||||||
│ size: 16.0 KiB │
|
│ size: 16.0 KiB │
|
||||||
│ called in jit │
|
│ called in jit │
|
||||||
│ PositionalSharding: │
|
│ PositionalSharding: │
|
||||||
│ [[[{CPU 0}] │
|
│ [[[{CPU 0}] │
|
||||||
│ [{CPU 1}] │
|
│ [{CPU 1}] │
|
||||||
│ [{CPU 2}] │
|
│ [{CPU 2}] │
|
||||||
│ [{CPU 3}] │
|
│ [{CPU 3}] │
|
||||||
│ [{CPU 4}] │
|
│ [{CPU 4}] │
|
||||||
│ [{CPU 5}] │
|
│ [{CPU 5}] │
|
||||||
│ [{CPU 6}] │
|
│ [{CPU 6}] │
|
||||||
│ [{CPU 7}]]] │
|
│ [{CPU 7}]]] │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │
|
||||||
╰───────────────────────────────────────────────╯
|
│ Total size: 16 │
|
||||||
|
╰─────────────────────────────────────────────╯
|
||||||
""".lstrip()
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
@ -362,13 +401,14 @@ def test_with_sharding_constraint(capsys):
|
||||||
sharding_info(arr)
|
sharding_info(arr)
|
||||||
|
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
╭───────────────────────────────────────────────╮
|
╭─────────────────────────────────────────────╮
|
||||||
│ shape: (16, 16, 16) │
|
│ shape: (16, 16, 16) │
|
||||||
│ dtype: float32 │
|
│ dtype: float32 │
|
||||||
│ size: 16.0 KiB │
|
│ size: 16.0 KiB │
|
||||||
│ GSPMDSharding({devices=[1,8,1]<=[8]}) │
|
│ GSPMDSharding({devices=[1,8,1]<=[8]}) │
|
||||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │
|
||||||
╰───────────────────────────────────────────────╯
|
│ Total size: 16 │
|
||||||
|
╰─────────────────────────────────────────────╯
|
||||||
""".lstrip()
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue