1
0
Fork 0
This repository has been archived on 2024-06-28. You can view files and clone it, but cannot push or open issues or pull requests.
collision-analysis-and-inte.../neural_network.py
2021-10-12 15:45:43 +02:00

158 lines
5 KiB
Python

import json
import random
from pathlib import Path
from typing import List
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from torch import nn, optim, from_numpy, Tensor
from torch.utils.data import DataLoader, TensorDataset
from CustomScaler import CustomScaler
from network import Network
from simulation import Simulation
from simulation_list import SimulationList
def x_array(s: Simulation) -> List[float]:
return [s.alpha, s.v, s.projectile_mass, s.gamma,
s.target_water_fraction, s.projectile_water_fraction]
def y_array(s: Simulation) -> List[float]:
return [
s.water_retention_both, s.mantle_retention_both,
s.core_retention_both, s.output_mass_fraction
]
def train():
filename = "rsmc_dataset"
simulations = SimulationList.jsonlines_load(Path(f"{filename}.jsonl"))
random.seed(1)
random.shuffle(simulations.simlist)
num_test = int(len(simulations.simlist) * 0.2)
test_data = simulations.simlist[:num_test]
train_data = simulations.simlist[num_test:]
print(len(train_data), len(test_data))
a = set(s.runid for s in train_data)
b = set(s.runid for s in test_data)
assert len(a & b) == 0, "no overlap between test data and training data"
X = np.array([x_array(s) for s in train_data])
scaler = CustomScaler()
scaler.fit(X)
x = scaler.transform_data(X)
del X
print(x.shape)
Y = np.array([y_array(s) for s in train_data])
X_test = np.array([x_array(s) for s in test_data])
Y_test = np.array([y_array(s) for s in test_data])
x_test = scaler.transform_data(X_test)
del X_test
random.seed()
dataset = TensorDataset(from_numpy(x).to(torch.float), from_numpy(Y).to(torch.float))
train_dataset = TensorDataset(from_numpy(x_test).to(torch.float), from_numpy(Y_test).to(torch.float))
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
validation_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False)
network = Network()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(network.parameters())
loss_train = []
loss_vali = []
max_epochs = 200
epochs = 0
fig: Figure = plt.figure()
ax: Axes = fig.gca()
x_axis = np.arange(epochs)
loss_plot: Line2D = ax.plot(x_axis, loss_train, label="loss_train")[0]
vali_plot: Line2D = ax.plot(x_axis, loss_vali, label="loss_validation")[0]
ax.legend()
plt.ion()
plt.pause(0.01)
plt.show()
for e in range(max_epochs):
print(f"Epoch: {e}")
total_loss = 0
network.train()
for xs, ys in dataloader:
# Training pass
optimizer.zero_grad()
output = network(xs)
loss = loss_fn(output, ys)
loss.backward()
optimizer.step()
total_loss += loss.item()
loss_train.append(float(total_loss / len(dataloader)))
print(f"Training loss: {total_loss / len(dataloader)}")
# validation:
network.eval()
total_loss_val = 0
for xs, ys in validation_dataloader:
output = network(xs)
total_loss_val += loss_fn(output, ys).item()
loss_vali.append(float(total_loss_val / len(validation_dataloader)))
print(f"Validation loss: {total_loss_val / len(validation_dataloader)}")
epochs += 1
x_axis = np.arange(epochs)
loss_plot.set_xdata(x_axis)
vali_plot.set_xdata(x_axis)
loss_plot.set_ydata(loss_train)
vali_plot.set_ydata(loss_vali)
ax.relim()
ax.autoscale_view(True, True, True)
plt.pause(0.01)
# plt.draw()
# if epochs > 6:
# a = np.sum(np.array(loss_vali[-3:]))
# b = np.sum(np.array(loss_vali[-6:-3]))
# if a > b: # overfitting on training data, stop training
# print("early stopping")
# break
np.savetxt("loss.txt", np.array([x_axis, loss_train, loss_vali]).T)
torch.save(network.state_dict(), "pytorch_model.zip")
with open("pytorch_model.json", "w") as f:
export_dict = {}
value_tensor: Tensor
for key, value_tensor in network.state_dict().items():
export_dict[key] = value_tensor.detach().tolist()
export_dict["means"] = scaler.means.tolist()
export_dict["stds"] = scaler.stds.tolist()
json.dump(export_dict, f)
plt.ioff()
model_test_y = []
for x in x_test:
result = network(from_numpy(np.array(x)).to(torch.float))
y = result.detach().numpy()
model_test_y.append(y)
model_test_y = np.asarray(model_test_y)
plt.figure()
plt.xlabel("model output")
plt.ylabel("real data")
for i, name in enumerate(["shell", "mantle", "core", "mass fraction"]):
plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2, label=name)
plt.legend()
plt.show()
if __name__ == '__main__':
train()