1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-13 09:03:49 +02:00

simpler caching

This commit is contained in:
Lukas Winkler 2022-08-24 11:42:28 +02:00
parent 6a3685daf7
commit 1baaae41b4
Signed by: lukas
GPG key ID: 54DE4D798D244853
3 changed files with 76 additions and 22 deletions

View file

@ -13,6 +13,7 @@ from matplotlib.axes import Axes
from matplotlib.colors import LogNorm
from matplotlib.figure import Figure
from cache import HDFCache
from cic import cic_from_radius, cic_range
from halo_mass_profile import halo_mass_profile
from nfw import fit_nfw
@ -30,6 +31,8 @@ class Mode(Enum):
mode = Mode.richings
cache = HDFCache(Path("auriga_cache.hdf5"))
def dir_name_to_parameter(dir_name: str):
return map(
@ -129,7 +132,6 @@ def main():
h = 0.6777
hr_coordinates, particles_meta, center = load_ramses_data(dir / "output_00007")
df = pd.DataFrame(hr_coordinates, columns=["X", "Y", "Z"])
center = center
softening_length = None
else:
df, particles_meta = read_file(input_file)
@ -228,17 +230,23 @@ def main():
i += 1
if has_baryons:
interpolation_method = "nearest" # "linear"
fig3, axs_baryon = plt.subplots(nrows=1, ncols=5, sharex="all", sharey="all", figsize=(10, 4))
extent = [46, 52, 54, 60] # xrange[0], xrange[-1], yrange[0], yrange[-1]
extent = [42, 62, 50, 70]
for ii, property in enumerate(["cic", "Densities", "Entropies", "InternalEnergies", "Temperatures"]):
print(property)
if property == "cic":
grid, _ = cic_range(X + center[0], Y + center[1], 1000, *extent, periodic=False)
grid = grid.T
key = f"grid_{property}_{interpolation_method}"
cached_grid = cache.get(key, str(input_file))
if cached_grid is not None and False:
grid = cached_grid
else:
grid = create_2d_slice(input_file, center, property=property, extent=extent)
print("minmax", grid.min(), grid.max())
assert grid.min() != grid.max()
if property == "cic":
grid, _ = cic_range(X + center[0], Y + center[1], 1000, *extent, periodic=False)
grid = grid.T
else:
grid = create_2d_slice(input_file, center, property=property,
extent=extent, method=interpolation_method)
cache.set(key, grid, str(input_file), compressed=True)
ax_baryon: Axes = axs_baryon[ii]
img = ax_baryon.imshow(
grid,
@ -254,6 +262,7 @@ def main():
fig3.suptitle(input_file.parent.stem)
fig3.tight_layout()
fig3.savefig(Path("~/tmp/slice.png").expanduser(), dpi=300)
# exit()
plt.show()
# plot_cic(

51
cache.py Normal file
View file

@ -0,0 +1,51 @@
from pathlib import Path
from typing import Optional
import numpy as np
from h5py import File
class HDFCache:
def __init__(self, filename: Path):
self.f = File(filename, "a")
def get(self, key: str, group: str = None) -> Optional[np.ndarray]:
try:
if group:
g = self.f[group]
else:
g = self.f
return np.asarray(g[key])
except KeyError:
return None
def set(self, key: str, data: np.ndarray, group: str = None, compressed: bool = False) -> None:
if self.get(key, group) is not None:
self.delete(key, group)
if not group:
g = self.f
elif group not in self.f:
g = self.f.create_group(group)
else:
g = self.f[group]
if compressed:
kwargs = {
"compression": "gzip",
"compression_opts": 5
}
else:
kwargs = {}
g.create_dataset(key, data=data, **kwargs)
def delete(self, key: str, group: str = None):
if not group:
g = self.f
else:
g = self.f[group]
del g[key]
def delgroup(self, group: str):
raise NotImplemented()
def __del__(self):
self.f.close()

View file

@ -1,26 +1,22 @@
import hashlib
import json
from pathlib import Path
import numpy as np
import pandas as pd
from cache import HDFCache
from utils import print_progress
cache_file = "center_cache.json"
try:
with open(cache_file, "r") as f:
center_cache = json.load(f)
except FileNotFoundError:
center_cache = {}
cache = HDFCache(Path("center_cache.hdf5"))
def find_center(df: pd.DataFrame, center: np.ndarray, initial_radius=1):
# plt.figure()
plt.figure()
all_particles = df[["X", "Y", "Z"]].to_numpy()
hash = hashlib.sha256(np.ascontiguousarray(all_particles).data).hexdigest()
if hash in center_cache:
return np.array(center_cache[hash])
cached_center = cache.get(hash)
if cached_center is not None:
return np.array(cached_center)
radius = initial_radius
center_history = []
i = 0
@ -44,7 +40,5 @@ def find_center(df: pd.DataFrame, center: np.ndarray, initial_radius=1):
# plt.colorbar(label="step")
# plt.show()
print()
center_cache[hash] = center.tolist()
with open(cache_file, "w") as f:
json.dump(center_cache, f)
cache.set(hash, center)
return center