mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-18 14:43:48 +02:00
add typing
This commit is contained in:
parent
4381a7998e
commit
a80cbafe1a
2 changed files with 8 additions and 4 deletions
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import rich
|
||||
from jax import Array
|
||||
|
@ -10,8 +12,10 @@ from rich.text import Text
|
|||
|
||||
from .utils import pretty_byte_size
|
||||
|
||||
SupportedArray = np.ndarray | Array
|
||||
|
||||
def sharding_info(arr, name=None):
|
||||
|
||||
def sharding_info(arr: SupportedArray, name: str = None):
|
||||
if isinstance(arr, np.ndarray):
|
||||
return print_sharding_info(arr, None, name)
|
||||
if not isinstance(arr, Array):
|
||||
|
@ -23,7 +27,7 @@ def sharding_info(arr, name=None):
|
|||
inspect_array_sharding(arr, callback=_info)
|
||||
|
||||
|
||||
def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
||||
def _print_sharding_info_raw(arr: SupportedArray, sharding: Optional[Sharding], console: Console):
|
||||
shape = arr.shape
|
||||
console.print(f"shape: {shape}")
|
||||
console.print(f"dtype: {arr.dtype}")
|
||||
|
@ -70,7 +74,7 @@ def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
|||
console.print(f" Total size: {global_size}")
|
||||
|
||||
|
||||
def print_sharding_info(arr: Array, sharding: Sharding, name=None):
|
||||
def print_sharding_info(arr: SupportedArray, sharding: Optional[Sharding], name: str = None):
|
||||
console = rich.console.Console()
|
||||
with console.capture() as capture:
|
||||
_print_sharding_info_raw(arr, sharding, console)
|
||||
|
|
|
@ -27,7 +27,7 @@ from jax._src.debugging import _raise_to_slice, _slice_to_chunk_idx, inspect_arr
|
|||
from jax.sharding import Sharding, PmapSharding
|
||||
|
||||
|
||||
def sharding_vis(arr, **kwargs):
|
||||
def sharding_vis(arr: Array, **kwargs):
|
||||
if not isinstance(arr, Array):
|
||||
raise ValueError(f"is not a jax array, got {type(arr)}")
|
||||
|
||||
|
|
Loading…
Reference in a new issue