mirror of
https://github.com/Findus23/collision-analyisis-and-interpolation.git
synced 2024-09-19 15:13:50 +02:00
add sliders and graph
This commit is contained in:
parent
e6089c434c
commit
bf164d3cba
3 changed files with 121 additions and 1 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -6,4 +6,5 @@ __pycache__/
|
|||
*.png
|
||||
logs/
|
||||
.ipynb_checkpoints/
|
||||
*.pdf
|
||||
*.pdf
|
||||
*.npy
|
||||
|
|
34
graph.dot
Normal file
34
graph.dot
Normal file
|
@ -0,0 +1,34 @@
|
|||
# https://gist.github.com/thigm85/5760134
|
||||
digraph G {
|
||||
rankdir=LR
|
||||
splines=line
|
||||
nodesep=.05;
|
||||
bgcolor="transparent";
|
||||
penwidth = 0
|
||||
|
||||
node [label=""];
|
||||
|
||||
subgraph cluster_0 {
|
||||
color=white;
|
||||
node [style=solid,color=blue4, shape=circle];
|
||||
x1 x2 x3 x4 x5 x6;
|
||||
label = "input layer";
|
||||
}
|
||||
|
||||
subgraph cluster_1 {
|
||||
color=white;
|
||||
node [style=solid,color=red2, shape=circle];
|
||||
a12 a22 a32 a42;
|
||||
label = "hidden layer";
|
||||
}
|
||||
|
||||
subgraph cluster_2 {
|
||||
color=white;
|
||||
node [style=solid,color=seagreen2, shape=circle];
|
||||
O1;
|
||||
label="output layer";
|
||||
}
|
||||
|
||||
{x1; x2; x3; x4; x5; x6} -> {a12;a22;a32;a42};
|
||||
{a12; a22; a32; a42} -> O1
|
||||
}
|
85
sliders_nn.py
Normal file
85
sliders_nn.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from keras.engine.saving import load_model
|
||||
from matplotlib.collections import QuadMesh
|
||||
from matplotlib.widgets import Slider
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from simulation_list import SimulationList
|
||||
|
||||
simlist = SimulationList.jsonlines_load()
|
||||
train_data = simlist.simlist
|
||||
|
||||
X = np.array([[s.mcode, s.wpcode, s.wtcode, s.gammacode, s.alphacode, s.vcode] for s in train_data])
|
||||
scaler = StandardScaler()
|
||||
scaler.fit(X)
|
||||
|
||||
|
||||
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
plt.subplots_adjust(bottom=0.35)
|
||||
t = np.arange(0.0, 1.0, 0.001)
|
||||
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 10.0, 10.0]
|
||||
|
||||
|
||||
xrange = np.linspace(-0.5, 60.5, 100)
|
||||
yrange = np.linspace(0.5, 5.5, 100)
|
||||
xgrid, ygrid = np.meshgrid(xrange, yrange)
|
||||
mcode = 24.
|
||||
wpcode = 10
|
||||
wtcode = 10
|
||||
gammacode = 1
|
||||
|
||||
testinput = np.array([[mcode, wpcode, wtcode, gammacode, np.nan, np.nan]] * 100 * 100)
|
||||
testinput[::, 4] = xgrid.flatten()
|
||||
testinput[::, 5] = ygrid.flatten()
|
||||
testinput = scaler.transform(testinput)
|
||||
|
||||
model = load_model("model.hd5")
|
||||
|
||||
testoutput = model.predict(testinput)
|
||||
outgrid = np.reshape(testoutput, (100, 100))
|
||||
mesh = plt.pcolormesh(xgrid, ygrid, outgrid, cmap="Blues", vmin=0, vmax=1) # type:QuadMesh
|
||||
plt.colorbar()
|
||||
|
||||
|
||||
axcolor = 'lightgoldenrodyellow'
|
||||
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
|
||||
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
|
||||
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03], facecolor=axcolor)
|
||||
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03], facecolor=axcolor)
|
||||
|
||||
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
|
||||
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
|
||||
s_wt = Slider(ax_wt, 'wt', 10, 20, valinit=wt_default)
|
||||
s_wp = Slider(ax_wp, 'wp', 10, 20, valinit=wp_default)
|
||||
|
||||
|
||||
def update(val):
|
||||
mcode = s_mcode.val
|
||||
gamma = s_gamma.val
|
||||
wt = s_wt.val
|
||||
wp = s_wp.val
|
||||
testinput = np.array([[mcode, wp, wt, gamma, np.nan, np.nan]] * 100 * 100)
|
||||
testinput[::, 4] = xgrid.flatten()
|
||||
testinput[::, 5] = ygrid.flatten()
|
||||
testinput = scaler.transform(testinput)
|
||||
|
||||
testoutput = model.predict(testinput)
|
||||
outgrid = np.reshape(testoutput, (100, 100))
|
||||
# if not isinstance(datagrid, np.ndarray):
|
||||
# return False
|
||||
formatedgrid = outgrid[:-1, :-1]
|
||||
|
||||
mesh.set_array(formatedgrid.ravel())
|
||||
|
||||
fig.canvas.draw_idle()
|
||||
|
||||
|
||||
s_gamma.on_changed(update)
|
||||
s_mcode.on_changed(update)
|
||||
s_wp.on_changed(update)
|
||||
s_wt.on_changed(update)
|
||||
|
||||
plt.show()
|
Loading…
Reference in a new issue