1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-19 16:03:50 +02:00
halo_comparison/find_center.py

49 lines
1.6 KiB
Python
Raw Permalink Normal View History

import hashlib
2022-08-24 11:42:28 +02:00
from pathlib import Path
2022-06-14 10:53:19 +02:00
import numpy as np
import pandas as pd
2022-08-24 11:42:28 +02:00
from cache import HDFCache
2022-06-14 16:38:50 +02:00
from utils import print_progress
2022-08-24 11:42:28 +02:00
cache = HDFCache(Path("center_cache.hdf5"))
2022-06-14 10:53:19 +02:00
def find_center(df: pd.DataFrame, center: np.ndarray, initial_radius=1):
2022-08-24 23:42:10 +02:00
# plt.figure()
2022-06-14 10:53:19 +02:00
all_particles = df[["X", "Y", "Z"]].to_numpy()
2022-08-24 23:42:10 +02:00
hashdata = hashlib.sha256()
hashdata.update(np.ascontiguousarray(all_particles).data)
hashdata.update(np.ascontiguousarray(center).data)
hashdata.update(np.array(initial_radius))
hash = hashdata.hexdigest()
2022-08-24 11:42:28 +02:00
cached_center = cache.get(hash)
if cached_center is not None:
return np.array(cached_center)
2022-06-14 10:53:19 +02:00
radius = initial_radius
center_history = []
2022-06-14 16:38:50 +02:00
i = 0
2022-06-14 10:53:19 +02:00
while True:
center_history.append(center)
distances = np.linalg.norm(all_particles - center, axis=1)
in_radius_particles = all_particles[distances < radius]
num_particles = in_radius_particles.shape[0]
2022-06-14 16:38:50 +02:00
print_progress(i, "?", f"n={num_particles}, r={radius}, c={center}")
2022-06-14 10:53:19 +02:00
if num_particles < 10:
break
center_of_mass = in_radius_particles.mean(axis=0)
new_center = (center_of_mass + center) / 2
shift = np.linalg.norm(center - new_center)
2022-08-24 23:42:10 +02:00
radius = max(0.8 * shift, radius * 0.9)
2022-06-14 10:53:19 +02:00
center = new_center
2022-06-14 16:38:50 +02:00
i += 1
2022-06-14 10:53:19 +02:00
center_history = np.array(center_history)
# print(center_history)
# plt.scatter(center_history[::, 0], center_history[::, 1], c=range(len(center_history[::, 1])))
# plt.colorbar(label="step")
# plt.show()
2022-06-21 16:26:42 +02:00
print()
2022-08-24 11:42:28 +02:00
cache.set(hash, center)
2022-06-14 10:53:19 +02:00
return center