mirror of
https://github.com/Findus23/halo_comparison.git
synced 2024-09-19 16:03:50 +02:00
strongly improve comparison plots
This commit is contained in:
parent
54f3b24a9b
commit
5f63da1b8a
2 changed files with 114 additions and 67 deletions
154
sizes.py
154
sizes.py
|
@ -8,22 +8,35 @@ from matplotlib.axes import Axes
|
|||
from matplotlib.collections import QuadMesh
|
||||
from matplotlib.colors import LogNorm
|
||||
from matplotlib.figure import Figure
|
||||
# density like in Vr:
|
||||
from numpy import log10
|
||||
|
||||
from halo_vis import get_comp_id
|
||||
from paths import base_dir
|
||||
from utils import figsize_from_page_fraction, rowcolumn_labels, waveforms
|
||||
from utils import figsize_from_page_fraction, rowcolumn_labels, waveforms, tex_fmt
|
||||
|
||||
# density like in Vr:
|
||||
|
||||
G = 43.022682 # in Mpc (km/s)^2 / (10^10 Msun)
|
||||
|
||||
vmaxs = {
|
||||
"Mvir": 52,
|
||||
"Vmax": 93,
|
||||
"cNFW": 31
|
||||
}
|
||||
|
||||
def concentration(row, halo_type: str):
|
||||
units = {
|
||||
"distance": "Mpc",
|
||||
"Mvir": r"10^{10} M_\odot",
|
||||
"Vmax": "???" # TODO
|
||||
}
|
||||
|
||||
|
||||
def concentration(row, halo_type: str) -> bool:
|
||||
r_200crit = row[f'{halo_type}_R_200crit']
|
||||
if r_200crit <= 0:
|
||||
cnfw = -1
|
||||
colour = 'orange'
|
||||
return cnfw, colour
|
||||
return False
|
||||
# return cnfw, colour
|
||||
|
||||
r_size = row[f'{halo_type}_R_size'] # largest difference from center of mass to any halo particle
|
||||
m_200crit = row[f'{halo_type}_Mass_200crit']
|
||||
|
@ -34,27 +47,32 @@ def concentration(row, halo_type: str):
|
|||
if VmaxVvir2 <= 1.05:
|
||||
if m_200crit == 0:
|
||||
cnfw = r_size / rmax
|
||||
colour = 'white'
|
||||
return False
|
||||
# colour = 'white'
|
||||
else:
|
||||
cnfw = r_200crit / rmax
|
||||
colour = 'white'
|
||||
return False
|
||||
# colour = 'white'
|
||||
else:
|
||||
if npart >= 100: # only calculate cnfw for groups with more than 100 particles
|
||||
cnfw = row[f'{halo_type}_cNFW']
|
||||
colour = 'black'
|
||||
return True
|
||||
# colour = 'black'
|
||||
else:
|
||||
if m_200crit == 0:
|
||||
cnfw = r_size / rmax
|
||||
colour = 'white'
|
||||
return False
|
||||
# colour = 'white'
|
||||
else:
|
||||
cnfw = r_200crit / rmax
|
||||
colour = 'white'
|
||||
assert np.isclose(cnfw, row[f'{halo_type}_cNFW'])
|
||||
|
||||
return cnfw, colour
|
||||
return False
|
||||
# colour = 'white'
|
||||
# assert np.isclose(cnfw, row[f'{halo_type}_cNFW'])
|
||||
#
|
||||
# return cnfw, colour
|
||||
|
||||
|
||||
def plot_comparison_hist2d(ax: Axes, ax_scatter: Axes, file: Path, property: str, mode: str):
|
||||
def plot_comparison_hist2d(ax: Axes, file: Path, property: str, mode: str):
|
||||
print("WARNING: Can only plot hist2d of properties with comp_ or ref_ right now!")
|
||||
print(f" Selected property: {property}")
|
||||
x_col = f"ref_{property}"
|
||||
|
@ -69,39 +87,50 @@ def plot_comparison_hist2d(ax: Axes, ax_scatter: Axes, file: Path, property: str
|
|||
max_x = max([max(df[x_col]), max(df[y_col])])
|
||||
num_bins = 100
|
||||
bins = np.geomspace(min_x, max_x, num_bins)
|
||||
if mode == "concentration_bla" and property == 'cNFW':
|
||||
colors = []
|
||||
if property == 'cNFW':
|
||||
rows = []
|
||||
for i, row in df.iterrows():
|
||||
comp_cnfw, comp_colour = concentration(row, halo_type="comp") # ref or comp
|
||||
ref_cnfw, ref_colour = concentration(row, halo_type='ref')
|
||||
if comp_colour == 'white' or ref_colour == 'white':
|
||||
colors.append('white')
|
||||
else:
|
||||
colors.append('black')
|
||||
ax.scatter(df[x_col], df[y_col], c=colors, s=1, alpha=.3)
|
||||
else:
|
||||
comp_cnfw_normal = concentration(row, halo_type="comp")
|
||||
|
||||
ref_cnfw_normal = concentration(row, halo_type='ref')
|
||||
cnfw_normal = comp_cnfw_normal and ref_cnfw_normal
|
||||
if cnfw_normal:
|
||||
rows.append(row)
|
||||
df = pd.concat(rows, axis=1).T
|
||||
print(df)
|
||||
if property == "Mvir":
|
||||
stds = []
|
||||
means = []
|
||||
for rep_row in range(num_bins):
|
||||
rep_x_left = bins[rep_row]
|
||||
rep_x_right = bins[rep_row] + 1
|
||||
rep_bin = (rep_x_left < df[x_col]) & (df[x_col] < rep_x_right)
|
||||
rep_values = log10(df.loc[rep_bin][y_col])
|
||||
# if len(rep_values) > 30:
|
||||
rep_values = df.loc[rep_bin][y_col] / df.loc[rep_bin][x_col]
|
||||
if len(rep_bin) < 30:
|
||||
continue
|
||||
mean = rep_values.mean()
|
||||
std = rep_values.std()
|
||||
means.append(mean)
|
||||
stds.append(len(rep_values))
|
||||
# else:
|
||||
# stds.append(np.nan)
|
||||
stds.append(std)
|
||||
means = np.array(means)
|
||||
stds = np.array(stds)
|
||||
print(10 ** (means - stds))
|
||||
ax.fill_between(bins, 10 ** (means - stds), 10 ** (means + stds), color="red", zorder=10, alpha=.6)
|
||||
ax_scatter.step(bins, stds, label=f"{file.stem}")
|
||||
args = {
|
||||
"color": "C2",
|
||||
"zorder": 10
|
||||
}
|
||||
ax.fill_between(bins, means - stds, means + stds, alpha=.2, **args)
|
||||
ax.plot(bins, means + stds, alpha=.5, **args)
|
||||
ax.plot(bins, means - stds, alpha=.5, **args)
|
||||
# ax_scatter.plot(bins, stds, label=f"{file.stem}")
|
||||
|
||||
if property in vmaxs:
|
||||
vmax = vmaxs[property]
|
||||
else:
|
||||
vmax = None
|
||||
print("WARNING: vmax not set")
|
||||
image: QuadMesh
|
||||
_, _, _, image = ax.hist2d(df[x_col], df[y_col], bins=(bins, bins), norm=LogNorm()) # TODO: set vmin/vmax
|
||||
_, _, _, image = ax.hist2d(df[x_col], df[y_col] / df[x_col], bins=(bins, np.linspace(0, 2, num_bins)),
|
||||
norm=LogNorm(vmax=vmax))
|
||||
# ax.plot([rep_x_left, rep_x_left], [mean - std, mean + std], c="C1")
|
||||
# ax.annotate(
|
||||
# text=f"std={std:.2f}", xy=(rep_x_left, mean + std),
|
||||
|
@ -111,11 +140,11 @@ def plot_comparison_hist2d(ax: Axes, ax_scatter: Axes, file: Path, property: str
|
|||
print("vmin/vmax", image.norm.vmin, image.norm.vmax)
|
||||
# fig.colorbar(hist)
|
||||
|
||||
# ax.set_xscale("log")
|
||||
ax.set_xscale("log")
|
||||
# ax.set_yscale("log")
|
||||
ax.set_xlim(min(df[x_col]), max(df[y_col]))
|
||||
|
||||
ax.loglog([min_x, max_x], [min_x, max_x], linewidth=1, color="C2")
|
||||
# ax.axis('scaled')
|
||||
ax.plot([min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C1", zorder=10)
|
||||
|
||||
return x_col, y_col
|
||||
# ax.set_title(file.name)
|
||||
|
@ -124,16 +153,11 @@ def plot_comparison_hist2d(ax: Axes, ax_scatter: Axes, file: Path, property: str
|
|||
|
||||
|
||||
def plot_comparison_hist(ax: Axes, file: Path, property: str, mode: str):
|
||||
print("WARNING: Can only plot hist of properties w/o comp_ or ref_ right now!")
|
||||
print(f" Selected property: {property}")
|
||||
df = pd.read_csv(file)
|
||||
if mode == 'concentration_analysis':
|
||||
df = df.loc[2 * df.ref_cNFW < df.comp_cNFW]
|
||||
|
||||
ax.hist(df[property][df[property] < 50], bins=100)
|
||||
ax.set_xlabel(property)
|
||||
# ax.set_title(file.name)
|
||||
# plt.show()
|
||||
|
||||
|
||||
comparisons_dir = base_dir / "comparisons"
|
||||
|
@ -142,7 +166,7 @@ hist_properties = ["distance", "match", "num_skipped_for_mass"]
|
|||
comparisons = [(256, 512), (256, 1024)] # , (512, 1024)
|
||||
|
||||
|
||||
def compare_property(property, mode):
|
||||
def compare_property(property, mode, show: bool):
|
||||
is_hist_property = property in hist_properties
|
||||
fig: Figure
|
||||
fig, axes = plt.subplots(
|
||||
|
@ -150,10 +174,6 @@ def compare_property(property, mode):
|
|||
sharey="all", sharex="all",
|
||||
figsize=figsize_from_page_fraction(columns=2)
|
||||
)
|
||||
if not is_hist_property:
|
||||
fig_scatter: Figure = plt.figure(figsize=figsize_from_page_fraction())
|
||||
ax_scatter: Axes = fig_scatter.gca()
|
||||
ax_scatter.set_xscale("log")
|
||||
for i, waveform in enumerate(waveforms):
|
||||
for j, (ref_res, comp_res) in enumerate(comparisons):
|
||||
file_id = get_comp_id(waveform, ref_res, waveform, comp_res)
|
||||
|
@ -163,25 +183,45 @@ def compare_property(property, mode):
|
|||
is_bottom_row = i == len(waveforms) - 1
|
||||
is_left_col = j == 0
|
||||
if not is_hist_property:
|
||||
x_col, y_col = plot_comparison_hist2d(ax, ax_scatter, file, property, mode)
|
||||
x_labels = {
|
||||
"Mvir": ("M", "vir"),
|
||||
"Vmax": ("V", "max"),
|
||||
"cNFW": ("c", None),
|
||||
}
|
||||
x_col, y_col = plot_comparison_hist2d(ax, file, property, mode)
|
||||
lab_a, lab_b = x_labels[property]
|
||||
unit = f"[{units[property]}]" if property in units and units[property] else ""
|
||||
if is_bottom_row:
|
||||
ax.set_xlabel(x_col)
|
||||
if lab_b:
|
||||
ax.set_xlabel(tex_fmt(r"$AA_{\textrm{BB},CC} DD$", lab_a, lab_b, ref_res, unit))
|
||||
else:
|
||||
ax.set_xlabel(tex_fmt(r"$AA_{BB} CC$", lab_a, ref_res, unit))
|
||||
if is_left_col:
|
||||
ax.set_ylabel(y_col)
|
||||
if lab_b:
|
||||
ax.set_ylabel(
|
||||
tex_fmt(r"$AA_{\textrm{BB},\textrm{comp}} / AA_{\textrm{BB},\textrm{CC}}$",
|
||||
lab_a, lab_b, ref_res))
|
||||
else:
|
||||
ax.set_ylabel(
|
||||
tex_fmt(r"$AA_{\textrm{comp}} / AA_{\textrm{BB}}$",
|
||||
lab_a, ref_res))
|
||||
# ax.set_ylabel(f"{property}_{{comp}}/{property}_{ref_res}")
|
||||
else:
|
||||
plot_comparison_hist(ax, file, property, mode)
|
||||
if is_bottom_row:
|
||||
ax.set_xlabel(property)
|
||||
x_labels = {
|
||||
"match": "$J$",
|
||||
"distance": "$R$"
|
||||
}
|
||||
ax.set_xlabel(x_labels[property])
|
||||
if is_left_col:
|
||||
ax.set_ylabel(r"\#")
|
||||
ax.set_ylabel(r"\# Halos")
|
||||
|
||||
rowcolumn_labels(axes, comparisons, isrow=False)
|
||||
rowcolumn_labels(axes, waveforms, isrow=True)
|
||||
fig.tight_layout()
|
||||
fig.savefig(Path(f"~/tmp/comparison_{property}.pdf").expanduser())
|
||||
if not is_hist_property:
|
||||
ax_scatter.legend()
|
||||
fig_scatter.tight_layout()
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
|
||||
|
@ -191,13 +231,13 @@ def main():
|
|||
if len(argv) > 1:
|
||||
properties = argv[1:]
|
||||
else:
|
||||
properties = ['match']
|
||||
properties = ["Mvir", "Vmax", "cNFW"]
|
||||
# mode = 'concentration_analysis'
|
||||
|
||||
mode = 'normal'
|
||||
mode = 'concentration_bla'
|
||||
|
||||
for property in properties:
|
||||
compare_property(property, mode)
|
||||
compare_property(property, mode, show=len(argv) == 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
7
utils.py
7
utils.py
|
@ -1,4 +1,5 @@
|
|||
from pathlib import Path
|
||||
from string import ascii_uppercase
|
||||
from typing import Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
@ -76,3 +77,9 @@ def rowcolumn_labels(axes, labels, isrow: bool, pad=5) -> None:
|
|||
ax.annotate(label, xy=xy, xytext=xytext,
|
||||
xycoords=xycoords, textcoords='offset points',
|
||||
size='large', ha=ha, va=va)
|
||||
|
||||
|
||||
def tex_fmt(format_str: str, *args) -> str:
|
||||
for i, arg in enumerate(args):
|
||||
format_str = format_str.replace(ascii_uppercase[i] * 2, str(arg))
|
||||
return format_str
|
||||
|
|
Loading…
Reference in a new issue