diff --git a/README.md b/README.md index 7e97258..7c8a18d 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ pip install git+https://github.com/Findus23/jax-array-info.git ``` ```python -from jax_array_info import sharding_info, sharding_vis +from jax_array_info import sharding_info, sharding_vis, print_array_stats ``` ## `sharding_info(arr)` @@ -59,6 +59,31 @@ sharding_vis(array) └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ ``` +## `print_array_stats()` + +Shows a nice overview over the all currently allocated arrays ordered by size. + +**Disclaimer**: This uses `jax.live_arrays()` to get its information. There might be allocated arrays that are missing in this view. Also + +```python +arr = jax.numpy.zeros(shape=(16, 16, 16)) +arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus"))) + +print_array_stats() +``` + +```text + allocated jax arrays +┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ +┃ size ┃ shape ┃ sharded ┃ +┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ +│ 16.0 KiB │ (16, 16, 16) │ │ +│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │ +├──────────┼──────────────┼───────────────────┤ +│ 16.1 KiB │ │ │ +└──────────┴──────────────┴───────────────────┘ +``` + ### Examples See [`tests/`](./tests/jaxtest.py) diff --git a/jax_array_info/array_stats.py b/jax_array_info/array_stats.py index 273fa10..f08953e 100644 --- a/jax_array_info/array_stats.py +++ b/jax_array_info/array_stats.py @@ -17,6 +17,7 @@ def print_array_stats(): table = Table(title="allocated jax arrays") table.add_column("size") table.add_column("shape") + table.add_column("dtype") table.add_column("sharded", justify="center") total_size = 0 for arr in array_stats_data(): @@ -26,7 +27,11 @@ def print_array_stats(): file_size /= len(arr.sharding.device_set) is_sharded = True total_size += file_size - table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "") + table.add_row( + pretty_byte_size(file_size), + str(arr.shape), + str(arr.dtype), + f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "") table.add_section() table.add_row(pretty_byte_size(total_size)) console.print(table) diff --git a/tests/jaxtest.py b/tests/jaxtest.py index ef9d711..945427e 100644 --- a/tests/jaxtest.py +++ b/tests/jaxtest.py @@ -379,15 +379,15 @@ def test_array_stats(capsys): print_array_stats() assert capsys.readouterr().out == """ - allocated jax arrays -┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ -┃ size ┃ shape ┃ sharded ┃ -┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ -│ 16.0 KiB │ (16, 16, 16) │ │ -│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │ -├──────────┼──────────────┼───────────────────┤ -│ 16.1 KiB │ │ │ -└──────────┴──────────────┴───────────────────┘ + allocated jax arrays +┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ +┃ size ┃ shape ┃ dtype ┃ sharded ┃ +┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ +│ 16.0 KiB │ (16, 16, 16) │ float32 │ │ +│ 64.0 B │ (2, 16, 4) │ float32 │ ✔ (512.0 B total) │ +├──────────┼──────────────┼─────────┼───────────────────┤ +│ 16.1 KiB │ │ │ │ +└──────────┴──────────────┴─────────┴───────────────────┘ """.lstrip("\n")