1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/halo_plot.py
2022-07-05 15:40:17 +02:00

96 lines
3.8 KiB
Python

from typing import List
import h5py
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from matplotlib.patches import Circle
from halo_vis import Coords
from paths import base_dir, vis_datafile
from read_vr_files import read_velo_halos
def coord_to_2d_extent(coords: Coords):
radius, X, Y, Z = coords
return X - radius, X + radius, Y - radius, Y + radius
def in_area(coords: Coords, xobj, yobj, zobj, factor=1.3) -> bool:
radius, xcenter, ycenter, zcenter = coords
radius *= factor
return (
(xcenter - radius < xobj < xcenter + radius)
and
(ycenter - radius < yobj < ycenter + radius)
and
(zcenter - radius < zobj < zcenter + radius)
)
def main():
rows = ["shannon", "DB8", "DB4", "DB2"]
offset = 2
columns = [128, 256, 512]
fig: Figure = plt.figure(figsize=(9, 9))
axes: List[List[Axes]] = fig.subplots(len(rows), len(columns), sharex="row", sharey="row")
with h5py.File(vis_datafile) as vis_out:
vmin, vmax = vis_out["vmin_vmax"]
print(vmin, vmax)
for i, waveform in enumerate(rows):
for j, resolution in enumerate(columns):
dir = base_dir / f"{waveform}_{resolution}_100"
halos = read_velo_halos(dir)
ax = axes[i][j]
rho = np.asarray(vis_out[f"{waveform}_{resolution}_rho"])
# radius, X, Y, Z
coords: Coords = tuple(vis_out[f"{waveform}_{resolution}_coords"])
mass = vis_out[f"{waveform}_{resolution}_mass"][()] # get scalar value from Dataset
main_halo_id = vis_out[f"{waveform}_{resolution}_halo_id"][()]
vmin_scaled = (vmin + offset) * mass
vmax_scaled = (vmax + offset) * mass
rho = (rho + offset) * mass
extent = coord_to_2d_extent(coords)
img = ax.imshow(rho.T, norm=LogNorm(vmin=vmin_scaled, vmax=vmax_scaled), extent=extent,
origin="lower")
found_main_halo = False
for halo_id, halo in halos.iterrows():
if halo["Vmax"] > 100:
if in_area(coords, halo.X, halo.Y, halo.Z):
color = "red" if halo_id == main_halo_id else "white"
if halo_id == main_halo_id:
found_main_halo = True
print("plotting main halo")
circle = Circle(
(halo.X, halo.Y),
halo["Rvir"], zorder=10,
linewidth=1, edgecolor=color, fill=None, alpha=.2
)
ax.add_artist(circle)
assert found_main_halo
print(img)
# break
# break
pad = 5
# based on https://stackoverflow.com/a/25814386/4398037
for ax, col in zip(axes[0], columns):
ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
xycoords='axes fraction', textcoords='offset points',
size='large', ha='center', va='baseline')
for ax, row in zip(axes[:, 0], rows):
ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
xycoords=ax.yaxis.label, textcoords='offset points',
size='large', ha='right', va='center')
fig.tight_layout()
fig.subplots_adjust(right=0.825)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(img, cax=cbar_ax)
fig.savefig("halo_plot.png", dpi=600)
plt.show()
if __name__ == '__main__':
main()