1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/legacy_comparison.py

66 lines
1.7 KiB
Python
Raw Normal View History

2022-05-09 15:20:10 +02:00
import matplotlib.pyplot as plt
2022-05-06 13:23:31 +02:00
import numpy as np
import pandas as pd
2022-05-09 15:20:10 +02:00
from matplotlib.axes import Axes
from matplotlib.figure import Figure
2022-05-06 13:23:31 +02:00
2022-07-21 16:05:16 +02:00
from halo_vis import get_comp_id
from paths import base_dir
2022-05-09 15:20:10 +02:00
num_bins = 5
bins = np.geomspace(450, 80000, num_bins + 1)
2022-07-21 16:05:16 +02:00
waveform = "shannon"
comparisons_dir = base_dir / "comparisons"
2022-05-06 13:23:31 +02:00
2022-05-09 15:20:10 +02:00
def read(mode, ref_res, comp_res):
df = pd.read_csv(
comparisons_dir / get_comp_id(waveform, ref_res, waveform, comp_res)
)
2022-07-21 16:05:16 +02:00
# df = pd.read_csv(f"{mode}_{ref_res}_100_{mode}_{comp_res}_100.csv")
print(min(df.ref_Mvir), max(df.ref_Mvir))
2022-05-06 13:23:31 +02:00
2022-07-21 16:05:16 +02:00
digits = np.digitize(df.ref_Mvir, bins)
2022-05-09 15:20:10 +02:00
bin_means = []
for i in range(num_bins):
values = np.where(digits == i + 1)
in_bin = df.iloc[values]
matches = np.array(
in_bin.match
) # TODO: or instead fraction of halos that are matching?
2022-05-09 15:20:10 +02:00
bin_means.append(matches.mean())
return bin_means
rows = [[1] * num_bins]
2022-07-21 16:05:16 +02:00
resolutions = [128, 256, 512, 1024]
2022-05-09 15:20:10 +02:00
ref_res = 128
for res in resolutions:
if res == ref_res:
continue
2022-07-21 16:05:16 +02:00
means = read(waveform, 128, res)
2022-05-09 15:20:10 +02:00
rows.append(means)
data = np.array(rows).T
fig: Figure = plt.figure()
ax: Axes = fig.gca()
ax.set_xticks(range(len(resolutions)))
ax.set_xticklabels(resolutions)
ax.set_yticks(np.arange(len(bins)) - 0.5)
ax.set_yticklabels(["{:.2f}".format(a) for a in bins])
for x in range(data.shape[0]):
for y in range(data.shape[1]):
text = ax.text(
y, x, "{:.2f}".format(data[x, y]), ha="center", va="center", color="w"
)
2022-05-09 15:20:10 +02:00
# print(data)
p = ax.imshow(data, origin="lower", vmin=0.5, vmax=1)
fig.colorbar(p)
2022-07-21 16:05:16 +02:00
ax.set_title(waveform)
# fig.savefig(method + ".png")
2022-05-09 15:20:10 +02:00
plt.show()