mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-10 05:23:48 +02:00
Debugging tool to print information (especially sharding) about jax arrays
jax_array_info | ||
tests | ||
.gitignore | ||
LICENSE | ||
pyproject.toml | ||
README.md |
jax-array-info
This package contains two functions for debugging jax Array
s:
pip install git+https://github.com/Findus23/jax-array-info.git
from jax_array_info import sharding_info, sharding_vis, print_array_stats
sharding_info(arr)
sharding_info(arr)
prints general information about a jax or numpy array with special focus on sharding (
supporting SingleDeviceSharding
, GSPMDSharding
, PositionalSharding
, NamedSharding
and PmapSharding
)
array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_info(array, "some_array")
╭────────────────── some_array ───────────────────╮
│ shape: (128, 128, 128) │
│ dtype: float32 │
│ size: 8.0 MiB │
│ NamedSharding: P(None, 'gpus') │
│ axis 1 is sharded: CPU 0 contains 0:16 (of 128) │
╰─────────────────────────────────────────────────╯
sharding_vis(arr)
A modified version
of jax.debug.visualize_array_sharding()
that also supports arrays with more than 2 dimensions (by ignoring non-sharded dimensions in the visualisation until
reaching 2 dimensions)
array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_vis(array)
─────────── showing dims [0, 1] from original shape (128, 128, 128) ────────────
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
print_array_stats()
Shows a nice overview over the all currently allocated arrays ordered by size.
Disclaimer: This uses jax.live_arrays()
to get its information. There might be allocated arrays that are missing in this view. Also
arr = jax.numpy.zeros(shape=(16, 16, 16))
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))
print_array_stats()
allocated jax arrays
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ size ┃ shape ┃ sharded ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ 16.0 KiB │ (16, 16, 16) │ │
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
├──────────┼──────────────┼───────────────────┤
│ 16.1 KiB │ │ │
└──────────┴──────────────┴───────────────────┘
Examples
See tests/