1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-13 09:03:49 +02:00

improve all plots

This commit is contained in:
Lukas Winkler 2022-08-18 12:27:25 +02:00
parent d425e35c3b
commit fe336969a3
Signed by: lukas
GPG key ID: 54DE4D798D244853
3 changed files with 46 additions and 30 deletions

View file

@ -6,12 +6,13 @@ import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.axis import XTick, YTick
from matplotlib.axis import YTick
from matplotlib.collections import QuadMesh
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from matplotlib.patches import Polygon
from numpy import inf
from scipy.stats import norm
from halo_vis import get_comp_id
from paths import base_dir
@ -102,27 +103,25 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
rows.append(row)
df = pd.concat(rows, axis=1).T
print(df)
if property == "Mvir":
stds = []
means = []
if property in ["Mvir","Vmax"]:
percentiles = []
for rep_row in range(num_bins):
rep_x_left = bins[rep_row]
rep_x_right = bins[rep_row] + 1
rep_bin = (rep_x_left < df[x_col]) & (df[x_col] < rep_x_right)
rep_values = df.loc[rep_bin][y_col] / df.loc[rep_bin][x_col]
if len(rep_bin) < 30:
if len(rep_values) < 10:
percentiles.append([np.nan, np.nan, np.nan])
continue
mean = rep_values.mean()
std = rep_values.std()
means.append(mean)
stds.append(std)
means = np.array(means)
stds = np.array(stds)
args = {"color": "C2", "zorder": 10}
ax.fill_between(bins, means - stds, means + stds, alpha=0.2, **args)
ax.plot(bins, means + stds, alpha=0.5, **args)
ax.plot(bins, means - stds, alpha=0.5, **args)
# ax_scatter.plot(bins, stds, label=f"{file.stem}")
percentiles.append(np.quantile(rep_values, [norm.cdf(-1), norm.cdf(0), norm.cdf(1)]))
percentiles = np.asarray(percentiles)
print(percentiles.shape)
args = {"color": "C1", "zorder": 10}
# ax.fill_between(bins, percentiles[::, 0], percentiles[::, 2], alpha=0.1, **args)
ax.plot(bins, percentiles[::, 0], alpha=0.9, **args)
ax.plot(bins, percentiles[::, 1], alpha=0.9, **args)
ax.plot(bins, percentiles[::, 2], alpha=0.9, **args)
if property in vmaxs:
vmax = vmaxs[property]
@ -135,6 +134,7 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
df[y_col] / df[x_col],
bins=(bins, np.linspace(0, 2, num_bins)),
norm=LogNorm(vmax=vmax),
cmap="gray_r",
rasterized=True,
)
# ax.plot([rep_x_left, rep_x_left], [mean - std, mean + std], c="C1")
@ -151,7 +151,7 @@ def plot_comparison_hist2d(ax: Axes, file: Path, property: str):
ax.set_xlim(min(df[x_col]), max(df[y_col]))
ax.plot(
[min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C1", zorder=10
[min(df[x_col]), max(df[y_col])], [1, 1], linewidth=1, color="C0", zorder=10, linestyle="dotted"
)
return x_col, y_col
@ -179,7 +179,6 @@ def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=
else:
bins = num_bins
if property == "match":
histtype = "step"
labels = {
(-inf, 30): "$M<30$",
(None, None): "$M$",
@ -198,7 +197,7 @@ def plot_comparison_hist(ax: Axes, file: Path, property: str, m_min=None, m_max=
else:
patches: List[Polygon]
hist_val, bin_edges, patches = ax.hist(
df[property], bins=bins, histtype=histtype, label=label, density=density
df[property], bins=bins, histtype="step", label=label, density=density
)

View file

@ -5,4 +5,6 @@ font.serif : Computer Modern Roman
font.sans-serif: Computer Modern Sans Serif
font.monospace : Computer Modern Typewriter
lines.linewidth: 1.2 # 1.5 is the default
axes.prop_cycle: cycler('color', ['3f90da', 'ffa90e', 'bd1f01', '94a4a2', '832db6', 'a96b59', 'e76300', 'b9ac70', '717581', '92dadd'])

View file

@ -1,13 +1,15 @@
import itertools
from pathlib import Path
from sys import argv
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.axis import XTick, YTick
from matplotlib.axis import XTick
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from paths import base_dir
from utils import figsize_from_page_fraction, waveforms
@ -38,7 +40,7 @@ colors = [f"C{i}" for i in range(10)]
def spectra_data(
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
waveform: str, resolution_1: int, resolution_2: int, Lbox: int, time: str
):
dir = base_dir / f"spectra/{waveform}_{Lbox}"
@ -86,7 +88,7 @@ def create_plot(mode):
3,
sharex=True,
sharey=True,
figsize=figsize_from_page_fraction(columns=2),
figsize=figsize_from_page_fraction(columns=2, height_to_width=.5),
)
crossings = np.zeros((len(waveforms), len(combination_list)))
for i, waveform in enumerate(waveforms):
@ -125,7 +127,7 @@ def create_plot(mode):
ax.grid(color="black", linestyle=":", linewidth=0.5, alpha=0.5)
for j, res in enumerate(
resolutions[:-1] if mode == "cross" else resolutions
resolutions[:-1] if mode == "cross" else resolutions
):
ax.axvline(
k0 * res,
@ -213,22 +215,35 @@ def create_plot(mode):
crossing_value = end_pcross[crossing_index] # and here
crossings[i][j] = crossing_value
ax_end.set_xlim(right=k0 * resolutions[-1])
ax_end.set_ylim(0.8, 1.02)
if bottom_row:
# ax_z1.legend()
ax_ics.legend(loc="lower left")
if mode == "power":
ax_ics.legend(loc="lower left")
else:
lines: List[Line2D] = ax_ics.get_lines()
half_lines1 = []
half_lines2 = []
for line in lines:
if line.get_label().startswith("128"):
half_lines1.append(line)
else:
half_lines2.append(line)
ax_ics.legend(handles=half_lines1, loc="lower left")
ax_z1.legend(handles=half_lines2, loc="lower left")
if not bottom_row:
last_xtick: XTick = ax_ics.yaxis.get_major_ticks()[0]
last_xtick.set_visible(False)
# fig.suptitle(f"Cross Spectra {time}") #Not needed for paper
# fig.tight_layout()
print(crossings)
crossings_df = pd.DataFrame(crossings, columns=combination_list, index=waveforms)
# print(crossings_df.to_markdown())
print(crossings_df.to_latex())
if mode=="cross":
print(crossings)
crossings_df = pd.DataFrame(crossings, columns=combination_list, index=waveforms)
# print(crossings_df.to_markdown())
print(crossings_df.to_latex())
fig.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)