1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/halo_plot.py

96 lines
3.7 KiB
Python
Raw Normal View History

from typing import List
import h5py
import numpy as np
from matplotlib import pyplot as plt
2022-06-01 12:12:45 +02:00
from matplotlib.axes import Axes
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
2022-06-01 11:31:42 +02:00
from matplotlib.patches import Circle
2022-06-01 11:31:42 +02:00
from cic import Extent
from paths import base_dir, vis_datafile
from read_vr_files import read_velo_halos
def increase_extent_1d(xmin: float, xmax: float, factor: float):
xrange = xmax - xmin
xcenter = (xmax + xmin) / 2
2022-06-01 11:38:54 +02:00
return xcenter - xrange / 2 * factor, xcenter + xrange / 2 * factor
2022-06-01 11:31:42 +02:00
def increase_extent(extent: Extent, factor: float) -> Extent:
xmin, xmax, ymin, ymax = extent
2022-06-01 11:38:54 +02:00
xmin, xmax = increase_extent_1d(xmin, xmax, factor)
ymin, ymax = increase_extent_1d(ymin, ymax, factor)
return xmin, xmax, ymin, ymax
2022-06-01 11:31:42 +02:00
def in_extent(extent: Extent, X, Y, factor=2) -> bool:
2022-06-01 11:38:54 +02:00
xmin, xmax, ymin, ymax = increase_extent(extent, factor)
2022-06-01 11:31:42 +02:00
return (xmin < X < xmax) and (ymin < Y < ymax)
def main():
2022-05-30 19:12:59 +02:00
rows = ["shannon", "DB8", "DB4", "DB2"]
2022-05-31 17:22:57 +02:00
offset = 2
columns = [128, 256, 512]
fig: Figure = plt.figure(figsize=(9, 9))
2022-05-31 17:22:57 +02:00
axes: List[List[Axes]] = fig.subplots(len(rows), len(columns), sharex=True, sharey=True)
2022-06-01 11:31:42 +02:00
with h5py.File(vis_datafile) as vis_out:
vmin, vmax = vis_out["vmin_vmax"]
print(vmin, vmax)
for i, waveform in enumerate(rows):
for j, resolution in enumerate(columns):
2022-06-01 11:31:42 +02:00
dir = base_dir / f"{waveform}_{resolution}_100"
halos = read_velo_halos(dir)
ax = axes[i][j]
rho = np.asarray(vis_out[f"{waveform}_{resolution}_rho"])
2022-06-01 11:31:42 +02:00
extent = tuple(vis_out[f"{waveform}_{resolution}_extent"])
mass = vis_out[f"{waveform}_{resolution}_mass"][()] # get scalar value from Dataset
main_halo_id = vis_out[f"{waveform}_{resolution}_halo_id"][()]
vmin_scaled = (vmin + offset) * mass
vmax_scaled = (vmax + offset) * mass
rho = (rho + offset) * mass
2022-06-01 11:31:42 +02:00
img = ax.imshow(rho.T, norm=LogNorm(vmin=vmin_scaled, vmax=vmax_scaled), extent=extent,
origin="lower")
for halo_id, halo in halos.iterrows():
if halo["Vmax"] > 135:
if in_extent(extent, halo.X, halo.Y):
color = "red" if halo_id == main_halo_id else "white"
if halo_id == main_halo_id:
print(halo_id == main_halo_id, halo_id, main_halo_id, halo["Rvir"])
print("plotting main halo")
circle = Circle(
(halo.X, halo.Y),
halo["Rvir"], zorder=10,
linewidth=1, edgecolor=color, fill=None, alpha=.2
)
ax.add_artist(circle)
print(img)
2022-06-01 11:31:42 +02:00
# break
# break
pad = 5
# based on https://stackoverflow.com/a/25814386/4398037
for ax, col in zip(axes[0], columns):
ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
xycoords='axes fraction', textcoords='offset points',
size='large', ha='center', va='baseline')
for ax, row in zip(axes[:, 0], rows):
ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
xycoords=ax.yaxis.label, textcoords='offset points',
size='large', ha='right', va='center')
fig.tight_layout()
fig.subplots_adjust(right=0.825)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(img, cax=cbar_ax)
fig.savefig("halo_plot.png", dpi=600)
plt.show()
if __name__ == '__main__':
main()