2021-03-29 15:02:46 +02:00
|
|
|
import json
|
2019-02-13 15:28:10 +01:00
|
|
|
import random
|
2021-03-29 15:02:46 +02:00
|
|
|
from pathlib import Path
|
2021-03-31 17:22:00 +02:00
|
|
|
from typing import List
|
2019-02-13 15:28:10 +01:00
|
|
|
|
|
|
|
import numpy as np
|
2021-03-29 15:02:46 +02:00
|
|
|
import torch
|
2019-02-13 15:28:10 +01:00
|
|
|
from matplotlib import pyplot as plt
|
2021-03-29 15:02:46 +02:00
|
|
|
from matplotlib.axes import Axes
|
|
|
|
from matplotlib.figure import Figure
|
|
|
|
from matplotlib.lines import Line2D
|
|
|
|
from torch import nn, optim, from_numpy, Tensor
|
|
|
|
from torch.utils.data import DataLoader, TensorDataset
|
2019-02-13 15:28:10 +01:00
|
|
|
|
2019-07-29 13:59:10 +02:00
|
|
|
from CustomScaler import CustomScaler
|
2021-03-29 15:02:46 +02:00
|
|
|
from network import Network
|
2021-03-31 17:22:00 +02:00
|
|
|
from simulation import Simulation
|
2019-02-13 15:28:10 +01:00
|
|
|
from simulation_list import SimulationList
|
|
|
|
|
2021-03-29 15:02:46 +02:00
|
|
|
|
2021-03-31 17:22:00 +02:00
|
|
|
def x_array(s: Simulation) -> List[float]:
|
|
|
|
return [s.alpha, s.v, s.projectile_mass, s.gamma,
|
|
|
|
s.target_water_fraction, s.projectile_water_fraction]
|
|
|
|
|
|
|
|
|
|
|
|
def y_array(s: Simulation) -> List[float]:
|
|
|
|
return [
|
|
|
|
s.water_retention_both, s.mantle_retention_both,
|
|
|
|
s.core_retention_both, s.output_mass_fraction
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2021-03-29 15:02:46 +02:00
|
|
|
def train():
|
|
|
|
filename = "rsmc_dataset"
|
|
|
|
|
|
|
|
simulations = SimulationList.jsonlines_load(Path(f"{filename}.jsonl"))
|
|
|
|
|
2021-03-31 17:22:00 +02:00
|
|
|
random.seed(1)
|
|
|
|
random.shuffle(simulations.simlist)
|
|
|
|
num_test = int(len(simulations.simlist) * 0.2)
|
|
|
|
test_data = simulations.simlist[:num_test]
|
|
|
|
train_data = simulations.simlist[num_test:]
|
2021-03-29 15:02:46 +02:00
|
|
|
print(len(train_data), len(test_data))
|
2021-03-31 17:22:00 +02:00
|
|
|
a = set(s.runid for s in train_data)
|
|
|
|
b = set(s.runid for s in test_data)
|
|
|
|
assert len(a & b) == 0, "no overlap between test data and training data"
|
2021-03-29 15:02:46 +02:00
|
|
|
|
2021-03-31 17:22:00 +02:00
|
|
|
X = np.array([x_array(s) for s in train_data])
|
2021-03-29 15:02:46 +02:00
|
|
|
scaler = CustomScaler()
|
|
|
|
scaler.fit(X)
|
|
|
|
x = scaler.transform_data(X)
|
2021-03-31 17:22:00 +02:00
|
|
|
del X
|
2021-03-29 15:02:46 +02:00
|
|
|
print(x.shape)
|
2021-03-31 17:22:00 +02:00
|
|
|
Y = np.array([y_array(s) for s in train_data])
|
|
|
|
|
|
|
|
X_test = np.array([x_array(s) for s in test_data])
|
|
|
|
Y_test = np.array([y_array(s) for s in test_data])
|
2021-03-29 15:02:46 +02:00
|
|
|
x_test = scaler.transform_data(X_test)
|
2021-03-31 17:22:00 +02:00
|
|
|
del X_test
|
2021-03-29 15:02:46 +02:00
|
|
|
random.seed()
|
|
|
|
|
|
|
|
dataset = TensorDataset(from_numpy(x).to(torch.float), from_numpy(Y).to(torch.float))
|
|
|
|
train_dataset = TensorDataset(from_numpy(x_test).to(torch.float), from_numpy(Y_test).to(torch.float))
|
|
|
|
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
|
|
|
|
validation_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False)
|
|
|
|
|
|
|
|
network = Network()
|
|
|
|
|
|
|
|
loss_fn = nn.MSELoss()
|
|
|
|
|
|
|
|
optimizer = optim.Adam(network.parameters())
|
|
|
|
|
|
|
|
loss_train = []
|
|
|
|
loss_vali = []
|
|
|
|
|
2021-10-12 15:45:43 +02:00
|
|
|
max_epochs = 200
|
2021-03-29 15:02:46 +02:00
|
|
|
epochs = 0
|
|
|
|
|
|
|
|
fig: Figure = plt.figure()
|
|
|
|
ax: Axes = fig.gca()
|
|
|
|
x_axis = np.arange(epochs)
|
|
|
|
loss_plot: Line2D = ax.plot(x_axis, loss_train, label="loss_train")[0]
|
|
|
|
vali_plot: Line2D = ax.plot(x_axis, loss_vali, label="loss_validation")[0]
|
|
|
|
ax.legend()
|
|
|
|
plt.ion()
|
|
|
|
plt.pause(0.01)
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
for e in range(max_epochs):
|
|
|
|
print(f"Epoch: {e}")
|
|
|
|
total_loss = 0
|
|
|
|
network.train()
|
|
|
|
for xs, ys in dataloader:
|
|
|
|
# Training pass
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
output = network(xs)
|
|
|
|
loss = loss_fn(output, ys)
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
loss_train.append(float(total_loss / len(dataloader)))
|
|
|
|
print(f"Training loss: {total_loss / len(dataloader)}")
|
|
|
|
|
|
|
|
# validation:
|
|
|
|
network.eval()
|
|
|
|
total_loss_val = 0
|
|
|
|
for xs, ys in validation_dataloader:
|
|
|
|
output = network(xs)
|
|
|
|
total_loss_val += loss_fn(output, ys).item()
|
|
|
|
loss_vali.append(float(total_loss_val / len(validation_dataloader)))
|
|
|
|
print(f"Validation loss: {total_loss_val / len(validation_dataloader)}")
|
|
|
|
epochs += 1
|
|
|
|
|
|
|
|
x_axis = np.arange(epochs)
|
|
|
|
loss_plot.set_xdata(x_axis)
|
|
|
|
vali_plot.set_xdata(x_axis)
|
|
|
|
loss_plot.set_ydata(loss_train)
|
|
|
|
vali_plot.set_ydata(loss_vali)
|
|
|
|
ax.relim()
|
|
|
|
ax.autoscale_view(True, True, True)
|
|
|
|
plt.pause(0.01)
|
|
|
|
# plt.draw()
|
|
|
|
# if epochs > 6:
|
|
|
|
# a = np.sum(np.array(loss_vali[-3:]))
|
|
|
|
# b = np.sum(np.array(loss_vali[-6:-3]))
|
|
|
|
# if a > b: # overfitting on training data, stop training
|
|
|
|
# print("early stopping")
|
|
|
|
# break
|
2021-10-12 15:45:43 +02:00
|
|
|
np.savetxt("loss.txt", np.array([x_axis, loss_train, loss_vali]).T)
|
|
|
|
torch.save(network.state_dict(), "pytorch_model.zip")
|
|
|
|
with open("pytorch_model.json", "w") as f:
|
|
|
|
export_dict = {}
|
|
|
|
value_tensor: Tensor
|
|
|
|
for key, value_tensor in network.state_dict().items():
|
|
|
|
export_dict[key] = value_tensor.detach().tolist()
|
|
|
|
export_dict["means"] = scaler.means.tolist()
|
|
|
|
export_dict["stds"] = scaler.stds.tolist()
|
|
|
|
json.dump(export_dict, f)
|
2021-03-29 15:02:46 +02:00
|
|
|
plt.ioff()
|
2021-03-31 17:22:00 +02:00
|
|
|
model_test_y = []
|
|
|
|
for x in x_test:
|
|
|
|
result = network(from_numpy(np.array(x)).to(torch.float))
|
|
|
|
y = result.detach().numpy()
|
|
|
|
model_test_y.append(y)
|
|
|
|
model_test_y = np.asarray(model_test_y)
|
|
|
|
plt.figure()
|
|
|
|
plt.xlabel("model output")
|
|
|
|
plt.ylabel("real data")
|
2021-10-12 15:45:43 +02:00
|
|
|
for i, name in enumerate(["shell", "mantle", "core", "mass fraction"]):
|
|
|
|
plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2, label=name)
|
2021-03-31 17:22:00 +02:00
|
|
|
plt.legend()
|
|
|
|
plt.show()
|
2021-03-29 15:02:46 +02:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
train()
|