mirror of
https://github.com/Findus23/halo_comparison.git
synced 2024-09-19 16:03:50 +02:00
improve all plots
This commit is contained in:
parent
d425e35c3b
commit
fe336969a3
3 changed files with 46 additions and 30 deletions
|
@ -6,12 +6,13 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from matplotlib.axis import XTick, YTick
|
from matplotlib.axis import YTick
|
||||||
from matplotlib.collections import QuadMesh
|
from matplotlib.collections import QuadMesh
|
||||||
from matplotlib.colors import LogNorm
|
from matplotlib.colors import LogNorm
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from matplotlib.patches import Polygon
|
from matplotlib.patches import Polygon
|
||||||
from numpy import inf
|
from numpy import inf
|
||||||
|
from scipy.stats import norm
|
||||||
|
|
||||||
from halo_vis import get_comp_id
|
from halo_vis import get_comp_id
|
||||||
from paths import base_dir
|
from paths import base_dir
|
||||||
|
@ -102,27 +103,25 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
df = pd.concat(rows, axis=1).T
|
df = pd.concat(rows, axis=1).T
|
||||||
print(df)
|
print(df)
|
||||||
if property == "Mvir":
|
if property in ["Mvir","Vmax"]:
|
||||||
stds = []
|
percentiles = []
|
||||||
means = []
|
|
||||||
for rep_row in range(num_bins):
|
for rep_row in range(num_bins):
|
||||||
rep_x_left = bins[rep_row]
|
rep_x_left = bins[rep_row]
|
||||||
rep_x_right = bins[rep_row] + 1
|
rep_x_right = bins[rep_row] + 1
|
||||||
rep_bin = (rep_x_left < df[x_col]) & (df[x_col] < rep_x_right)
|
rep_bin = (rep_x_left < df[x_col]) & (df[x_col] < rep_x_right)
|
||||||
rep_values = df.loc[rep_bin][y_col] / df.loc[rep_bin][x_col]
|
rep_values = df.loc[rep_bin][y_col] / df.loc[rep_bin][x_col]
|
||||||
if len(rep_bin) < 30:
|
if len(rep_values) < 10:
|
||||||
|
percentiles.append([np.nan, np.nan, np.nan])
|
||||||
continue
|
continue
|
||||||
mean = rep_values.mean()
|
percentiles.append(np.quantile(rep_values, [norm.cdf(-1), norm.cdf(0), norm.cdf(1)]))
|
||||||
std = rep_values.std()
|
|
||||||
means.append(mean)
|
percentiles = np.asarray(percentiles)
|
||||||
stds.append(std)
|
print(percentiles.shape)
|
||||||
means = np.array(means)
|
args = {"color": "C1", "zorder": 10}
|
||||||
stds = np.array(stds)
|
# ax.fill_between(bins, percentiles[::, 0], percentiles[::, 2], alpha=0.1, **args)
|
||||||
args = {"color": "C2", "zorder": 10}
|
ax.plot(bins, percentiles[::, 0], alpha=0.9, **args)
|
||||||
ax.fill_between(bins, means - stds, means + stds, alpha=0.2, **args)
|
ax.plot(bins, percentiles[::, 1], alpha=0.9, **args)
|
||||||
ax.plot(bins, means + stds, alpha=0.5, **args)
|
ax.plot(bins, percentiles[::, 2], alpha=0.9, **args)
|
||||||
ax.plot(bins, means - stds, alpha=0.5, **args)
|
|
||||||
# ax_scatter.plot(bins, stds, label=f"{file.stem}")
|
|
||||||
|
|
||||||
if property in vmaxs:
|
if property in vmaxs:
|
||||||
vmax = vmaxs[property]
|
vmax = vmaxs[property]
|
||||||
|
@ -135,6 +134,7 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
|
||||||
df[y_col] / df[x_col],
|
df[y_col] / df[x_col],
|
||||||
bins=(bins, np.linspace(0, 2, num_bins)),
|
bins=(bins, np.linspace(0, 2, num_bins)),
|
||||||
norm=LogNorm(vmax=vmax),
|
norm=LogNorm(vmax=vmax),
|
||||||
|
cmap="gray_r",
|
||||||
rasterized=True,
|
rasterized=True,
|
||||||
)
|
)
|
||||||
# ax.plot([rep_x_left, rep_x_left], [mean - std, mean + std], c="C1")
|
# ax.plot([rep_x_left, rep_x_left], [mean - std, mean + std], c="C1")
|
||||||
|
@ -151,7 +151,7 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
|
||||||
ax.set_xlim(min(df[x_col]), max(df[y_col]))
|
ax.set_xlim(min(df[x_col]), max(df[y_col]))
|
||||||
|
|
||||||
ax.plot(
|
ax.plot(
|
||||||
[min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C1", zorder=10
|
[min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C0", zorder=10, linestyle="dotted"
|
||||||
)
|
)
|
||||||
|
|
||||||
return x_col, y_col
|
return x_col, y_col
|
||||||
|
@ -179,7 +179,6 @@ def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=
|
||||||
else:
|
else:
|
||||||
bins = num_bins
|
bins = num_bins
|
||||||
if property == "match":
|
if property == "match":
|
||||||
histtype = "step"
|
|
||||||
labels = {
|
labels = {
|
||||||
(-inf, 30): "$M<30$",
|
(-inf, 30): "$M<30$",
|
||||||
(None, None): "$M$",
|
(None, None): "$M$",
|
||||||
|
@ -198,7 +197,7 @@ def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=
|
||||||
else:
|
else:
|
||||||
patches: List[Polygon]
|
patches: List[Polygon]
|
||||||
hist_val, bin_edges, patches = ax.hist(
|
hist_val, bin_edges, patches = ax.hist(
|
||||||
df[property], bins=bins, histtype=histtype, label=label, density=density
|
df[property], bins=bins, histtype="step", label=label, density=density
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,4 +5,6 @@ font.serif : Computer Modern Roman
|
||||||
font.sans-serif: Computer Modern Sans Serif
|
font.sans-serif: Computer Modern Sans Serif
|
||||||
font.monospace : Computer Modern Typewriter
|
font.monospace : Computer Modern Typewriter
|
||||||
|
|
||||||
|
lines.linewidth: 1.2 # 1.5 is the default
|
||||||
|
|
||||||
axes.prop_cycle: cycler('color', ['3f90da', 'ffa90e', 'bd1f01', '94a4a2', '832db6', 'a96b59', 'e76300', 'b9ac70', '717581', '92dadd'])
|
axes.prop_cycle: cycler('color', ['3f90da', 'ffa90e', 'bd1f01', '94a4a2', '832db6', 'a96b59', 'e76300', 'b9ac70', '717581', '92dadd'])
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import itertools
|
import itertools
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sys import argv
|
from sys import argv
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from matplotlib.axis import XTick, YTick
|
from matplotlib.axis import XTick
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
from matplotlib.lines import Line2D
|
||||||
|
|
||||||
from paths import base_dir
|
from paths import base_dir
|
||||||
from utils import figsize_from_page_fraction, waveforms
|
from utils import figsize_from_page_fraction, waveforms
|
||||||
|
@ -38,7 +40,7 @@ colors = [f"C{i}" for i in range(10)]
|
||||||
|
|
||||||
|
|
||||||
def spectra_data(
|
def spectra_data(
|
||||||
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
|
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
|
||||||
):
|
):
|
||||||
dir = base_dir / f"spectra/{waveform}_{Lbox}"
|
dir = base_dir / f"spectra/{waveform}_{Lbox}"
|
||||||
|
|
||||||
|
@ -86,7 +88,7 @@ def create_plot(mode):
|
||||||
3,
|
3,
|
||||||
sharex=True,
|
sharex=True,
|
||||||
sharey=True,
|
sharey=True,
|
||||||
figsize=figsize_from_page_fraction(columns=2),
|
figsize=figsize_from_page_fraction(columns=2, height_to_width=.5),
|
||||||
)
|
)
|
||||||
crossings = np.zeros((len(waveforms), len(combination_list)))
|
crossings = np.zeros((len(waveforms), len(combination_list)))
|
||||||
for i, waveform in enumerate(waveforms):
|
for i, waveform in enumerate(waveforms):
|
||||||
|
@ -125,7 +127,7 @@ def create_plot(mode):
|
||||||
ax.grid(color="black", linestyle=":", linewidth=0.5, alpha=0.5)
|
ax.grid(color="black", linestyle=":", linewidth=0.5, alpha=0.5)
|
||||||
|
|
||||||
for j, res in enumerate(
|
for j, res in enumerate(
|
||||||
resolutions[:-1] if mode == "cross" else resolutions
|
resolutions[:-1] if mode == "cross" else resolutions
|
||||||
):
|
):
|
||||||
ax.axvline(
|
ax.axvline(
|
||||||
k0 * res,
|
k0 * res,
|
||||||
|
@ -213,22 +215,35 @@ def create_plot(mode):
|
||||||
crossing_value = end_pcross[crossing_index] # and here
|
crossing_value = end_pcross[crossing_index] # and here
|
||||||
crossings[i][j] = crossing_value
|
crossings[i][j] = crossing_value
|
||||||
|
|
||||||
|
|
||||||
ax_end.set_xlim(right=k0 * resolutions[-1])
|
ax_end.set_xlim(right=k0 * resolutions[-1])
|
||||||
ax_end.set_ylim(0.8, 1.02)
|
ax_end.set_ylim(0.8, 1.02)
|
||||||
if bottom_row:
|
if bottom_row:
|
||||||
# ax_z1.legend()
|
if mode == "power":
|
||||||
ax_ics.legend(loc="lower left")
|
ax_ics.legend(loc="lower left")
|
||||||
|
else:
|
||||||
|
lines: List[Line2D] = ax_ics.get_lines()
|
||||||
|
half_lines1 = []
|
||||||
|
half_lines2 = []
|
||||||
|
for line in lines:
|
||||||
|
if line.get_label().startswith("128"):
|
||||||
|
half_lines1.append(line)
|
||||||
|
else:
|
||||||
|
half_lines2.append(line)
|
||||||
|
|
||||||
|
ax_ics.legend(handles=half_lines1, loc="lower left")
|
||||||
|
ax_z1.legend(handles=half_lines2, loc="lower left")
|
||||||
|
|
||||||
if not bottom_row:
|
if not bottom_row:
|
||||||
last_xtick: XTick = ax_ics.yaxis.get_major_ticks()[0]
|
last_xtick: XTick = ax_ics.yaxis.get_major_ticks()[0]
|
||||||
last_xtick.set_visible(False)
|
last_xtick.set_visible(False)
|
||||||
|
|
||||||
# fig.suptitle(f"Cross Spectra {time}") #Not needed for paper
|
# fig.suptitle(f"Cross Spectra {time}") #Not needed for paper
|
||||||
# fig.tight_layout()
|
# fig.tight_layout()
|
||||||
print(crossings)
|
if mode=="cross":
|
||||||
crossings_df = pd.DataFrame(crossings, columns=combination_list, index=waveforms)
|
print(crossings)
|
||||||
# print(crossings_df.to_markdown())
|
crossings_df = pd.DataFrame(crossings, columns=combination_list, index=waveforms)
|
||||||
print(crossings_df.to_latex())
|
# print(crossings_df.to_markdown())
|
||||||
|
print(crossings_df.to_latex())
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
fig.subplots_adjust(wspace=0, hspace=0)
|
fig.subplots_adjust(wspace=0, hspace=0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue