mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-16 12:23:47 +02:00
add stats over all currently allocated arrays
This commit is contained in:
parent
eaf968a9c6
commit
ad1c2d5e91
5 changed files with 60 additions and 12 deletions
|
@ -1,2 +1,3 @@
|
|||
from .sharding_info import sharding_info
|
||||
from .sharding_vis import sharding_vis
|
||||
from .array_stats import print_array_stats
|
||||
|
|
32
jax_array_info/array_stats.py
Normal file
32
jax_array_info/array_stats.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import jax
|
||||
import rich
|
||||
from jax.sharding import SingleDeviceSharding
|
||||
from rich.table import Table
|
||||
|
||||
from .utils import pretty_byte_size
|
||||
|
||||
|
||||
def array_stats_data() -> list[jax.Array]:
|
||||
arrs = jax.live_arrays()
|
||||
arrs.sort(key=lambda a: -a.nbytes)
|
||||
return arrs
|
||||
|
||||
|
||||
def print_array_stats():
|
||||
console = rich.console.Console()
|
||||
table = Table(title="allocated jax arrays")
|
||||
table.add_column("size")
|
||||
table.add_column("shape")
|
||||
table.add_column("sharded", justify="center")
|
||||
total_size = 0
|
||||
for arr in array_stats_data():
|
||||
file_size = arr.nbytes
|
||||
is_sharded = False
|
||||
if len(arr.sharding.device_set)>1:
|
||||
file_size /= len(arr.sharding.device_set)
|
||||
is_sharded = True
|
||||
total_size += file_size
|
||||
table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
|
||||
table.add_section()
|
||||
table.add_row(pretty_byte_size(total_size))
|
||||
console.print(table)
|
|
@ -8,6 +8,8 @@ from rich.console import Console
|
|||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from .utils import pretty_byte_size
|
||||
|
||||
|
||||
def sharding_info(arr, name=None):
|
||||
if isinstance(arr, np.ndarray):
|
||||
|
@ -21,13 +23,6 @@ def sharding_info(arr, name=None):
|
|||
inspect_array_sharding(arr, callback=_info)
|
||||
|
||||
|
||||
def pretty_byte_size(nbytes: int):
|
||||
for unit in ("", "Ki", "Mi", "Gi", "Ti"):
|
||||
if abs(nbytes) < 1024.0:
|
||||
return f"{nbytes:3.1f} {unit}B"
|
||||
nbytes /= 1024.0
|
||||
|
||||
|
||||
def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
||||
shape = arr.shape
|
||||
console.print(f"shape: {shape}")
|
||||
|
|
5
jax_array_info/utils.py
Normal file
5
jax_array_info/utils.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
def pretty_byte_size(nbytes: int):
|
||||
for unit in ("", "Ki", "Mi", "Gi", "Ti"):
|
||||
if abs(nbytes) < 1024.0:
|
||||
return f"{nbytes:3.1f} {unit}B"
|
||||
nbytes /= 1024.0
|
|
@ -6,8 +6,7 @@ import pytest
|
|||
from jax._src.sharding_impls import PositionalSharding
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||
|
||||
from jax_array_info import sharding_info, sharding_vis
|
||||
from jax_array_info import sharding_info, sharding_vis, print_array_stats
|
||||
|
||||
num_gpus = 8
|
||||
|
||||
|
@ -337,9 +336,6 @@ def test_indirectly_sharded(capsys):
|
|||
|
||||
|
||||
def test_with_sharding_constraint(capsys):
|
||||
"""
|
||||
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
||||
"""
|
||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
|
||||
def func(x):
|
||||
|
@ -360,6 +356,25 @@ def test_with_sharding_constraint(capsys):
|
|||
""".lstrip()
|
||||
|
||||
|
||||
def test_array_stats(capsys):
|
||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), simple_sharding)
|
||||
|
||||
print_array_stats()
|
||||
|
||||
assert capsys.readouterr().out == """
|
||||
allocated jax arrays
|
||||
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ size ┃ shape ┃ sharded ┃
|
||||
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
||||
│ 16.0 KiB │ (16, 16, 16) │ │
|
||||
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
|
||||
├──────────┼──────────────┼───────────────────┤
|
||||
│ 16.1 KiB │ │ │
|
||||
└──────────┴──────────────┴───────────────────┘
|
||||
""".lstrip("\n")
|
||||
|
||||
|
||||
def test_non_array(capsys):
|
||||
arr = [1, 2, 3]
|
||||
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
||||
|
|
Loading…
Reference in a new issue