simplify test- and training-set for nn
This commit is contained in:
parent
b6b55519fc
commit
0021906370
1 changed files with 43 additions and 19 deletions
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -13,40 +14,49 @@ from torch.utils.data import DataLoader, TensorDataset
|
|||
|
||||
from CustomScaler import CustomScaler
|
||||
from network import Network
|
||||
from simulation import Simulation
|
||||
from simulation_list import SimulationList
|
||||
|
||||
|
||||
def x_array(s: Simulation) -> List[float]:
|
||||
return [s.alpha, s.v, s.projectile_mass, s.gamma,
|
||||
s.target_water_fraction, s.projectile_water_fraction]
|
||||
|
||||
|
||||
def y_array(s: Simulation) -> List[float]:
|
||||
return [
|
||||
s.water_retention_both, s.mantle_retention_both,
|
||||
s.core_retention_both, s.output_mass_fraction
|
||||
]
|
||||
|
||||
|
||||
def train():
|
||||
filename = "rsmc_dataset"
|
||||
|
||||
simulations = SimulationList.jsonlines_load(Path(f"{filename}.jsonl"))
|
||||
|
||||
# random.seed(1)
|
||||
test_data = random.sample(simulations.simlist, int(len(simulations.simlist) * 0.2))
|
||||
test_set = set(test_data) # use a set for faster *in* computation
|
||||
train_data = [s for s in simulations.simlist if s not in test_set]
|
||||
random.seed(1)
|
||||
random.shuffle(simulations.simlist)
|
||||
num_test = int(len(simulations.simlist) * 0.2)
|
||||
test_data = simulations.simlist[:num_test]
|
||||
train_data = simulations.simlist[num_test:]
|
||||
print(len(train_data), len(test_data))
|
||||
a = set(s.runid for s in train_data)
|
||||
b = set(s.runid for s in test_data)
|
||||
assert len(a & b) == 0, "no overlap between test data and training data"
|
||||
|
||||
X = np.array(
|
||||
[[s.alpha, s.v, s.projectile_mass, s.gamma, s.target_water_fraction, s.projectile_water_fraction] for s in
|
||||
train_data])
|
||||
X = np.array([x_array(s) for s in train_data])
|
||||
scaler = CustomScaler()
|
||||
scaler.fit(X)
|
||||
x = scaler.transform_data(X)
|
||||
del X
|
||||
print(x.shape)
|
||||
Y = np.array([[
|
||||
s.water_retention_both, s.mantle_retention_both, s.core_retention_both,
|
||||
s.output_mass_fraction
|
||||
] for s in train_data])
|
||||
Y = np.array([y_array(s) for s in train_data])
|
||||
|
||||
X_test = np.array(
|
||||
[[s.alpha, s.v, s.projectile_mass, s.gamma, s.target_water_fraction, s.projectile_water_fraction] for s in
|
||||
test_data])
|
||||
Y_test = np.array([[
|
||||
s.water_retention_both, s.mantle_retention_both, s.core_retention_both,
|
||||
s.output_mass_fraction
|
||||
] for s in test_data])
|
||||
X_test = np.array([x_array(s) for s in test_data])
|
||||
Y_test = np.array([y_array(s) for s in test_data])
|
||||
x_test = scaler.transform_data(X_test)
|
||||
del X_test
|
||||
random.seed()
|
||||
|
||||
dataset = TensorDataset(from_numpy(x).to(torch.float), from_numpy(Y).to(torch.float))
|
||||
|
@ -63,7 +73,7 @@ def train():
|
|||
loss_train = []
|
||||
loss_vali = []
|
||||
|
||||
max_epochs = 120
|
||||
max_epochs = 500
|
||||
epochs = 0
|
||||
|
||||
fig: Figure = plt.figure()
|
||||
|
@ -119,6 +129,19 @@ def train():
|
|||
# print("early stopping")
|
||||
# break
|
||||
plt.ioff()
|
||||
model_test_y = []
|
||||
for x in x_test:
|
||||
result = network(from_numpy(np.array(x)).to(torch.float))
|
||||
y = result.detach().numpy()
|
||||
model_test_y.append(y)
|
||||
model_test_y = np.asarray(model_test_y)
|
||||
plt.figure()
|
||||
plt.xlabel("model output")
|
||||
plt.ylabel("real data")
|
||||
for i, name in enumerate(["shell","mantle","core","mass fraction"]):
|
||||
plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2,label=name)
|
||||
plt.legend()
|
||||
plt.show()
|
||||
torch.save(network.state_dict(), "pytorch_model.zip")
|
||||
with open("pytorch_model.json", "w") as f:
|
||||
export_dict = {}
|
||||
|
@ -150,6 +173,7 @@ def train():
|
|||
print("minmax")
|
||||
print(np.nanmin(outgrid), np.nanmax(outgrid))
|
||||
cmap = "Blues"
|
||||
plt.figure()
|
||||
plt.title(
|
||||
"m={:3.0e}, gamma={:3.1f}, wt={:2.0f}%, wp={:2.0f}%\n".format(mcode, gammacode, wtcode * 100, wpcode * 100))
|
||||
plt.imshow(outgrid, interpolation='none', cmap=cmap, aspect="auto", origin="lower", vmin=0, vmax=1,
|
||||
|
|
Reference in a new issue