1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00
Debugging tool to print information (especially sharding) about jax arrays
Find a file
2023-09-04 16:47:53 +02:00
jax_array_info minor fixes 2023-09-04 16:33:38 +02:00
tests one more test 2023-09-04 16:47:53 +02:00
.gitignore initial version 2023-09-04 15:34:05 +02:00
LICENSE initial version 2023-09-04 15:34:05 +02:00
pyproject.toml initial version 2023-09-04 15:34:05 +02:00
README.md add installation 2023-09-04 16:26:03 +02:00

jax-array-info

This package contains two functions for debugging jax Arrays:

pip install git+https://github.com/Findus23/jax-array-info.git
from jax_array_info import sharding_info, sharding_vis

sharding_info(arr)

sharding_info(arr) prints general information about a jax or numpy array with special focus on sharding ( supporting SingleDeviceSharding, GSPMDSharding, PositionalSharding, NamedSharding and PmapSharding)

array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_info(array, "some_array")
╭────────────────── some_array ───────────────────╮
│ shape: (128, 128, 128)                          │
│ dtype: float32                                  │
│ size: 8.0 MiB                                   │
│ NamedSharding: P(None, 'gpus')                  │
│ axis 1 is sharded: CPU 0 contains 0:16 (of 128) │
╰─────────────────────────────────────────────────╯

sharding_vis(arr)

A modified version of jax.debug.visualize_array_sharding() that also supports arrays with more than 2 dimensions (by ignoring non-sharded dimensions in the visualisation until reaching 2 dimensions)

array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_vis(array)
─────────── showing dims [0, 1] from original shape (128, 128, 128) ────────────
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

Examples

See tests/