From ff0dbd83b2386e2d948cda3ba82e3496ad56ca2e Mon Sep 17 00:00:00 2001 From: Lukas Winkler Date: Tue, 12 Oct 2021 15:45:43 +0200 Subject: [PATCH] improve network --- cli.py | 9 ++----- config.py | 2 +- cov.py | 10 +++++--- example.py | 11 +++++--- network.py | 4 +-- neural_network.py | 59 ++++++++++-------------------------------- nn_single.py | 27 +++++++++++++++++++ pyproject.toml | 14 +++++----- readfiles.py | 25 +++++++----------- simulation.py | 43 +++++++++++++++++++++++-------- simulation_list.py | 4 +++ sliders.py | 22 ++++++++-------- sliders_nn.py | 14 +++++----- visualize.py | 30 +++++++++++++++++----- visualize_nn.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 15 files changed, 217 insertions(+), 121 deletions(-) create mode 100644 nn_single.py create mode 100644 visualize_nn.py diff --git a/cli.py b/cli.py index 6efdffe..1e39ab0 100644 --- a/cli.py +++ b/cli.py @@ -76,7 +76,6 @@ if gamma > 1: alpha = clamp(alpha, 0, 60) velocity = clamp(velocity, 1, 5) - m_ceres = 9.393e+20 m_earth = 5.9722e+24 projectile_mass = clamp(projectile_mass, 2 * m_ceres, 2 * m_earth) @@ -88,14 +87,10 @@ scaler.fit(simulations.X) scaled_data = scaler.transform_data(simulations.X) water_interpolator = RbfInterpolator(scaled_data, simulations.Y_water) -mass_interpolator = RbfInterpolator(scaled_data, simulations.Y_mass) +mass_interpolator = RbfInterpolator(scaled_data, simulations.Y_mantle) -testinput = [alpha, velocity, projectile_mass, gamma, - target_water_fraction, projectile_water_fraction] +testinput = [32, 1, 7.6e22, 0.16, 0.15, 0.15] -with open("xl3", "w") as f: - f.write("# alpha velocity projectile_mass gamma target_water_fraction projectile_water_fraction\n") - f.write(" ".join(map(str, testinput))) scaled_input = list(scaler.transform_parameters(testinput)) water_retention = water_interpolator.interpolate(*scaled_input) diff --git a/config.py b/config.py index 5905617..18dc6ba 100644 --- a/config.py +++ b/config.py @@ -1 +1 @@ -water_fraction = False +water_fraction = True diff --git a/cov.py b/cov.py index 27e579b..ed44237 100644 --- a/cov.py +++ b/cov.py @@ -1,10 +1,14 @@ import numpy as np from matplotlib import pyplot as plt from matplotlib.axes import Axes +from pathlib import Path from simulation_list import SimulationList -simulations = SimulationList.jsonlines_load() +plt.style.use('dark_background') + + +simulations = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl")) np.set_printoptions(linewidth=1000, edgeitems=4) x = simulations.as_matrix @@ -35,12 +39,12 @@ plt.close() ax = plt.gca() # type:Axes print(len(labels), len(simple_cov)) print(simple_cov) -plt.barh(range(len(simple_cov)), simple_cov) +plt.barh(range(len(simple_cov)), simple_cov,color="#B3DE69") # ax.set_xticks(index + bar_width / 2) ax.set_yticklabels([0] + labels) ax2 = ax.twinx() # type:Axes ax2.set_yticklabels([0] + ["({:.2f})".format(a) for a in simple_cov]) ax2.set_ylim(ax.get_ylim()) plt.tight_layout() -plt.savefig("../arbeit/images/cov.pdf") +plt.savefig("../arbeit/images/cov.pdf",transparent=True) plt.show() diff --git a/example.py b/example.py index fb6ae20..e851420 100644 --- a/example.py +++ b/example.py @@ -1,14 +1,17 @@ """ Just a demo file on how to quickly read the dataset """ +from pathlib import Path +from pprint import pprint from simulation_list import SimulationList -simlist = SimulationList.jsonlines_load() + +simlist = SimulationList.jsonlines_load(Path("save.jsonl")) for s in simlist.simlist: if not s.testcase: continue - print(vars(s)) - if s.water_retention_both < 0: - print(s.runid) + if s.water_retention_both < 0.2: + pprint(vars(s)) + print(s.water_retention_both) diff --git a/network.py b/network.py index 529ae02..3fccf28 100644 --- a/network.py +++ b/network.py @@ -4,8 +4,8 @@ from torch import nn class Network(nn.Module): def __init__(self): super().__init__() - self.hidden = nn.Linear(6, 50) - self.output = nn.Linear(50, 4) + self.hidden = nn.Linear(6, 70) + self.output = nn.Linear(70, 4) self.sigmoid = nn.Sigmoid() self.relu = nn.ReLU() diff --git a/neural_network.py b/neural_network.py index e221410..92b095c 100644 --- a/neural_network.py +++ b/neural_network.py @@ -73,7 +73,7 @@ def train(): loss_train = [] loss_vali = [] - max_epochs = 500 + max_epochs = 200 epochs = 0 fig: Figure = plt.figure() @@ -128,6 +128,16 @@ def train(): # if a > b: # overfitting on training data, stop training # print("early stopping") # break + np.savetxt("loss.txt", np.array([x_axis, loss_train, loss_vali]).T) + torch.save(network.state_dict(), "pytorch_model.zip") + with open("pytorch_model.json", "w") as f: + export_dict = {} + value_tensor: Tensor + for key, value_tensor in network.state_dict().items(): + export_dict[key] = value_tensor.detach().tolist() + export_dict["means"] = scaler.means.tolist() + export_dict["stds"] = scaler.stds.tolist() + json.dump(export_dict, f) plt.ioff() model_test_y = [] for x in x_test: @@ -138,53 +148,10 @@ def train(): plt.figure() plt.xlabel("model output") plt.ylabel("real data") - for i, name in enumerate(["shell","mantle","core","mass fraction"]): - plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2,label=name) + for i, name in enumerate(["shell", "mantle", "core", "mass fraction"]): + plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2, label=name) plt.legend() plt.show() - torch.save(network.state_dict(), "pytorch_model.zip") - with open("pytorch_model.json", "w") as f: - export_dict = {} - value_tensor: Tensor - for key, value_tensor in network.state_dict().items(): - export_dict[key] = value_tensor.detach().tolist() - export_dict["means"] = scaler.means.tolist() - export_dict["stds"] = scaler.stds.tolist() - json.dump(export_dict, f) - - xrange = np.linspace(-0.5, 60.5, 300) - yrange = np.linspace(0.5, 5.5, 300) - xgrid, ygrid = np.meshgrid(xrange, yrange) - mcode = 1e24 - wpcode = 1e-4 - - wtcode = 1e-4 - gammacode = 0.6 - testinput = np.array([[np.nan, np.nan, mcode, gammacode, wtcode, wpcode]] * 300 * 300) - testinput[::, 0] = xgrid.flatten() - testinput[::, 1] = ygrid.flatten() - testinput = scaler.transform_data(testinput) - - print(testinput) - print(testinput.shape) - testoutput: Tensor = network(from_numpy(testinput).to(torch.float)) - data = testoutput.detach().numpy() - outgrid = np.reshape(data[::, 0], (300, 300)) - print("minmax") - print(np.nanmin(outgrid), np.nanmax(outgrid)) - cmap = "Blues" - plt.figure() - plt.title( - "m={:3.0e}, gamma={:3.1f}, wt={:2.0f}%, wp={:2.0f}%\n".format(mcode, gammacode, wtcode * 100, wpcode * 100)) - plt.imshow(outgrid, interpolation='none', cmap=cmap, aspect="auto", origin="lower", vmin=0, vmax=1, - extent=[xgrid.min(), xgrid.max(), ygrid.min(), ygrid.max()]) - - plt.colorbar().set_label("water retention fraction") - plt.xlabel("impact angle $\\alpha$ [$^{\circ}$]") - plt.ylabel("velocity $v$ [$v_{esc}$]") - plt.tight_layout() - # plt.savefig("/home/lukas/tmp/nn.svg", transparent=True) - plt.show() if __name__ == '__main__': diff --git a/nn_single.py b/nn_single.py new file mode 100644 index 0000000..8cb50cb --- /dev/null +++ b/nn_single.py @@ -0,0 +1,27 @@ +import json +from pathlib import Path + +import numpy as np +import torch + +from CustomScaler import CustomScaler +from network import Network +from simulation_list import SimulationList + +resolution = 100 + +with open("pytorch_model.json") as f: + data = json.load(f) + scaler = CustomScaler() + scaler.means = np.array(data["means"]) + scaler.stds = np.array(data["stds"]) + +model = Network() +model.load_state_dict(torch.load("pytorch_model.zip")) + +ang = 30 +v = 2 +m = 1e24 +gamma = 0.6 +wp = wt = 1e-4 +print(model(torch.Tensor(list(scaler.transform_parameters([ang, v, m, gamma, wt, wp]))))) diff --git a/pyproject.toml b/pyproject.toml index 5b6c2e7..fe7cf7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,15 +6,15 @@ authors = ["Lukas Winkler "] [tool.poetry.dependencies] python = "^3.8" -scipy = "^1.5.4" -numpy = "^1.19.4" -matplotlib = "^3.3.3" +scipy = "^1.6.1" +numpy = "^1.19.0" +matplotlib = "^3.3.4" tabulate = "^0.8.7" +#Keras = "^2.4.3" +#tensorflow = "^2.3.1" +pydot = "^1.4.1" - - -[tool.poetry.dev-dependencies] -PyQt5 = "5.12.1" +#"vext.pyqt5" = "^0.7.4" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/readfiles.py b/readfiles.py index 6b30bb7..6d39930 100644 --- a/readfiles.py +++ b/readfiles.py @@ -1,4 +1,3 @@ -from glob import glob from os import path from pathlib import Path @@ -6,33 +5,27 @@ from simulation import Simulation from simulation_list import SimulationList simulation_sets = { - "original": sorted(glob("../data/*")), - "cloud": sorted(glob("../../Bachelorarbeit_data/results/*")) + "winter": sorted(Path("../../tmp/winter/").glob("*")) + # "original": sorted(glob("../data/*")), + # "cloud": sorted(glob("../../Bachelorarbeit_data/results/*")) # "benchmark": sorted(glob("../../Bachelorarbeit_benchmark/results/*")) } simulations = SimulationList() for set_type, directories in simulation_sets.items(): for dir in directories: - original = set_type == "original" - spheres_file = dir + "/spheres_ini_log" - timings_file = dir + "/pythontimings.json" - aggregates_file = dir + ("/sim/aggregates.txt" if original else "/aggregates.txt") + print(dir) + spheres_file = dir / "spheres_ini.log" + aggregates_file = sorted(dir.glob("frames/aggregates.*"))[-1] if not path.exists(spheres_file) or not path.exists(aggregates_file): print(f"skipping {dir}") continue - if "id" not in dir and original: - continue sim = Simulation() - if set_type == "original": - sim.load_params_from_dirname(path.basename(dir)) - else: - sim.load_params_from_json(dir + "/parameters.json") + sim.load_params_from_setup_txt(dir / "cfg.txt") sim.type = set_type sim.load_params_from_spheres_ini_log(spheres_file) sim.load_params_from_aggregates_txt(aggregates_file) - sim.load_params_from_pythontiming_json(timings_file) - sim.assert_all_loaded() + # sim.assert_all_loaded() if sim.rel_velocity < 0 or sim.distance < 0: # Sometimes in the old dataset the second object wasn't detected. # To be save, we'll exclude them @@ -52,4 +45,4 @@ for set_type, directories in simulation_sets.items(): print(len(simulations.simlist)) -simulations.jsonlines_save(Path("output.jsonl")) +simulations.jsonlines_save(Path("winter.jsonl")) diff --git a/simulation.py b/simulation.py index b993a2d..0350c9f 100644 --- a/simulation.py +++ b/simulation.py @@ -1,6 +1,9 @@ import json +from pathlib import Path from typing import Optional +import yaml + class Simulation: @@ -132,7 +135,10 @@ class Simulation: def output_mass_fraction(self) -> Optional[float]: if not self.largest_aggregate_mass: return 0 # FIXME - return self.second_largest_aggregate_mass / self.largest_aggregate_mass + massive_size_relation=self.largest_aggregate_mass/self.target_mass + less_massive_size_relation=self.second_largest_aggregate_mass/self.projectile_mass + print("mr",massive_size_relation, less_massive_size_relation) + return less_massive_size_relation / massive_size_relation @property def original_simulation(self) -> bool: @@ -149,7 +155,7 @@ class Simulation: ) def __repr__(self): - return f"" + return f"" def load_params_from_dirname(self, dirname: str) -> None: params = dirname.split("_") @@ -173,8 +179,18 @@ class Simulation: self.wtcode = data["wt_code"] self.wpcode = data["wp_code"] - def load_params_from_spheres_ini_log(self, filename: str) -> None: - with open(filename) as f: + def load_params_from_setup_txt(self, file: Path) -> None: + with file.open() as f: + data = yaml.safe_load(f) + # self.runid=data["ID"] + # self.vcode=data["vel_vesc_touching_ball"] + # self.alphacode=data["impact_angle_touching_ball"] + # self.mcode=data["M_tot"] + # self.gammacode=data["gamma"] + # TODO: maybe more needed? + + def load_params_from_spheres_ini_log(self, filename: Path) -> None: + with filename.open() as f: lines = [line.rstrip("\n") for line in f] for i in range(len(lines)): line = lines[i] @@ -191,20 +207,25 @@ class Simulation: self.projectile_water_fraction = float(lines[i + 1].split()[7]) self.target_water_fraction = float(lines[i + 3].split()[7]) if "Particle numbers" in line: - self.desired_N = int(lines[i + 1].split()[3]) + self.desired_N = int(lines[i + 1].split()[4]) self.actual_N = int(lines[i + 1].split()[-1]) - def load_params_from_aggregates_txt(self, filename: str) -> None: - with open(filename) as f: + def load_params_from_aggregates_txt(self, filename: Path) -> None: + with filename.open() as f: lines = [line.rstrip("\n") for line in f] for i in range(len(lines)): line = lines[i] if "# largest aggregate" in line: - self.largest_aggregate_mass = float(lines[i + 2].split()[0]) - self.largest_aggregate_water_fraction = float(lines[i + 2].split()[2]) + cols = lines[i + 2].split() + self.largest_aggregate_mass = float(cols[0]) + self.largest_aggregate_core_fraction = float(cols[1]) + self.largest_aggregate_water_fraction = 1 - float(cols[2]) - self.largest_aggregate_core_fraction if "# 2nd-largest aggregate:" in line: - self.second_largest_aggregate_mass = float(lines[i + 2].split()[0]) - self.second_largest_aggregate_water_fraction = float(lines[i + 2].split()[2]) + cols = lines[i + 2].split() + self.second_largest_aggregate_mass = float(cols[0]) + self.second_largest_aggregate_core_fraction = float(cols[1]) + self.second_largest_aggregate_water_fraction = 1 - float( + cols[2]) - self.second_largest_aggregate_core_fraction if "# distance" in line: # TODO: not sure if correct anymore self.distance = float(lines[i + 1].split()[0]) self.rel_velocity = float(lines[i + 1].split()[1]) diff --git a/simulation_list.py b/simulation_list.py index 181dd17..a0b96cd 100644 --- a/simulation_list.py +++ b/simulation_list.py @@ -78,3 +78,7 @@ class SimulationList: @property def Y_mantle(self): return np.array([s.mantle_retention_both for s in self.simlist if not s.testcase]) + + @property + def Y_mass_fraction(self): + return np.array([s.output_mass_fraction for s in self.simlist if not s.testcase]) diff --git a/sliders.py b/sliders.py index 967a719..92e80e0 100644 --- a/sliders.py +++ b/sliders.py @@ -10,17 +10,18 @@ from interpolators.rbf import RbfInterpolator from simulation_list import SimulationList simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl")) +resolution = 100 data = simlist.X -values = simlist.Y +values = simlist.Y_mass_fraction scaler = CustomScaler() scaler.fit(data) scaled_data = scaler.transform_data(data) interpolator = RbfInterpolator(scaled_data, values) -alpharange = np.linspace(-0.5, 60.5, 100) -vrange = np.linspace(0.5, 5.5, 100) +alpharange = np.linspace(-0.5, 60.5, resolution) +vrange = np.linspace(0.5, 5.5, resolution) grid_alpha, grid_v = np.meshgrid(alpharange, vrange) fig, ax = plt.subplots() @@ -30,7 +31,7 @@ mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0] datagrid = np.zeros_like(grid_alpha) -mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1) # type:QuadMesh +mesh = plt.pcolormesh(grid_alpha, grid_v, datagrid, cmap="Blues", vmin=0, vmax=1, shading="nearest") # type:QuadMesh plt.colorbar() # axcolor = 'lightgoldenrodyellow' @@ -44,8 +45,8 @@ button = Button(buttonax, 'Update', hovercolor='0.975') s_mcode = Slider(ax_mcode, 'mcode', 21, 25, valinit=mcode_default) s_gamma = Slider(ax_gamma, 'gamma', 0.1, 1, valinit=gamma_default) -s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-4, valinit=wt_default) -s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-4, valinit=wp_default) +s_wt = Slider(ax_wt, 'wt', 1e-5, 1e-3, valinit=wt_default) +s_wp = Slider(ax_wp, 'wp', 1e-5, 1e-3, valinit=wp_default) def update(val): @@ -57,23 +58,24 @@ def update(val): gamma = s_gamma.val wt = s_wt.val wp = s_wp.val - parameters = [grid_alpha, grid_v, 10 ** mcode, gamma, wt / 100, wp / 100] + parameters = [grid_alpha, grid_v, 10 ** mcode, gamma, wt, wp] scaled_parameters = list(scaler.transform_parameters(parameters)) datagrid = interpolator.interpolate(*scaled_parameters) print(datagrid) - print(np.isnan(datagrid).sum() / (100 * 100)) + print(np.isnan(datagrid).sum() / (resolution * resolution)) if not isinstance(datagrid, np.ndarray): return False - formatedgrid = datagrid[:-1, :-1] - mesh.set_array(formatedgrid.ravel()) + mesh.set_array(datagrid.ravel()) print("finished updating") # thetext.set_text("finished") fig.canvas.draw_idle() +update(None) + # s_gamma.on_changed(update) # s_mcode.on_changed(update) # s_wp.on_changed(update) diff --git a/sliders_nn.py b/sliders_nn.py index 0b05b6a..63c4216 100644 --- a/sliders_nn.py +++ b/sliders_nn.py @@ -1,4 +1,4 @@ -from pathlib import Path +import json import matplotlib.pyplot as plt import numpy as np @@ -8,21 +8,21 @@ from matplotlib.widgets import Slider from CustomScaler import CustomScaler from network import Network -from simulation_list import SimulationList -simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl")) resolution = 100 -data = simlist.X -scaler = CustomScaler() -scaler.fit(data) +with open("pytorch_model.json") as f: + data = json.load(f) + scaler = CustomScaler() + scaler.means = np.array(data["means"]) + scaler.stds = np.array(data["stds"]) fig, ax = plt.subplots() plt.subplots_adjust(bottom=0.35) t = np.arange(0.0, 1.0, 0.001) mcode_default, gamma_default, wt_default, wp_default = [24.0, 1, 15.0, 15.0] -alpharange = np.linspace(-0.5, 60.5, resolution) +alpharange = np.linspace(0, 60, resolution) vrange = np.linspace(0.5, 5.5, resolution) grid_alpha, grid_v = np.meshgrid(alpharange, vrange) diff --git a/visualize.py b/visualize.py index cfedb39..d9b7772 100644 --- a/visualize.py +++ b/visualize.py @@ -6,15 +6,16 @@ from matplotlib import pyplot as plt, cm from CustomScaler import CustomScaler from config import water_fraction +from interpolators.griddata import GriddataInterpolator from interpolators.rbf import RbfInterpolator +from simulation import Simulation from simulation_list import SimulationList - -# plt.style.use('dark_background') +plt.style.use('dark_background') def main(): - mcode, gamma, wt, wp = [10 ** 21, 0.6, 1e-5, 1e-5] + mcode, gamma, wt, wp = [10 ** 22, 0.6, 1e-5, 1e-5] simlist = SimulationList.jsonlines_load(Path("rsmc_dataset.jsonl")) # for s in simlist.simlist: # if s.type!="original": @@ -26,7 +27,7 @@ def main(): print(len(data)) # print(data[0]) # exit() - values = simlist.Y + values = simlist.Y_mass_fraction scaler = CustomScaler() scaler.fit(data) @@ -34,7 +35,7 @@ def main(): interpolator = RbfInterpolator(scaled_data, values) # interpolator = GriddataInterpolator(scaled_data, values) - alpharange = np.linspace(-0.5, 60.5, 300) + alpharange = np.linspace(0, 60, 300) vrange = np.linspace(0.5, 5.5, 300) grid_alpha, grid_v = np.meshgrid(alpharange, vrange) @@ -53,13 +54,28 @@ def main(): # plt.pcolormesh(grid_alpha, grid_v, grid_result, cmap="Blues", vmin=0, vmax=1) plt.imshow(grid_result, interpolation='none', cmap=cmap, aspect="auto", origin="lower", vmin=0, vmax=1, extent=[grid_alpha.min(), grid_alpha.max(), grid_v.min(), grid_v.max()]) - plt.colorbar().set_label("water retention fraction" if water_fraction else "core mass retention fraction") # plt.scatter(data[:, 0], data[:, 1], c=values, cmap="Blues") plt.xlabel("impact angle $\\alpha$ [$^{\circ}$]") plt.ylabel("velocity $v$ [$v_{esc}$]") + s: Simulation + xs = [] + ys = [] + zs = [] + for s in simlist.simlist: + # if not (0.4 < s.gamma < 0.6) or not (1e23 < s.total_mass < 5e24): + # continue + # if s.alpha < 60 or s.v > 5 or s.v < 2: + # continue + z = s.output_mass_fraction + zs.append(z) + xs.append(s.alpha) + ys.append(s.v) + print(z, s.runid) + plt.scatter(xs, ys, c=zs, cmap=cmap, vmin=0, vmax=1) + plt.colorbar().set_label("stone retention fraction" if water_fraction else "core mass retention fraction") plt.tight_layout() # plt.savefig("vis.png", transparent=True) - plt.savefig("/home/lukas/tmp/test.svg", transparent=True) + plt.savefig("/home/lukas/tmp/test.pdf", transparent=True) plt.show() diff --git a/visualize_nn.py b/visualize_nn.py new file mode 100644 index 0000000..caf47f0 --- /dev/null +++ b/visualize_nn.py @@ -0,0 +1,64 @@ +import json + +import numpy as np +import torch +from matplotlib import pyplot as plt +from torch import from_numpy, Tensor + +from CustomScaler import CustomScaler +from network import Network + +resolution = 300 + + +def main(): + mcode, gamma, wt, wp = [10 ** 24, 0.4, 1e-5, 1e-5] + with open("pytorch_model.json") as f: + data = json.load(f) + scaler = CustomScaler() + scaler.means = np.array(data["means"]) + scaler.stds = np.array(data["stds"]) + + model = Network() + model.load_state_dict(torch.load("pytorch_model.zip")) + + alpharange = np.linspace(0, 60, resolution) + vrange = np.linspace(0.5, 5.5, resolution) + grid_alpha, grid_v = np.meshgrid(alpharange, vrange) + mcode = 1e24 + wpcode = 1e-5 + + wtcode = 1e-5 + gammacode = 0.6 + testinput = np.array([[np.nan, np.nan, mcode, gammacode, wtcode, wpcode]] * resolution * resolution) + testinput[::, 0] = grid_alpha.flatten() + testinput[::, 1] = grid_v.flatten() + testinput = scaler.transform_data(testinput) + + print(testinput) + print(testinput.shape) + network = Network() + network.load_state_dict(torch.load("pytorch_model.zip")) + + print(testinput) + testoutput: Tensor = network(from_numpy(testinput).to(torch.float)) + data = testoutput.detach().numpy() + grid_result = np.reshape(data[::, 0], (300, 300)) + print("minmax") + print(np.nanmin(grid_result), np.nanmax(grid_result)) + cmap = "Blues" + plt.figure() + plt.title( + "m={:3.0e}, gamma={:3.1f}, wt={:2.0e}%, wp={:2.0e}%\n".format(mcode, gammacode, wtcode, wpcode)) + plt.imshow(grid_result, interpolation='none', cmap=cmap, aspect="auto", origin="lower", vmin=0, vmax=1, + extent=[grid_alpha.min(), grid_alpha.max(), grid_v.min(), grid_v.max()]) + + plt.colorbar().set_label("water retention fraction") + plt.xlabel("impact angle $\\alpha$ [$^{\circ}$]") + plt.ylabel("velocity $v$ [$v_{esc}$]") + plt.tight_layout() + plt.savefig("/home/lukas/tmp/nn.pdf", transparent=True) + plt.show() + +if __name__ == '__main__': + main()