From be105bd30cfc83ab2a919228fbbae3a6ae26f3ef Mon Sep 17 00:00:00 2001 From: Maxwell Flitton Date: Fri, 16 Aug 2024 23:53:43 +0100 Subject: [PATCH] Return an array / vector from `ml` functions (#4523) Co-authored-by: Tobie Morgan Hitchcock --- Cargo.lock | 1 - core/Cargo.toml | 1 - core/src/sql/array.rs | 6 ++++++ core/src/sql/model.rs | 12 ++++++------ core/src/sql/value/value.rs | 20 ++++++++++++++++++++ core/src/syn/parser/test/mod.rs | 2 -- tests/ml_integration.rs | 7 ++++--- 7 files changed, 36 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 90fee11b..dffdbce1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5971,7 +5971,6 @@ dependencies = [ "nanoid", "ndarray", "ndarray-stats", - "nom", "num-traits", "num_cpus", "object_store", diff --git a/core/Cargo.toml b/core/Cargo.toml index 77db57f6..fb308d6b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -104,7 +104,6 @@ nanoid = "0.4.0" ndarray = { version = "=0.15.6" } ndarray-stats = "=0.5.1" num-traits = "0.2.18" -nom = { version = "7.1.3", features = ["alloc"] } num_cpus = "1.16.0" object_store = { version = "0.10.2", optional = false } once_cell = "1.18.0" diff --git a/core/src/sql/array.rs b/core/src/sql/array.rs index ddb4ed16..9431690e 100644 --- a/core/src/sql/array.rs +++ b/core/src/sql/array.rs @@ -42,6 +42,12 @@ impl From> for Array { } } +impl From> for Array { + fn from(v: Vec) -> Self { + Self(v.into_iter().map(Value::from).collect()) + } +} + impl From> for Array { fn from(v: Vec) -> Self { Self(v.into_iter().map(Value::from).collect()) diff --git a/core/src/sql/model.rs b/core/src/sql/model.rs index ab0286d9..e685d0d8 100644 --- a/core/src/sql/model.rs +++ b/core/src/sql/model.rs @@ -129,7 +129,7 @@ impl Model { // Get the model file as bytes let bytes = crate::obs::get(&path).await?; // Run the compute in a blocking task - let outcome = tokio::task::spawn_blocking(move || { + let outcome: Vec = tokio::task::spawn_blocking(move || { let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| { Error::ModelComputation(err.message.to_string()) })?; @@ -143,7 +143,7 @@ impl Model { .await .unwrap()?; // Convert the output to a value - Ok(outcome[0].into()) + Ok(outcome.into()) } // Perform raw compute Value::Number(v) => { @@ -157,7 +157,7 @@ impl Model { // Convert the argument to a tensor let tensor = ndarray::arr1::(&[args]).into_dyn(); // Run the compute in a blocking task - let outcome = tokio::task::spawn_blocking(move || { + let outcome: Vec = tokio::task::spawn_blocking(move || { let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| { Error::ModelComputation(err.message.to_string()) })?; @@ -171,7 +171,7 @@ impl Model { .await .unwrap()?; // Convert the output to a value - Ok(outcome[0].into()) + Ok(outcome.into()) } // Perform raw compute Value::Array(v) => { @@ -189,7 +189,7 @@ impl Model { // Convert the argument to a tensor let tensor = ndarray::arr1::(&args).into_dyn(); // Run the compute in a blocking task - let outcome = tokio::task::spawn_blocking(move || { + let outcome: Vec = tokio::task::spawn_blocking(move || { let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| { Error::ModelComputation(err.message.to_string()) })?; @@ -203,7 +203,7 @@ impl Model { .await .unwrap()?; // Convert the output to a value - Ok(outcome[0].into()) + Ok(outcome.into()) } // _ => Err(Error::InvalidArguments { diff --git a/core/src/sql/value/value.rs b/core/src/sql/value/value.rs index 1d533bb7..9e7707b0 100644 --- a/core/src/sql/value/value.rs +++ b/core/src/sql/value/value.rs @@ -458,6 +458,12 @@ impl From> for Value { } } +impl From> for Value { + fn from(v: Vec) -> Self { + Value::Array(Array::from(v)) + } +} + impl From> for Value { fn from(v: Vec) -> Self { Value::Array(Array::from(v)) @@ -3005,4 +3011,18 @@ mod tests { let dec: Value = enc.into(); assert_eq!(res, dec); } + + #[test] + fn test_value_from_vec_i32() { + let vector: Vec = vec![1, 2, 3, 4, 5, 6]; + let value = Value::from(vector); + assert!(matches!(value, Value::Array(Array(_)))); + } + + #[test] + fn test_value_from_vec_f32() { + let vector: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let value = Value::from(vector); + assert!(matches!(value, Value::Array(Array(_)))); + } } diff --git a/core/src/syn/parser/test/mod.rs b/core/src/syn/parser/test/mod.rs index a5693e25..30c36cfe 100644 --- a/core/src/syn/parser/test/mod.rs +++ b/core/src/syn/parser/test/mod.rs @@ -1,5 +1,3 @@ -use nom::AsBytes; - use crate::{ sql::{self, Id, Statement, Thing, Value}, syn::parser::mac::test_parse, diff --git a/tests/ml_integration.rs b/tests/ml_integration.rs index 5e6f1241..b84175da 100644 --- a/tests/ml_integration.rs +++ b/tests/ml_integration.rs @@ -26,7 +26,7 @@ mod ml_integration { #[derive(Serialize, Deserialize, Debug)] struct Data { - result: f64, + result: Vec, status: String, time: String, } @@ -245,8 +245,9 @@ mod ml_integration { .await?; assert!(res.status().is_success(), "body: {}", res.text().await?); let body = res.text().await?; + let deserialized_data: Vec = serde_json::from_str(&body)?; - assert_eq!(deserialized_data[0].result, 0.9998061656951904); + assert_eq!(deserialized_data[0].result[0], 0.9998061656951904); } Ok(()) } @@ -282,7 +283,7 @@ mod ml_integration { assert!(res.status().is_success(), "body: {}", res.text().await?); let body = res.text().await?; let deserialized_data: Vec = serde_json::from_str(&body)?; - assert_eq!(deserialized_data[0].result, 177206.21875); + assert_eq!(deserialized_data[0].result[0], 177206.21875); } Ok(()) }