1
0
Fork 0
This repository has been archived on 2024-06-28. You can view files and clone it, but cannot push or open issues or pull requests.
collision-analysis-and-inte.../sliders_nn.py

85 lines
2.4 KiB
Python
Raw Normal View History

2021-10-12 15:45:43 +02:00
import json
2021-03-29 15:02:46 +02:00
2019-04-18 10:52:13 +02:00
import matplotlib.pyplot as plt
import numpy as np
2021-03-29 15:02:46 +02:00
import torch
2019-04-18 10:52:13 +02:00
from matplotlib.collections import QuadMesh
from matplotlib.widgets import Slider
2021-03-29 15:02:46 +02:00
from CustomScaler import CustomScaler
from network import Network
2019-04-18 10:52:13 +02:00
2021-03-29 15:02:46 +02:00
resolution = 100
2019-04-18 10:52:13 +02:00
2021-10-12 15:45:43 +02:00
with open("pytorch_model.json") as f:
data = json.load(f)
scaler = CustomScaler()
scaler.means = np.array(data["means"])
scaler.stds = np.array(data["stds"])
2019-04-18 10:52:13 +02:00
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.35)
t = np.arange(0.0, 1.0, 0.001)
2019-07-06 15:50:16 +02:00
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
2019-04-18 10:52:13 +02:00
2021-10-12 15:45:43 +02:00
alpharange = np.linspace(0, 60, resolution)
2021-03-29 15:02:46 +02:00
vrange = np.linspace(0.5, 5.5, resolution)
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
2019-04-18 10:52:13 +02:00
2021-03-29 15:02:46 +02:00
model = Network()
model.load_state_dict(torch.load("pytorch_model.zip"))
2019-04-18 10:52:13 +02:00
2021-03-29 15:02:46 +02:00
datagrid = np.zeros_like(grid_alpha)
2019-04-18 10:52:13 +02:00
2021-03-29 15:02:46 +02:00
mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="auto") # type:QuadMesh
2019-04-18 10:52:13 +02:00
plt.colorbar()
axcolor = 'lightgoldenrodyellow'
2021-03-29 15:02:46 +02:00
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03])
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03])
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03])
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03])
ax_mode = plt.axes([0.25, 0.05, 0.65, 0.03])
2019-04-18 10:52:13 +02:00
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
2021-03-29 15:02:46 +02:00
s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default)
s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default)
s_mode = Slider(ax_mode, 'shell/mantle/core/mass_fraction', 1, 4, valinit=1, valstep=1)
2019-04-18 10:52:13 +02:00
def update(val):
2021-03-29 15:02:46 +02:00
mcode = s_mcode.val
2019-04-18 10:52:13 +02:00
gamma = s_gamma.val
2021-03-29 15:02:46 +02:00
wt = s_wt.val
wp = s_wp.val
mode = s_mode.val
testinput = np.array([[np.nan, np.nan, 10 ** mcode, gamma, wt, wp]] * resolution * resolution)
testinput[::, 0] = grid_alpha.flatten()
testinput[::, 1] = grid_v.flatten()
testinput = scaler.transform_data(testinput)
try:
testoutput: torch.Tensor = model(torch.from_numpy(testinput).to(torch.float))
data = testoutput.detach().numpy()
print(data.shape)
except TypeError: # can't convert np.ndarray of type numpy.object_.
data = np.zeros((resolution ** 2, 3))
2019-04-18 10:52:13 +02:00
2021-03-29 15:02:46 +02:00
datagrid = np.reshape(data[::, mode - 1], (resolution, resolution))
2019-04-18 10:52:13 +02:00
2021-03-29 15:02:46 +02:00
mesh.set_array(datagrid.ravel())
2019-04-18 10:52:13 +02:00
fig.canvas.draw_idle()
2021-03-29 15:02:46 +02:00
update(None)
2019-04-18 10:52:13 +02:00
s_gamma.on_changed(update)
s_mcode.on_changed(update)
s_wp.on_changed(update)
s_wt.on_changed(update)
2021-03-29 15:02:46 +02:00
s_mode.on_changed(update)
2019-04-18 10:52:13 +02:00
plt.show()