1
0
Fork 0
mirror of https://github.com/Findus23/collision-analyisis-and-interpolation.git synced 2024-09-19 15:13:50 +02:00
collision-analyisis-and-int.../neural_network.py

159 lines
5 KiB
Python
Raw Normal View History

2021-03-29 15:02:46 +02:00
import json
2019-02-13 15:28:10 +01:00
import random
2021-03-29 15:02:46 +02:00
from pathlib import Path
2021-03-31 17:22:00 +02:00
from typing import List
2019-02-13 15:28:10 +01:00
import numpy as np
2021-03-29 15:02:46 +02:00
import torch
2019-02-13 15:28:10 +01:00
from matplotlib import pyplot as plt
2021-03-29 15:02:46 +02:00
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from torch import nn, optim, from_numpy, Tensor
from torch.utils.data import DataLoader, TensorDataset
2019-02-13 15:28:10 +01:00
2019-07-29 13:59:10 +02:00
from CustomScaler import CustomScaler
2021-03-29 15:02:46 +02:00
from network import Network
2021-03-31 17:22:00 +02:00
from simulation import Simulation
2019-02-13 15:28:10 +01:00
from simulation_list import SimulationList
2021-03-29 15:02:46 +02:00
2021-03-31 17:22:00 +02:00
def x_array(s: Simulation) -> List[float]:
return [s.alpha, s.v, s.projectile_mass, s.gamma,
s.target_water_fraction, s.projectile_water_fraction]
def y_array(s: Simulation) -> List[float]:
return [
s.water_retention_both, s.mantle_retention_both,
s.core_retention_both, s.output_mass_fraction
]
2021-03-29 15:02:46 +02:00
def train():
filename = "rsmc_dataset"
simulations = SimulationList.jsonlines_load(Path(f"{filename}.jsonl"))
2021-03-31 17:22:00 +02:00
random.seed(1)
random.shuffle(simulations.simlist)
num_test = int(len(simulations.simlist) * 0.2)
test_data = simulations.simlist[:num_test]
train_data = simulations.simlist[num_test:]
2021-03-29 15:02:46 +02:00
print(len(train_data), len(test_data))
2021-03-31 17:22:00 +02:00
a = set(s.runid for s in train_data)
b = set(s.runid for s in test_data)
assert len(a & b) == 0, "no overlap between test data and training data"
2021-03-29 15:02:46 +02:00
2021-03-31 17:22:00 +02:00
X = np.array([x_array(s) for s in train_data])
2021-03-29 15:02:46 +02:00
scaler = CustomScaler()
scaler.fit(X)
x = scaler.transform_data(X)
2021-03-31 17:22:00 +02:00
del X
2021-03-29 15:02:46 +02:00
print(x.shape)
2021-03-31 17:22:00 +02:00
Y = np.array([y_array(s) for s in train_data])
X_test = np.array([x_array(s) for s in test_data])
Y_test = np.array([y_array(s) for s in test_data])
2021-03-29 15:02:46 +02:00
x_test = scaler.transform_data(X_test)
2021-03-31 17:22:00 +02:00
del X_test
2021-03-29 15:02:46 +02:00
random.seed()
dataset = TensorDataset(from_numpy(x).to(torch.float), from_numpy(Y).to(torch.float))
train_dataset = TensorDataset(from_numpy(x_test).to(torch.float), from_numpy(Y_test).to(torch.float))
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
validation_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False)
network = Network()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(network.parameters())
loss_train = []
loss_vali = []
2021-10-12 15:45:43 +02:00
max_epochs = 200
2021-03-29 15:02:46 +02:00
epochs = 0
fig: Figure = plt.figure()
ax: Axes = fig.gca()
x_axis = np.arange(epochs)
loss_plot: Line2D = ax.plot(x_axis, loss_train, label="loss_train")[0]
vali_plot: Line2D = ax.plot(x_axis, loss_vali, label="loss_validation")[0]
ax.legend()
plt.ion()
plt.pause(0.01)
plt.show()
for e in range(max_epochs):
print(f"Epoch: {e}")
total_loss = 0
network.train()
for xs, ys in dataloader:
# Training pass
optimizer.zero_grad()
output = network(xs)
loss = loss_fn(output, ys)
loss.backward()
optimizer.step()
total_loss += loss.item()
loss_train.append(float(total_loss / len(dataloader)))
print(f"Training loss: {total_loss / len(dataloader)}")
# validation:
network.eval()
total_loss_val = 0
for xs, ys in validation_dataloader:
output = network(xs)
total_loss_val += loss_fn(output, ys).item()
loss_vali.append(float(total_loss_val / len(validation_dataloader)))
print(f"Validation loss: {total_loss_val / len(validation_dataloader)}")
epochs += 1
x_axis = np.arange(epochs)
loss_plot.set_xdata(x_axis)
vali_plot.set_xdata(x_axis)
loss_plot.set_ydata(loss_train)
vali_plot.set_ydata(loss_vali)
ax.relim()
ax.autoscale_view(True, True, True)
plt.pause(0.01)
# plt.draw()
# if epochs > 6:
# a = np.sum(np.array(loss_vali[-3:]))
# b = np.sum(np.array(loss_vali[-6:-3]))
# if a > b: # overfitting on training data, stop training
# print("early stopping")
# break
2021-10-12 15:45:43 +02:00
np.savetxt("loss.txt", np.array([x_axis, loss_train, loss_vali]).T)
torch.save(network.state_dict(), "pytorch_model.zip")
with open("pytorch_model.json", "w") as f:
export_dict = {}
value_tensor: Tensor
for key, value_tensor in network.state_dict().items():
export_dict[key] = value_tensor.detach().tolist()
export_dict["means"] = scaler.means.tolist()
export_dict["stds"] = scaler.stds.tolist()
json.dump(export_dict, f)
2021-03-29 15:02:46 +02:00
plt.ioff()
2021-03-31 17:22:00 +02:00
model_test_y = []
for x in x_test:
result = network(from_numpy(np.array(x)).to(torch.float))
y = result.detach().numpy()
model_test_y.append(y)
model_test_y = np.asarray(model_test_y)
plt.figure()
plt.xlabel("model output")
plt.ylabel("real data")
2021-10-12 15:45:43 +02:00
for i, name in enumerate(["shell", "mantle", "core", "mass fraction"]):
plt.scatter(model_test_y[::, i], Y_test[::, i], s=0.2, label=name)
2021-03-31 17:22:00 +02:00
plt.legend()
plt.show()
2021-03-29 15:02:46 +02:00
if __name__ == '__main__':
train()