mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
change output format slightly and add shard_map example
This commit is contained in:
parent
9ac3ea6715
commit
4381a7998e
2 changed files with 125 additions and 81 deletions
|
@ -63,7 +63,11 @@ def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
|||
for i, sl in enumerate(slcs):
|
||||
if sl.start is None:
|
||||
continue
|
||||
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (of {shape[i]})")
|
||||
local_size = sl.stop - sl.start
|
||||
global_size = shape[i]
|
||||
num_shards = global_size // local_size
|
||||
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (1/{num_shards})")
|
||||
console.print(f" Total size: {global_size}")
|
||||
|
||||
|
||||
def print_sharding_info(arr: Array, sharding: Sharding, name=None):
|
||||
|
|
200
tests/jaxtest.py
200
tests/jaxtest.py
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import jax.numpy
|
||||
import numpy as np
|
||||
import pytest
|
||||
from jax._src.sharding_impls import PositionalSharding
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
from jax_array_info import sharding_info, sharding_vis, print_array_stats
|
||||
|
||||
|
@ -25,11 +27,11 @@ mesh_3d = Mesh(devices_3d, axis_names=('a', 'b', 'c'))
|
|||
|
||||
def test_simple(capsys):
|
||||
arr = jax.numpy.array([1, 2, 3])
|
||||
sharding_info(arr)
|
||||
sharding_info(arr, "arr")
|
||||
sharding_vis(arr)
|
||||
|
||||
assert capsys.readouterr().out == """
|
||||
╭──────────────╮
|
||||
╭──── arr ─────╮
|
||||
│ shape: (3,) │
|
||||
│ dtype: int32 │
|
||||
│ size: 12.0 B │
|
||||
|
@ -73,13 +75,14 @@ def test_device_put_sharded(capsys):
|
|||
sharding_info(arr)
|
||||
sharding_vis(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
╭─────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||
│ Total size: 32 │
|
||||
╰─────────────────────────────────────────────╯
|
||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ │ │ │ │ │ │ │ │
|
||||
|
@ -102,13 +105,14 @@ def test_operator_sharded(capsys):
|
|||
sharding_info(arr)
|
||||
sharding_vis(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
╭─────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||
│ Total size: 32 │
|
||||
╰─────────────────────────────────────────────╯
|
||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ │ │ │ │ │ │ │ │
|
||||
|
@ -135,13 +139,14 @@ def test_jit_out_sharding_sharded(capsys):
|
|||
sharding_info(arr)
|
||||
sharding_vis(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
╭─────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||
│ Total size: 32 │
|
||||
╰─────────────────────────────────────────────╯
|
||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ │ │ │ │ │ │ │ │
|
||||
|
@ -169,7 +174,8 @@ def test_positional_sharded(capsys):
|
|||
│ size: 256.0 B │
|
||||
│ PositionalSharding: │
|
||||
│ [{CPU 0} {CPU 1} {CPU 2} {CPU 3} {CPU 4} {CPU 5} {CPU 6} {CPU 7}] │
|
||||
│ axis 0 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
│ axis 0 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||
│ Total size: 32 │
|
||||
╰───────────────────────────────────────────────────────────────────╯
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
|
||||
|
@ -189,22 +195,23 @@ def test_in_jit(capsys):
|
|||
func = jax.jit(func)
|
||||
func(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ called in jit │
|
||||
│ PositionalSharding: │
|
||||
│ [[[{CPU 0}] │
|
||||
│ [{CPU 1}] │
|
||||
│ [{CPU 2}] │
|
||||
│ [{CPU 3}] │
|
||||
│ [{CPU 4}] │
|
||||
│ [{CPU 5}] │
|
||||
│ [{CPU 6}] │
|
||||
│ [{CPU 7}]]] │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (of 32) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
╭─────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ called in jit │
|
||||
│ PositionalSharding: │
|
||||
│ [[[{CPU 0}] │
|
||||
│ [{CPU 1}] │
|
||||
│ [{CPU 2}] │
|
||||
│ [{CPU 3}] │
|
||||
│ [{CPU 4}] │
|
||||
│ [{CPU 5}] │
|
||||
│ [{CPU 6}] │
|
||||
│ [{CPU 7}]]] │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:4 (1/8) │
|
||||
│ Total size: 32 │
|
||||
╰─────────────────────────────────────────────╯
|
||||
───────────── showing dims [0, 1] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||
│ │ │ │ │ │ │ │ │
|
||||
|
@ -277,14 +284,16 @@ def test_2d_sharded(capsys):
|
|||
sharding_info(arr)
|
||||
sharding_vis(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭────────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'a', 'b') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:8 (of 32) │
|
||||
│ axis 2 is sharded: CPU 0 contains 0:16 (of 32) │
|
||||
╰────────────────────────────────────────────────╯
|
||||
╭──────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P(None, 'a', 'b') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:8 (1/4) │
|
||||
│ Total size: 32 │
|
||||
│ axis 2 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||
│ Total size: 32 │
|
||||
╰──────────────────────────────────────────────╯
|
||||
───────────── showing dims [1, 2] from original shape (32, 32, 32) ─────────────
|
||||
┌───────┬───────┐
|
||||
│ CPU 0 │ CPU 1 │
|
||||
|
@ -306,15 +315,44 @@ def test_3d_sharded(capsys):
|
|||
match=r"can only visualize up to 2 sharded dimension. \[0, 1, 2\] are sharded."):
|
||||
sharding_vis(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭────────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P('a', 'b', 'c') │
|
||||
│ axis 0 is sharded: CPU 0 contains 0:16 (of 32) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:16 (of 32) │
|
||||
│ axis 2 is sharded: CPU 0 contains 0:16 (of 32) │
|
||||
╰────────────────────────────────────────────────╯
|
||||
╭──────────────────────────────────────────────╮
|
||||
│ shape: (32, 32, 32) │
|
||||
│ dtype: complex64 │
|
||||
│ size: 256.0 KiB │
|
||||
│ NamedSharding: P('a', 'b', 'c') │
|
||||
│ axis 0 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||
│ Total size: 32 │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||
│ Total size: 32 │
|
||||
│ axis 2 is sharded: CPU 0 contains 0:16 (1/2) │
|
||||
│ Total size: 32 │
|
||||
╰──────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
def test_shard_map(capsys):
|
||||
"""
|
||||
https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html
|
||||
"""
|
||||
arr = jax.numpy.zeros(shape=(16, 16))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(None, 'gpus'), out_specs=P(None, 'gpus'))
|
||||
def test(a):
|
||||
# sharding_info(a,"input") # doesn't seem to work inside a shard_map
|
||||
return a ** 2
|
||||
|
||||
out = test(arr)
|
||||
|
||||
sharding_info(out)
|
||||
assert capsys.readouterr().out == """
|
||||
╭─────────────────────────────────────────────╮
|
||||
│ shape: (16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 1.0 KiB │
|
||||
│ NamedSharding: P(None, 'gpus') │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │
|
||||
│ Total size: 16 │
|
||||
╰─────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
|
@ -332,22 +370,23 @@ def test_indirectly_sharded(capsys):
|
|||
func = jax.jit(func, out_shardings=simple_sharding)
|
||||
arr = func(arr)
|
||||
assert capsys.readouterr().out == """
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ called in jit │
|
||||
│ PositionalSharding: │
|
||||
│ [[[{CPU 0}] │
|
||||
│ [{CPU 1}] │
|
||||
│ [{CPU 2}] │
|
||||
│ [{CPU 3}] │
|
||||
│ [{CPU 4}] │
|
||||
│ [{CPU 5}] │
|
||||
│ [{CPU 6}] │
|
||||
│ [{CPU 7}]]] │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
╭─────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ called in jit │
|
||||
│ PositionalSharding: │
|
||||
│ [[[{CPU 0}] │
|
||||
│ [{CPU 1}] │
|
||||
│ [{CPU 2}] │
|
||||
│ [{CPU 3}] │
|
||||
│ [{CPU 4}] │
|
||||
│ [{CPU 5}] │
|
||||
│ [{CPU 6}] │
|
||||
│ [{CPU 7}]]] │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │
|
||||
│ Total size: 16 │
|
||||
╰─────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
|
@ -362,13 +401,14 @@ def test_with_sharding_constraint(capsys):
|
|||
sharding_info(arr)
|
||||
|
||||
assert capsys.readouterr().out == """
|
||||
╭───────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ GSPMDSharding({devices=[1,8,1]<=[8]}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (of 16) │
|
||||
╰───────────────────────────────────────────────╯
|
||||
╭─────────────────────────────────────────────╮
|
||||
│ shape: (16, 16, 16) │
|
||||
│ dtype: float32 │
|
||||
│ size: 16.0 KiB │
|
||||
│ GSPMDSharding({devices=[1,8,1]<=[8]}) │
|
||||
│ axis 1 is sharded: CPU 0 contains 0:2 (1/8) │
|
||||
│ Total size: 16 │
|
||||
╰─────────────────────────────────────────────╯
|
||||
""".lstrip()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue