1
0
Fork 0
mirror of https://github.com/Findus23/nn_evaluate.git synced 2024-09-08 02:03:45 +02:00

add rust version

This commit is contained in:
Lukas Winkler 2021-03-21 19:09:55 +01:00
parent 48cf082b9e
commit e4b82cf4d4
Signed by: lukas
GPG key ID: 54DE4D798D244853
5 changed files with 203 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
.idea

1
rust/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

87
rust/Cargo.lock generated Normal file
View file

@ -0,0 +1,87 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "itoa"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
[[package]]
name = "nn_evaluate"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "proc-macro2"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
dependencies = [
"unicode-xid",
]
[[package]]
name = "quote"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
[[package]]
name = "serde"
version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd761ff957cb2a45fbb9ab3da6512de9de55872866160b23c25f1a841e99d29f"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1800f7693e94e186f5e25a28291ae1570da908aff7d97a095dec1e56ff99069b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "syn"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f"
dependencies = [
"proc-macro2",
"quote",
"unicode-xid",
]
[[package]]
name = "unicode-xid"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"

11
rust/Cargo.toml Normal file
View file

@ -0,0 +1,11 @@
[package]
name = "nn_evaluate"
version = "0.1.0"
authors = ["Lukas Winkler <git@lw1.at>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }

103
rust/src/main.rs Normal file
View file

@ -0,0 +1,103 @@
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::BufReader;
use std::time::Instant;
fn scale_input(model: &Model, input: Vec<f64>) -> Vec<f64> {
let mut z: Vec<f64> = Vec::new();
for i in 0..input.len() {
z.push((input[i] - model.means[i]) / (model.stds[i]))
}
return z;
}
fn calculate_layer(
layer_size: &usize,
parent_layer: Vec<f64>,
weight: &Vec<Vec<f64>>,
bias: &Vec<f64>,
) -> Vec<f64> {
let mut new_layer: Vec<f64> = Vec::new();
for hl in 0..*layer_size {
let mut node: f64 = 0.;
for parent in 0..parent_layer.len() {
node += parent_layer[parent] * weight[hl][parent]
}
node += bias[hl];
new_layer.push(node);
}
return new_layer;
}
fn relu(x: f64) -> f64 {
if x >= 0. {
return x;
}
return 0.;
}
fn sigmoid(x: f64) -> f64 {
return 1. / (1. + (-x).exp());
}
fn evaluate(model: Model, input: Vec<f64>) -> Vec<f64> {
let scaled_input = scale_input(&model, Vec::from(input));
let hidden_layer_unfinished = calculate_layer(
&model.hidden_bias.len(), scaled_input,
&model.hidden_weight, &model.hidden_bias,
);
let mut hidden_layer: Vec<f64> = Vec::new();
for value in hidden_layer_unfinished {
hidden_layer.push(relu(value))
}
let output_layer_unfinished = calculate_layer(
&model.output_bias.len(), hidden_layer,
&model.output_weight, &model.output_bias,
);
let mut output_layer: Vec<f64> = Vec::new();
for value in output_layer_unfinished {
output_layer.push(sigmoid(value))
}
return output_layer;
}
fn main() {
let model = load_json();
let ang = 30.;
let v = 2.;
let m = 1e24;
let gamma = 0.6;
let wp = 1e-4;
let wt = wp;
let input = [ang, v, m, gamma, wt, wp];
let start = Instant::now();
let result=evaluate(model, Vec::from(input));
let duration = start.elapsed();
println!("{:?}", result);
println!("{:?}", duration);
}
#[derive(Serialize, Deserialize)]
struct Model {
means: Vec<f64>,
stds: Vec<f64>,
#[serde(alias = "hidden.weight")]
hidden_weight: Vec<Vec<f64>>,
#[serde(alias = "hidden.bias")]
hidden_bias: Vec<f64>,
#[serde(alias = "output.weight")]
output_weight: Vec<Vec<f64>>,
#[serde(alias = "output.bias")]
output_bias: Vec<f64>,
}
fn load_json() -> Model {
let file = File::open("../../pytorch_model.json").expect("can't open file");
let reader = BufReader::new(file);
let v: Model = serde_json::from_reader(reader).expect("can't parse json");
return v;
}