1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-16 12:23:47 +02:00
jax-array-info/jax_array_info/sharding_info.py
2024-08-02 17:02:39 +02:00

83 lines
3 KiB
Python

from typing import Optional
import numpy as np
import rich
from jax import Array
from jax._src.debugging import inspect_array_sharding
from jax.core import Tracer
from jax.sharding import Sharding, NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, PositionalSharding
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from .utils import pretty_byte_size
SupportedArray = np.ndarray | Array
def sharding_info(arr: SupportedArray, name: str = None):
if isinstance(arr, np.ndarray):
return print_sharding_info(arr, None, name)
if not isinstance(arr, Array):
raise ValueError(f"is not a jax array, got {type(arr)}")
def _info(sharding):
print_sharding_info(arr, sharding, name)
inspect_array_sharding(arr, callback=_info)
def _print_sharding_info_raw(arr: SupportedArray, sharding: Optional[Sharding], console: Console):
shape = arr.shape
console.print(f"shape: {shape}")
console.print(f"dtype: {arr.dtype}")
console.print(f"size: {pretty_byte_size(arr.nbytes)}")
if isinstance(arr, np.ndarray):
console.print("[bold]numpy array")
return
if not isinstance(arr, Array):
raise ValueError(f"is not a jax array, got {type(arr)}")
device_kind = next(iter(sharding.device_set)).platform.upper()
is_in_jit = isinstance(arr, Tracer)
if not is_in_jit and not arr.is_fully_addressable:
console.print("!is_fully_addressable")
if is_in_jit:
console.print("[bright_black]called in jit")
# if not arr.is_fully_replicated:
# console.print("!is_fully_replicated")
if isinstance(sharding, SingleDeviceSharding):
console.print("[red]not sharded")
if isinstance(sharding, GSPMDSharding):
console.print(sharding)
if isinstance(sharding, PositionalSharding):
console.print(f"PositionalSharding:")
console.print(sharding._ids)
if isinstance(sharding, NamedSharding):
console.print(f"NamedSharding: P{tuple(sharding.spec)}")
if isinstance(sharding, PmapSharding):
console.print(sharding)
return
device_indices_map = sharding.devices_indices_map(tuple(shape))
slcs = next(iter(device_indices_map.values()))
sl: slice
for i, sl in enumerate(slcs):
if sl.start is None:
continue
local_size = sl.stop - sl.start
global_size = shape[i]
num_shards = global_size // local_size
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (1/{num_shards})")
console.print(f" Total size: {global_size}")
def print_sharding_info(arr: SupportedArray, sharding: Optional[Sharding], name: str = None):
console = rich.console.Console()
with console.capture() as capture:
_print_sharding_info_raw(arr, sharding, console)
str_output = capture.get()
text = Text.from_ansi(str_output)
console.print(Panel(text, expand=False, title=f"[bold]{name}" if name is not None else None))