1
0
Fork 0
mirror of https://github.com/Findus23/collision-analyisis-and-interpolation.git synced 2024-09-19 15:13:50 +02:00

add sliders and graph

This commit is contained in:
Lukas Winkler 2019-04-18 10:52:13 +02:00
parent e6089c434c
commit bf164d3cba
Signed by: lukas
GPG key ID: 54DE4D798D244853
3 changed files with 121 additions and 1 deletions

3
.gitignore vendored
View file

@ -6,4 +6,5 @@ __pycache__/
*.png *.png
logs/ logs/
.ipynb_checkpoints/ .ipynb_checkpoints/
*.pdf *.pdf
*.npy

34
graph.dot Normal file
View file

@ -0,0 +1,34 @@
# https://gist.github.com/thigm85/5760134
digraph G {
rankdir=LR
splines=line
nodesep=.05;
bgcolor="transparent";
penwidth = 0
node [label=""];
subgraph cluster_0 {
color=white;
node [style=solid,color=blue4, shape=circle];
x1 x2 x3 x4 x5 x6;
label = "input layer";
}
subgraph cluster_1 {
color=white;
node [style=solid,color=red2, shape=circle];
a12 a22 a32 a42;
label = "hidden layer";
}
subgraph cluster_2 {
color=white;
node [style=solid,color=seagreen2, shape=circle];
O1;
label="output layer";
}
{x1; x2; x3; x4; x5; x6} -> {a12;a22;a32;a42};
{a12; a22; a32; a42} -> O1
}

85
sliders_nn.py Normal file
View file

@ -0,0 +1,85 @@
import matplotlib.pyplot as plt
import numpy as np
from keras.engine.saving import load_model
from matplotlib.collections import QuadMesh
from matplotlib.widgets import Slider
from sklearn.preprocessing import StandardScaler
from simulation_list import SimulationList
simlist = SimulationList.jsonlines_load()
train_data = simlist.simlist
X = np.array([[s.mcode, s.wpcode, s.wtcode, s.gammacode, s.alphacode, s.vcode] for s in train_data])
scaler = StandardScaler()
scaler.fit(X)
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.35)
t = np.arange(0.0, 1.0, 0.001)
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 10.0, 10.0]
xrange = np.linspace(-0.5, 60.5, 100)
yrange = np.linspace(0.5, 5.5, 100)
xgrid, ygrid = np.meshgrid(xrange, yrange)
mcode = 24.
wpcode = 10
wtcode = 10
gammacode = 1
testinput = np.array([[mcode, wpcode, wtcode, gammacode, np.nan, np.nan]] * 100 * 100)
testinput[::, 4] = xgrid.flatten()
testinput[::, 5] = ygrid.flatten()
testinput = scaler.transform(testinput)
model = load_model("model.hd5")
testoutput = model.predict(testinput)
outgrid = np.reshape(testoutput, (100, 100))
mesh = plt.pcolormesh(xgrid, ygrid, outgrid, cmap="Blues", vmin=0, vmax=1) # type:QuadMesh
plt.colorbar()
axcolor = 'lightgoldenrodyellow'
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03], facecolor=axcolor)
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03], facecolor=axcolor)
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
s_wt = Slider(ax_wt, 'wt', 10, 20, valinit=wt_default)
s_wp = Slider(ax_wp, 'wp', 10, 20, valinit=wp_default)
def update(val):
mcode = s_mcode.val
gamma = s_gamma.val
wt = s_wt.val
wp = s_wp.val
testinput = np.array([[mcode, wp, wt, gamma, np.nan, np.nan]] * 100 * 100)
testinput[::, 4] = xgrid.flatten()
testinput[::, 5] = ygrid.flatten()
testinput = scaler.transform(testinput)
testoutput = model.predict(testinput)
outgrid = np.reshape(testoutput, (100, 100))
# if not isinstance(datagrid, np.ndarray):
# return False
formatedgrid = outgrid[:-1, :-1]
mesh.set_array(formatedgrid.ravel())
fig.canvas.draw_idle()
s_gamma.on_changed(update)
s_mcode.on_changed(update)
s_wp.on_changed(update)
s_wt.on_changed(update)
plt.show()