1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

document print_array_stats and add dtype to it

This commit is contained in:
Lukas Winkler 2024-06-21 12:18:31 +02:00
parent e1dc02aee2
commit 9ac3ea6715
Signed by: lukas
GPG key ID: 54DE4D798D244853
3 changed files with 41 additions and 11 deletions

View file

@ -7,7 +7,7 @@ pip install git+https://github.com/Findus23/jax-array-info.git
``` ```
```python ```python
from jax_array_info import sharding_info, sharding_vis from jax_array_info import sharding_info, sharding_vis, print_array_stats
``` ```
## `sharding_info(arr)` ## `sharding_info(arr)`
@ -59,6 +59,31 @@ sharding_vis(array)
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
``` ```
## `print_array_stats()`
Shows a nice overview over the all currently allocated arrays ordered by size.
**Disclaimer**: This uses `jax.live_arrays()` to get its information. There might be allocated arrays that are missing in this view. Also
```python
arr = jax.numpy.zeros(shape=(16, 16, 16))
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))
print_array_stats()
```
```text
allocated jax arrays
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ size ┃ shape ┃ sharded ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ 16.0 KiB │ (16, 16, 16) │ │
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
├──────────┼──────────────┼───────────────────┤
│ 16.1 KiB │ │ │
└──────────┴──────────────┴───────────────────┘
```
### Examples ### Examples
See [`tests/`](./tests/jaxtest.py) See [`tests/`](./tests/jaxtest.py)

View file

@ -17,6 +17,7 @@ def print_array_stats():
table = Table(title="allocated jax arrays") table = Table(title="allocated jax arrays")
table.add_column("size") table.add_column("size")
table.add_column("shape") table.add_column("shape")
table.add_column("dtype")
table.add_column("sharded", justify="center") table.add_column("sharded", justify="center")
total_size = 0 total_size = 0
for arr in array_stats_data(): for arr in array_stats_data():
@ -26,7 +27,11 @@ def print_array_stats():
file_size /= len(arr.sharding.device_set) file_size /= len(arr.sharding.device_set)
is_sharded = True is_sharded = True
total_size += file_size total_size += file_size
table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "") table.add_row(
pretty_byte_size(file_size),
str(arr.shape),
str(arr.dtype),
f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
table.add_section() table.add_section()
table.add_row(pretty_byte_size(total_size)) table.add_row(pretty_byte_size(total_size))
console.print(table) console.print(table)

View file

@ -380,14 +380,14 @@ def test_array_stats(capsys):
assert capsys.readouterr().out == """ assert capsys.readouterr().out == """
allocated jax arrays allocated jax arrays
size shape sharded size shape dtype sharded
16.0 KiB (16, 16, 16) 16.0 KiB (16, 16, 16) float32
64.0 B (2, 16, 4) (512.0 B total) 64.0 B (2, 16, 4) float32 (512.0 B total)
16.1 KiB 16.1 KiB
""".lstrip("\n") """.lstrip("\n")