1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00

simplify spectra code

This commit is contained in:
Lukas Winkler 2022-06-28 11:36:15 +02:00
parent da1c5b3e29
commit 4fa8ee0ac2
Signed by: lukas
GPG key ID: 54DE4D798D244853

View file

@ -64,11 +64,15 @@ def spectra_data(
def create_plot(mode, time, show=True):
fig: Figure = plt.figure(figsize=(9, 9))
fig: Figure
if mode == "power":
subfigs = fig.subplots(len(waveforms), 1, sharex=True, sharey=True).flatten()
fig, axes = plt.subplots(
len(waveforms), 1, sharex=True, sharey=True,
constrained_layout=True, figsize=(9, 9)
)
axes = axes.flatten()
for i, waveform in enumerate(waveforms):
ax: Axes = subfigs[i]
ax: Axes = axes[i]
ax.set_xlabel("k [Mpc$^{-1}$]")
ax.set_ylabel("P")
ax.text(
@ -96,46 +100,46 @@ def create_plot(mode, time, show=True):
ax.legend()
# fig.suptitle(f"Power Spectra {time}") #Not needed for paper
fig.tight_layout()
# fig.tight_layout()
elif mode == "cross":
combination_list = list(itertools.combinations(resolutions, 2))
subfigs = fig.subplots(
len(waveforms), 2, sharex=True, sharey=True
).flatten()
fig.subplots_adjust(wspace = 0, hspace = 0)
fig, axes = plt.subplots(
len(waveforms), 2, sharex=True, sharey=True,
constrained_layout=True, figsize=(9, 9),
)
# fig.subplots_adjust(wspace=0, hspace=0)
for i, waveform in enumerate(waveforms):
ax_ics: Axes = subfigs[2 * i]
ax_end: Axes = subfigs[2 * i + 1]
if i == len(waveforms) - 1:
ax_ics: Axes = axes[i][0]
ax_end: Axes = axes[i][1]
bottom_row = i == len(waveforms) - 1
if bottom_row:
ax_ics.set_xlabel("k [Mpc$^{-1}$]")
ax_ics.set_ylabel("C")
ax_ics.text(
0.02,
0.85,
f"{waveform}",
size=13,
horizontalalignment="left",
verticalalignment="top",
transform=ax_ics.transAxes,
)
ax_ics.set_xticklabels([])
ax_ics.set_yticklabels([])
if i == len(waveforms) - 1:
for is_end, ax in enumerate([ax_ics, ax_end]):
ax.text(
0.02,
0.85,
f"{waveform}",
size=13,
horizontalalignment="left",
verticalalignment="top",
transform=ax.transAxes,
)
ax.text(
0.98,
0.85,
"end" if is_end else "ics",
size=13,
horizontalalignment="right",
verticalalignment="top",
transform=ax.transAxes,
)
# ax.set_xticklabels([])
# ax.set_yticklabels([])
if bottom_row:
ax_end.set_xlabel("k [Mpc$^{-1}$]")
# ax_end.set_ylabel("C")
ax_end.text(
0.02,
0.85,
f"{waveform}",
size=13,
horizontalalignment="left",
verticalalignment="top",
transform=ax_end.transAxes,
)
ax_end.set_xticklabels([])
ax_end.set_yticklabels([])
for j, (res1, res2) in enumerate(combination_list):
ics_data = spectra_data(waveform, res1, res2, Lbox, 'ics')
ics_k = ics_data["k [Mpc]"]
@ -151,10 +155,11 @@ def create_plot(mode, time, show=True):
ax_end.set_xlim(right=k0 * res1)
ax_end.set_ylim(0.8, 1.02)
ax_end.legend()
if bottom_row:
ax_end.legend()
# fig.suptitle(f"Cross Spectra {time}") #Not needed for paper
fig.tight_layout()
# fig.tight_layout()
fig.savefig(Path(f"~/tmp/spectra_{time}_{mode}.pdf").expanduser())
if show:
plt.show()