1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00

move comparison improvements

This commit is contained in:
Lukas Winkler 2022-08-01 11:33:50 +02:00
parent 5f63da1b8a
commit 18f99c946d
Signed by: lukas
GPG key ID: 54DE4D798D244853

View file

@ -8,6 +8,7 @@ from matplotlib.axes import Axes
from matplotlib.collections import QuadMesh
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from numpy import inf
from halo_vis import get_comp_id
from paths import base_dir
@ -25,8 +26,8 @@ vmaxs = {
units = {
"distance": "Mpc",
"Mvir": r"10^{10} M_\odot",
"Vmax": "???" # TODO
"Mvir": r"10^{10} \textrm{M}_\odot",
"Vmax": r"\textrm{km} \textrm{s}^{-1}" # TODO
}
@ -72,19 +73,19 @@ def concentration(row, halo_type: str) -> bool:
# return cnfw, colour
def plot_comparison_hist2d(ax: Axes, file: Path, property: str, mode: str):
def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
print("WARNING: Can only plot hist2d of properties with comp_ or ref_ right now!")
print(f" Selected property: {property}")
x_col = f"ref_{property}"
y_col = f"comp_{property}"
df = pd.read_csv(file)
if mode == 'concentration_analysis':
min_x = min([min(df[x_col]), min(df[y_col])])
max_x = max([max(df[x_col]), max(df[y_col])])
df = df.loc[2 * df.ref_cNFW < df.comp_cNFW]
else:
min_x = min([min(df[x_col]), min(df[y_col])])
max_x = max([max(df[x_col]), max(df[y_col])])
# if mode == 'concentration_analysis':
# min_x = min([min(df[x_col]), min(df[y_col])])
# max_x = max([max(df[x_col]), max(df[y_col])])
# df = df.loc[2 * df.ref_cNFW < df.comp_cNFW]
# else:
min_x = min([min(df[x_col]), min(df[y_col])])
max_x = max([max(df[x_col]), max(df[y_col])])
num_bins = 100
bins = np.geomspace(min_x, max_x, num_bins)
if property == 'cNFW':
@ -152,12 +153,30 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str, mode: str):
# fig.suptitle
def plot_comparison_hist(ax: Axes, file: Path, property: str, mode: str):
def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=None):
df = pd.read_csv(file)
if mode == 'concentration_analysis':
df = df.loc[2 * df.ref_cNFW < df.comp_cNFW]
if m_min:
df = df.loc[(m_min < df["ref_Mvir"]) & (df["ref_Mvir"] < m_max)]
ax.hist(df[property][df[property] < 50], bins=100)
num_bins = 100
histtype = "bar"
label = None
density = False
if property == "distance":
bins = np.geomspace(min(df[property]), max(df[property]), 100)
mean = df[property].mean()
median = df[property].median()
ax.axvline(mean, label="mean", color="C1")
ax.axvline(median, label="median", color="C2")
else:
bins = num_bins
if property == "match":
histtype = "step"
label = f"${m_min} < M < {m_max}$"
density = True
ax.hist(df[property], bins=bins, histtype=histtype, label=label, density=density)
comparisons_dir = base_dir / "comparisons"
@ -166,7 +185,7 @@ hist_properties = ["distance", "match", "num_skipped_for_mass"]
comparisons = [(256, 512), (256, 1024)] # , (512, 1024)
def compare_property(property, mode, show: bool):
def compare_property(property, show: bool):
is_hist_property = property in hist_properties
fig: Figure
fig, axes = plt.subplots(
@ -186,9 +205,9 @@ def compare_property(property, mode, show: bool):
x_labels = {
"Mvir": ("M", "vir"),
"Vmax": ("V", "max"),
"cNFW": ("c", None),
"cNFW": ("C", None),
}
x_col, y_col = plot_comparison_hist2d(ax, file, property, mode)
x_col, y_col = plot_comparison_hist2d(ax, file, property)
lab_a, lab_b = x_labels[property]
unit = f"[{units[property]}]" if property in units and units[property] else ""
if is_bottom_row:
@ -207,15 +226,33 @@ def compare_property(property, mode, show: bool):
lab_a, ref_res))
# ax.set_ylabel(f"{property}_{{comp}}/{property}_{ref_res}")
else:
plot_comparison_hist(ax, file, property, mode)
if property == "match":
# mass_bins = np.geomspace(10, 30000, num_mass_bins)
plot_comparison_hist(ax, file, property)
mass_bins = [-inf, 30, 50, 100, inf]
for k in range(len(mass_bins) - 1):
m_min = mass_bins[k]
m_max = mass_bins[k + 1]
plot_comparison_hist(ax, file, property, m_min, m_max)
if is_bottom_row and is_left_col:
ax.legend()
else:
plot_comparison_hist(ax, file, property)
x_labels = {
"match": "$J$",
"distance": "$D$"
}
if is_bottom_row:
x_labels = {
"match": "$J$",
"distance": "$R$"
}
ax.set_xlabel(x_labels[property])
ax.set_xlabel(x_labels[property])
if is_left_col:
ax.set_ylabel(r"\# Halos")
if property == "distance":
ax.set_xscale("log")
ax.set_yscale("log")
if is_bottom_row and is_left_col:
ax.legend()
rowcolumn_labels(axes, comparisons, isrow=False)
rowcolumn_labels(axes, waveforms, isrow=True)
@ -227,17 +264,14 @@ def compare_property(property, mode, show: bool):
def main():
# properties = ['group_size', 'Mass_200crit', 'Mass_tot', 'Mvir', 'R_200crit', 'Rvir', 'Vmax', 'cNFW', 'q',
# 's'] # Mass_FOF and cNFW_200crit don't work, rest looks normal except for cNFW
# 's']
if len(argv) > 1:
properties = argv[1:]
else:
properties = ["Mvir", "Vmax", "cNFW"]
# mode = 'concentration_analysis'
mode = 'concentration_bla'
for property in properties:
compare_property(property, mode, show=len(argv) == 2)
compare_property(property, show=len(argv) == 2)
if __name__ == '__main__':