Return an array / vector from ml functions (#4523)

Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
Maxwell Flitton 2024-08-16 23:53:43 +01:00 committed by GitHub
parent 6d26797e56
commit be105bd30c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 36 additions and 13 deletions

1
Cargo.lock generated
View file

@ -5971,7 +5971,6 @@ dependencies = [
"nanoid",
"ndarray",
"ndarray-stats",
"nom",
"num-traits",
"num_cpus",
"object_store",

View file

@ -104,7 +104,6 @@ nanoid = "0.4.0"
ndarray = { version = "=0.15.6" }
ndarray-stats = "=0.5.1"
num-traits = "0.2.18"
nom = { version = "7.1.3", features = ["alloc"] }
num_cpus = "1.16.0"
object_store = { version = "0.10.2", optional = false }
once_cell = "1.18.0"

View file

@ -42,6 +42,12 @@ impl From<Vec<i32>> for Array {
}
}
impl From<Vec<f32>> for Array {
fn from(v: Vec<f32>) -> Self {
Self(v.into_iter().map(Value::from).collect())
}
}
impl From<Vec<f64>> for Array {
fn from(v: Vec<f64>) -> Self {
Self(v.into_iter().map(Value::from).collect())

View file

@ -129,7 +129,7 @@ impl Model {
// Get the model file as bytes
let bytes = crate::obs::get(&path).await?;
// Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || {
let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})?;
@ -143,7 +143,7 @@ impl Model {
.await
.unwrap()?;
// Convert the output to a value
Ok(outcome[0].into())
Ok(outcome.into())
}
// Perform raw compute
Value::Number(v) => {
@ -157,7 +157,7 @@ impl Model {
// Convert the argument to a tensor
let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
// Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || {
let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})?;
@ -171,7 +171,7 @@ impl Model {
.await
.unwrap()?;
// Convert the output to a value
Ok(outcome[0].into())
Ok(outcome.into())
}
// Perform raw compute
Value::Array(v) => {
@ -189,7 +189,7 @@ impl Model {
// Convert the argument to a tensor
let tensor = ndarray::arr1::<f32>(&args).into_dyn();
// Run the compute in a blocking task
let outcome = tokio::task::spawn_blocking(move || {
let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
Error::ModelComputation(err.message.to_string())
})?;
@ -203,7 +203,7 @@ impl Model {
.await
.unwrap()?;
// Convert the output to a value
Ok(outcome[0].into())
Ok(outcome.into())
}
//
_ => Err(Error::InvalidArguments {

View file

@ -458,6 +458,12 @@ impl From<Vec<i32>> for Value {
}
}
impl From<Vec<f32>> for Value {
fn from(v: Vec<f32>) -> Self {
Value::Array(Array::from(v))
}
}
impl From<Vec<Value>> for Value {
fn from(v: Vec<Value>) -> Self {
Value::Array(Array::from(v))
@ -3005,4 +3011,18 @@ mod tests {
let dec: Value = enc.into();
assert_eq!(res, dec);
}
#[test]
fn test_value_from_vec_i32() {
let vector: Vec<i32> = vec![1, 2, 3, 4, 5, 6];
let value = Value::from(vector);
assert!(matches!(value, Value::Array(Array(_))));
}
#[test]
fn test_value_from_vec_f32() {
let vector: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let value = Value::from(vector);
assert!(matches!(value, Value::Array(Array(_))));
}
}

View file

@ -1,5 +1,3 @@
use nom::AsBytes;
use crate::{
sql::{self, Id, Statement, Thing, Value},
syn::parser::mac::test_parse,

View file

@ -26,7 +26,7 @@ mod ml_integration {
#[derive(Serialize, Deserialize, Debug)]
struct Data {
result: f64,
result: Vec<f64>,
status: String,
time: String,
}
@ -245,8 +245,9 @@ mod ml_integration {
.await?;
assert!(res.status().is_success(), "body: {}", res.text().await?);
let body = res.text().await?;
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
assert_eq!(deserialized_data[0].result, 0.9998061656951904);
assert_eq!(deserialized_data[0].result[0], 0.9998061656951904);
}
Ok(())
}
@ -282,7 +283,7 @@ mod ml_integration {
assert!(res.status().is_success(), "body: {}", res.text().await?);
let body = res.text().await?;
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
assert_eq!(deserialized_data[0].result, 177206.21875);
assert_eq!(deserialized_data[0].result[0], 177206.21875);
}
Ok(())
}