1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/slices.py

80 lines
2.4 KiB
Python
Raw Normal View History

2022-08-23 17:29:44 +02:00
from typing import List, Tuple
2022-08-05 16:13:05 +02:00
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from scipy.interpolate import griddata
2022-08-23 16:38:33 +02:00
from temperatures import calculate_T
2022-08-05 16:13:05 +02:00
from utils import create_figure
2022-08-23 17:29:44 +02:00
def filter_3d(
coords: np.ndarray, extent: List[float], data: np.ndarray = None, zlimit=None
) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
2022-08-23 17:29:44 +02:00
filter = (
(extent[0] < coords[::, 0]) &
(coords[::, 0] < extent[1]) &
(extent[2] < coords[::, 1]) &
(coords[::, 1] < extent[3])
)
if zlimit:
filter = filter & (
(zlimit[0] < coords[::, 2]) &
(coords[::, 2] < zlimit[1])
)
2022-08-23 17:29:44 +02:00
print("before", coords.shape)
if data is not None:
data = data[filter]
2022-08-23 17:29:44 +02:00
coords = coords[filter]
print("after", coords.shape)
if data is not None:
return coords, data
return coords
2022-08-23 17:29:44 +02:00
def create_2d_slice(center: List[float], extent, coords: np.ndarray, property_name: str, property_data: np.ndarray,
resolution: int,
method="nearest") -> np.ndarray:
2022-08-23 17:29:44 +02:00
cut_axis = 2 # Z
2022-08-23 16:36:48 +02:00
coords, property_data = filter_3d(coords, extent, property_data)
if property_name == "Temperatures":
print("calculating temperatures")
property_data = np.array([calculate_T(u) for u in property_data])
xrange = np.linspace(extent[0], extent[1], resolution)
yrange = np.linspace(extent[2], extent[3], resolution)
gx, gy, gz = np.meshgrid(xrange, yrange, center[cut_axis])
print("interpolating")
grid = griddata(coords, property_data, (gx, gy, gz), method=method)[::, ::, 0]
return grid
print(grid.shape)
# stats, x_edge, y_edge, _ = binned_statistic_2d(
# coords_in_slice[::, x_axis],
# coords_in_slice[::, y_axis],
# data_in_slice,
# bins=500,
# statistic="mean"
# )
fig, ax = create_figure()
# stats = np.nan_to_num(stats)
print("plotting")
img = ax.imshow(
grid,
norm=LogNorm(),
interpolation="nearest",
origin="lower",
extent=[xrange[0], xrange[-1], yrange[0], yrange[-1]],
)
ax.set_title(input_file.parent.stem)
ax.set_xlabel(x_axis_label)
ax.set_ylabel(y_axis_label)
ax.set_aspect("equal")
fig.colorbar(img, label="Temperatures")
fig.tight_layout()
plt.show()