from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.collections import QuadMesh from matplotlib.widgets import Slider from CustomScaler import CustomScaler from network import Network from simulation_list import SimulationList simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl")) resolution = 100 data = simlist.X scaler = CustomScaler() scaler.fit(data) fig, ax = plt.subplots() plt.subplots_adjust(bottom=0.35) t = np.arange(0.0, 1.0, 0.001) mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0] alpharange = np.linspace(-0.5, 60.5, resolution) vrange = np.linspace(0.5, 5.5, resolution) grid_alpha, grid_v = np.meshgrid(alpharange, vrange) model = Network() model.load_state_dict(torch.load("pytorch_model.zip")) datagrid = np.zeros_like(grid_alpha) mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="auto") # type:QuadMesh plt.colorbar() axcolor = 'lightgoldenrodyellow' ax_mcode = plt.axes([0.25, 0.1, 0.65, 0.03]) ax_gamma = plt.axes([0.25, 0.15, 0.65, 0.03]) ax_wt = plt.axes([0.25, 0.20, 0.65, 0.03]) ax_wp = plt.axes([0.25, 0.25, 0.65, 0.03]) ax_mode = plt.axes([0.25, 0.05, 0.65, 0.03]) s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default) s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default) s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default) s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default) s_mode = Slider(ax_mode, 'shell/mantle/core/mass_fraction', 1, 4, valinit=1, valstep=1) def update(val): mcode = s_mcode.val gamma = s_gamma.val wt = s_wt.val wp = s_wp.val mode = s_mode.val testinput = np.array([[np.nan, np.nan, 10 ** mcode, gamma, wt, wp]] * resolution * resolution) testinput[::, 0] = grid_alpha.flatten() testinput[::, 1] = grid_v.flatten() testinput = scaler.transform_data(testinput) try: testoutput: torch.Tensor = model(torch.from_numpy(testinput).to(torch.float)) data = testoutput.detach().numpy() print(data.shape) except TypeError: # can't convert np.ndarray of type numpy.object_. data = np.zeros((resolution ** 2, 3)) datagrid = np.reshape(data[::, mode - 1], (resolution, resolution)) mesh.set_array(datagrid.ravel()) fig.canvas.draw_idle() update(None) s_gamma.on_changed(update) s_mcode.on_changed(update) s_wp.on_changed(update) s_wt.on_changed(update) s_mode.on_changed(update) plt.show()