1
0
Fork 0
This repository has been archived on 2024-06-28. You can view files and clone it, but cannot push or open issues or pull requests.
collision-analysis-and-inte.../sliders_nn.py
2021-10-12 15:45:43 +02:00

84 lines
2.4 KiB
Python

import json
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.collections import QuadMesh
from matplotlib.widgets import Slider
from CustomScaler import CustomScaler
from network import Network
resolution = 100
with open("pytorch_model.json") as f:
data = json.load(f)
scaler = CustomScaler()
scaler.means = np.array(data["means"])
scaler.stds = np.array(data["stds"])
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.35)
t = np.arange(0.0, 1.0, 0.001)
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
alpharange = np.linspace(0, 60, resolution)
vrange = np.linspace(0.5, 5.5, resolution)
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
model = Network()
model.load_state_dict(torch.load("pytorch_model.zip"))
datagrid = np.zeros_like(grid_alpha)
mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="auto") # type:QuadMesh
plt.colorbar()
axcolor = 'lightgoldenrodyellow'
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03])
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03])
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03])
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03])
ax_mode = plt.axes([0.25, 0.05, 0.65, 0.03])
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default)
s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default)
s_mode = Slider(ax_mode, 'shell/mantle/core/mass_fraction', 1, 4, valinit=1, valstep=1)
def update(val):
mcode = s_mcode.val
gamma = s_gamma.val
wt = s_wt.val
wp = s_wp.val
mode = s_mode.val
testinput = np.array([[np.nan, np.nan, 10 ** mcode, gamma, wt, wp]] * resolution * resolution)
testinput[::, 0] = grid_alpha.flatten()
testinput[::, 1] = grid_v.flatten()
testinput = scaler.transform_data(testinput)
try:
testoutput: torch.Tensor = model(torch.from_numpy(testinput).to(torch.float))
data = testoutput.detach().numpy()
print(data.shape)
except TypeError: # can't convert np.ndarray of type numpy.object_.
data = np.zeros((resolution ** 2, 3))
datagrid = np.reshape(data[::, mode - 1], (resolution, resolution))
mesh.set_array(datagrid.ravel())
fig.canvas.draw_idle()
update(None)
s_gamma.on_changed(update)
s_mcode.on_changed(update)
s_wp.on_changed(update)
s_wt.on_changed(update)
s_mode.on_changed(update)
plt.show()