1
0
Fork 0
mirror of https://github.com/Findus23/collision-analyisis-and-interpolation.git synced 2024-09-19 15:13:50 +02:00
collision-analyisis-and-int.../nn_single.py

28 lines
589 B
Python
Raw Normal View History

2021-10-12 15:45:43 +02:00
import json
from pathlib import Path
import numpy as np
import torch
from CustomScaler import CustomScaler
from network import Network
from simulation_list import SimulationList
resolution = 100
with open("pytorch_model.json") as f:
data = json.load(f)
scaler = CustomScaler()
scaler.means = np.array(data["means"])
scaler.stds = np.array(data["stds"])
model = Network()
model.load_state_dict(torch.load("pytorch_model.zip"))
ang = 30
v = 2
m = 1e24
gamma = 0.6
wp = wt = 1e-4
print(model(torch.Tensor(list(scaler.transform_parameters([ang, v, m, gamma, wt, wp])))))