1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/halo_plot.py

147 lines
5.5 KiB
Python
Raw Normal View History

2022-07-14 15:14:26 +02:00
from pathlib import Path
from sys import argv
from typing import List
import h5py
import numpy as np
from matplotlib import pyplot as plt
2022-06-01 12:12:45 +02:00
from matplotlib.axes import Axes
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
2022-06-01 11:31:42 +02:00
from matplotlib.patches import Circle
2022-07-14 15:14:26 +02:00
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
2022-07-05 15:40:17 +02:00
from halo_vis import Coords
2022-07-14 15:14:26 +02:00
from paths import base_dir, vis_datafile, has_1024_simulations
2022-06-01 11:31:42 +02:00
from read_vr_files import read_velo_halos
2022-07-21 12:34:57 +02:00
from utils import figsize_from_page_fraction, rowcolumn_labels, waveforms
2022-06-01 11:31:42 +02:00
2022-07-05 15:40:17 +02:00
def coord_to_2d_extent(coords: Coords):
radius, X, Y, Z = coords
return X - radius, X + radius, Y - radius, Y + radius
2022-06-01 11:31:42 +02:00
2022-07-05 15:40:17 +02:00
def in_area(coords: Coords, xobj, yobj, zobj, factor=1.3) -> bool:
radius, xcenter, ycenter, zcenter = coords
radius *= factor
return (
2022-08-18 14:21:47 +02:00
(xcenter - radius < xobj < xcenter + radius)
and (ycenter - radius < yobj < ycenter + radius)
and (zcenter - radius < zobj < zcenter + radius)
2022-07-05 15:40:17 +02:00
)
2022-06-01 11:31:42 +02:00
def main():
2022-05-31 17:22:57 +02:00
offset = 2
2022-07-21 12:35:25 +02:00
resolutions = [128, 256, 512]
2022-08-18 14:21:47 +02:00
if argv[1] == "box":
initial_halo_id = 0
else:
initial_halo_id = int(argv[1])
is_box = not initial_halo_id
2022-07-14 15:14:26 +02:00
if has_1024_simulations:
2022-07-21 12:35:25 +02:00
resolutions.append(1024)
fig: Figure = plt.figure(
2022-08-18 15:24:43 +02:00
figsize=figsize_from_page_fraction(columns=2, height_to_width=1 if is_box else 1.2)
)
axes: List[List[Axes]] = fig.subplots(
2022-08-18 14:21:47 +02:00
len(waveforms) if is_box else len(waveforms) * 2, 4, sharex="row", sharey="row"
)
2022-06-01 11:31:42 +02:00
with h5py.File(vis_datafile) as vis_out:
2022-07-14 15:14:26 +02:00
halo_group = vis_out[str(initial_halo_id)]
vmin, vmax = halo_group["vmin_vmax"]
print(vmin, vmax)
2022-07-21 12:34:57 +02:00
for i, waveform in enumerate(waveforms):
2022-07-21 12:35:25 +02:00
for j, resolution in enumerate(resolutions):
2022-06-01 11:31:42 +02:00
dir = base_dir / f"{waveform}_{resolution}_100"
halos = read_velo_halos(dir)
2022-08-18 14:21:47 +02:00
if is_box:
ax_both = axes[i][j]
else:
ax_both = axes[i * 2][j]
ax_without = axes[i * 2 + 1][j]
2022-07-14 15:14:26 +02:00
dataset_group = halo_group[f"{waveform}_{resolution}"]
rho = np.asarray(dataset_group["rho"])
2022-07-05 15:40:17 +02:00
# radius, X, Y, Z
2022-07-14 15:14:26 +02:00
coords: Coords = tuple(dataset_group["coords"])
2022-08-18 14:21:47 +02:00
radius, X, Y, Z = coords
2022-07-14 15:14:26 +02:00
mass = dataset_group["mass"][()] # get scalar value from Dataset
2022-08-18 14:21:47 +02:00
main_halo_id = dataset_group["halo_id"][()] if initial_halo_id else None
vmin_scaled = (vmin + offset) * mass
vmax_scaled = (vmax + offset) * mass
rho = (rho + offset) * mass
2022-07-05 15:40:17 +02:00
extent = coord_to_2d_extent(coords)
2022-08-18 14:21:47 +02:00
for ax in [ax_both, ax_without] if not is_box else [ax_both]:
img = ax.imshow(
rho.T.T,
norm=LogNorm(vmin=vmin_scaled, vmax=vmax_scaled),
extent=extent,
origin="lower",
cmap="Greys",
2022-08-19 10:23:24 +02:00
interpolation="none"
2022-08-18 14:21:47 +02:00
) # ax.set_axis_off()
ax.set_xticks([])
ax.set_yticks([])
if not is_box:
2022-08-18 15:24:43 +02:00
ax.set_ylim(Y - radius * .5, Y + radius * .5)
2022-08-18 14:21:47 +02:00
if j == 0:
scalebar = AnchoredSizeBar(
ax.transData,
1,
"1 Mpc",
"lower left",
# pad=0.1,
color="black",
frameon=False,
# size_vertical=1
)
ax.add_artist(scalebar)
if not is_box:
found_main_halo = False
for halo_id, halo in halos.iterrows():
if halo["Vmax"] > 75:
if in_area(coords, halo.X, halo.Y, halo.Z):
if halo_id == main_halo_id:
color = "C2"
elif halo["Structuretype"] > 10:
color = "C0"
else:
color = "C1"
if halo_id == main_halo_id:
found_main_halo = True
print("plotting main halo")
circle = Circle(
(halo.Y - Y + X, halo.X - X + Y),
halo["Rvir"],
zorder=10,
linewidth=1,
edgecolor=color,
fill=None,
alpha=0.2,
)
ax_both.add_artist(circle)
assert found_main_halo
print(img)
2022-07-14 15:14:26 +02:00
2022-06-01 11:31:42 +02:00
# break
# break
2022-08-18 14:21:47 +02:00
ylabels = [item for item in waveforms for _ in range(2)]
rowcolumn_labels(axes, ylabels, isrow=True)
2022-07-21 12:35:25 +02:00
rowcolumn_labels(axes, resolutions, isrow=False)
2022-07-14 15:14:26 +02:00
fig.tight_layout()
2022-08-18 14:21:47 +02:00
fig.subplots_adjust(wspace=0, hspace=0)
2022-07-14 15:14:26 +02:00
fig.savefig(Path(f"~/tmp/halo_plot_{initial_halo_id}.pdf").expanduser())
2022-07-28 00:43:05 +02:00
fig.savefig(Path(f"~/tmp/halo_plot_{initial_halo_id}.png").expanduser(), dpi=300)
2022-07-14 15:14:26 +02:00
plt.show()
if __name__ == "__main__":
main()