diff --git a/spectra_plot.py b/spectra_plot.py index aa39f81..9f25ca3 100644 --- a/spectra_plot.py +++ b/spectra_plot.py @@ -64,11 +64,15 @@ def spectra_data( def create_plot(mode, time, show=True): - fig: Figure = plt.figure(figsize=(9, 9)) + fig: Figure if mode == "power": - subfigs = fig.subplots(len(waveforms), 1, sharex=True, sharey=True).flatten() + fig, axes = plt.subplots( + len(waveforms), 1, sharex=True, sharey=True, + constrained_layout=True, figsize=(9, 9) + ) + axes = axes.flatten() for i, waveform in enumerate(waveforms): - ax: Axes = subfigs[i] + ax: Axes = axes[i] ax.set_xlabel("k [Mpc$^{-1}$]") ax.set_ylabel("P") ax.text( @@ -96,46 +100,46 @@ def create_plot(mode, time, show=True): ax.legend() # fig.suptitle(f"Power Spectra {time}") #Not needed for paper - fig.tight_layout() + # fig.tight_layout() elif mode == "cross": combination_list = list(itertools.combinations(resolutions, 2)) - subfigs = fig.subplots( - len(waveforms), 2, sharex=True, sharey=True - ).flatten() - fig.subplots_adjust(wspace = 0, hspace = 0) + fig, axes = plt.subplots( + len(waveforms), 2, sharex=True, sharey=True, + constrained_layout=True, figsize=(9, 9), + ) + # fig.subplots_adjust(wspace=0, hspace=0) for i, waveform in enumerate(waveforms): - ax_ics: Axes = subfigs[2 * i] - ax_end: Axes = subfigs[2 * i + 1] - - if i == len(waveforms) - 1: + ax_ics: Axes = axes[i][0] + ax_end: Axes = axes[i][1] + bottom_row = i == len(waveforms) - 1 + if bottom_row: ax_ics.set_xlabel("k [Mpc$^{-1}$]") ax_ics.set_ylabel("C") - ax_ics.text( - 0.02, - 0.85, - f"{waveform}", - size=13, - horizontalalignment="left", - verticalalignment="top", - transform=ax_ics.transAxes, - ) - ax_ics.set_xticklabels([]) - ax_ics.set_yticklabels([]) - if i == len(waveforms) - 1: + for is_end, ax in enumerate([ax_ics, ax_end]): + ax.text( + 0.02, + 0.85, + f"{waveform}", + size=13, + horizontalalignment="left", + verticalalignment="top", + transform=ax.transAxes, + ) + ax.text( + 0.98, + 0.85, + "end" if is_end else "ics", + size=13, + horizontalalignment="right", + verticalalignment="top", + transform=ax.transAxes, + ) + # ax.set_xticklabels([]) + # ax.set_yticklabels([]) + if bottom_row: ax_end.set_xlabel("k [Mpc$^{-1}$]") # ax_end.set_ylabel("C") - ax_end.text( - 0.02, - 0.85, - f"{waveform}", - size=13, - horizontalalignment="left", - verticalalignment="top", - transform=ax_end.transAxes, - ) - ax_end.set_xticklabels([]) - ax_end.set_yticklabels([]) for j, (res1, res2) in enumerate(combination_list): ics_data = spectra_data(waveform, res1, res2, Lbox, 'ics') ics_k = ics_data["k [Mpc]"] @@ -151,10 +155,11 @@ def create_plot(mode, time, show=True): ax_end.set_xlim(right=k0 * res1) ax_end.set_ylim(0.8, 1.02) - ax_end.legend() + if bottom_row: + ax_end.legend() # fig.suptitle(f"Cross Spectra {time}") #Not needed for paper - fig.tight_layout() + # fig.tight_layout() fig.savefig(Path(f"~/tmp/spectra_{time}_{mode}.pdf").expanduser()) if show: plt.show()