mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
document print_array_stats and add dtype to it
This commit is contained in:
parent
e1dc02aee2
commit
9ac3ea6715
3 changed files with 41 additions and 11 deletions
27
README.md
27
README.md
|
@ -7,7 +7,7 @@ pip install git+https://github.com/Findus23/jax-array-info.git
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from jax_array_info import sharding_info, sharding_vis
|
from jax_array_info import sharding_info, sharding_vis, print_array_stats
|
||||||
```
|
```
|
||||||
|
|
||||||
## `sharding_info(arr)`
|
## `sharding_info(arr)`
|
||||||
|
@ -59,6 +59,31 @@ sharding_vis(array)
|
||||||
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
|
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## `print_array_stats()`
|
||||||
|
|
||||||
|
Shows a nice overview over the all currently allocated arrays ordered by size.
|
||||||
|
|
||||||
|
**Disclaimer**: This uses `jax.live_arrays()` to get its information. There might be allocated arrays that are missing in this view. Also
|
||||||
|
|
||||||
|
```python
|
||||||
|
arr = jax.numpy.zeros(shape=(16, 16, 16))
|
||||||
|
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))
|
||||||
|
|
||||||
|
print_array_stats()
|
||||||
|
```
|
||||||
|
|
||||||
|
```text
|
||||||
|
allocated jax arrays
|
||||||
|
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
||||||
|
┃ size ┃ shape ┃ sharded ┃
|
||||||
|
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ 16.0 KiB │ (16, 16, 16) │ │
|
||||||
|
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
|
||||||
|
├──────────┼──────────────┼───────────────────┤
|
||||||
|
│ 16.1 KiB │ │ │
|
||||||
|
└──────────┴──────────────┴───────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
See [`tests/`](./tests/jaxtest.py)
|
See [`tests/`](./tests/jaxtest.py)
|
||||||
|
|
|
@ -17,6 +17,7 @@ def print_array_stats():
|
||||||
table = Table(title="allocated jax arrays")
|
table = Table(title="allocated jax arrays")
|
||||||
table.add_column("size")
|
table.add_column("size")
|
||||||
table.add_column("shape")
|
table.add_column("shape")
|
||||||
|
table.add_column("dtype")
|
||||||
table.add_column("sharded", justify="center")
|
table.add_column("sharded", justify="center")
|
||||||
total_size = 0
|
total_size = 0
|
||||||
for arr in array_stats_data():
|
for arr in array_stats_data():
|
||||||
|
@ -26,7 +27,11 @@ def print_array_stats():
|
||||||
file_size /= len(arr.sharding.device_set)
|
file_size /= len(arr.sharding.device_set)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
total_size += file_size
|
total_size += file_size
|
||||||
table.add_row(pretty_byte_size(file_size), str(arr.shape), f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
|
table.add_row(
|
||||||
|
pretty_byte_size(file_size),
|
||||||
|
str(arr.shape),
|
||||||
|
str(arr.dtype),
|
||||||
|
f"✔ ({pretty_byte_size(arr.nbytes)} total)" if is_sharded else "")
|
||||||
table.add_section()
|
table.add_section()
|
||||||
table.add_row(pretty_byte_size(total_size))
|
table.add_row(pretty_byte_size(total_size))
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
|
@ -380,14 +380,14 @@ def test_array_stats(capsys):
|
||||||
|
|
||||||
assert capsys.readouterr().out == """
|
assert capsys.readouterr().out == """
|
||||||
allocated jax arrays
|
allocated jax arrays
|
||||||
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
|
||||||
┃ size ┃ shape ┃ sharded ┃
|
┃ size ┃ shape ┃ dtype ┃ sharded ┃
|
||||||
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
|
||||||
│ 16.0 KiB │ (16, 16, 16) │ │
|
│ 16.0 KiB │ (16, 16, 16) │ float32 │ │
|
||||||
│ 64.0 B │ (2, 16, 4) │ ✔ (512.0 B total) │
|
│ 64.0 B │ (2, 16, 4) │ float32 │ ✔ (512.0 B total) │
|
||||||
├──────────┼──────────────┼───────────────────┤
|
├──────────┼──────────────┼─────────┼───────────────────┤
|
||||||
│ 16.1 KiB │ │ │
|
│ 16.1 KiB │ │ │ │
|
||||||
└──────────┴──────────────┴───────────────────┘
|
└──────────┴──────────────┴─────────┴───────────────────┘
|
||||||
""".lstrip("\n")
|
""".lstrip("\n")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue