1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/spectra_plot.py

221 lines
7.2 KiB
Python
Raw Normal View History

import itertools
2022-06-20 16:51:14 +02:00
from pathlib import Path
from sys import argv
2022-06-20 16:51:14 +02:00
import matplotlib.pyplot as plt
2022-07-05 13:30:57 +02:00
import numpy as np
2022-06-20 16:51:14 +02:00
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from paths import base_dir
2022-07-12 14:39:01 +02:00
from utils import figsize_from_page_fraction
2022-06-20 16:51:14 +02:00
Lbox = 100
2022-06-29 12:53:40 +02:00
h = 0.690021
k0 = 3.14159265358979323846264338327950 / Lbox
resolutions = [128, 256, 512, 1024]
2022-06-20 16:51:14 +02:00
waveforms = ["DB2", "DB4", "DB8", "shannon"]
# Careful: k is actually in Mpc^-1, the column is just named weirdly.
columns = [
"k [Mpc]",
"Pcross",
"P1",
"err. P1",
"P2",
"err. P2",
"P2-1",
"err. P2-1",
"modes in bin",
]
# linestyles = ["solid", "dashed", "dotted"]
2022-07-05 11:38:51 +02:00
colors = [f"C{i}" for i in range(10)]
2022-06-30 18:14:20 +02:00
# colors = ["C1", "C2", "C3", "C4"]
2022-06-20 16:51:14 +02:00
def spectra_data(
2022-06-20 16:51:14 +02:00
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
):
dir = base_dir / f"spectra/{waveform}_{Lbox}"
if time == "ics":
spectra_data = pd.read_csv(
f"{dir}/{waveform}_{Lbox}_ics_{resolution_1}_{resolution_2}_cross_spectrum.txt",
sep=" ",
skipinitialspace=True,
header=None,
names=columns,
skiprows=1,
)
elif time == "z=1":
spectra_data = pd.read_csv(
f"{dir}/{waveform}_{Lbox}_a2_{resolution_1}_{resolution_2}_cross_spectrum.txt",
sep=" ",
skipinitialspace=True,
header=None,
names=columns,
skiprows=1,
)
elif time == "end":
spectra_data = pd.read_csv(
f"{dir}/{waveform}_{Lbox}_a4_{resolution_1}_{resolution_2}_cross_spectrum.txt",
sep=" ",
skipinitialspace=True,
header=None,
names=columns,
skiprows=1,
)
else:
raise ValueError(f"invalid time ({time}) should be (ics|z=1|end)")
# only consider rows above resolution limit
spectra_data = spectra_data[spectra_data["k [Mpc]"] >= k0]
return spectra_data
2022-06-28 12:05:50 +02:00
def create_plot(mode):
2022-06-28 11:36:15 +02:00
fig: Figure
2022-06-28 12:05:50 +02:00
combination_list = list(itertools.combinations(resolutions, 2))
fig, axes = plt.subplots(
2022-07-05 11:38:51 +02:00
len(waveforms), 3, sharex=True, sharey=True,
2022-07-12 14:39:01 +02:00
constrained_layout=True, figsize=figsize_from_page_fraction(columns=2),
2022-06-28 12:05:50 +02:00
)
2022-07-05 13:30:57 +02:00
crossings = np.zeros((len(waveforms), len(combination_list)))
2022-06-28 12:05:50 +02:00
for i, waveform in enumerate(waveforms):
ax_ics: Axes = axes[i][0]
2022-07-05 11:38:51 +02:00
ax_z1: Axes = axes[i][1]
ax_end: Axes = axes[i][2]
axes_names = {
# TODO: better names
ax_ics: "ics",
ax_z1: "z=1",
ax_end: "end"
}
2022-06-28 12:05:50 +02:00
bottom_row = i == len(waveforms) - 1
# for is_end, ax in enumerate([ax_ics, ax_z1]):
2022-07-05 11:38:51 +02:00
for is_end, ax in enumerate([ax_ics, ax_z1, ax_end]):
2022-06-28 12:05:50 +02:00
if bottom_row:
ax.set_xlabel("k [Mpc$^{-1}$]")
ax.text(
0.02,
2022-06-28 12:05:50 +02:00
0.85,
f"{waveform}",
horizontalalignment="left",
verticalalignment="top",
transform=ax.transAxes,
)
2022-06-28 12:05:50 +02:00
ax.text(
0.98,
2022-06-28 12:05:50 +02:00
0.85,
2022-07-05 11:38:51 +02:00
axes_names[ax],
2022-06-28 12:05:50 +02:00
horizontalalignment="right",
verticalalignment="top",
transform=ax.transAxes,
)
2022-06-30 18:14:20 +02:00
for j, res in enumerate(resolutions[:-1] if mode == "cross" else resolutions):
2022-06-28 16:59:44 +02:00
ax.axvline(
k0 * res,
2022-06-30 18:14:20 +02:00
color=colors[j],
2022-06-28 16:59:44 +02:00
linestyle="dashed",
label=f"$k_\\mathrm{{ny, {res}}}$" if mode =="power" else None,
2022-06-28 16:59:44 +02:00
)
2022-06-28 12:05:50 +02:00
# ax.set_xticklabels([])
# ax.set_yticklabels([])
if mode == "power":
ax_ics.set_ylabel("P")
for j, resolution in enumerate(resolutions):
2022-06-28 12:05:50 +02:00
ics_data = spectra_data(waveform, resolution, resolution, Lbox, "ics")
ics_k = ics_data["k [Mpc]"]
ics_p1 = ics_data["P1"]
comp_data = spectra_data(waveform, resolutions[-1], resolutions[-1], Lbox, "ics")
2022-06-28 16:59:44 +02:00
comp_p1 = comp_data["P1"]
ics_p1 /= comp_p1
2022-06-28 12:05:50 +02:00
end_data = spectra_data(waveform, resolution, resolution, Lbox, "end")
end_k = end_data["k [Mpc]"]
end_p1 = end_data["P1"]
comp_data = spectra_data(waveform, resolutions[-1], resolutions[-1], Lbox, "end")
2022-06-28 16:59:44 +02:00
comp_p1 = comp_data["P1"]
end_p1 /= comp_p1
2022-06-28 12:05:50 +02:00
2022-07-05 11:38:51 +02:00
z1_data = spectra_data(waveform, resolution, resolution, Lbox, "z=1")
z1_k = z1_data["k [Mpc]"]
z1_p1 = z1_data["P1"]
comp_data = spectra_data(waveform, resolutions[-1], resolutions[-1], Lbox, 'z=1')
2022-07-05 11:38:51 +02:00
comp_p1 = comp_data["P1"]
z1_p1 /= comp_p1
2022-06-30 13:39:30 +02:00
ax_ics.semilogx(ics_k, ics_p1, color=colors[j])
2022-07-05 11:38:51 +02:00
ax_z1.semilogx(z1_k, z1_p1, color=colors[j])
2022-06-30 13:39:30 +02:00
ax_end.semilogx(end_k, end_p1, color=colors[j])
2022-07-05 11:38:51 +02:00
for ax in [ax_ics, ax_z1, ax_end]:
2022-06-30 18:14:20 +02:00
ax.set_ylim(0.9, 1.10)
2022-06-28 11:10:37 +02:00
# fig.suptitle(f"Power Spectra {time}") #Not needed for paper
2022-06-28 11:36:15 +02:00
# fig.tight_layout()
2022-06-28 12:05:50 +02:00
elif mode == "cross":
2022-06-28 11:10:37 +02:00
ax_ics.set_ylabel("C")
# ax_end.set_ylabel("C")
for j, (res1, res2) in enumerate(combination_list):
ics_data = spectra_data(waveform, res1, res2, Lbox, 'ics')
ics_k = ics_data["k [Mpc]"]
ics_pcross = ics_data["Pcross"]
2022-07-05 13:30:57 +02:00
smaller_res = min(res1, res2)
crossing_index = np.searchsorted(ics_k.to_list(), k0 * smaller_res)
crossing_value = ics_pcross[crossing_index]
crossings[i][j] = crossing_value
2022-07-05 11:38:51 +02:00
ax_ics.semilogx(ics_k, ics_pcross, color=colors[j + 3], label=f'{res1} vs {res2}')
z1_data = spectra_data(waveform, res1, res2, Lbox, 'z=1')
z1_k = z1_data["k [Mpc]"]
z1_pcross = z1_data["Pcross"]
ax_z1.semilogx(z1_k, z1_pcross, color=colors[j + 3], label=f'{res1} vs {res2}')
2022-06-28 11:10:37 +02:00
end_data = spectra_data(waveform, res1, res2, Lbox, 'end')
end_k = end_data["k [Mpc]"]
end_pcross = end_data["Pcross"]
2022-07-05 11:38:51 +02:00
ax_end.semilogx(end_k, end_pcross, color=colors[j + 3], label=f'{res1} vs {res2}')
2022-06-28 11:10:37 +02:00
2022-06-30 18:14:20 +02:00
ax_end.set_xlim(right=k0 * resolutions[-1])
2022-06-28 11:10:37 +02:00
ax_end.set_ylim(0.8, 1.02)
2022-06-28 12:05:50 +02:00
if bottom_row:
# ax_z1.legend()
ax_ics.legend(loc='lower left')
2022-06-28 11:10:37 +02:00
# fig.suptitle(f"Cross Spectra {time}") #Not needed for paper
2022-06-28 11:36:15 +02:00
# fig.tight_layout()
2022-07-05 13:30:57 +02:00
print(crossings)
crossings_df = pd.DataFrame(crossings, columns=combination_list, index=waveforms)
# print(crossings_df.to_markdown())
print(crossings_df.to_latex())
2022-07-12 14:39:01 +02:00
# fig.tight_layout()
2022-06-28 12:05:50 +02:00
fig.savefig(Path(f"~/tmp/spectra_{mode}.pdf").expanduser())
2022-06-20 16:51:14 +02:00
def main():
if len(argv) < 2:
2022-06-28 12:05:50 +02:00
print("run spectra_plot.py [power|cross] or spectra_plot.py all")
2022-06-20 16:51:14 +02:00
exit(1)
if argv[1] == "all":
2022-06-28 12:05:50 +02:00
for mode in ["power", "cross"]:
2022-06-28 12:10:03 +02:00
create_plot(mode)
2022-06-28 12:05:50 +02:00
plt.show()
2022-06-20 16:51:14 +02:00
return
2022-06-28 12:05:50 +02:00
mode = argv[1]
create_plot(mode)
plt.show()
2022-06-20 16:51:14 +02:00
if __name__ == "__main__":
main()