mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
document print_array_stats and add dtype to it
This commit is contained in:
parent
e1dc02aee2
commit
9ac3ea6715
3 changed files with 41 additions and 11 deletions
27
README.md
27
README.md
|
@ -7,7 +7,7 @@ pip install git+https://github.com/Findus23/jax-array-info.git
|
|||
```
|
||||
|
||||
```python
|
||||
from jax_array_info import sharding_info, sharding_vis
|
||||
from jax_array_info import sharding_info, sharding_vis, print_array_stats
|
||||
```
|
||||
|
||||
## `sharding_info(arr)`
|
||||
|
@ -59,6 +59,31 @@ sharding_vis(array)
|
|||
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
|
||||
```
|
||||
|
||||
## `print_array_stats()`
|
||||
|
||||
Shows a nice overview over the all currently allocated arrays ordered by size.
|
||||
|
||||
**Disclaimer**: This uses `jax.live_arrays()` to get its information. There might be allocated arrays that are missing in this view. Also
|
||||
|
||||
```python
|
||||
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))
|
||||
|
||||
print_array_stats()
|
||||
```
|
||||
|
||||
```text
|
||||
allocated jax arrays
|
||||
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ size ┃ shape ┃ sharded ┃
|
||||
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
||||
│ 16.0 KiB │ (16, 16, 16) │ │
|
||||
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
|
||||
├──────────┼──────────────┼───────────────────┤
|
||||
│ 16.1 KiB │ │ │
|
||||
└──────────┴──────────────┴───────────────────┘
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
See [`tests/`](./tests/jaxtest.py)
|
||||
|
|
|
@ -17,6 +17,7 @@ def print_array_stats():
|
|||
table = Table(title="allocated jax arrays")
|
||||
table.add_column("size")
|
||||
table.add_column("shape")
|
||||
table.add_column("dtype")
|
||||
table.add_column("sharded", justify="center")
|
||||
total_size = 0
|
||||
for arr in array_stats_data():
|
||||
|
@ -26,7 +27,11 @@ def print_array_stats():
|
|||
file_size /= len(arr.sharding.device_set)
|
||||
is_sharded = True
|
||||
total_size += file_size
|
||||
table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
|
||||
table.add_row(
|
||||
pretty_byte_size(file_size),
|
||||
str(arr.shape),
|
||||
str(arr.dtype),
|
||||
f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
|
||||
table.add_section()
|
||||
table.add_row(pretty_byte_size(total_size))
|
||||
console.print(table)
|
||||
|
|
|
@ -380,14 +380,14 @@ def test_array_stats(capsys):
|
|||
|
||||
assert capsys.readouterr().out == """
|
||||
allocated jax arrays
|
||||
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ size ┃ shape ┃ sharded ┃
|
||||
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
||||
│ 16.0 KiB │ (16, 16, 16) │ │
|
||||
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
|
||||
├──────────┼──────────────┼───────────────────┤
|
||||
│ 16.1 KiB │ │ │
|
||||
└──────────┴──────────────┴───────────────────┘
|
||||
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ size ┃ shape ┃ dtype ┃ sharded ┃
|
||||
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
||||
│ 16.0 KiB │ (16, 16, 16) │ float32 │ │
|
||||
│ 64.0 B │ (2, 16, 4) │ float32 │ ✔ (512.0 B total) │
|
||||
├──────────┼──────────────┼─────────┼───────────────────┤
|
||||
│ 16.1 KiB │ │ │ │
|
||||
└──────────┴──────────────┴─────────┴───────────────────┘
|
||||
""".lstrip("\n")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue