1
0
Fork 0
mirror of https://github.com/Findus23/collision-analyisis-and-interpolation.git synced 2024-09-19 15:13:50 +02:00
collision-analyisis-and-int.../sliders.py
2021-10-12 15:45:43 +02:00

85 lines
2.5 KiB
Python

from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import QuadMesh
from matplotlib.widgets import Slider, Button
from CustomScaler import CustomScaler
from interpolators.rbf import RbfInterpolator
from simulation_list import SimulationList
simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl"))
resolution = 100
data = simlist.X
values = simlist.Y_mass_fraction
scaler = CustomScaler()
scaler.fit(data)
scaled_data = scaler.transform_data(data)
interpolator = RbfInterpolator(scaled_data, values)
alpharange = np.linspace(-0.5, 60.5, resolution)
vrange = np.linspace(0.5, 5.5, resolution)
grid_alpha, grid_v = np.meshgrid(alpharange, vrange)
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.35)
t = np.arange(0.0, 1.0, 0.001)
mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0]
datagrid = np.zeros_like(grid_alpha)
mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="nearest") # type:QuadMesh
plt.colorbar()
# axcolor = 'lightgoldenrodyellow'
ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03])
ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03])
ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03])
ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03])
buttonax = plt.axes([0.8, 0.025, 0.1, 0.04])
button = Button(buttonax, 'Update', hovercolor='0.975')
# thetext = ax.text(-10, 0, "hello", fontsize=12) #type:Text
s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default)
s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default)
s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default)
s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default)
def update(val):
print("start updating")
# thetext.set_text("updating")
# fig.canvas.draw()
mcode = s_mcode.val
gamma = s_gamma.val
wt = s_wt.val
wp = s_wp.val
parameters = [grid_alpha, grid_v, 10 ** mcode, gamma, wt, wp]
scaled_parameters = list(scaler.transform_parameters(parameters))
datagrid = interpolator.interpolate(*scaled_parameters)
print(datagrid)
print(np.isnan(datagrid).sum() / (resolution * resolution))
if not isinstance(datagrid, np.ndarray):
return False
mesh.set_array(datagrid.ravel())
print("finished updating")
# thetext.set_text("finished")
fig.canvas.draw_idle()
update(None)
# s_gamma.on_changed(update)
# s_mcode.on_changed(update)
# s_wp.on_changed(update)
# s_wt.on_changed(update)
button.on_clicked(update)
plt.show()