mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
add stats over all currently allocated arrays
This commit is contained in:
parent
eaf968a9c6
commit
ad1c2d5e91
5 changed files with 60 additions and 12 deletions
|
@ -1,2 +1,3 @@
|
||||||
from .sharding_info import sharding_info
|
from .sharding_info import sharding_info
|
||||||
from .sharding_vis import sharding_vis
|
from .sharding_vis import sharding_vis
|
||||||
|
from .array_stats import print_array_stats
|
||||||
|
|
32
jax_array_info/array_stats.py
Normal file
32
jax_array_info/array_stats.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import jax
|
||||||
|
import rich
|
||||||
|
from jax.sharding import SingleDeviceSharding
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from .utils import pretty_byte_size
|
||||||
|
|
||||||
|
|
||||||
|
def array_stats_data() -> list[jax.Array]:
|
||||||
|
arrs = jax.live_arrays()
|
||||||
|
arrs.sort(key=lambda a: -a.nbytes)
|
||||||
|
return arrs
|
||||||
|
|
||||||
|
|
||||||
|
def print_array_stats():
|
||||||
|
console = rich.console.Console()
|
||||||
|
table = Table(title="allocated jax arrays")
|
||||||
|
table.add_column("size")
|
||||||
|
table.add_column("shape")
|
||||||
|
table.add_column("sharded", justify="center")
|
||||||
|
total_size = 0
|
||||||
|
for arr in array_stats_data():
|
||||||
|
file_size = arr.nbytes
|
||||||
|
is_sharded = False
|
||||||
|
if len(arr.sharding.device_set)>1:
|
||||||
|
file_size /= len(arr.sharding.device_set)
|
||||||
|
is_sharded = True
|
||||||
|
total_size += file_size
|
||||||
|
table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
|
||||||
|
table.add_section()
|
||||||
|
table.add_row(pretty_byte_size(total_size))
|
||||||
|
console.print(table)
|
|
@ -8,6 +8,8 @@ from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
|
from .utils import pretty_byte_size
|
||||||
|
|
||||||
|
|
||||||
def sharding_info(arr, name=None):
|
def sharding_info(arr, name=None):
|
||||||
if isinstance(arr, np.ndarray):
|
if isinstance(arr, np.ndarray):
|
||||||
|
@ -21,13 +23,6 @@ def sharding_info(arr, name=None):
|
||||||
inspect_array_sharding(arr, callback=_info)
|
inspect_array_sharding(arr, callback=_info)
|
||||||
|
|
||||||
|
|
||||||
def pretty_byte_size(nbytes: int):
|
|
||||||
for unit in ("", "Ki", "Mi", "Gi", "Ti"):
|
|
||||||
if abs(nbytes) < 1024.0:
|
|
||||||
return f"{nbytes:3.1f} {unit}B"
|
|
||||||
nbytes /= 1024.0
|
|
||||||
|
|
||||||
|
|
||||||
def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
||||||
shape = arr.shape
|
shape = arr.shape
|
||||||
console.print(f"shape: {shape}")
|
console.print(f"shape: {shape}")
|
||||||
|
|
5
jax_array_info/utils.py
Normal file
5
jax_array_info/utils.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
def pretty_byte_size(nbytes: int):
|
||||||
|
for unit in ("", "Ki", "Mi", "Gi", "Ti"):
|
||||||
|
if abs(nbytes) < 1024.0:
|
||||||
|
return f"{nbytes:3.1f} {unit}B"
|
||||||
|
nbytes /= 1024.0
|
|
@ -6,8 +6,7 @@ import pytest
|
||||||
from jax._src.sharding_impls import PositionalSharding
|
from jax._src.sharding_impls import PositionalSharding
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
||||||
|
from jax_array_info import sharding_info, sharding_vis, print_array_stats
|
||||||
from jax_array_info import sharding_info, sharding_vis
|
|
||||||
|
|
||||||
num_gpus = 8
|
num_gpus = 8
|
||||||
|
|
||||||
|
@ -337,9 +336,6 @@ def test_indirectly_sharded(capsys):
|
||||||
|
|
||||||
|
|
||||||
def test_with_sharding_constraint(capsys):
|
def test_with_sharding_constraint(capsys):
|
||||||
"""
|
|
||||||
y is never explicitly sharded, but it seems like the sharding is back-propagated through the jit compiled function
|
|
||||||
"""
|
|
||||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||||
|
|
||||||
def func(x):
|
def func(x):
|
||||||
|
@ -360,6 +356,25 @@ def test_with_sharding_constraint(capsys):
|
||||||
""".lstrip()
|
""".lstrip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_stats(capsys):
|
||||||
|
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||||
|
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), simple_sharding)
|
||||||
|
|
||||||
|
print_array_stats()
|
||||||
|
|
||||||
|
assert capsys.readouterr().out == """
|
||||||
|
allocated jax arrays
|
||||||
|
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
||||||
|
┃ size ┃ shape ┃ sharded ┃
|
||||||
|
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ 16.0 KiB │ (16, 16, 16) │ │
|
||||||
|
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
|
||||||
|
├──────────┼──────────────┼───────────────────┤
|
||||||
|
│ 16.1 KiB │ │ │
|
||||||
|
└──────────┴──────────────┴───────────────────┘
|
||||||
|
""".lstrip("\n")
|
||||||
|
|
||||||
|
|
||||||
def test_non_array(capsys):
|
def test_non_array(capsys):
|
||||||
arr = [1, 2, 3]
|
arr = [1, 2, 3]
|
||||||
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
with pytest.raises(ValueError, match="is not a jax array, got <class 'list'>"):
|
||||||
|
|
Loading…
Reference in a new issue