Return an array / vector from ml
functions (#4523)
Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
parent
6d26797e56
commit
be105bd30c
7 changed files with 36 additions and 13 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -5971,7 +5971,6 @@ dependencies = [
|
|||
"nanoid",
|
||||
"ndarray",
|
||||
"ndarray-stats",
|
||||
"nom",
|
||||
"num-traits",
|
||||
"num_cpus",
|
||||
"object_store",
|
||||
|
|
|
@ -104,7 +104,6 @@ nanoid = "0.4.0"
|
|||
ndarray = { version = "=0.15.6" }
|
||||
ndarray-stats = "=0.5.1"
|
||||
num-traits = "0.2.18"
|
||||
nom = { version = "7.1.3", features = ["alloc"] }
|
||||
num_cpus = "1.16.0"
|
||||
object_store = { version = "0.10.2", optional = false }
|
||||
once_cell = "1.18.0"
|
||||
|
|
|
@ -42,6 +42,12 @@ impl From<Vec<i32>> for Array {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Vec<f32>> for Array {
|
||||
fn from(v: Vec<f32>) -> Self {
|
||||
Self(v.into_iter().map(Value::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<f64>> for Array {
|
||||
fn from(v: Vec<f64>) -> Self {
|
||||
Self(v.into_iter().map(Value::from).collect())
|
||||
|
|
|
@ -129,7 +129,7 @@ impl Model {
|
|||
// Get the model file as bytes
|
||||
let bytes = crate::obs::get(&path).await?;
|
||||
// Run the compute in a blocking task
|
||||
let outcome = tokio::task::spawn_blocking(move || {
|
||||
let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
|
||||
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})?;
|
||||
|
@ -143,7 +143,7 @@ impl Model {
|
|||
.await
|
||||
.unwrap()?;
|
||||
// Convert the output to a value
|
||||
Ok(outcome[0].into())
|
||||
Ok(outcome.into())
|
||||
}
|
||||
// Perform raw compute
|
||||
Value::Number(v) => {
|
||||
|
@ -157,7 +157,7 @@ impl Model {
|
|||
// Convert the argument to a tensor
|
||||
let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
|
||||
// Run the compute in a blocking task
|
||||
let outcome = tokio::task::spawn_blocking(move || {
|
||||
let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
|
||||
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})?;
|
||||
|
@ -171,7 +171,7 @@ impl Model {
|
|||
.await
|
||||
.unwrap()?;
|
||||
// Convert the output to a value
|
||||
Ok(outcome[0].into())
|
||||
Ok(outcome.into())
|
||||
}
|
||||
// Perform raw compute
|
||||
Value::Array(v) => {
|
||||
|
@ -189,7 +189,7 @@ impl Model {
|
|||
// Convert the argument to a tensor
|
||||
let tensor = ndarray::arr1::<f32>(&args).into_dyn();
|
||||
// Run the compute in a blocking task
|
||||
let outcome = tokio::task::spawn_blocking(move || {
|
||||
let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
|
||||
let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
|
||||
Error::ModelComputation(err.message.to_string())
|
||||
})?;
|
||||
|
@ -203,7 +203,7 @@ impl Model {
|
|||
.await
|
||||
.unwrap()?;
|
||||
// Convert the output to a value
|
||||
Ok(outcome[0].into())
|
||||
Ok(outcome.into())
|
||||
}
|
||||
//
|
||||
_ => Err(Error::InvalidArguments {
|
||||
|
|
|
@ -458,6 +458,12 @@ impl From<Vec<i32>> for Value {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Vec<f32>> for Value {
|
||||
fn from(v: Vec<f32>) -> Self {
|
||||
Value::Array(Array::from(v))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Value>> for Value {
|
||||
fn from(v: Vec<Value>) -> Self {
|
||||
Value::Array(Array::from(v))
|
||||
|
@ -3005,4 +3011,18 @@ mod tests {
|
|||
let dec: Value = enc.into();
|
||||
assert_eq!(res, dec);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_value_from_vec_i32() {
|
||||
let vector: Vec<i32> = vec![1, 2, 3, 4, 5, 6];
|
||||
let value = Value::from(vector);
|
||||
assert!(matches!(value, Value::Array(Array(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_value_from_vec_f32() {
|
||||
let vector: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let value = Value::from(vector);
|
||||
assert!(matches!(value, Value::Array(Array(_))));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use nom::AsBytes;
|
||||
|
||||
use crate::{
|
||||
sql::{self, Id, Statement, Thing, Value},
|
||||
syn::parser::mac::test_parse,
|
||||
|
|
|
@ -26,7 +26,7 @@ mod ml_integration {
|
|||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct Data {
|
||||
result: f64,
|
||||
result: Vec<f64>,
|
||||
status: String,
|
||||
time: String,
|
||||
}
|
||||
|
@ -245,8 +245,9 @@ mod ml_integration {
|
|||
.await?;
|
||||
assert!(res.status().is_success(), "body: {}", res.text().await?);
|
||||
let body = res.text().await?;
|
||||
|
||||
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
|
||||
assert_eq!(deserialized_data[0].result, 0.9998061656951904);
|
||||
assert_eq!(deserialized_data[0].result[0], 0.9998061656951904);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -282,7 +283,7 @@ mod ml_integration {
|
|||
assert!(res.status().is_success(), "body: {}", res.text().await?);
|
||||
let body = res.text().await?;
|
||||
let deserialized_data: Vec<Data> = serde_json::from_str(&body)?;
|
||||
assert_eq!(deserialized_data[0].result, 177206.21875);
|
||||
assert_eq!(deserialized_data[0].result[0], 177206.21875);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue