mirror of
https://github.com/Findus23/collision-analyisis-and-interpolation.git
synced 2024-09-18 14:03:51 +02:00
improve network
This commit is contained in:
parent
0021906370
commit
ff0dbd83b2
15 changed files with 217 additions and 121 deletions
9
cli.py
9
cli.py
|
@ -76,7 +76,6 @@ if gamma > 1:
|
|||
alpha = clamp(alpha, 0, 60)
|
||||
velocity = clamp(velocity, 1, 5)
|
||||
|
||||
|
||||
m_ceres = 9.393e+20
|
||||
m_earth = 5.9722e+24
|
||||
projectile_mass = clamp(projectile_mass, 2 * m_ceres, 2 * m_earth)
|
||||
|
@ -88,14 +87,10 @@ scaler.fit(simulations.X)
|
|||
|
||||
scaled_data = scaler.transform_data(simulations.X)
|
||||
water_interpolator = RbfInterpolator(scaled_data, simulations.Y_water)
|
||||
mass_interpolator = RbfInterpolator(scaled_data, simulations.Y_mass)
|
||||
mass_interpolator = RbfInterpolator(scaled_data, simulations.Y_mantle)
|
||||
|
||||
testinput = [alpha, velocity, projectile_mass, gamma,
|
||||
target_water_fraction, projectile_water_fraction]
|
||||
testinput = [32, 1, 7.6e22, 0.16, 0.15, 0.15]
|
||||
|
||||
with open("xl3", "w") as f:
|
||||
f.write("# alpha velocity projectile_mass gamma target_water_fraction projectile_water_fraction\n")
|
||||
f.write(" ".join(map(str, testinput)))
|
||||
|
||||
scaled_input = list(scaler.transform_parameters(testinput))
|
||||
water_retention = water_interpolator.interpolate(*scaled_input)
|
||||
|
|
|
@ -1 +1 @@
|
|||
water_fraction = False
|
||||
water_fraction = True
|
||||
|
|
10
cov.py
10
cov.py
|
@ -1,10 +1,14 @@
|
|||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.axes import Axes
|
||||
from pathlib import Path
|
||||
|
||||
from simulation_list import SimulationList
|
||||
|
||||
simulations = SimulationList.jsonlines_load()
|
||||
plt.style.use('dark_background')
|
||||
|
||||
|
||||
simulations = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl"))
|
||||
np.set_printoptions(linewidth=1000, edgeitems=4)
|
||||
|
||||
x = simulations.as_matrix
|
||||
|
@ -35,12 +39,12 @@ plt.close()
|
|||
ax = plt.gca() # type:Axes
|
||||
print(len(labels), len(simple_cov))
|
||||
print(simple_cov)
|
||||
plt.barh(range(len(simple_cov)), simple_cov)
|
||||
plt.barh(range(len(simple_cov)), simple_cov,color="#B3DE69")
|
||||
# ax.set_xticks(index + bar_width / 2)
|
||||
ax.set_yticklabels([0] + labels)
|
||||
ax2 = ax.twinx() # type:Axes
|
||||
ax2.set_yticklabels([0] + ["({:.2f})".format(a) for a in simple_cov])
|
||||
ax2.set_ylim(ax.get_ylim())
|
||||
plt.tight_layout()
|
||||
plt.savefig("../arbeit/images/cov.pdf")
|
||||
plt.savefig("../arbeit/images/cov.pdf",transparent=True)
|
||||
plt.show()
|
||||
|
|
11
example.py
11
example.py
|
@ -1,14 +1,17 @@
|
|||
"""
|
||||
Just a demo file on how to quickly read the dataset
|
||||
"""
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from simulation_list import SimulationList
|
||||
|
||||
simlist = SimulationList.jsonlines_load()
|
||||
|
||||
simlist = SimulationList.jsonlines_load(Path("save.jsonl"))
|
||||
|
||||
for s in simlist.simlist:
|
||||
if not s.testcase:
|
||||
continue
|
||||
print(vars(s))
|
||||
if s.water_retention_both < 0:
|
||||
print(s.runid)
|
||||
if s.water_retention_both < 0.2:
|
||||
pprint(vars(s))
|
||||
print(s.water_retention_both)
|
||||
|
|
|
@ -4,8 +4,8 @@ from torch import nn
|
|||
class Network(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.hidden = nn.Linear(6, 50)
|
||||
self.output = nn.Linear(50, 4)
|
||||
self.hidden = nn.Linear(6, 70)
|
||||
self.output = nn.Linear(70, 4)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU()
|
||||
|
|
|
@ -73,7 +73,7 @@ def train():
|
|||
loss_train = []
|
||||
loss_vali = []
|
||||
|
||||
max_epochs = 500
|
||||
max_epochs = 200
|
||||
epochs = 0
|
||||
|
||||
fig: Figure = plt.figure()
|
||||
|
@ -128,6 +128,16 @@ def train():
|
|||
# if a > b: # overfitting on training data, stop training
|
||||
# print("early stopping")
|
||||
# break
|
||||
np.savetxt("loss.txt", np.array([x_axis, loss_train, loss_vali]).T)
|
||||
torch.save(network.state_dict(), "pytorch_model.zip")
|
||||
with open("pytorch_model.json", "w") as f:
|
||||
export_dict = {}
|
||||
value_tensor: Tensor
|
||||
for key, value_tensor in network.state_dict().items():
|
||||
export_dict[key] = value_tensor.detach().tolist()
|
||||
export_dict["means"] = scaler.means.tolist()
|
||||
export_dict["stds"] = scaler.stds.tolist()
|
||||
json.dump(export_dict, f)
|
||||
plt.ioff()
|
||||
model_test_y = []
|
||||
for x in x_test:
|
||||
|
@ -138,53 +148,10 @@ def train():
|
|||
plt.figure()
|
||||
plt.xlabel("model output")
|
||||
plt.ylabel("real data")
|
||||
for i, name in enumerate(["shell","mantle","core","mass fraction"]):
|
||||
plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2,label=name)
|
||||
for i, name in enumerate(["shell", "mantle", "core", "mass fraction"]):
|
||||
plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2, label=name)
|
||||
plt.legend()
|
||||
plt.show()
|
||||
torch.save(network.state_dict(), "pytorch_model.zip")
|
||||
with open("pytorch_model.json", "w") as f:
|
||||
export_dict = {}
|
||||
value_tensor: Tensor
|
||||
for key, value_tensor in network.state_dict().items():
|
||||
export_dict[key] = value_tensor.detach().tolist()
|
||||
export_dict["means"] = scaler.means.tolist()
|
||||
export_dict["stds"] = scaler.stds.tolist()
|
||||
json.dump(export_dict, f)
|
||||
|
||||
xrange = np.linspace(-0.5, 60.5, 300)
|
||||
yrange = np.linspace(0.5, 5.5, 300)
|
||||
xgrid, ygrid = np.meshgrid(xrange, yrange)
|
||||
mcode = 1e24
|
||||
wpcode = 1e-4
|
||||
|
||||
wtcode = 1e-4
|
||||
gammacode = 0.6
|
||||
testinput = np.array([[np.nan, np.nan, mcode, gammacode, wtcode, wpcode]] * 300 * 300)
|
||||
testinput[::, 0] = xgrid.flatten()
|
||||
testinput[::, 1] = ygrid.flatten()
|
||||
testinput = scaler.transform_data(testinput)
|
||||
|
||||
print(testinput)
|
||||
print(testinput.shape)
|
||||
testoutput: Tensor = network(from_numpy(testinput).to(torch.float))
|
||||
data = testoutput.detach().numpy()
|
||||
outgrid = np.reshape(data[::, 0], (300, 300))
|
||||
print("minmax")
|
||||
print(np.nanmin(outgrid), np.nanmax(outgrid))
|
||||
cmap = "Blues"
|
||||
plt.figure()
|
||||
plt.title(
|
||||
"m={:3.0e}, gamma={:3.1f}, wt={:2.0f}%, wp={:2.0f}%\n".format(mcode, gammacode, wtcode * 100, wpcode * 100))
|
||||
plt.imshow(outgrid, interpolation='none', cmap=cmap, aspect="auto", origin="lower", vmin=0, vmax=1,
|
||||
extent=[xgrid.min(), xgrid.max(), ygrid.min(), ygrid.max()])
|
||||
|
||||
plt.colorbar().set_label("water retention fraction")
|
||||
plt.xlabel("impact angle $\\alpha$ [$^{\circ}$]")
|
||||
plt.ylabel("velocity $v$ [$v_{esc}$]")
|
||||
plt.tight_layout()
|
||||
# plt.savefig("/home/lukas/tmp/nn.svg", transparent=True)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
27
nn_single.py
Normal file
27
nn_single.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from CustomScaler import CustomScaler
|
||||
from network import Network
|
||||
from simulation_list import SimulationList
|
||||
|
||||
resolution = 100
|
||||
|
||||
with open("pytorch_model.json") as f:
|
||||
data = json.load(f)
|
||||
scaler = CustomScaler()
|
||||
scaler.means = np.array(data["means"])
|
||||
scaler.stds = np.array(data["stds"])
|
||||
|
||||
model = Network()
|
||||
model.load_state_dict(torch.load("pytorch_model.zip"))
|
||||
|
||||
ang = 30
|
||||
v = 2
|
||||
m = 1e24
|
||||
gamma = 0.6
|
||||
wp = wt = 1e-4
|
||||
print(model(torch.Tensor(list(scaler.transform_parameters([ang, v, m, gamma, wt, wp])))))
|
|
@ -6,15 +6,15 @@ authors = ["Lukas Winkler <git@lw1.at>"]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
scipy = "^1.5.4"
|
||||
numpy = "^1.19.4"
|
||||
matplotlib = "^3.3.3"
|
||||
scipy = "^1.6.1"
|
||||
numpy = "^1.19.0"
|
||||
matplotlib = "^3.3.4"
|
||||
tabulate = "^0.8.7"
|
||||
#Keras = "^2.4.3"
|
||||
#tensorflow = "^2.3.1"
|
||||
pydot = "^1.4.1"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
PyQt5 = "5.12.1"
|
||||
#"vext.pyqt5" = "^0.7.4"
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
25
readfiles.py
25
readfiles.py
|
@ -1,4 +1,3 @@
|
|||
from glob import glob
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -6,33 +5,27 @@ from simulation import Simulation
|
|||
from simulation_list import SimulationList
|
||||
|
||||
simulation_sets = {
|
||||
"original": sorted(glob("../data/*")),
|
||||
"cloud": sorted(glob("../../Bachelorarbeit_data/results/*"))
|
||||
"winter": sorted(Path("../../tmp/winter/").glob("*"))
|
||||
# "original": sorted(glob("../data/*")),
|
||||
# "cloud": sorted(glob("../../Bachelorarbeit_data/results/*"))
|
||||
# "benchmark": sorted(glob("../../Bachelorarbeit_benchmark/results/*"))
|
||||
}
|
||||
simulations = SimulationList()
|
||||
|
||||
for set_type, directories in simulation_sets.items():
|
||||
for dir in directories:
|
||||
original = set_type == "original"
|
||||
spheres_file = dir + "/spheres_ini_log"
|
||||
timings_file = dir + "/pythontimings.json"
|
||||
aggregates_file = dir + ("/sim/aggregates.txt" if original else "/aggregates.txt")
|
||||
print(dir)
|
||||
spheres_file = dir / "spheres_ini.log"
|
||||
aggregates_file = sorted(dir.glob("frames/aggregates.*"))[-1]
|
||||
if not path.exists(spheres_file) or not path.exists(aggregates_file):
|
||||
print(f"skipping {dir}")
|
||||
continue
|
||||
if "id" not in dir and original:
|
||||
continue
|
||||
sim = Simulation()
|
||||
if set_type == "original":
|
||||
sim.load_params_from_dirname(path.basename(dir))
|
||||
else:
|
||||
sim.load_params_from_json(dir + "/parameters.json")
|
||||
sim.load_params_from_setup_txt(dir / "cfg.txt")
|
||||
sim.type = set_type
|
||||
sim.load_params_from_spheres_ini_log(spheres_file)
|
||||
sim.load_params_from_aggregates_txt(aggregates_file)
|
||||
sim.load_params_from_pythontiming_json(timings_file)
|
||||
sim.assert_all_loaded()
|
||||
# sim.assert_all_loaded()
|
||||
if sim.rel_velocity < 0 or sim.distance < 0:
|
||||
# Sometimes in the old dataset the second object wasn't detected.
|
||||
# To be save, we'll exclude them
|
||||
|
@ -52,4 +45,4 @@ for set_type, directories in simulation_sets.items():
|
|||
|
||||
print(len(simulations.simlist))
|
||||
|
||||
simulations.jsonlines_save(Path("output.jsonl"))
|
||||
simulations.jsonlines_save(Path("winter.jsonl"))
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class Simulation:
|
||||
|
||||
|
@ -132,7 +135,10 @@ class Simulation:
|
|||
def output_mass_fraction(self) -> Optional[float]:
|
||||
if not self.largest_aggregate_mass:
|
||||
return 0 # FIXME
|
||||
return self.second_largest_aggregate_mass / self.largest_aggregate_mass
|
||||
massive_size_relation=self.largest_aggregate_mass/self.target_mass
|
||||
less_massive_size_relation=self.second_largest_aggregate_mass/self.projectile_mass
|
||||
print("mr",massive_size_relation, less_massive_size_relation)
|
||||
return less_massive_size_relation / massive_size_relation
|
||||
|
||||
@property
|
||||
def original_simulation(self) -> bool:
|
||||
|
@ -149,7 +155,7 @@ class Simulation:
|
|||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Simulation '{self.simulation_key}'>"
|
||||
return f"<Simulation '{vars(self)}'>"
|
||||
|
||||
def load_params_from_dirname(self, dirname: str) -> None:
|
||||
params = dirname.split("_")
|
||||
|
@ -173,8 +179,18 @@ class Simulation:
|
|||
self.wtcode = data["wt_code"]
|
||||
self.wpcode = data["wp_code"]
|
||||
|
||||
def load_params_from_spheres_ini_log(self, filename: str) -> None:
|
||||
with open(filename) as f:
|
||||
def load_params_from_setup_txt(self, file: Path) -> None:
|
||||
with file.open() as f:
|
||||
data = yaml.safe_load(f)
|
||||
# self.runid=data["ID"]
|
||||
# self.vcode=data["vel_vesc_touching_ball"]
|
||||
# self.alphacode=data["impact_angle_touching_ball"]
|
||||
# self.mcode=data["M_tot"]
|
||||
# self.gammacode=data["gamma"]
|
||||
# TODO: maybe more needed?
|
||||
|
||||
def load_params_from_spheres_ini_log(self, filename: Path) -> None:
|
||||
with filename.open() as f:
|
||||
lines = [line.rstrip("\n") for line in f]
|
||||
for i in range(len(lines)):
|
||||
line = lines[i]
|
||||
|
@ -191,20 +207,25 @@ class Simulation:
|
|||
self.projectile_water_fraction = float(lines[i + 1].split()[7])
|
||||
self.target_water_fraction = float(lines[i + 3].split()[7])
|
||||
if "Particle numbers" in line:
|
||||
self.desired_N = int(lines[i + 1].split()[3])
|
||||
self.desired_N = int(lines[i + 1].split()[4])
|
||||
self.actual_N = int(lines[i + 1].split()[-1])
|
||||
|
||||
def load_params_from_aggregates_txt(self, filename: str) -> None:
|
||||
with open(filename) as f:
|
||||
def load_params_from_aggregates_txt(self, filename: Path) -> None:
|
||||
with filename.open() as f:
|
||||
lines = [line.rstrip("\n") for line in f]
|
||||
for i in range(len(lines)):
|
||||
line = lines[i]
|
||||
if "# largest aggregate" in line:
|
||||
self.largest_aggregate_mass = float(lines[i + 2].split()[0])
|
||||
self.largest_aggregate_water_fraction = float(lines[i + 2].split()[2])
|
||||
cols = lines[i + 2].split()
|
||||
self.largest_aggregate_mass = float(cols[0])
|
||||
self.largest_aggregate_core_fraction = float(cols[1])
|
||||
self.largest_aggregate_water_fraction = 1 - float(cols[2]) - self.largest_aggregate_core_fraction
|
||||
if "# 2nd-largest aggregate:" in line:
|
||||
self.second_largest_aggregate_mass = float(lines[i + 2].split()[0])
|
||||
self.second_largest_aggregate_water_fraction = float(lines[i + 2].split()[2])
|
||||
cols = lines[i + 2].split()
|
||||
self.second_largest_aggregate_mass = float(cols[0])
|
||||
self.second_largest_aggregate_core_fraction = float(cols[1])
|
||||
self.second_largest_aggregate_water_fraction = 1 - float(
|
||||
cols[2]) - self.second_largest_aggregate_core_fraction
|
||||
if "# distance" in line: # TODO: not sure if correct anymore
|
||||
self.distance = float(lines[i + 1].split()[0])
|
||||
self.rel_velocity = float(lines[i + 1].split()[1])
|
||||
|
|
|
@ -78,3 +78,7 @@ class SimulationList:
|
|||
@property
|
||||
def Y_mantle(self):
|
||||
return np.array([s.mantle_retention_both for s in self.simlist if not s.testcase])
|
||||
|
||||
@property
|
||||
def Y_mass_fraction(self):
|
||||
return np.array([s.output_mass_fraction for s in self.simlist if not s.testcase])
|
||||
|
|
22
sliders.py
22
sliders.py
|
@ -10,17 +10,18 @@ from interpolators.rbf import RbfInterpolator
|
|||
from simulation_list import SimulationList
|
||||
|
||||
simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl"))
|
||||
resolution = 100
|
||||
|
||||
data = simlist.X
|
||||
values = simlist.Y
|
||||
values = simlist.Y_mass_fraction
|
||||
|
||||
scaler = CustomScaler()
|
||||
scaler.fit(data)
|
||||
scaled_data = scaler.transform_data(data)
|
||||
interpolator = RbfInterpolator(scaled_data, values)
|
||||
|
||||
alpharange = np.linspace(-0.5, 60.5, 100)
|
||||
vrange = np.linspace(0.5, 5.5, 100)
|
||||
alpharange = np.linspace(-0.5, 60.5, resolution)
|
||||
vrange = np.linspace(0.5, 5.5, resolution)
|
||||
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
@ -30,7 +31,7 @@ mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
|
|||
|
||||
datagrid = np.zeros_like(grid_alpha)
|
||||
|
||||
mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1) # type:QuadMesh
|
||||
mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="nearest") # type:QuadMesh
|
||||
plt.colorbar()
|
||||
|
||||
# axcolor = 'lightgoldenrodyellow'
|
||||
|
@ -44,8 +45,8 @@ button = Button(buttonax, 'Update', hovercolor='0.975')
|
|||
|
||||
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
|
||||
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
|
||||
s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-4, valinit=wt_default)
|
||||
s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-4, valinit=wp_default)
|
||||
s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default)
|
||||
s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default)
|
||||
|
||||
|
||||
def update(val):
|
||||
|
@ -57,23 +58,24 @@ def update(val):
|
|||
gamma = s_gamma.val
|
||||
wt = s_wt.val
|
||||
wp = s_wp.val
|
||||
parameters = [grid_alpha, grid_v, 10 ** mcode, gamma, wt / 100, wp / 100]
|
||||
parameters = [grid_alpha, grid_v, 10 ** mcode, gamma, wt, wp]
|
||||
scaled_parameters = list(scaler.transform_parameters(parameters))
|
||||
|
||||
datagrid = interpolator.interpolate(*scaled_parameters)
|
||||
print(datagrid)
|
||||
print(np.isnan(datagrid).sum() / (100 * 100))
|
||||
print(np.isnan(datagrid).sum() / (resolution * resolution))
|
||||
if not isinstance(datagrid, np.ndarray):
|
||||
return False
|
||||
formatedgrid = datagrid[:-1, :-1]
|
||||
|
||||
mesh.set_array(formatedgrid.ravel())
|
||||
mesh.set_array(datagrid.ravel())
|
||||
print("finished updating")
|
||||
# thetext.set_text("finished")
|
||||
|
||||
fig.canvas.draw_idle()
|
||||
|
||||
|
||||
update(None)
|
||||
|
||||
# s_gamma.on_changed(update)
|
||||
# s_mcode.on_changed(update)
|
||||
# s_wp.on_changed(update)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -8,21 +8,21 @@ from matplotlib.widgets import Slider
|
|||
|
||||
from CustomScaler import CustomScaler
|
||||
from network import Network
|
||||
from simulation_list import SimulationList
|
||||
|
||||
simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl"))
|
||||
resolution = 100
|
||||
|
||||
data = simlist.X
|
||||
scaler = CustomScaler()
|
||||
scaler.fit(data)
|
||||
with open("pytorch_model.json") as f:
|
||||
data = json.load(f)
|
||||
scaler = CustomScaler()
|
||||
scaler.means = np.array(data["means"])
|
||||
scaler.stds = np.array(data["stds"])
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
plt.subplots_adjust(bottom=0.35)
|
||||
t = np.arange(0.0, 1.0, 0.001)
|
||||
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
|
||||
|
||||
alpharange = np.linspace(-0.5, 60.5, resolution)
|
||||
alpharange = np.linspace(0, 60, resolution)
|
||||
vrange = np.linspace(0.5, 5.5, resolution)
|
||||
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
|
||||
|
||||
|
|
30
visualize.py
30
visualize.py
|
@ -6,15 +6,16 @@ from matplotlib import pyplot as plt, cm
|
|||
|
||||
from CustomScaler import CustomScaler
|
||||
from config import water_fraction
|
||||
from interpolators.griddata import GriddataInterpolator
|
||||
from interpolators.rbf import RbfInterpolator
|
||||
from simulation import Simulation
|
||||
from simulation_list import SimulationList
|
||||
|
||||
|
||||
# plt.style.use('dark_background')
|
||||
plt.style.use('dark_background')
|
||||
|
||||
|
||||
def main():
|
||||
mcode, gamma, wt, wp = [10 ** 21, 0.6, 1e-5, 1e-5]
|
||||
mcode, gamma, wt, wp = [10 ** 22, 0.6, 1e-5, 1e-5]
|
||||
simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl"))
|
||||
# for s in simlist.simlist:
|
||||
# if s.type!="original":
|
||||
|
@ -26,7 +27,7 @@ def main():
|
|||
print(len(data))
|
||||
# print(data[0])
|
||||
# exit()
|
||||
values = simlist.Y
|
||||
values = simlist.Y_mass_fraction
|
||||
|
||||
scaler = CustomScaler()
|
||||
scaler.fit(data)
|
||||
|
@ -34,7 +35,7 @@ def main():
|
|||
interpolator = RbfInterpolator(scaled_data, values)
|
||||
# interpolator = GriddataInterpolator(scaled_data, values)
|
||||
|
||||
alpharange = np.linspace(-0.5, 60.5, 300)
|
||||
alpharange = np.linspace(0, 60, 300)
|
||||
vrange = np.linspace(0.5, 5.5, 300)
|
||||
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
|
||||
|
||||
|
@ -53,13 +54,28 @@ def main():
|
|||
# plt.pcolormesh(grid_alpha, grid_v, grid_result, cmap="Blues", vmin=0, vmax=1)
|
||||
plt.imshow(grid_result, interpolation='none', cmap=cmap, aspect="auto", origin="lower", vmin=0, vmax=1,
|
||||
extent=[grid_alpha.min(), grid_alpha.max(), grid_v.min(), grid_v.max()])
|
||||
plt.colorbar().set_label("water retention fraction" if water_fraction else "core mass retention fraction")
|
||||
# plt.scatter(data[:, 0], data[:, 1], c=values, cmap="Blues")
|
||||
plt.xlabel("impact angle $\\alpha$ [$^{\circ}$]")
|
||||
plt.ylabel("velocity $v$ [$v_{esc}$]")
|
||||
s: Simulation
|
||||
xs = []
|
||||
ys = []
|
||||
zs = []
|
||||
for s in simlist.simlist:
|
||||
# if not (0.4 < s.gamma < 0.6) or not (1e23 < s.total_mass < 5e24):
|
||||
# continue
|
||||
# if s.alpha < 60 or s.v > 5 or s.v < 2:
|
||||
# continue
|
||||
z = s.output_mass_fraction
|
||||
zs.append(z)
|
||||
xs.append(s.alpha)
|
||||
ys.append(s.v)
|
||||
print(z, s.runid)
|
||||
plt.scatter(xs, ys, c=zs, cmap=cmap, vmin=0, vmax=1)
|
||||
plt.colorbar().set_label("stone retention fraction" if water_fraction else "core mass retention fraction")
|
||||
plt.tight_layout()
|
||||
# plt.savefig("vis.png", transparent=True)
|
||||
plt.savefig("/home/lukas/tmp/test.svg", transparent=True)
|
||||
plt.savefig("/home/lukas/tmp/test.pdf", transparent=True)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
|
64
visualize_nn.py
Normal file
64
visualize_nn.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from torch import from_numpy, Tensor
|
||||
|
||||
from CustomScaler import CustomScaler
|
||||
from network import Network
|
||||
|
||||
resolution = 300
|
||||
|
||||
|
||||
def main():
|
||||
mcode, gamma, wt, wp = [10 ** 24, 0.4, 1e-5, 1e-5]
|
||||
with open("pytorch_model.json") as f:
|
||||
data = json.load(f)
|
||||
scaler = CustomScaler()
|
||||
scaler.means = np.array(data["means"])
|
||||
scaler.stds = np.array(data["stds"])
|
||||
|
||||
model = Network()
|
||||
model.load_state_dict(torch.load("pytorch_model.zip"))
|
||||
|
||||
alpharange = np.linspace(0, 60, resolution)
|
||||
vrange = np.linspace(0.5, 5.5, resolution)
|
||||
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
|
||||
mcode = 1e24
|
||||
wpcode = 1e-5
|
||||
|
||||
wtcode = 1e-5
|
||||
gammacode = 0.6
|
||||
testinput = np.array([[np.nan, np.nan, mcode, gammacode, wtcode, wpcode]] * resolution * resolution)
|
||||
testinput[::, 0] = grid_alpha.flatten()
|
||||
testinput[::, 1] = grid_v.flatten()
|
||||
testinput = scaler.transform_data(testinput)
|
||||
|
||||
print(testinput)
|
||||
print(testinput.shape)
|
||||
network = Network()
|
||||
network.load_state_dict(torch.load("pytorch_model.zip"))
|
||||
|
||||
print(testinput)
|
||||
testoutput: Tensor = network(from_numpy(testinput).to(torch.float))
|
||||
data = testoutput.detach().numpy()
|
||||
grid_result = np.reshape(data[::, 0], (300, 300))
|
||||
print("minmax")
|
||||
print(np.nanmin(grid_result), np.nanmax(grid_result))
|
||||
cmap = "Blues"
|
||||
plt.figure()
|
||||
plt.title(
|
||||
"m={:3.0e}, gamma={:3.1f}, wt={:2.0e}%, wp={:2.0e}%\n".format(mcode, gammacode, wtcode, wpcode))
|
||||
plt.imshow(grid_result, interpolation='none', cmap=cmap, aspect="auto", origin="lower", vmin=0, vmax=1,
|
||||
extent=[grid_alpha.min(), grid_alpha.max(), grid_v.min(), grid_v.max()])
|
||||
|
||||
plt.colorbar().set_label("water retention fraction")
|
||||
plt.xlabel("impact angle $\\alpha$ [$^{\circ}$]")
|
||||
plt.ylabel("velocity $v$ [$v_{esc}$]")
|
||||
plt.tight_layout()
|
||||
plt.savefig("/home/lukas/tmp/nn.pdf", transparent=True)
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in a new issue