From a80cbafe1aec712113577ef813a2dce5e621395c Mon Sep 17 00:00:00 2001 From: Lukas Winkler Date: Fri, 2 Aug 2024 17:02:39 +0200 Subject: [PATCH] add typing --- jax_array_info/sharding_info.py | 10 +++++++--- jax_array_info/sharding_vis.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/jax_array_info/sharding_info.py b/jax_array_info/sharding_info.py index b8cf7f1..4609e8b 100644 --- a/jax_array_info/sharding_info.py +++ b/jax_array_info/sharding_info.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import rich from jax import Array @@ -10,8 +12,10 @@ from rich.text import Text from .utils import pretty_byte_size +SupportedArray = np.ndarray | Array -def sharding_info(arr, name=None): + +def sharding_info(arr: SupportedArray, name: str = None): if isinstance(arr, np.ndarray): return print_sharding_info(arr, None, name) if not isinstance(arr, Array): @@ -23,7 +27,7 @@ def sharding_info(arr, name=None): inspect_array_sharding(arr, callback=_info) -def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console): +def _print_sharding_info_raw(arr: SupportedArray, sharding: Optional[Sharding], console: Console): shape = arr.shape console.print(f"shape: {shape}") console.print(f"dtype: {arr.dtype}") @@ -70,7 +74,7 @@ def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console): console.print(f" Total size: {global_size}") -def print_sharding_info(arr: Array, sharding: Sharding, name=None): +def print_sharding_info(arr: SupportedArray, sharding: Optional[Sharding], name: str = None): console = rich.console.Console() with console.capture() as capture: _print_sharding_info_raw(arr, sharding, console) diff --git a/jax_array_info/sharding_vis.py b/jax_array_info/sharding_vis.py index 07dce7d..c84589b 100644 --- a/jax_array_info/sharding_vis.py +++ b/jax_array_info/sharding_vis.py @@ -27,7 +27,7 @@ from jax._src.debugging import _raise_to_slice, _slice_to_chunk_idx, inspect_arr from jax.sharding import Sharding, PmapSharding -def sharding_vis(arr, **kwargs): +def sharding_vis(arr: Array, **kwargs): if not isinstance(arr, Array): raise ValueError(f"is not a jax array, got {type(arr)}")