1
0
Fork 0
mirror of https://github.com/Findus23/jax-array-info.git synced 2024-09-19 15:53:47 +02:00

document print_array_stats and add dtype to it

This commit is contained in:
Lukas Winkler 2024-06-21 12:18:31 +02:00
parent e1dc02aee2
commit 9ac3ea6715
Signed by: lukas
GPG key ID: 54DE4D798D244853
3 changed files with 41 additions and 11 deletions

View file

@ -7,7 +7,7 @@ pip install git+https://github.com/Findus23/jax-array-info.git
```
```python
from jax_array_info import sharding_info, sharding_vis
from jax_array_info import sharding_info, sharding_vis, print_array_stats
```
## `sharding_info(arr)`
@ -59,6 +59,31 @@ sharding_vis(array)
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
```
## `print_array_stats()`
Shows a nice overview over the all currently allocated arrays ordered by size.
**Disclaimer**: This uses `jax.live_arrays()` to get its information. There might be allocated arrays that are missing in this view. Also
```python
arr = jax.numpy.zeros(shape=(16, 16, 16))
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))
print_array_stats()
```
```text
allocated jax arrays
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ size ┃ shape ┃ sharded ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ 16.0 KiB │ (16, 16, 16) │ │
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
├──────────┼──────────────┼───────────────────┤
│ 16.1 KiB │ │ │
└──────────┴──────────────┴───────────────────┘
```
### Examples
See [`tests/`](./tests/jaxtest.py)

View file

@ -17,6 +17,7 @@ def print_array_stats():
table = Table(title="allocated jax arrays")
table.add_column("size")
table.add_column("shape")
table.add_column("dtype")
table.add_column("sharded", justify="center")
total_size = 0
for arr in array_stats_data():
@ -26,7 +27,11 @@ def print_array_stats():
file_size /= len(arr.sharding.device_set)
is_sharded = True
total_size += file_size
table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
table.add_row(
pretty_byte_size(file_size),
str(arr.shape),
str(arr.dtype),
f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
table.add_section()
table.add_row(pretty_byte_size(total_size))
console.print(table)

View file

@ -379,15 +379,15 @@ def test_array_stats(capsys):
print_array_stats()
assert capsys.readouterr().out == """
allocated jax arrays
size shape sharded
16.0 KiB (16, 16, 16)
64.0 B (2, 16, 4) (512.0 B total)
16.1 KiB
allocated jax arrays
size shape dtype sharded
16.0 KiB (16, 16, 16) float32
64.0 B (2, 16, 4) float32 (512.0 B total)
16.1 KiB
""".lstrip("\n")