mirror of
https://github.com/Findus23/halo_comparison.git
synced 2024-09-13 09:03:49 +02:00
simpler caching
This commit is contained in:
parent
6a3685daf7
commit
1baaae41b4
3 changed files with 76 additions and 22 deletions
|
@ -13,6 +13,7 @@ from matplotlib.axes import Axes
|
|||
from matplotlib.colors import LogNorm
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from cache import HDFCache
|
||||
from cic import cic_from_radius, cic_range
|
||||
from halo_mass_profile import halo_mass_profile
|
||||
from nfw import fit_nfw
|
||||
|
@ -30,6 +31,8 @@ class Mode(Enum):
|
|||
|
||||
mode = Mode.richings
|
||||
|
||||
cache = HDFCache(Path("auriga_cache.hdf5"))
|
||||
|
||||
|
||||
def dir_name_to_parameter(dir_name: str):
|
||||
return map(
|
||||
|
@ -129,7 +132,6 @@ def main():
|
|||
h = 0.6777
|
||||
hr_coordinates, particles_meta, center = load_ramses_data(dir / "output_00007")
|
||||
df = pd.DataFrame(hr_coordinates, columns=["X", "Y", "Z"])
|
||||
center = center
|
||||
softening_length = None
|
||||
else:
|
||||
df, particles_meta = read_file(input_file)
|
||||
|
@ -228,17 +230,23 @@ def main():
|
|||
i += 1
|
||||
|
||||
if has_baryons:
|
||||
interpolation_method = "nearest" # "linear"
|
||||
fig3, axs_baryon = plt.subplots(nrows=1, ncols=5, sharex="all", sharey="all", figsize=(10, 4))
|
||||
extent = [46, 52, 54, 60] # xrange[0], xrange[-1], yrange[0], yrange[-1]
|
||||
extent = [42, 62, 50, 70]
|
||||
for ii, property in enumerate(["cic", "Densities", "Entropies", "InternalEnergies", "Temperatures"]):
|
||||
print(property)
|
||||
if property == "cic":
|
||||
grid, _ = cic_range(X + center[0], Y + center[1], 1000, *extent, periodic=False)
|
||||
grid = grid.T
|
||||
key = f"grid_{property}_{interpolation_method}"
|
||||
cached_grid = cache.get(key, str(input_file))
|
||||
if cached_grid is not None and False:
|
||||
grid = cached_grid
|
||||
else:
|
||||
grid = create_2d_slice(input_file, center, property=property, extent=extent)
|
||||
print("minmax", grid.min(), grid.max())
|
||||
assert grid.min() != grid.max()
|
||||
if property == "cic":
|
||||
grid, _ = cic_range(X + center[0], Y + center[1], 1000, *extent, periodic=False)
|
||||
grid = grid.T
|
||||
else:
|
||||
grid = create_2d_slice(input_file, center, property=property,
|
||||
extent=extent, method=interpolation_method)
|
||||
cache.set(key, grid, str(input_file), compressed=True)
|
||||
ax_baryon: Axes = axs_baryon[ii]
|
||||
img = ax_baryon.imshow(
|
||||
grid,
|
||||
|
@ -254,6 +262,7 @@ def main():
|
|||
fig3.suptitle(input_file.parent.stem)
|
||||
fig3.tight_layout()
|
||||
fig3.savefig(Path("~/tmp/slice.png").expanduser(), dpi=300)
|
||||
# exit()
|
||||
plt.show()
|
||||
|
||||
# plot_cic(
|
||||
|
|
51
cache.py
Normal file
51
cache.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from h5py import File
|
||||
|
||||
|
||||
class HDFCache:
|
||||
def __init__(self, filename: Path):
|
||||
self.f = File(filename, "a")
|
||||
|
||||
def get(self, key: str, group: str = None) -> Optional[np.ndarray]:
|
||||
try:
|
||||
if group:
|
||||
g = self.f[group]
|
||||
else:
|
||||
g = self.f
|
||||
return np.asarray(g[key])
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def set(self, key: str, data: np.ndarray, group: str = None, compressed: bool = False) -> None:
|
||||
if self.get(key, group) is not None:
|
||||
self.delete(key, group)
|
||||
if not group:
|
||||
g = self.f
|
||||
elif group not in self.f:
|
||||
g = self.f.create_group(group)
|
||||
else:
|
||||
g = self.f[group]
|
||||
if compressed:
|
||||
kwargs = {
|
||||
"compression": "gzip",
|
||||
"compression_opts": 5
|
||||
}
|
||||
else:
|
||||
kwargs = {}
|
||||
g.create_dataset(key, data=data, **kwargs)
|
||||
|
||||
def delete(self, key: str, group: str = None):
|
||||
if not group:
|
||||
g = self.f
|
||||
else:
|
||||
g = self.f[group]
|
||||
del g[key]
|
||||
|
||||
def delgroup(self, group: str):
|
||||
raise NotImplemented()
|
||||
|
||||
def __del__(self):
|
||||
self.f.close()
|
|
@ -1,26 +1,22 @@
|
|||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from cache import HDFCache
|
||||
from utils import print_progress
|
||||
|
||||
cache_file = "center_cache.json"
|
||||
|
||||
try:
|
||||
with open(cache_file, "r") as f:
|
||||
center_cache = json.load(f)
|
||||
except FileNotFoundError:
|
||||
center_cache = {}
|
||||
cache = HDFCache(Path("center_cache.hdf5"))
|
||||
|
||||
|
||||
def find_center(df: pd.DataFrame, center: np.ndarray, initial_radius=1):
|
||||
# plt.figure()
|
||||
plt.figure()
|
||||
all_particles = df[["X", "Y", "Z"]].to_numpy()
|
||||
hash = hashlib.sha256(np.ascontiguousarray(all_particles).data).hexdigest()
|
||||
if hash in center_cache:
|
||||
return np.array(center_cache[hash])
|
||||
cached_center = cache.get(hash)
|
||||
if cached_center is not None:
|
||||
return np.array(cached_center)
|
||||
radius = initial_radius
|
||||
center_history = []
|
||||
i = 0
|
||||
|
@ -44,7 +40,5 @@ def find_center(df: pd.DataFrame, center: np.ndarray, initial_radius=1):
|
|||
# plt.colorbar(label="step")
|
||||
# plt.show()
|
||||
print()
|
||||
center_cache[hash] = center.tolist()
|
||||
with open(cache_file, "w") as f:
|
||||
json.dump(center_cache, f)
|
||||
cache.set(hash, center)
|
||||
return center
|
||||
|
|
Loading…
Reference in a new issue