mirror of
https://github.com/Findus23/halo_comparison.git
synced 2024-09-13 09:03:49 +02:00
improve all plots
This commit is contained in:
parent
d425e35c3b
commit
fe336969a3
3 changed files with 46 additions and 30 deletions
|
@ -6,12 +6,13 @@ import numpy as np
|
|||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.axis import XTick, YTick
|
||||
from matplotlib.axis import YTick
|
||||
from matplotlib.collections import QuadMesh
|
||||
from matplotlib.colors import LogNorm
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.patches import Polygon
|
||||
from numpy import inf
|
||||
from scipy.stats import norm
|
||||
|
||||
from halo_vis import get_comp_id
|
||||
from paths import base_dir
|
||||
|
@ -102,27 +103,25 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
|
|||
rows.append(row)
|
||||
df = pd.concat(rows, axis=1).T
|
||||
print(df)
|
||||
if property == "Mvir":
|
||||
stds = []
|
||||
means = []
|
||||
if property in ["Mvir","Vmax"]:
|
||||
percentiles = []
|
||||
for rep_row in range(num_bins):
|
||||
rep_x_left = bins[rep_row]
|
||||
rep_x_right = bins[rep_row] + 1
|
||||
rep_bin = (rep_x_left < df[x_col]) & (df[x_col] < rep_x_right)
|
||||
rep_values = df.loc[rep_bin][y_col] / df.loc[rep_bin][x_col]
|
||||
if len(rep_bin) < 30:
|
||||
if len(rep_values) < 10:
|
||||
percentiles.append([np.nan, np.nan, np.nan])
|
||||
continue
|
||||
mean = rep_values.mean()
|
||||
std = rep_values.std()
|
||||
means.append(mean)
|
||||
stds.append(std)
|
||||
means = np.array(means)
|
||||
stds = np.array(stds)
|
||||
args = {"color": "C2", "zorder": 10}
|
||||
ax.fill_between(bins, means - stds, means + stds, alpha=0.2, **args)
|
||||
ax.plot(bins, means + stds, alpha=0.5, **args)
|
||||
ax.plot(bins, means - stds, alpha=0.5, **args)
|
||||
# ax_scatter.plot(bins, stds, label=f"{file.stem}")
|
||||
percentiles.append(np.quantile(rep_values, [norm.cdf(-1), norm.cdf(0), norm.cdf(1)]))
|
||||
|
||||
percentiles = np.asarray(percentiles)
|
||||
print(percentiles.shape)
|
||||
args = {"color": "C1", "zorder": 10}
|
||||
# ax.fill_between(bins, percentiles[::, 0], percentiles[::, 2], alpha=0.1, **args)
|
||||
ax.plot(bins, percentiles[::, 0], alpha=0.9, **args)
|
||||
ax.plot(bins, percentiles[::, 1], alpha=0.9, **args)
|
||||
ax.plot(bins, percentiles[::, 2], alpha=0.9, **args)
|
||||
|
||||
if property in vmaxs:
|
||||
vmax = vmaxs[property]
|
||||
|
@ -135,6 +134,7 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
|
|||
df[y_col] / df[x_col],
|
||||
bins=(bins, np.linspace(0, 2, num_bins)),
|
||||
norm=LogNorm(vmax=vmax),
|
||||
cmap="gray_r",
|
||||
rasterized=True,
|
||||
)
|
||||
# ax.plot([rep_x_left, rep_x_left], [mean - std, mean + std], c="C1")
|
||||
|
@ -151,7 +151,7 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
|
|||
ax.set_xlim(min(df[x_col]), max(df[y_col]))
|
||||
|
||||
ax.plot(
|
||||
[min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C1", zorder=10
|
||||
[min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C0", zorder=10, linestyle="dotted"
|
||||
)
|
||||
|
||||
return x_col, y_col
|
||||
|
@ -179,7 +179,6 @@ def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=
|
|||
else:
|
||||
bins = num_bins
|
||||
if property == "match":
|
||||
histtype = "step"
|
||||
labels = {
|
||||
(-inf, 30): "$M<30$",
|
||||
(None, None): "$M$",
|
||||
|
@ -198,7 +197,7 @@ def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=
|
|||
else:
|
||||
patches: List[Polygon]
|
||||
hist_val, bin_edges, patches = ax.hist(
|
||||
df[property], bins=bins, histtype=histtype, label=label, density=density
|
||||
df[property], bins=bins, histtype="step", label=label, density=density
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -5,4 +5,6 @@ font.serif : Computer Modern Roman
|
|||
font.sans-serif: Computer Modern Sans Serif
|
||||
font.monospace : Computer Modern Typewriter
|
||||
|
||||
lines.linewidth: 1.2 # 1.5 is the default
|
||||
|
||||
axes.prop_cycle: cycler('color', ['3f90da', 'ffa90e', 'bd1f01', '94a4a2', '832db6', 'a96b59', 'e76300', 'b9ac70', '717581', '92dadd'])
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import itertools
|
||||
from pathlib import Path
|
||||
from sys import argv
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.axis import XTick, YTick
|
||||
from matplotlib.axis import XTick
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
from paths import base_dir
|
||||
from utils import figsize_from_page_fraction, waveforms
|
||||
|
@ -38,7 +40,7 @@ colors = [f"C{i}" for i in range(10)]
|
|||
|
||||
|
||||
def spectra_data(
|
||||
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
|
||||
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
|
||||
):
|
||||
dir = base_dir / f"spectra/{waveform}_{Lbox}"
|
||||
|
||||
|
@ -86,7 +88,7 @@ def create_plot(mode):
|
|||
3,
|
||||
sharex=True,
|
||||
sharey=True,
|
||||
figsize=figsize_from_page_fraction(columns=2),
|
||||
figsize=figsize_from_page_fraction(columns=2, height_to_width=.5),
|
||||
)
|
||||
crossings = np.zeros((len(waveforms), len(combination_list)))
|
||||
for i, waveform in enumerate(waveforms):
|
||||
|
@ -125,7 +127,7 @@ def create_plot(mode):
|
|||
ax.grid(color="black", linestyle=":", linewidth=0.5, alpha=0.5)
|
||||
|
||||
for j, res in enumerate(
|
||||
resolutions[:-1] if mode == "cross" else resolutions
|
||||
resolutions[:-1] if mode == "cross" else resolutions
|
||||
):
|
||||
ax.axvline(
|
||||
k0 * res,
|
||||
|
@ -213,22 +215,35 @@ def create_plot(mode):
|
|||
crossing_value = end_pcross[crossing_index] # and here
|
||||
crossings[i][j] = crossing_value
|
||||
|
||||
|
||||
ax_end.set_xlim(right=k0 * resolutions[-1])
|
||||
ax_end.set_ylim(0.8, 1.02)
|
||||
if bottom_row:
|
||||
# ax_z1.legend()
|
||||
ax_ics.legend(loc="lower left")
|
||||
if mode == "power":
|
||||
ax_ics.legend(loc="lower left")
|
||||
else:
|
||||
lines: List[Line2D] = ax_ics.get_lines()
|
||||
half_lines1 = []
|
||||
half_lines2 = []
|
||||
for line in lines:
|
||||
if line.get_label().startswith("128"):
|
||||
half_lines1.append(line)
|
||||
else:
|
||||
half_lines2.append(line)
|
||||
|
||||
ax_ics.legend(handles=half_lines1, loc="lower left")
|
||||
ax_z1.legend(handles=half_lines2, loc="lower left")
|
||||
|
||||
if not bottom_row:
|
||||
last_xtick: XTick = ax_ics.yaxis.get_major_ticks()[0]
|
||||
last_xtick.set_visible(False)
|
||||
|
||||
# fig.suptitle(f"Cross Spectra {time}") #Not needed for paper
|
||||
# fig.tight_layout()
|
||||
print(crossings)
|
||||
crossings_df = pd.DataFrame(crossings, columns=combination_list, index=waveforms)
|
||||
# print(crossings_df.to_markdown())
|
||||
print(crossings_df.to_latex())
|
||||
if mode=="cross":
|
||||
print(crossings)
|
||||
crossings_df = pd.DataFrame(crossings, columns=combination_list, index=waveforms)
|
||||
# print(crossings_df.to_markdown())
|
||||
print(crossings_df.to_latex())
|
||||
fig.tight_layout()
|
||||
fig.subplots_adjust(wspace=0, hspace=0)
|
||||
|
||||
|
|
Loading…
Reference in a new issue