From bf164d3cbaf7a3d2a9d88cfcd1cac7605cf58620 Mon Sep 17 00:00:00 2001 From: Lukas Winkler Date: Thu, 18 Apr 2019 10:52:13 +0200 Subject: [PATCH] add sliders and graph --- .gitignore | 3 +- graph.dot | 34 +++++++++++++++++++++ sliders_nn.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 graph.dot create mode 100644 sliders_nn.py diff --git a/.gitignore b/.gitignore index d1f85bb..e7c84ed 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ __pycache__/ *.png logs/ .ipynb_checkpoints/ -*.pdf \ No newline at end of file +*.pdf +*.npy diff --git a/graph.dot b/graph.dot new file mode 100644 index 0000000..1481964 --- /dev/null +++ b/graph.dot @@ -0,0 +1,34 @@ +# https://gist.github.com/thigm85/5760134 +digraph G { + rankdir=LR + splines=line + nodesep=.05; + bgcolor="transparent"; + penwidth = 0 + + node [label=""]; + + subgraph cluster_0 { + color=white; + node [style=solid,color=blue4, shape=circle]; + x1 x2 x3 x4 x5 x6; + label = "input layer"; + } + + subgraph cluster_1 { + color=white; + node [style=solid,color=red2, shape=circle]; + a12 a22 a32 a42; + label = "hidden layer"; + } + + subgraph cluster_2 { + color=white; + node [style=solid,color=seagreen2, shape=circle]; + O1; + label="output layer"; + } + + {x1; x2; x3; x4; x5; x6} -> {a12;a22;a32;a42}; + {a12; a22; a32; a42} -> O1 +} diff --git a/sliders_nn.py b/sliders_nn.py new file mode 100644 index 0000000..ba32b42 --- /dev/null +++ b/sliders_nn.py @@ -0,0 +1,85 @@ +import matplotlib.pyplot as plt +import numpy as np +from keras.engine.saving import load_model +from matplotlib.collections import QuadMesh +from matplotlib.widgets import Slider +from sklearn.preprocessing import StandardScaler + +from simulation_list import SimulationList + +simlist = SimulationList.jsonlines_load() +train_data = simlist.simlist + +X = np.array([[s.mcode, s.wpcode, s.wtcode, s.gammacode, s.alphacode, s.vcode] for s in train_data]) +scaler = StandardScaler() +scaler.fit(X) + + + + +fig, ax = plt.subplots() +plt.subplots_adjust(bottom=0.35) +t = np.arange(0.0, 1.0, 0.001) +mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 10.0, 10.0] + + +xrange = np.linspace(-0.5, 60.5, 100) +yrange = np.linspace(0.5, 5.5, 100) +xgrid, ygrid = np.meshgrid(xrange, yrange) +mcode = 24. +wpcode = 10 +wtcode = 10 +gammacode = 1 + +testinput = np.array([[mcode, wpcode, wtcode, gammacode, np.nan, np.nan]] * 100 * 100) +testinput[::, 4] = xgrid.flatten() +testinput[::, 5] = ygrid.flatten() +testinput = scaler.transform(testinput) + +model = load_model("model.hd5") + +testoutput = model.predict(testinput) +outgrid = np.reshape(testoutput, (100, 100)) +mesh = plt.pcolormesh(xgrid, ygrid, outgrid, cmap="Blues", vmin=0, vmax=1) # type:QuadMesh +plt.colorbar() + + +axcolor = 'lightgoldenrodyellow' +ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor) +ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor) +ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03], facecolor=axcolor) +ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03], facecolor=axcolor) + +s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default) +s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default) +s_wt = Slider(ax_wt, 'wt', 10, 20, valinit=wt_default) +s_wp = Slider(ax_wp, 'wp', 10, 20, valinit=wp_default) + + +def update(val): + mcode = s_mcode.val + gamma = s_gamma.val + wt = s_wt.val + wp = s_wp.val + testinput = np.array([[mcode, wp, wt, gamma, np.nan, np.nan]] * 100 * 100) + testinput[::, 4] = xgrid.flatten() + testinput[::, 5] = ygrid.flatten() + testinput = scaler.transform(testinput) + + testoutput = model.predict(testinput) + outgrid = np.reshape(testoutput, (100, 100)) + # if not isinstance(datagrid, np.ndarray): + # return False + formatedgrid = outgrid[:-1, :-1] + + mesh.set_array(formatedgrid.ravel()) + + fig.canvas.draw_idle() + + +s_gamma.on_changed(update) +s_mcode.on_changed(update) +s_wp.on_changed(update) +s_wt.on_changed(update) + +plt.show()