From d002fa395895ac128afefac1ac7372e697bb87ca Mon Sep 17 00:00:00 2001 From: Rushmore Mushambi Date: Wed, 21 Aug 2024 15:35:26 +0100 Subject: [PATCH] Improve consistency of the `run` method API with existing methods (#4563) --- lib/benches/processor.rs | 2 +- lib/src/api/engine/local/mod.rs | 39 +++--- lib/src/api/method/mod.rs | 39 +++--- lib/src/api/method/run.rs | 202 ++++++++++++-------------------- lib/src/api/method/tests/mod.rs | 3 +- lib/tests/api/mod.rs | 24 ++-- 6 files changed, 116 insertions(+), 193 deletions(-) diff --git a/lib/benches/processor.rs b/lib/benches/processor.rs index b7c7e04f..9d1f7305 100644 --- a/lib/benches/processor.rs +++ b/lib/benches/processor.rs @@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughpu use std::time::Duration; use surrealdb::dbs::Session; use surrealdb::kvs::Datastore; -use surrealdb::Value; +use surrealdb_core::sql::Value; use tokio::runtime::Runtime; fn bench_processor(c: &mut Criterion) { diff --git a/lib/src/api/engine/local/mod.rs b/lib/src/api/engine/local/mod.rs index 58b79cf7..2bd27520 100644 --- a/lib/src/api/engine/local/mod.rs +++ b/lib/src/api/engine/local/mod.rs @@ -52,7 +52,6 @@ use surrealdb_core::{ }; use uuid::Uuid; -#[cfg(not(target_arch = "wasm32"))] use crate::api::err::Error; #[cfg(not(target_arch = "wasm32"))] use std::path::PathBuf; @@ -1008,28 +1007,24 @@ async fn router( version: _version, args, } => { - let func: CoreValue = if let Some(name) = name.strip_prefix("fn::") { - Function::Custom(name.to_owned(), args.0).into() - } else if let Some(_name) = name.strip_prefix("ml::") { - #[cfg(feature = "ml")] - { - let mut model = Model::default(); - - model.name = _name.to_owned(); - model.args = args.0; - model.version = _version - .ok_or(Error::Query("ML functions must have a version".to_string()))?; - model.into() - } - #[cfg(not(feature = "ml"))] - { - return Err(crate::error::Db::InvalidModel { - message: "Machine learning computation is not enabled.".to_owned(), + let func: CoreValue = match name.strip_prefix("fn::") { + Some(name) => Function::Custom(name.to_owned(), args.0).into(), + None => match name.strip_prefix("ml::") { + #[cfg(feature = "ml")] + Some(name) => { + let mut tmp = Model::default(); + tmp.name = name.to_owned(); + tmp.args = args.0; + tmp.version = _version + .ok_or(Error::Query("ML functions must have a version".to_string()))?; + tmp.into() } - .into()); - } - } else { - Function::Custom(name, args.0).into() + #[cfg(not(feature = "ml"))] + Some(_) => { + return Err(Error::Query(format!("tried to call an ML function `{name}` but the `ml` feature is not enabled")).into()); + } + None => Function::Normal(name, args.0).into(), + }, }; let stmt = Statement::Value(func); diff --git a/lib/src/api/method/mod.rs b/lib/src/api/method/mod.rs index 61d54e36..24227b57 100644 --- a/lib/src/api/method/mod.rs +++ b/lib/src/api/method/mod.rs @@ -19,7 +19,7 @@ use std::pin::Pin; use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; -use surrealdb_core::sql::{to_value as to_core_value, Array as CoreArray}; +use surrealdb_core::sql::to_value as to_core_value; pub(crate) mod live; pub(crate) mod query; @@ -76,8 +76,8 @@ pub use merge::Merge; pub use patch::Patch; pub use query::Query; pub use query::QueryStream; +pub use run::IntoFn; pub use run::Run; -pub use run::{IntoArgs, IntoFn}; pub use select::Select; use serde_content::Serializer; pub use set::Set; @@ -1223,42 +1223,33 @@ where } } - // TODO: Re-enable doc tests /// Runs a function /// /// # Examples /// - /// ```ignore + /// ```no_run /// # #[tokio::main] /// # async fn main() -> surrealdb::Result<()> { /// # let db = surrealdb::engine::any::connect("mem://").await?; - /// // Note that the sdk is currently undergoing some changes so the below examples might not - /// work until the sdk is somewhat more stable. - /// - /// // specify no args with an empty tuple, vec, or slice - /// let foo = db.run("fn::foo", ()).await?; // fn::foo() - /// // a single value will be turned into one arguement unless it is a tuple or vec - /// let bar = db.run("fn::bar", 42).await?; // fn::bar(42) - /// // to specify a single arguement, which is an array turn it into a value, or wrap in a singleton tuple - /// let count = db.run("count", Value::from(vec![1,2,3])).await?; - /// let count = db.run("count", (vec![1,2,3],)).await?; - /// // specify multiple args with either a tuple or vec - /// let two = db.run("math::log", (100, 10)).await?; // math::log(100, 10) - /// let two = db.run("math::log", [100, 10]).await?; // math::log(100, 10) + /// // Specify no args by not calling `.args()` + /// let foo = db.run("fn::foo").await?; // fn::foo() + /// // A single value will be turned into one argument + /// let bar = db.run("fn::bar").args(42).await?; // fn::bar(42) + /// // Arrays are treated as single arguments + /// let count = db.run("count").args(vec![1,2,3]).await?; + /// // Specify multiple args using a tuple + /// let two = db.run("math::log").args((100, 10)).await?; // math::log(100, 10) /// /// # Ok(()) /// # } /// ``` /// - pub fn run(&self, name: impl IntoFn, args: impl IntoArgs) -> Run { - let (name, version) = name.into_fn(); - let mut arguments = CoreArray::default(); - arguments.0 = crate::Value::array_to_core(args.into_args()); + pub fn run(&self, function: impl IntoFn) -> Run { Run { client: Cow::Borrowed(self), - name, - version, - args: arguments, + function: function.into_fn(), + args: Ok(serde_content::Value::Tuple(vec![])), + response_type: PhantomData, } } diff --git a/lib/src/api/method/run.rs b/lib/src/api/method/run.rs index 59f0a9e5..509e0d6d 100644 --- a/lib/src/api/method/run.rs +++ b/lib/src/api/method/run.rs @@ -1,30 +1,35 @@ use crate::api::conn::Command; +use crate::api::method::BoxFuture; use crate::api::Connection; use crate::api::Result; use crate::method::OnceLockExt; +use crate::sql::Value; use crate::Surreal; -use crate::Value; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_content::Serializer; +use serde_content::Value as Content; use std::borrow::Cow; use std::future::IntoFuture; -use surrealdb_core::sql::Array as CoreArray; - -use super::BoxFuture; +use std::marker::PhantomData; +use surrealdb_core::sql::to_value; +use surrealdb_core::sql::Array; /// A run future #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Run<'r, C: Connection> { +pub struct Run<'r, C: Connection, R> { pub(super) client: Cow<'r, Surreal>, - pub(super) name: String, - pub(super) version: Option, - pub(super) args: CoreArray, + pub(super) function: Result<(String, Option)>, + pub(super) args: serde_content::Result>, + pub(super) response_type: PhantomData, } -impl Run<'_, C> +impl Run<'_, C, R> where C: Connection, { /// Converts to an owned type which can easily be moved to a different thread - pub fn into_owned(self) -> Run<'static, C> { + pub fn into_owned(self) -> Run<'static, C, R> { Run { client: Cow::Owned(self.client.into_owned()), ..self @@ -32,24 +37,36 @@ where } } -impl<'r, Client> IntoFuture for Run<'r, Client> +impl<'r, Client, R> IntoFuture for Run<'r, Client, R> where Client: Connection, + R: DeserializeOwned, { - type Output = Result; + type Output = Result; type IntoFuture = BoxFuture<'r, Self::Output>; fn into_future(self) -> Self::IntoFuture { let Run { client, - name, - version, + function, args, + .. } = self; Box::pin(async move { let router = client.router.extract()?; + let (name, version) = function?; + let value = match args.map_err(crate::error::Db::from)? { + // Tuples are treated as multiple function arguments + Content::Tuple(tup) => tup, + // Everything else is treated as a single argument + content => vec![content], + }; + let args = match to_value(value)? { + Value::Array(array) => array, + value => Array::from(vec![value]), + }; router - .execute_value(Command::Run { + .execute(Command::Run { name, version, args, @@ -59,132 +76,57 @@ where } } -pub trait IntoArgs { - fn into_args(self) -> Vec; -} - -impl IntoArgs for Value { - fn into_args(self) -> Vec { - vec![self] - } -} - -impl IntoArgs for Vec +impl<'r, Client, R> Run<'r, Client, R> where - T: Into, + Client: Connection, { - fn into_args(self) -> Vec { - self.into_iter().map(Into::into).collect() + /// Supply arguments to the function being run. + pub fn args(mut self, args: impl Serialize) -> Self { + self.args = Serializer::new().serialize(args); + self } } -impl IntoArgs for [T; N] -where - T: Into, -{ - fn into_args(self) -> Vec { - self.into_iter().map(Into::into).collect() - } -} - -impl IntoArgs for &[T; N] -where - T: Into + Clone, -{ - fn into_args(self) -> Vec { - self.iter().cloned().map(Into::into).collect() - } -} - -impl IntoArgs for &[T] -where - T: Into + Clone, -{ - fn into_args(self) -> Vec { - self.iter().cloned().map(Into::into).collect() - } -} - -macro_rules! impl_args_tuple { - ($($i:ident), *$(,)?) => { - impl_args_tuple!(@marker $($i,)*); - }; - ($($cur:ident,)* @marker $head:ident, $($tail:ident,)*) => { - impl<$($cur: Into,)*> IntoArgs for ($($cur,)*) { - #[allow(non_snake_case)] - fn into_args(self) -> Vec { - let ($($cur,)*) = self; - vec![$($cur.into(),)*] - } - } - - impl_args_tuple!($($cur,)* $head, @marker $($tail,)*); - }; - ($($cur:ident,)* @marker ) => { - impl<$($cur: Into,)*> IntoArgs for ($($cur,)*) { - #[allow(non_snake_case)] - fn into_args(self) -> Vec { - let ($($cur,)*) = self; - vec![$($cur.into(),)*] - } - } - } -} - -impl_args_tuple!(A, B, C, D, E, F,); - -/* TODO: Removed for now. - * The detach value PR removed a lot of conversion methods with, pending later request which might - * add them back depending on how the sdk turns out. - * -macro_rules! into_impl { - ($type:ty) => { - impl IntoArgs for $type { - fn into_args(self) -> Vec { - vec![Value::from(self)] - } - } - }; -} -into_impl!(i8); -into_impl!(i16); -into_impl!(i32); -into_impl!(i64); -into_impl!(i128); -into_impl!(u8); -into_impl!(u16); -into_impl!(u32); -into_impl!(u64); -into_impl!(u128); -into_impl!(usize); -into_impl!(isize); -into_impl!(f32); -into_impl!(f64); -into_impl!(String); -into_impl!(&str); -*/ - +/// Converts a function into name and version parts pub trait IntoFn { - fn into_fn(self) -> (String, Option); + /// Handles the conversion of the function string + fn into_fn(self) -> Result<(String, Option)>; } impl IntoFn for String { - fn into_fn(self) -> (String, Option) { - (self, None) - } -} -impl IntoFn for &str { - fn into_fn(self) -> (String, Option) { - (self.to_owned(), None) + fn into_fn(self) -> Result<(String, Option)> { + match self.split_once('<') { + Some((name, rest)) => match rest.strip_suffix('>') { + Some(version) => Ok((name.to_owned(), Some(version.to_owned()))), + None => Err(crate::error::Db::InvalidFunction { + name: self, + message: "function version is missing a clossing '>'".to_owned(), + } + .into()), + }, + None => Ok((self, None)), + } } } -impl IntoFn for (S0, S1) -where - S0: Into, - S1: Into, -{ - fn into_fn(self) -> (String, Option) { - (self.0.into(), Some(self.1.into())) +impl IntoFn for &str { + fn into_fn(self) -> Result<(String, Option)> { + match self.split_once('<') { + Some((name, rest)) => match rest.strip_suffix('>') { + Some(version) => Ok((name.to_owned(), Some(version.to_owned()))), + None => Err(crate::error::Db::InvalidFunction { + name: self.to_owned(), + message: "function version is missing a clossing '>'".to_owned(), + } + .into()), + }, + None => Ok((self.to_owned(), None)), + } + } +} + +impl IntoFn for &String { + fn into_fn(self) -> Result<(String, Option)> { + self.as_str().into_fn() } } diff --git a/lib/src/api/method/tests/mod.rs b/lib/src/api/method/tests/mod.rs index 6863ca5a..702b9b2d 100644 --- a/lib/src/api/method/tests/mod.rs +++ b/lib/src/api/method/tests/mod.rs @@ -14,7 +14,6 @@ use crate::api::opt::auth::Root; use crate::api::opt::PatchOp; use crate::api::Response as QueryResponse; use crate::api::Surreal; -use crate::Value; use once_cell::sync::Lazy; use protocol::Client; use protocol::Test; @@ -170,7 +169,7 @@ async fn api() { let _: Version = DB.version().await.unwrap(); // run - let _: Value = DB.run("foo", ()).await.unwrap(); + let _: Option = DB.run("foo").await.unwrap(); } fn assert_send_sync(_: impl Send + Sync) {} diff --git a/lib/tests/api/mod.rs b/lib/tests/api/mod.rs index 3c53c09b..93e6342a 100644 --- a/lib/tests/api/mod.rs +++ b/lib/tests/api/mod.rs @@ -1363,9 +1363,6 @@ async fn return_bool() { assert_eq!(value.into_inner(), CoreValue::Bool(false)); } -/* - * TODO: Reenable test. - * Disabling run test for now as it depends on value conversions which are removed #[test_log::test(tokio::test)] async fn run() { let (permit, db) = new_db().await; @@ -1384,24 +1381,23 @@ async fn run() { "; let _ = db.query(sql).await; - let tmp = db.run("fn::foo", ()).await.unwrap(); - assert_eq!(tmp, Value::from(42)); + let tmp: i32 = db.run("fn::foo").await.unwrap(); + assert_eq!(tmp, 42); - let tmp = db.run("fn::foo", 7).await.unwrap_err(); + let tmp = db.run::("fn::foo").args(7).await.unwrap_err(); println!("fn::foo res: {tmp}"); assert!(tmp.to_string().contains("The function expects 0 arguments.")); - let tmp = db.run("fn::idnotexist", ()).await.unwrap_err(); + let tmp = db.run::<()>("fn::idnotexist").await.unwrap_err(); println!("fn::idontexist res: {tmp}"); assert!(tmp.to_string().contains("The function 'fn::idnotexist' does not exist")); - let tmp = db.run("count", Value::from(vec![1, 2, 3])).await.unwrap(); - assert_eq!(tmp, Value::from(3)); + let tmp: usize = db.run("count").args(vec![1, 2, 3]).await.unwrap(); + assert_eq!(tmp, 3); - let tmp = db.run("fn::bar", 7).await.unwrap(); - assert_eq!(tmp, Value::None); + let tmp: Option = db.run("fn::bar").args(7).await.unwrap(); + assert_eq!(tmp, None); - let tmp = db.run("fn::baz", ()).await.unwrap(); - assert_eq!(tmp, Value::from(7)); + let tmp: i32 = db.run("fn::baz").await.unwrap(); + assert_eq!(tmp, 7); } -*/