1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/comparison_plot.py

391 lines
14 KiB
Python
Raw Normal View History

2022-05-24 17:06:49 +02:00
from pathlib import Path
2022-07-21 16:05:27 +02:00
from sys import argv
2022-08-01 14:03:46 +02:00
from typing import List
2022-05-09 15:20:10 +02:00
2022-06-10 11:06:32 +02:00
import numpy as np
2022-05-04 13:42:57 +02:00
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
2022-08-04 17:30:39 +02:00
from matplotlib.axis import XTick, YTick
2022-07-18 19:27:56 +02:00
from matplotlib.collections import QuadMesh
2022-06-10 11:06:32 +02:00
from matplotlib.colors import LogNorm
2022-05-04 13:42:57 +02:00
from matplotlib.figure import Figure
2022-08-01 14:03:46 +02:00
from matplotlib.patches import Polygon
2022-08-01 11:33:50 +02:00
from numpy import inf
2022-07-28 00:43:26 +02:00
2022-07-18 19:27:56 +02:00
from halo_vis import get_comp_id
from paths import base_dir
2022-07-29 13:08:05 +02:00
from utils import figsize_from_page_fraction, rowcolumn_labels, waveforms, tex_fmt
# density like in Vr:
2022-07-18 19:27:56 +02:00
2022-07-12 16:09:52 +02:00
G = 43.022682 # in Mpc (km/s)^2 / (10^10 Msun)
vmaxs = {"Mvir": 52, "Vmax": 93, "cNFW": 31}
2022-07-12 16:09:52 +02:00
2022-07-29 13:08:05 +02:00
units = {
"distance": "Mpc",
"Mvir": r"10^{10} \textrm{ M}_\odot",
"Vmax": r"\textrm{km } \textrm{s}^{-1}", # TODO
2022-07-29 13:08:05 +02:00
}
def concentration(row, halo_type: str) -> bool:
r_200crit = row[f"{halo_type}_R_200crit"]
2022-07-12 16:16:00 +02:00
if r_200crit <= 0:
cnfw = -1
colour = "orange"
2022-07-29 13:08:05 +02:00
return False
# return cnfw, colour
2022-07-12 16:09:52 +02:00
r_size = row[
f"{halo_type}_R_size"
] # largest difference from center of mass to any halo particle
m_200crit = row[f"{halo_type}_Mass_200crit"]
vmax = row[
f"{halo_type}_Vmax"
] # largest velocity coming from enclosed mass profile calculation
rmax = row[f"{halo_type}_Rmax"]
npart = row[f"{halo_type}_npart"]
2022-07-12 16:16:00 +02:00
VmaxVvir2 = vmax ** 2 * r_200crit / (G * m_200crit)
if VmaxVvir2 <= 1.05:
if m_200crit == 0:
cnfw = r_size / rmax
2022-07-29 13:08:05 +02:00
return False
# colour = 'white'
2022-07-12 16:16:00 +02:00
else:
cnfw = r_200crit / rmax
2022-07-29 13:08:05 +02:00
return False
# colour = 'white'
2022-07-12 16:16:00 +02:00
else:
if npart >= 100: # only calculate cnfw for groups with more than 100 particles
cnfw = row[f"{halo_type}_cNFW"]
2022-07-29 13:08:05 +02:00
return True
# colour = 'black'
2022-07-12 16:09:52 +02:00
else:
2022-07-12 16:16:00 +02:00
if m_200crit == 0:
cnfw = r_size / rmax
2022-07-29 13:08:05 +02:00
return False
# colour = 'white'
2022-07-12 16:09:52 +02:00
else:
2022-07-12 16:16:00 +02:00
cnfw = r_200crit / rmax
2022-07-29 13:08:05 +02:00
return False
# colour = 'white'
# assert np.isclose(cnfw, row[f'{halo_type}_cNFW'])
#
# return cnfw, colour
2022-07-12 16:09:52 +02:00
2022-08-01 11:33:50 +02:00
def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
print("WARNING: Can only plot hist2d of properties with comp_ or ref_ right now!")
print(f" Selected property: {property}")
x_col = f"ref_{property}"
y_col = f"comp_{property}"
df = pd.read_csv(file)
2022-08-01 11:33:50 +02:00
# if mode == 'concentration_analysis':
# min_x = min([min(df[x_col]), min(df[y_col])])
# max_x = max([max(df[x_col]), max(df[y_col])])
# df = df.loc[2 * df.ref_cNFW < df.comp_cNFW]
# else:
min_x = min([min(df[x_col]), min(df[y_col])])
max_x = max([max(df[x_col]), max(df[y_col])])
2022-07-21 12:34:57 +02:00
num_bins = 100
bins = np.geomspace(min_x, max_x, num_bins)
if property == "cNFW":
2022-07-29 13:08:05 +02:00
rows = []
2022-07-12 16:12:34 +02:00
for i, row in df.iterrows():
2022-07-29 13:08:05 +02:00
comp_cnfw_normal = concentration(row, halo_type="comp")
ref_cnfw_normal = concentration(row, halo_type="ref")
2022-07-29 13:08:05 +02:00
cnfw_normal = comp_cnfw_normal and ref_cnfw_normal
if cnfw_normal:
rows.append(row)
df = pd.concat(rows, axis=1).T
print(df)
if property == "Mvir":
2022-07-21 12:34:57 +02:00
stds = []
2022-07-28 00:43:26 +02:00
means = []
2022-07-21 12:34:57 +02:00
for rep_row in range(num_bins):
rep_x_left = bins[rep_row]
rep_x_right = bins[rep_row] + 1
rep_bin = (rep_x_left < df[x_col]) & (df[x_col] < rep_x_right)
2022-07-29 13:08:05 +02:00
rep_values = df.loc[rep_bin][y_col] / df.loc[rep_bin][x_col]
if len(rep_bin) < 30:
continue
2022-07-28 00:43:26 +02:00
mean = rep_values.mean()
std = rep_values.std()
means.append(mean)
2022-07-29 13:08:05 +02:00
stds.append(std)
2022-07-28 00:43:26 +02:00
means = np.array(means)
stds = np.array(stds)
args = {"color": "C2", "zorder": 10}
ax.fill_between(bins, means - stds, means + stds, alpha=0.2, **args)
ax.plot(bins, means + stds, alpha=0.5, **args)
ax.plot(bins, means - stds, alpha=0.5, **args)
2022-07-29 13:08:05 +02:00
# ax_scatter.plot(bins, stds, label=f"{file.stem}")
if property in vmaxs:
vmax = vmaxs[property]
else:
vmax = None
print("WARNING: vmax not set")
image: QuadMesh
_, _, _, image = ax.hist2d(
df[x_col],
df[y_col] / df[x_col],
bins=(bins, np.linspace(0, 2, num_bins)),
norm=LogNorm(vmax=vmax),
rasterized=True,
)
2022-07-29 13:08:05 +02:00
# ax.plot([rep_x_left, rep_x_left], [mean - std, mean + std], c="C1")
# ax.annotate(
# text=f"std={std:.2f}", xy=(rep_x_left, mean + std),
# textcoords="axes fraction", xytext=(0.1, 0.9),
# arrowprops={}
# )
print("vmin/vmax", image.norm.vmin, image.norm.vmax)
# fig.colorbar(hist)
ax.set_xscale("log")
# ax.set_yscale("log")
2022-07-29 13:08:05 +02:00
ax.set_xlim(min(df[x_col]), max(df[y_col]))
ax.plot(
[min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C1", zorder=10
)
2022-07-21 16:05:27 +02:00
2022-07-18 19:27:56 +02:00
return x_col, y_col
# ax.set_title(file.name)
# fig.savefig(Path(f"~/tmp/comparison_{file.stem}.pdf").expanduser())
2022-07-18 19:27:56 +02:00
# fig.suptitle
2022-07-12 16:09:52 +02:00
2022-08-01 11:33:50 +02:00
def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=None):
df = pd.read_csv(file)
2022-08-01 11:33:50 +02:00
if m_min:
df = df.loc[(m_min < df["ref_Mvir"]) & (df["ref_Mvir"] < m_max)]
num_bins = 100
histtype = "bar"
label = None
density = False
if property == "distance":
bins = np.geomspace(min(df[property]), max(df[property]), 100)
mean = df[property].mean()
median = df[property].median()
ax.axvline(mean, label="mean", color="C1")
ax.axvline(median, label="median", color="C2")
else:
bins = num_bins
if property == "match":
histtype = "step"
2022-08-01 14:03:46 +02:00
labels = {
(-inf, 30): "$M<30$",
(None, None): "$M$",
(30, 100): "$30<M<100$",
(100, inf): "$100<M$",
}
label = labels[(m_min, m_max)]
2022-08-01 11:33:50 +02:00
density = True
2022-08-01 14:03:46 +02:00
if property == "match":
hist_val, bin_edges = np.histogram(df[property], bins=bins, density=density)
bin_centers = []
for i in range(len(hist_val)):
bin_centers.append((bin_edges[i] + bin_edges[i + 1]) / 2)
2022-05-06 13:23:31 +02:00
2022-08-01 14:03:46 +02:00
ax.plot(bin_centers, hist_val, label=label)
else:
patches: List[Polygon]
hist_val, bin_edges, patches = ax.hist(
df[property], bins=bins, histtype=histtype, label=label, density=density
)
2022-07-12 16:09:52 +02:00
2022-07-21 16:05:27 +02:00
comparisons_dir = base_dir / "comparisons"
hist_properties = ["distance", "match", "num_skipped_for_mass"]
2022-05-04 13:42:57 +02:00
2022-07-18 19:27:56 +02:00
comparisons = [(256, 512), (256, 1024)] # , (512, 1024)
2022-05-04 13:42:57 +02:00
2022-07-21 16:05:27 +02:00
2022-08-01 11:33:50 +02:00
def compare_property(property, show: bool):
2022-07-21 16:05:27 +02:00
is_hist_property = property in hist_properties
2022-07-18 19:27:56 +02:00
fig: Figure
fig, axes = plt.subplots(
len(waveforms),
len(comparisons),
sharey="all",
sharex="all",
2022-08-01 14:03:46 +02:00
figsize=figsize_from_page_fraction(columns=2),
2022-07-18 19:27:56 +02:00
)
for i, waveform in enumerate(waveforms):
for j, (ref_res, comp_res) in enumerate(comparisons):
file_id = get_comp_id(waveform, ref_res, waveform, comp_res)
file = comparisons_dir / file_id
print(file)
ax: Axes = axes[i, j]
2022-07-21 16:05:27 +02:00
is_bottom_row = i == len(waveforms) - 1
2022-08-01 14:03:46 +02:00
is_top_row = i == 0
2022-07-21 16:05:27 +02:00
is_left_col = j == 0
if not is_hist_property:
2022-07-29 13:08:05 +02:00
x_labels = {
"Mvir": ("M", "vir"),
"Vmax": ("V", "max"),
2022-08-01 11:33:50 +02:00
"cNFW": ("C", None),
2022-07-29 13:08:05 +02:00
}
2022-08-01 11:33:50 +02:00
x_col, y_col = plot_comparison_hist2d(ax, file, property)
2022-07-29 13:08:05 +02:00
lab_a, lab_b = x_labels[property]
unit = (
f"[{units[property]}]"
if property in units and units[property]
else ""
)
2022-07-21 16:05:27 +02:00
if is_bottom_row:
2022-07-29 13:08:05 +02:00
if lab_b:
ax.set_xlabel(
tex_fmt(
r"$AA_{\textrm{BB},\textrm{ CC}} \textrm{ } DD$",
lab_a,
lab_b,
ref_res,
unit,
)
)
# fig.supxlabel(tex_fmt(r"$AA_{\textrm{BB},\textrm{ } CC} \textrm{ } DD$", lab_a, lab_b, ref_res, unit), fontsize='medium')
2022-07-29 13:08:05 +02:00
else:
ax.set_xlabel(
tex_fmt(
r"$AA_{\textrm{BB}} \textrm{ } CC$",
lab_a,
ref_res,
unit,
)
)
# fig.supxlabel(tex_fmt(r"$AA_{BB} \textrm{ } CC$", lab_a, ref_res, unit), fontsize='medium')
2022-07-21 16:05:27 +02:00
if is_left_col:
2022-07-29 13:08:05 +02:00
if lab_b:
# ax.set_ylabel(
# tex_fmt(r"$AA_{\textrm{BB},\textrm{comp}} \textrm{ } / \textrm{ } AA_{\textrm{BB},\textrm{CC}}$",
# lab_a, lab_b, ref_res))
# fig.text(0.015, 0.5, tex_fmt(r"$AA_{\textrm{BB},\textrm{ comp}} \textrm{ } / \textrm{ } AA_{\textrm{BB},\textrm{ CC}}$", lab_a, lab_b, ref_res), va='center', rotation='vertical', size='medium')
fig.supylabel(
tex_fmt(
r"$AA_{\textrm{BB},\textrm{ comp}} \textrm{ } / \textrm{ } AA_{\textrm{BB},\textrm{ CC}}$",
lab_a,
lab_b,
ref_res,
),
fontsize="medium",
fontvariant="small-caps",
)
2022-07-29 13:08:05 +02:00
else:
# ax.set_ylabel(
# tex_fmt(r"$AA_{\textrm{comp}} \textrm{ } / \textrm{ } AA_{\textrm{BB}}$",
# lab_a, ref_res))
# fig.text(0.015, 0.5, tex_fmt(r"$AA_{\textrm{comp}} \textrm{ } / \textrm{ } AA_{\textrm{BB}}$", lab_a, ref_res), va='center', rotation='vertical', size='medium')
fig.supylabel(
tex_fmt(
r"$AA_{\textrm{comp}} \textrm{ } / \textrm{ } AA_{\textrm{BB}}$",
lab_a,
ref_res,
),
fontsize="medium",
)
2022-07-29 13:08:05 +02:00
# ax.set_ylabel(f"{property}_{{comp}}/{property}_{ref_res}")
ax.text(
0.975,
0.9,
f"comp = {comp_res}",
horizontalalignment="right",
verticalalignment="top",
transform=ax.transAxes,
)
2022-07-21 16:05:27 +02:00
else:
2022-08-01 11:33:50 +02:00
if property == "match":
if not (is_bottom_row and is_left_col):
ax.text(
0.05,
0.9,
f"comp = {comp_res}",
horizontalalignment="left",
verticalalignment="top",
transform=ax.transAxes,
)
2022-08-01 11:33:50 +02:00
# mass_bins = np.geomspace(10, 30000, num_mass_bins)
plot_comparison_hist(ax, file, property)
2022-08-01 14:03:46 +02:00
mass_bins = [-inf, 30, 100, inf]
2022-08-01 11:33:50 +02:00
for k in range(len(mass_bins) - 1):
m_min = mass_bins[k]
m_max = mass_bins[k + 1]
plot_comparison_hist(ax, file, property, m_min, m_max)
if is_bottom_row and is_left_col:
ax.legend()
else:
ax.text(
0.05,
0.9,
f"comp = {comp_res}",
horizontalalignment="left",
verticalalignment="top",
transform=ax.transAxes,
)
2022-08-01 11:33:50 +02:00
plot_comparison_hist(ax, file, property)
x_labels = {"match": "$J$", "distance": "$D$ [$R_\mathrm{{vir}}$]"}
2022-07-21 16:05:27 +02:00
if is_bottom_row:
2022-08-01 11:33:50 +02:00
ax.set_xlabel(x_labels[property])
2022-07-21 16:05:27 +02:00
if is_left_col:
2022-08-04 14:04:18 +02:00
if property == "match":
# ax.set_ylabel(r"$p(J)$")
fig.supylabel(r"$p(J)$", fontsize="medium")
2022-08-04 14:04:18 +02:00
else:
# ax.set_ylabel(r"\# Halos")
fig.supylabel(r"\# Halos", fontsize="medium")
2022-08-01 11:33:50 +02:00
if property == "distance":
ax.set_xscale("log")
ax.set_yscale("log")
if is_bottom_row and is_left_col:
ax.legend()
2022-08-01 14:03:46 +02:00
if not is_top_row:
2022-08-04 17:30:39 +02:00
last_ytick: YTick = ax.yaxis.get_major_ticks()[-1]
last_ytick.set_visible(False)
2022-08-01 14:03:46 +02:00
if property == "Mvir" and is_top_row:
particle_masses = {256: 0.23524624, 512: 0.02940578, 1024: 0.0036757225}
2022-08-01 14:03:46 +02:00
partmass = particle_masses[ref_res]
def mass2partnum(mass: float) -> float:
return mass / partmass
def partnum2mass(partnum: float) -> float:
return partnum * partmass
sec_ax = ax.secondary_xaxis(
"top", functions=(mass2partnum, partnum2mass)
)
sec_ax.set_xlabel(r"\textrm{Halo Size }[\# \textrm{particles}]")
2022-07-21 16:05:27 +02:00
# rowcolumn_labels(axes, comparisons, isrow=False)
2022-07-18 19:27:56 +02:00
rowcolumn_labels(axes, waveforms, isrow=True)
fig.tight_layout()
2022-08-01 14:03:46 +02:00
fig.subplots_adjust(hspace=0)
2022-07-18 19:27:56 +02:00
fig.savefig(Path(f"~/tmp/comparison_{property}.pdf").expanduser())
2022-07-29 13:08:05 +02:00
if show:
plt.show()
2022-07-21 16:05:27 +02:00
def main():
# properties = ['group_size', 'Mass_200crit', 'Mass_tot', 'Mvir', 'R_200crit', 'Rvir', 'Vmax', 'cNFW', 'q',
2022-08-01 11:33:50 +02:00
# 's']
2022-07-21 16:05:27 +02:00
if len(argv) > 1:
properties = argv[1:]
else:
2022-08-01 14:03:46 +02:00
properties = ["Mvir", "Vmax", "cNFW", "distance", "match"]
2022-07-21 16:05:27 +02:00
for property in properties:
2022-08-01 11:33:50 +02:00
compare_property(property, show=len(argv) == 2)
2022-07-21 16:05:27 +02:00
if __name__ == "__main__":
2022-07-21 16:05:27 +02:00
main()