1
0
Fork 0
mirror of https://github.com/Findus23/collision-analyisis-and-interpolation.git synced 2024-09-18 14:03:51 +02:00
collision-analyisis-and-int.../sliders_nn.py
2021-10-12 15:45:43 +02:00

84 lines
2.4 KiB
Python

import json
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.collections import QuadMesh
from matplotlib.widgets import Slider
from CustomScaler import CustomScaler
from network import Network
resolution = 100
with open("pytorch_model.json") as f:
data = json.load(f)
scaler = CustomScaler()
scaler.means = np.array(data["means"])
scaler.stds = np.array(data["stds"])
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.35)
t = np.arange(0.0, 1.0, 0.001)
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
alpharange = np.linspace(0, 60, resolution)
vrange = np.linspace(0.5, 5.5, resolution)
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
model = Network()
model.load_state_dict(torch.load("pytorch_model.zip"))
datagrid = np.zeros_like(grid_alpha)
mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="auto") # type:QuadMesh
plt.colorbar()
axcolor = 'lightgoldenrodyellow'
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03])
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03])
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03])
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03])
ax_mode = plt.axes([0.25, 0.05, 0.65, 0.03])
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default)
s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default)
s_mode = Slider(ax_mode, 'shell/mantle/core/mass_fraction', 1, 4, valinit=1, valstep=1)
def update(val):
mcode = s_mcode.val
gamma = s_gamma.val
wt = s_wt.val
wp = s_wp.val
mode = s_mode.val
testinput = np.array([[np.nan, np.nan, 10 ** mcode, gamma, wt, wp]] * resolution * resolution)
testinput[::, 0] = grid_alpha.flatten()
testinput[::, 1] = grid_v.flatten()
testinput = scaler.transform_data(testinput)
try:
testoutput: torch.Tensor = model(torch.from_numpy(testinput).to(torch.float))
data = testoutput.detach().numpy()
print(data.shape)
except TypeError: # can't convert np.ndarray of type numpy.object_.
data = np.zeros((resolution ** 2, 3))
datagrid = np.reshape(data[::, mode - 1], (resolution, resolution))
mesh.set_array(datagrid.ravel())
fig.canvas.draw_idle()
update(None)
s_gamma.on_changed(update)
s_mcode.on_changed(update)
s_wp.on_changed(update)
s_wt.on_changed(update)
s_mode.on_changed(update)
plt.show()