1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/spectra_plot.py

152 lines
4.3 KiB
Python
Raw Normal View History

import itertools
2022-06-20 16:51:14 +02:00
from pathlib import Path
from sys import argv
2022-06-20 16:51:14 +02:00
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from paths import base_dir
2022-06-20 16:51:14 +02:00
Lbox = 100
k0 = 2 * 3.14159265358979323846264338327950 / Lbox
resolutions = [128, 256, 512]
waveforms = ["DB2", "DB4", "DB8", "shannon"]
# Careful: k is actually in Mpc^-1, the column is just named weirdly.
columns = [
"k [Mpc]",
"Pcross",
"P1",
"err. P1",
"P2",
"err. P2",
"P2-1",
"err. P2-1",
"modes in bin",
]
# linestyles = ["solid", "dashed", "dotted"]
colors = ["C1", "C2", "C3", "C4"]
def spectra_data(
2022-06-20 16:51:14 +02:00
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
):
dir = base_dir / f"spectra/{waveform}_{Lbox}"
if time == "ics":
spectra_data = pd.read_csv(
f"{dir}/{waveform}_{Lbox}_ics_{resolution_1}_{resolution_2}_cross_spectrum.txt",
sep=" ",
skipinitialspace=True,
header=None,
names=columns,
skiprows=1,
)
elif time == "end":
spectra_data = pd.read_csv(
f"{dir}/{waveform}_{Lbox}_a4_{resolution_1}_{resolution_2}_cross_spectrum.txt",
sep=" ",
skipinitialspace=True,
header=None,
names=columns,
skiprows=1,
)
else:
raise ValueError(f"invalid time ({time}) should be (ics|end)")
# only consider rows above resolution limit
spectra_data = spectra_data[spectra_data["k [Mpc]"] >= k0]
return spectra_data
2022-06-20 16:51:14 +02:00
def create_plot(mode, time, show=True):
fig: Figure = plt.figure(figsize=(9, 9))
if mode == "power":
subfigs = fig.subplots(len(waveforms), 1, sharex=True, sharey=True).flatten()
for i, waveform in enumerate(waveforms):
ax: Axes = subfigs[i]
ax.set_xlabel("k [Mpc$^{-1}$]")
ax.set_ylabel("P")
ax.text(
0.02,
0.93,
waveform,
size=10,
horizontalalignment="left",
verticalalignment="top",
transform=ax.transAxes,
)
for j, resolution in enumerate(resolutions):
data = spectra_data(waveform, resolution, resolution, Lbox, time)
k = data["k [Mpc]"]
p1 = data["P1"]
p1_error = data["err. P1"]
ax.loglog(k, p1, color=colors[j])
ax.axvline(
k0 * resolution,
color=colors[j],
linestyle="dashed",
label=f"{resolution}",
)
ax.legend()
fig.suptitle(f"Power Spectra {time}")
fig.tight_layout()
2022-06-20 16:51:14 +02:00
elif mode == "cross":
combination_list = list(itertools.combinations(resolutions, 2))
subfigs = fig.subplots(
len(combination_list), 1, sharex=True, sharey=True
).flatten()
for j, (res1, res2) in enumerate(combination_list):
ax: Axes = subfigs[j]
ax.set_xlabel("k [Mpc$^{-1}$]")
ax.set_ylabel("C")
ax.text(
0.02,
0.93,
f"{res1} vs {res2}",
size=10,
horizontalalignment="left",
verticalalignment="top",
transform=ax.transAxes,
)
for i, waveform in enumerate(waveforms):
data = spectra_data(waveform, res1, res2, Lbox, time)
k = data["k [Mpc]"]
pcross = data["Pcross"]
ax.semilogx(k, pcross, color=colors[i], label=waveform)
ax.set_xlim(right=k0 * res1)
ax.set_ylim(0.8, 1.02)
ax.legend()
fig.suptitle(f"Cross Spectra {time}")
fig.tight_layout()
2022-06-20 16:51:14 +02:00
fig.savefig(Path(f"~/tmp/spectra_{time}_{mode}.pdf").expanduser())
if show:
plt.show()
2022-06-20 16:51:14 +02:00
def main():
if len(argv) < 2:
print("run spectra_plot.py [ics|end] [power|cross] or spectra_plot.py all")
exit(1)
if argv[1] == "all":
for time in ["ics", "end"]:
for mode in ["power", "cross"]:
create_plot(mode, time, show=False)
return
time = argv[1]
mode = argv[2]
create_plot(mode, time)
if __name__ == "__main__":
main()