mirror of
https://github.com/Findus23/nn_evaluate.git
synced 2024-09-08 02:03:45 +02:00
add rust version
This commit is contained in:
parent
48cf082b9e
commit
e4b82cf4d4
5 changed files with 203 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
.idea
|
1
rust/.gitignore
vendored
Normal file
1
rust/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
/target
|
87
rust/Cargo.lock
generated
Normal file
87
rust/Cargo.lock
generated
Normal file
|
@ -0,0 +1,87 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
|
||||
|
||||
[[package]]
|
||||
name = "nn_evaluate"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
|
||||
dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd761ff957cb2a45fbb9ab3da6512de9de55872866160b23c25f1a841e99d29f"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1800f7693e94e186f5e25a28291ae1570da908aff7d97a095dec1e56ff99069b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
|
11
rust/Cargo.toml
Normal file
11
rust/Cargo.toml
Normal file
|
@ -0,0 +1,11 @@
|
|||
[package]
|
||||
name = "nn_evaluate"
|
||||
version = "0.1.0"
|
||||
authors = ["Lukas Winkler <git@lw1.at>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
serde_json = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
103
rust/src/main.rs
Normal file
103
rust/src/main.rs
Normal file
|
@ -0,0 +1,103 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::time::Instant;
|
||||
|
||||
fn scale_input(model: &Model, input: Vec<f64>) -> Vec<f64> {
|
||||
let mut z: Vec<f64> = Vec::new();
|
||||
for i in 0..input.len() {
|
||||
z.push((input[i] - model.means[i]) / (model.stds[i]))
|
||||
}
|
||||
return z;
|
||||
}
|
||||
|
||||
fn calculate_layer(
|
||||
layer_size: &usize,
|
||||
parent_layer: Vec<f64>,
|
||||
weight: &Vec<Vec<f64>>,
|
||||
bias: &Vec<f64>,
|
||||
) -> Vec<f64> {
|
||||
let mut new_layer: Vec<f64> = Vec::new();
|
||||
for hl in 0..*layer_size {
|
||||
let mut node: f64 = 0.;
|
||||
for parent in 0..parent_layer.len() {
|
||||
node += parent_layer[parent] * weight[hl][parent]
|
||||
}
|
||||
node += bias[hl];
|
||||
new_layer.push(node);
|
||||
}
|
||||
return new_layer;
|
||||
}
|
||||
|
||||
fn relu(x: f64) -> f64 {
|
||||
if x >= 0. {
|
||||
return x;
|
||||
}
|
||||
return 0.;
|
||||
}
|
||||
|
||||
fn sigmoid(x: f64) -> f64 {
|
||||
return 1. / (1. + (-x).exp());
|
||||
}
|
||||
|
||||
fn evaluate(model: Model, input: Vec<f64>) -> Vec<f64> {
|
||||
let scaled_input = scale_input(&model, Vec::from(input));
|
||||
let hidden_layer_unfinished = calculate_layer(
|
||||
&model.hidden_bias.len(), scaled_input,
|
||||
&model.hidden_weight, &model.hidden_bias,
|
||||
);
|
||||
let mut hidden_layer: Vec<f64> = Vec::new();
|
||||
|
||||
for value in hidden_layer_unfinished {
|
||||
hidden_layer.push(relu(value))
|
||||
}
|
||||
let output_layer_unfinished = calculate_layer(
|
||||
&model.output_bias.len(), hidden_layer,
|
||||
&model.output_weight, &model.output_bias,
|
||||
);
|
||||
let mut output_layer: Vec<f64> = Vec::new();
|
||||
|
||||
for value in output_layer_unfinished {
|
||||
output_layer.push(sigmoid(value))
|
||||
}
|
||||
return output_layer;
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let model = load_json();
|
||||
let ang = 30.;
|
||||
let v = 2.;
|
||||
let m = 1e24;
|
||||
let gamma = 0.6;
|
||||
let wp = 1e-4;
|
||||
let wt = wp;
|
||||
let input = [ang, v, m, gamma, wt, wp];
|
||||
let start = Instant::now();
|
||||
let result=evaluate(model, Vec::from(input));
|
||||
let duration = start.elapsed();
|
||||
println!("{:?}", result);
|
||||
println!("{:?}", duration);
|
||||
}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Model {
|
||||
means: Vec<f64>,
|
||||
stds: Vec<f64>,
|
||||
#[serde(alias = "hidden.weight")]
|
||||
hidden_weight: Vec<Vec<f64>>,
|
||||
#[serde(alias = "hidden.bias")]
|
||||
hidden_bias: Vec<f64>,
|
||||
#[serde(alias = "output.weight")]
|
||||
output_weight: Vec<Vec<f64>>,
|
||||
#[serde(alias = "output.bias")]
|
||||
output_bias: Vec<f64>,
|
||||
}
|
||||
|
||||
fn load_json() -> Model {
|
||||
let file = File::open("../../pytorch_model.json").expect("can't open file");
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let v: Model = serde_json::from_reader(reader).expect("can't parse json");
|
||||
return v;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in a new issue