1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

add typing

This commit is contained in:
Lukas Winkler 2024-08-02 17:02:39 +02:00
parent 4381a7998e
commit a80cbafe1a
Signed by: lukas
GPG key ID: 54DE4D798D244853
2 changed files with 8 additions and 4 deletions

View file

@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import rich import rich
from jax import Array from jax import Array
@ -10,8 +12,10 @@ from rich.text import Text
from .utils import pretty_byte_size from .utils import pretty_byte_size
SupportedArray = np.ndarray | Array
def sharding_info(arr, name=None):
def sharding_info(arr: SupportedArray, name: str = None):
if isinstance(arr, np.ndarray): if isinstance(arr, np.ndarray):
return print_sharding_info(arr, None, name) return print_sharding_info(arr, None, name)
if not isinstance(arr, Array): if not isinstance(arr, Array):
@ -23,7 +27,7 @@ def sharding_info(arr, name=None):
inspect_array_sharding(arr, callback=_info) inspect_array_sharding(arr, callback=_info)
def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console): def _print_sharding_info_raw(arr: SupportedArray, sharding: Optional[Sharding], console: Console):
shape = arr.shape shape = arr.shape
console.print(f"shape: {shape}") console.print(f"shape: {shape}")
console.print(f"dtype: {arr.dtype}") console.print(f"dtype: {arr.dtype}")
@ -70,7 +74,7 @@ def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
console.print(f" Total size: {global_size}") console.print(f" Total size: {global_size}")
def print_sharding_info(arr: Array, sharding: Sharding, name=None): def print_sharding_info(arr: SupportedArray, sharding: Optional[Sharding], name: str = None):
console = rich.console.Console() console = rich.console.Console()
with console.capture() as capture: with console.capture() as capture:
_print_sharding_info_raw(arr, sharding, console) _print_sharding_info_raw(arr, sharding, console)

View file

@ -27,7 +27,7 @@ from jax._src.debugging import _raise_to_slice, _slice_to_chunk_idx, inspect_arr
from jax.sharding import Sharding, PmapSharding from jax.sharding import Sharding, PmapSharding
def sharding_vis(arr, **kwargs): def sharding_vis(arr: Array, **kwargs):
if not isinstance(arr, Array): if not isinstance(arr, Array):
raise ValueError(f"is not a jax array, got {type(arr)}") raise ValueError(f"is not a jax array, got {type(arr)}")