diff --git a/auriga_comparison.py b/auriga_comparison.py index 0ab5658..99127f9 100644 --- a/auriga_comparison.py +++ b/auriga_comparison.py @@ -13,6 +13,7 @@ from matplotlib.axes import Axes from matplotlib.colors import LogNorm from matplotlib.figure import Figure +from cache import HDFCache from cic import cic_from_radius, cic_range from halo_mass_profile import halo_mass_profile from nfw import fit_nfw @@ -30,6 +31,8 @@ class Mode(Enum): mode = Mode.richings +cache = HDFCache(Path("auriga_cache.hdf5")) + def dir_name_to_parameter(dir_name: str): return map( @@ -129,7 +132,6 @@ def main(): h = 0.6777 hr_coordinates, particles_meta, center = load_ramses_data(dir / "output_00007") df = pd.DataFrame(hr_coordinates, columns=["X", "Y", "Z"]) - center = center softening_length = None else: df, particles_meta = read_file(input_file) @@ -228,17 +230,23 @@ def main(): i += 1 if has_baryons: + interpolation_method = "nearest" # "linear" fig3, axs_baryon = plt.subplots(nrows=1, ncols=5, sharex="all", sharey="all", figsize=(10, 4)) extent = [46, 52, 54, 60] # xrange[0], xrange[-1], yrange[0], yrange[-1] + extent = [42, 62, 50, 70] for ii, property in enumerate(["cic", "Densities", "Entropies", "InternalEnergies", "Temperatures"]): - print(property) - if property == "cic": - grid, _ = cic_range(X + center[0], Y + center[1], 1000, *extent, periodic=False) - grid = grid.T + key = f"grid_{property}_{interpolation_method}" + cached_grid = cache.get(key, str(input_file)) + if cached_grid is not None and False: + grid = cached_grid else: - grid = create_2d_slice(input_file, center, property=property, extent=extent) - print("minmax", grid.min(), grid.max()) - assert grid.min() != grid.max() + if property == "cic": + grid, _ = cic_range(X + center[0], Y + center[1], 1000, *extent, periodic=False) + grid = grid.T + else: + grid = create_2d_slice(input_file, center, property=property, + extent=extent, method=interpolation_method) + cache.set(key, grid, str(input_file), compressed=True) ax_baryon: Axes = axs_baryon[ii] img = ax_baryon.imshow( grid, @@ -254,6 +262,7 @@ def main(): fig3.suptitle(input_file.parent.stem) fig3.tight_layout() fig3.savefig(Path("~/tmp/slice.png").expanduser(), dpi=300) + # exit() plt.show() # plot_cic( diff --git a/cache.py b/cache.py new file mode 100644 index 0000000..01ef265 --- /dev/null +++ b/cache.py @@ -0,0 +1,51 @@ +from pathlib import Path +from typing import Optional + +import numpy as np +from h5py import File + + +class HDFCache: + def __init__(self, filename: Path): + self.f = File(filename, "a") + + def get(self, key: str, group: str = None) -> Optional[np.ndarray]: + try: + if group: + g = self.f[group] + else: + g = self.f + return np.asarray(g[key]) + except KeyError: + return None + + def set(self, key: str, data: np.ndarray, group: str = None, compressed: bool = False) -> None: + if self.get(key, group) is not None: + self.delete(key, group) + if not group: + g = self.f + elif group not in self.f: + g = self.f.create_group(group) + else: + g = self.f[group] + if compressed: + kwargs = { + "compression": "gzip", + "compression_opts": 5 + } + else: + kwargs = {} + g.create_dataset(key, data=data, **kwargs) + + def delete(self, key: str, group: str = None): + if not group: + g = self.f + else: + g = self.f[group] + del g[key] + + def delgroup(self, group: str): + raise NotImplemented() + + def __del__(self): + self.f.close() diff --git a/find_center.py b/find_center.py index ba1e1b0..605ecb6 100644 --- a/find_center.py +++ b/find_center.py @@ -1,26 +1,22 @@ import hashlib -import json +from pathlib import Path import numpy as np import pandas as pd +from cache import HDFCache from utils import print_progress -cache_file = "center_cache.json" - -try: - with open(cache_file, "r") as f: - center_cache = json.load(f) -except FileNotFoundError: - center_cache = {} +cache = HDFCache(Path("center_cache.hdf5")) def find_center(df: pd.DataFrame, center: np.ndarray, initial_radius=1): - # plt.figure() + plt.figure() all_particles = df[["X", "Y", "Z"]].to_numpy() hash = hashlib.sha256(np.ascontiguousarray(all_particles).data).hexdigest() - if hash in center_cache: - return np.array(center_cache[hash]) + cached_center = cache.get(hash) + if cached_center is not None: + return np.array(cached_center) radius = initial_radius center_history = [] i = 0 @@ -44,7 +40,5 @@ def find_center(df: pd.DataFrame, center: np.ndarray, initial_radius=1): # plt.colorbar(label="step") # plt.show() print() - center_cache[hash] = center.tolist() - with open(cache_file, "w") as f: - json.dump(center_cache, f) + cache.set(hash, center) return center