1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-13 09:03:49 +02:00
halo_comparison/halo_plot.py
2023-01-11 12:22:07 +01:00

139 lines
5.1 KiB
Python

from pathlib import Path
from sys import argv
from typing import List
import h5py
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from matplotlib.patches import Circle
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from halo_vis import Coords
from paths import base_dir, vis_datafile, has_1024_simulations
from read_vr_files import read_velo_halos
from utils import figsize_from_page_fraction, rowcolumn_labels, waveforms
def coord_to_2d_extent(coords: Coords):
radius, X, Y, Z = coords
return X - radius, X + radius, Y - radius, Y + radius
def in_area(coords: Coords, xobj, yobj, zobj, factor=1.3) -> bool:
radius, xcenter, ycenter, zcenter = coords
radius *= factor
return (
(xcenter - radius < xobj < xcenter + radius)
and (ycenter - radius < yobj < ycenter + radius)
and (zcenter - radius < zobj < zcenter + radius)
)
def main():
offset = 2
resolutions = [128, 256, 512]
if argv[1] == "box":
initial_halo_id = 0
else:
initial_halo_id = int(argv[1])
is_box = not initial_halo_id
if has_1024_simulations:
resolutions.append(1024)
fig: Figure = plt.figure(
figsize=figsize_from_page_fraction(columns=2, height_to_width=1.05), layout="constrained"
)
axes: List[List[Axes]] = fig.subplots(
len(waveforms), 4, sharex="row", sharey="row"
)
with h5py.File(vis_datafile) as vis_out:
halo_group = vis_out[str(initial_halo_id)]
vmin, vmax = halo_group["vmin_vmax"]
print(vmin, vmax)
for i, waveform in enumerate(waveforms):
for j, resolution in enumerate(resolutions):
dir = base_dir / f"{waveform}_{resolution}_100"
halos = read_velo_halos(dir)
ax = axes[i][j]
dataset_group = halo_group[f"{waveform}_{resolution}"]
rho = np.asarray(dataset_group["rho"])
# radius, X, Y, Z
coords: Coords = tuple(dataset_group["coords"])
radius, X, Y, Z = coords
mass = dataset_group["mass"][()] # get scalar value from Dataset
main_halo_id = dataset_group["halo_id"][()] if initial_halo_id else None
vmin_scaled = (vmin + offset) * mass
vmax_scaled = (vmax + offset) * mass
rho = (rho + offset) * mass
extent = coord_to_2d_extent(coords)
img = ax.imshow(
rho.T.T,
norm=LogNorm(vmin=vmin_scaled, vmax=vmax_scaled),
extent=extent,
origin="lower",
cmap="Greys",
interpolation="none"
) # ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
if j == 0:
scalebar = AnchoredSizeBar(
ax.transData,
1,
"1 Mpc",
"lower left",
# pad=0.1,
color="black",
frameon=False,
# size_vertical=1
)
ax.add_artist(scalebar)
if not is_box:
found_main_halo = False
for halo_id, halo in halos.iterrows():
if halo["Vmax"] > 50:
if in_area(coords, halo.X, halo.Y, halo.Z):
if halo_id == main_halo_id:
color = "C2"
elif halo["Structuretype"] > 10:
color = "C1"
else:
color = "C0"
if halo_id == main_halo_id:
found_main_halo = True
print("plotting main halo")
circle = Circle(
(halo.Y - Y + X, halo.X - X + Y),
halo["R_200crit"],
zorder=10,
linewidth=1,
edgecolor=color,
fill=None,
alpha=0.2,
)
ax.add_artist(circle)
assert found_main_halo
print(img)
# break
# break
ylabels = waveforms
rowcolumn_labels(axes, ylabels, isrow=True)
rowcolumn_labels(axes, resolutions, isrow=False)
fig.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)
fig.savefig(Path(f"~/tmp/halo_plot_{initial_halo_id}.pdf").expanduser())
fig.savefig(Path(f"~/tmp/halo_plot_{initial_halo_id}.png").expanduser(), dpi=300)
plt.show()
if __name__ == "__main__":
main()