mirror of
https://github.com/Findus23/jax-array-info.git
synced 2024-09-19 15:53:47 +02:00
initial version
This commit is contained in:
commit
58694cb800
7 changed files with 539 additions and 0 deletions
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
.idea/
|
||||||
|
*.egg-info
|
||||||
|
dist/
|
202
LICENSE
Normal file
202
LICENSE
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
56
README.md
Normal file
56
README.md
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# jax-array-info
|
||||||
|
|
||||||
|
This package contains two functions for debugging jax `Array`s:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from jax_array_info import sharding_info, sharding_vis
|
||||||
|
```
|
||||||
|
|
||||||
|
## `sharding_info(arr)`
|
||||||
|
|
||||||
|
`sharding_info(arr)` prints general information about a jax or numpy array with special focus on sharding (
|
||||||
|
supporting `SingleDeviceSharding`, `GSPMDSharding`, `PositionalSharding`, `NamedSharding` and `PmapSharding`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
|
||||||
|
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
|
||||||
|
sharding_info(array, "some_array")
|
||||||
|
```
|
||||||
|
|
||||||
|
```text
|
||||||
|
╭────────────────── some_array ───────────────────╮
|
||||||
|
│ shape: (128, 128, 128) │
|
||||||
|
│ dtype: float32 │
|
||||||
|
│ size: 8.0 MiB │
|
||||||
|
│ NamedSharding: P(None, 'gpus') │
|
||||||
|
│ axis 1 is sharded: CPU 0 contains 0:16 (of 128) │
|
||||||
|
╰─────────────────────────────────────────────────╯
|
||||||
|
```
|
||||||
|
|
||||||
|
## `sharding_vis(arr)`
|
||||||
|
|
||||||
|
A modified version
|
||||||
|
of [`jax.debug.visualize_array_sharding()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.debug.visualize_array_sharding.html)
|
||||||
|
that also supports arrays with more than 2 dimensions (by ignoring non-sharded dimensions in the visualisation until
|
||||||
|
reaching 2 dimensions)
|
||||||
|
|
||||||
|
```python
|
||||||
|
array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
|
||||||
|
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
|
||||||
|
sharding_vis(array)
|
||||||
|
```
|
||||||
|
|
||||||
|
```text
|
||||||
|
─────────── showing dims [0, 1] from original shape (128, 128, 128) ────────────
|
||||||
|
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
│ │ │ │ │ │ │ │ │
|
||||||
|
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
|
||||||
|
```
|
2
jax_array_info/__init__.py
Normal file
2
jax_array_info/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .sharding_info import sharding_info
|
||||||
|
from .sharding_vis import sharding_vis
|
78
jax_array_info/sharding_info.py
Normal file
78
jax_array_info/sharding_info.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
import numpy as np
|
||||||
|
import rich
|
||||||
|
from jax import Array
|
||||||
|
from jax._src.debugging import inspect_array_sharding
|
||||||
|
from jax.core import Tracer
|
||||||
|
from jax.sharding import Sharding, NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, PositionalSharding
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
|
||||||
|
def sharding_info(arr, name=None):
|
||||||
|
if isinstance(arr, np.ndarray):
|
||||||
|
return print_sharding_info(arr, None, name)
|
||||||
|
|
||||||
|
def _info(sharding):
|
||||||
|
print_sharding_info(arr, sharding, name)
|
||||||
|
|
||||||
|
inspect_array_sharding(arr, callback=_info)
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_byte_size(nbytes: int):
|
||||||
|
for unit in ("", "Ki", "Mi", "Gi", "Ti"):
|
||||||
|
if abs(nbytes) < 1024.0:
|
||||||
|
return f"{nbytes:3.1f} {unit}B"
|
||||||
|
nbytes /= 1024.0
|
||||||
|
|
||||||
|
|
||||||
|
def _print_sharding_info_raw(arr: Array, sharding: Sharding, console: Console):
|
||||||
|
shape = arr.shape
|
||||||
|
console.print(f"shape: {shape}")
|
||||||
|
console.print(f"dtype: {arr.dtype}")
|
||||||
|
console.print(f"size: {pretty_byte_size(arr.nbytes)}")
|
||||||
|
|
||||||
|
if isinstance(arr, np.ndarray):
|
||||||
|
console.print("[bold]numpy array")
|
||||||
|
return
|
||||||
|
if not isinstance(arr, Array):
|
||||||
|
raise ValueError(f"is not a jax array, got {type(arr)}")
|
||||||
|
|
||||||
|
device_kind = next(iter(sharding.device_set)).platform.upper()
|
||||||
|
is_in_jit = isinstance(arr, Tracer)
|
||||||
|
if not is_in_jit and not arr.is_fully_addressable:
|
||||||
|
console.print("!is_fully_addressable")
|
||||||
|
if is_in_jit:
|
||||||
|
console.print("[bright_black]called in jit")
|
||||||
|
# if not arr.is_fully_replicated:
|
||||||
|
# console.print("!is_fully_replicated")
|
||||||
|
|
||||||
|
if isinstance(sharding, SingleDeviceSharding):
|
||||||
|
console.print("[red]not sharded")
|
||||||
|
if isinstance(sharding, GSPMDSharding):
|
||||||
|
console.print(sharding)
|
||||||
|
if isinstance(sharding, PositionalSharding):
|
||||||
|
console.print(f"PositionalSharding:")
|
||||||
|
console.print(sharding._ids)
|
||||||
|
if isinstance(sharding, NamedSharding):
|
||||||
|
console.print(f"NamedSharding: P{tuple(sharding.spec)}")
|
||||||
|
|
||||||
|
if isinstance(sharding, PmapSharding):
|
||||||
|
console.print(sharding)
|
||||||
|
return
|
||||||
|
device_indices_map = sharding.devices_indices_map(tuple(shape))
|
||||||
|
slcs = next(iter(device_indices_map.values()))
|
||||||
|
sl: slice
|
||||||
|
for i, sl in enumerate(slcs):
|
||||||
|
if sl.start is None:
|
||||||
|
continue
|
||||||
|
console.print(f"axis {i} is sharded: {device_kind} 0 contains {sl.start}:{sl.stop} (of {shape[i]})")
|
||||||
|
|
||||||
|
|
||||||
|
def print_sharding_info(arr: Array, sharding: Sharding, name=None):
|
||||||
|
console = rich.console.Console()
|
||||||
|
with console.capture() as capture:
|
||||||
|
_print_sharding_info_raw(arr, sharding, console)
|
||||||
|
str_output = capture.get()
|
||||||
|
text = Text.from_ansi(str_output)
|
||||||
|
console.print(Panel(text, expand=False, title=f"[bold]{name}" if name is not None else None))
|
172
jax_array_info/sharding_vis.py
Normal file
172
jax_array_info/sharding_vis.py
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
"""
|
||||||
|
based on visualize_sharding() from jax/_src/debugging.py
|
||||||
|
|
||||||
|
https://github.com/google/jax/blob/main/jax/_src/debugging.py
|
||||||
|
|
||||||
|
# Copyright 2022 The JAX Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Optional, Dict, Tuple, Set
|
||||||
|
|
||||||
|
import rich
|
||||||
|
from jax import Array
|
||||||
|
from jax._src.debugging import _raise_to_slice, _slice_to_chunk_idx, inspect_array_sharding, ColorMap, make_color_iter, \
|
||||||
|
_canonicalize_color, _get_text_color
|
||||||
|
from jax.sharding import Sharding, PmapSharding
|
||||||
|
|
||||||
|
|
||||||
|
def sharding_vis(arr, **kwargs):
|
||||||
|
if not isinstance(arr, Array):
|
||||||
|
raise ValueError(f"is not a jax array, got {type(arr)}")
|
||||||
|
|
||||||
|
def _visualize(sharding):
|
||||||
|
return visualize_sharding(arr.shape, sharding, **kwargs)
|
||||||
|
|
||||||
|
inspect_array_sharding(arr, callback=_visualize)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sharded_dims(shape: Sequence[int], sharding: Sharding) -> list[int]:
|
||||||
|
device_indices_map = sharding.devices_indices_map(tuple(shape))
|
||||||
|
slcs = next(iter(device_indices_map.values()))
|
||||||
|
sharded_dims = []
|
||||||
|
sl: slice
|
||||||
|
for i, sl in enumerate(slcs):
|
||||||
|
if sl.start is not None:
|
||||||
|
sharded_dims.append(i)
|
||||||
|
print(sl)
|
||||||
|
return sharded_dims
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
|
||||||
|
use_color: bool = False, scale: float = 1.,
|
||||||
|
min_width: int = 9, max_width: int = 80,
|
||||||
|
color_map: Optional[ColorMap] = None):
|
||||||
|
"""
|
||||||
|
based on `jax.debug.visualize_array_sharding` and `jax.debug.visualize_sharding`
|
||||||
|
"""
|
||||||
|
console = rich.console.Console(width=max_width)
|
||||||
|
use_color = use_color and console.color_system is not None
|
||||||
|
if use_color and not color_map:
|
||||||
|
try:
|
||||||
|
import matplotlib as mpl # pytype: disable=import-error
|
||||||
|
color_map = mpl.colormaps["tab20b"]
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
use_color = False
|
||||||
|
|
||||||
|
base_height = int(10 * scale)
|
||||||
|
aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0]
|
||||||
|
base_width = int(base_height * aspect_ratio)
|
||||||
|
height_to_width_ratio = 2.5
|
||||||
|
|
||||||
|
# Grab the device kind from the first device
|
||||||
|
device_kind = next(iter(sharding.device_set)).platform.upper()
|
||||||
|
|
||||||
|
device_indices_map = sharding.devices_indices_map(tuple(shape))
|
||||||
|
slices: Dict[Tuple[int, ...], Set[int]] = {}
|
||||||
|
heights: Dict[Tuple[int, ...], Optional[float]] = {}
|
||||||
|
widths: Dict[Tuple[int, ...], float] = {}
|
||||||
|
|
||||||
|
dims = list(range(len(shape)))
|
||||||
|
if isinstance(sharding, PmapSharding):
|
||||||
|
console.print("[red bold] Output for PmapSharding")
|
||||||
|
if len(shape) > 2:
|
||||||
|
raise NotImplementedError("can only visualize PmapSharding with shapes with less than 3 dimensions")
|
||||||
|
if len(shape) > 2 and not isinstance(sharding, PmapSharding):
|
||||||
|
sharded_dims = get_sharded_dims(shape, sharding)
|
||||||
|
if len(sharded_dims) > 2:
|
||||||
|
raise NotImplementedError(f"can only visualize up to 2 sharded dimension. {sharded_dims} are sharded.")
|
||||||
|
chosen_dims = sharded_dims.copy()
|
||||||
|
while len(chosen_dims) < 2:
|
||||||
|
for i in dims:
|
||||||
|
if i not in chosen_dims:
|
||||||
|
chosen_dims.append(i)
|
||||||
|
break
|
||||||
|
chosen_dims.sort()
|
||||||
|
console.rule(title=f"showing dims {chosen_dims} from original shape {shape}")
|
||||||
|
|
||||||
|
for i, (dev, slcs) in enumerate(device_indices_map.items()):
|
||||||
|
assert slcs is not None
|
||||||
|
slcs = tuple(map(_raise_to_slice, slcs))
|
||||||
|
chunk_idxs = tuple(map(_slice_to_chunk_idx, shape, slcs))
|
||||||
|
|
||||||
|
if slcs is None:
|
||||||
|
raise NotImplementedError
|
||||||
|
if len(slcs) > 1:
|
||||||
|
if len(slcs) > 2:
|
||||||
|
slcs = tuple([slcs[i] for i in chosen_dims])
|
||||||
|
chunk_idxs = tuple([chunk_idxs[i] for i in chosen_dims])
|
||||||
|
vert, horiz = slcs
|
||||||
|
vert_size = ((vert.stop - vert.start) if vert.stop is not None
|
||||||
|
else shape[0])
|
||||||
|
horiz_size = ((horiz.stop - horiz.start) if horiz.stop is not None
|
||||||
|
else shape[1])
|
||||||
|
chunk_height = vert_size / shape[0]
|
||||||
|
chunk_width = horiz_size / shape[1]
|
||||||
|
heights[chunk_idxs] = chunk_height
|
||||||
|
widths[chunk_idxs] = chunk_width
|
||||||
|
else:
|
||||||
|
# In the 1D case, we set the height to 1.
|
||||||
|
horiz, = slcs
|
||||||
|
vert = slice(0, 1, None)
|
||||||
|
horiz_size = (
|
||||||
|
(horiz.stop - horiz.start) if horiz.stop is not None else shape[0])
|
||||||
|
chunk_idxs = (0, *chunk_idxs)
|
||||||
|
heights[chunk_idxs] = None
|
||||||
|
widths[chunk_idxs] = horiz_size / shape[0]
|
||||||
|
slices.setdefault(chunk_idxs, set()).add(dev.id)
|
||||||
|
num_rows = max([a[0] for a in slices.keys()]) + 1
|
||||||
|
if len(list(slices.keys())[0]) == 1:
|
||||||
|
num_cols = 1
|
||||||
|
else:
|
||||||
|
num_cols = max([a[1] for a in slices.keys()]) + 1
|
||||||
|
color_iter = make_color_iter(color_map, num_rows, num_cols)
|
||||||
|
table = rich.table.Table(show_header=False, show_lines=not use_color,
|
||||||
|
padding=0,
|
||||||
|
highlight=not use_color, pad_edge=False,
|
||||||
|
box=rich.box.SQUARE if not use_color else None)
|
||||||
|
for i in range(num_rows):
|
||||||
|
col = []
|
||||||
|
for j in range(num_cols):
|
||||||
|
entry = f"{device_kind} " + ",".join([str(s) for s in sorted(slices[i, j])])
|
||||||
|
width, maybe_height = widths[i, j], heights[i, j]
|
||||||
|
width = int(width * base_width * height_to_width_ratio)
|
||||||
|
if maybe_height is None:
|
||||||
|
height = 1
|
||||||
|
else:
|
||||||
|
height = int(maybe_height * base_height)
|
||||||
|
width = min(max(width, min_width), max_width)
|
||||||
|
left_padding, remainder = divmod(width - len(entry) - 2, 2)
|
||||||
|
right_padding = left_padding + remainder
|
||||||
|
top_padding, remainder = divmod(height - 2, 2)
|
||||||
|
bottom_padding = top_padding + remainder
|
||||||
|
if use_color:
|
||||||
|
color = _canonicalize_color(next(color_iter)[:3])
|
||||||
|
text_color = _get_text_color(color)
|
||||||
|
top_padding += 1
|
||||||
|
bottom_padding += 1
|
||||||
|
left_padding += 1
|
||||||
|
right_padding += 1
|
||||||
|
else:
|
||||||
|
color = None
|
||||||
|
text_color = None
|
||||||
|
padding = (top_padding, right_padding, bottom_padding, left_padding)
|
||||||
|
padding = tuple(max(x, 0) for x in padding) # type: ignore
|
||||||
|
col.append(
|
||||||
|
rich.padding.Padding(
|
||||||
|
rich.align.Align(entry, "center", vertical="middle"), padding,
|
||||||
|
style=rich.style.Style(bgcolor=color,
|
||||||
|
color=text_color)))
|
||||||
|
table.add_row(*col)
|
||||||
|
console.print(table, end='\n\n')
|
26
pyproject.toml
Normal file
26
pyproject.toml
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "jax-array-info"
|
||||||
|
version = "0.0.1"
|
||||||
|
authors = [
|
||||||
|
{ name = "Lukas Winkler", email = "python@lw1.at" },
|
||||||
|
]
|
||||||
|
description = "Debugging tool to print information (especially sharding) about jax arrays"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"jax>=0.4.8",
|
||||||
|
"rich"
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
"Homepage" = "https://github.com/Findus23/jax-array-info"
|
||||||
|
"Bug Tracker" = "https://github.com/Findus23/jax-array-info/issues"
|
Loading…
Reference in a new issue