1
0
Fork 0
mirror of https://github.com/Findus23/collision-analyisis-and-interpolation.git synced 2024-09-19 15:13:50 +02:00
collision-analyisis-and-int.../sliders_nn.py

80 lines
2.3 KiB
Python
Raw Normal View History

2019-04-18 10:52:13 +02:00
import matplotlib.pyplot as plt
import numpy as np
from keras.engine.saving import load_model
from matplotlib.collections import QuadMesh
from matplotlib.widgets import Slider
from sklearn.preprocessing import StandardScaler
from simulation_list import SimulationList
simlist = SimulationList.jsonlines_load()
2019-07-06 15:50:16 +02:00
X = simlist.X
2019-04-18 10:52:13 +02:00
scaler = StandardScaler()
scaler.fit(X)
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.35)
t = np.arange(0.0, 1.0, 0.001)
2019-07-06 15:50:16 +02:00
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
2019-04-18 10:52:13 +02:00
xrange = np.linspace(-0.5, 60.5, 100)
yrange = np.linspace(0.5, 5.5, 100)
xgrid, ygrid = np.meshgrid(xrange, yrange)
mcode = 24.
2019-07-06 15:50:16 +02:00
wpcode = 15 / 100
wtcode = 15 / 100
2019-04-18 10:52:13 +02:00
gammacode = 1
2019-07-06 15:50:16 +02:00
testinput = np.array([[np.nan, np.nan, mcode, gammacode, wtcode, wpcode]] * 100 * 100)
testinput[::, 0] = xgrid.flatten()
testinput[::, 1] = ygrid.flatten()
2019-04-18 10:52:13 +02:00
testinput = scaler.transform(testinput)
model = load_model("model.hd5")
testoutput = model.predict(testinput)
outgrid = np.reshape(testoutput, (100, 100))
mesh = plt.pcolormesh(xgrid, ygrid, outgrid, cmap="Blues", vmin=0, vmax=1) # type:QuadMesh
plt.colorbar()
axcolor = 'lightgoldenrodyellow'
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03], facecolor=axcolor)
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03], facecolor=axcolor)
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
s_wt = Slider(ax_wt, 'wt', 10, 20, valinit=wt_default)
s_wp = Slider(ax_wp, 'wp', 10, 20, valinit=wp_default)
def update(val):
2019-07-06 15:50:16 +02:00
mcode = 10 ** s_mcode.val
2019-04-18 10:52:13 +02:00
gamma = s_gamma.val
2019-07-06 15:50:16 +02:00
wt = s_wt.val / 100
wp = s_wp.val / 100
testinput = np.array([[np.nan, np.nan, mcode, gamma, wt, wp]] * 100 * 100)
testinput[::, 0] = xgrid.flatten()
testinput[::, 1] = ygrid.flatten()
2019-04-18 10:52:13 +02:00
testinput = scaler.transform(testinput)
testoutput = model.predict(testinput)
outgrid = np.reshape(testoutput, (100, 100))
# if not isinstance(datagrid, np.ndarray):
# return False
formatedgrid = outgrid[:-1, :-1]
mesh.set_array(formatedgrid.ravel())
fig.canvas.draw_idle()
s_gamma.on_changed(update)
s_mcode.on_changed(update)
s_wp.on_changed(update)
s_wt.on_changed(update)
plt.show()