diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1 @@ +/target diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 0000000..a065422 --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,87 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +[[package]] +name = "itoa" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" + +[[package]] +name = "nn_evaluate" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "proc-macro2" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71" +dependencies = [ + "unicode-xid", +] + +[[package]] +name = "quote" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" + +[[package]] +name = "serde" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd761ff957cb2a45fbb9ab3da6512de9de55872866160b23c25f1a841e99d29f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1800f7693e94e186f5e25a28291ae1570da908aff7d97a095dec1e56ff99069b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "syn" +version = "1.0.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "unicode-xid" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000..81f03b8 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "nn_evaluate" +version = "0.1.0" +authors = ["Lukas Winkler "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +serde_json = "1.0" +serde = { version = "1.0", features = ["derive"] } diff --git a/rust/src/main.rs b/rust/src/main.rs new file mode 100644 index 0000000..a1e48c7 --- /dev/null +++ b/rust/src/main.rs @@ -0,0 +1,103 @@ +use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::BufReader; +use std::time::Instant; + +fn scale_input(model: &Model, input: Vec) -> Vec { + let mut z: Vec = Vec::new(); + for i in 0..input.len() { + z.push((input[i] - model.means[i]) / (model.stds[i])) + } + return z; +} + +fn calculate_layer( + layer_size: &usize, + parent_layer: Vec, + weight: &Vec>, + bias: &Vec, +) -> Vec { + let mut new_layer: Vec = Vec::new(); + for hl in 0..*layer_size { + let mut node: f64 = 0.; + for parent in 0..parent_layer.len() { + node += parent_layer[parent] * weight[hl][parent] + } + node += bias[hl]; + new_layer.push(node); + } + return new_layer; +} + +fn relu(x: f64) -> f64 { + if x >= 0. { + return x; + } + return 0.; +} + +fn sigmoid(x: f64) -> f64 { + return 1. / (1. + (-x).exp()); +} + +fn evaluate(model: Model, input: Vec) -> Vec { + let scaled_input = scale_input(&model, Vec::from(input)); + let hidden_layer_unfinished = calculate_layer( + &model.hidden_bias.len(), scaled_input, + &model.hidden_weight, &model.hidden_bias, + ); + let mut hidden_layer: Vec = Vec::new(); + + for value in hidden_layer_unfinished { + hidden_layer.push(relu(value)) + } + let output_layer_unfinished = calculate_layer( + &model.output_bias.len(), hidden_layer, + &model.output_weight, &model.output_bias, + ); + let mut output_layer: Vec = Vec::new(); + + for value in output_layer_unfinished { + output_layer.push(sigmoid(value)) + } + return output_layer; +} + +fn main() { + let model = load_json(); + let ang = 30.; + let v = 2.; + let m = 1e24; + let gamma = 0.6; + let wp = 1e-4; + let wt = wp; + let input = [ang, v, m, gamma, wt, wp]; + let start = Instant::now(); + let result=evaluate(model, Vec::from(input)); + let duration = start.elapsed(); + println!("{:?}", result); + println!("{:?}", duration); +} +#[derive(Serialize, Deserialize)] +struct Model { + means: Vec, + stds: Vec, + #[serde(alias = "hidden.weight")] + hidden_weight: Vec>, + #[serde(alias = "hidden.bias")] + hidden_bias: Vec, + #[serde(alias = "output.weight")] + output_weight: Vec>, + #[serde(alias = "output.bias")] + output_bias: Vec, +} + +fn load_json() -> Model { + let file = File::open("../../pytorch_model.json").expect("can't open file"); + let reader = BufReader::new(file); + + let v: Model = serde_json::from_reader(reader).expect("can't parse json"); + return v; +} + +