1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

add stats over all currently allocated arrays

This commit is contained in:
Lukas Winkler 2023-10-09 17:15:23 +02:00
parent eaf968a9c6
commit ad1c2d5e91
Signed by: lukas
GPG key ID: 54DE4D798D244853
5 changed files with 60 additions and 12 deletions

View file

@ -1,2 +1,3 @@
from .sharding_info import sharding_info
from .sharding_vis import sharding_vis
from .array_stats import print_array_stats

View file

@ -0,0 +1,32 @@
import jax
import rich
from jax.sharding import SingleDeviceSharding
from rich.table import Table
from .utils import pretty_byte_size
def array_stats_data() -> list[jax.Array]:
arrs = jax.live_arrays()
arrs.sort(key=lambda a: -a.nbytes)
return arrs
def print_array_stats():
console = rich.console.Console()
table = Table(title="allocated jax arrays")
table.add_column("size")
table.add_column("shape")
table.add_column("sharded", justify="center")
total_size = 0
for arr in array_stats_data():
file_size = arr.nbytes
is_sharded = False
if len(arr.sharding.device_set)>1:
file_size /= len(arr.sharding.device_set)
is_sharded = True
total_size += file_size
table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
table.add_section()
table.add_row(pretty_byte_size(total_size))
console.print(table)

View file

@ -8,6 +8,8 @@ from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from .utils import pretty_byte_size
def sharding_info(arr, name=None):
if isinstance(arr, np.ndarray):
@ -21,13 +23,6 @@ def sharding_info(arr, name=None):
inspect_array_sharding(arr, callback=_info)
def pretty_byte_size(nbytes: int):
for unit in ("", "Ki", "Mi", "Gi", "Ti"):
if abs(nbytes) < 1024.0:
return f"{nbytes:3.1f} {unit}B"
nbytes /= 1024.0
def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
shape = arr.shape
console.print(f"shape: {shape}")

5
jax_array_info/utils.py Normal file
View file

@ -0,0 +1,5 @@
def pretty_byte_size(nbytes: int):
for unit in ("", "Ki", "Mi", "Gi", "Ti"):
if abs(nbytes) < 1024.0:
return f"{nbytes:3.1f} {unit}B"
nbytes /= 1024.0

View file

@ -6,8 +6,7 @@ import pytest
from jax._src.sharding_impls import PositionalSharding
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax_array_info import sharding_info, sharding_vis
from jax_array_info import sharding_info, sharding_vis, print_array_stats
num_gpus = 8
@ -337,9 +336,6 @@ def test_indirectly_sharded(capsys):
def test_with_sharding_constraint(capsys):
"""
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
"""
arr = jax.numpy.zeros(shape=(16, 16, 16))
def func(x):
@ -360,6 +356,25 @@ def test_with_sharding_constraint(capsys):
""".lstrip()
def test_array_stats(capsys):
arr = jax.numpy.zeros(shape=(16, 16, 16))
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), simple_sharding)
print_array_stats()
assert capsys.readouterr().out == """
allocated jax arrays
size shape sharded
16.0 KiB (16, 16, 16)
64.0 B (2, 16, 4) (512.0 B total)
16.1 KiB
""".lstrip("\n")
def test_non_array(capsys):
arr = [1, 2, 3]
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):