From b1e9af5d4eceb03e099f8ee28417dace4f05ec49 Mon Sep 17 00:00:00 2001 From: Raphael Darley Date: Tue, 20 Aug 2024 03:50:19 -0700 Subject: [PATCH] Add run to sdk (#4305) Co-authored-by: Micha de Vries --- lib/src/api/conn/cmd.rs | 17 +++ lib/src/api/engine/local/mod.rs | 32 +++++ lib/src/api/method/mod.rs | 37 ++++++ lib/src/api/method/run.rs | 207 +++++++++++++++++++++++++++++ lib/src/api/method/tests/mod.rs | 4 + lib/src/api/method/tests/server.rs | 3 + lib/tests/api/mod.rs | 39 ++++++ 7 files changed, 339 insertions(+) create mode 100644 lib/src/api/method/run.rs diff --git a/lib/src/api/conn/cmd.rs b/lib/src/api/conn/cmd.rs index 8ae63d9d..35290f90 100644 --- a/lib/src/api/conn/cmd.rs +++ b/lib/src/api/conn/cmd.rs @@ -6,6 +6,7 @@ use revision::Revisioned; use serde::{ser::SerializeMap as _, Serialize}; use std::path::PathBuf; use std::{collections::BTreeMap, io::Read}; +use surrealdb_core::sql::Array; use surrealdb_core::{ dbs::Notification, sql::{Object, Query, Value}, @@ -99,6 +100,11 @@ pub(crate) enum Command { Kill { uuid: Uuid, }, + Run { + name: String, + version: Option, + args: Array, + }, } impl Command { @@ -318,6 +324,17 @@ impl Command { method: "kill".into(), params: Some(vec![Value::from(uuid)].into()), }, + Command::Run { + name, + version, + args, + } => RouterRequest { + id, + method: "run".into(), + params: Some( + vec![Value::from(name), Value::from(version), Value::Array(args)].into(), + ), + }, }; Some(res) } diff --git a/lib/src/api/engine/local/mod.rs b/lib/src/api/engine/local/mod.rs index 0acfb3cd..2b4e8689 100644 --- a/lib/src/api/engine/local/mod.rs +++ b/lib/src/api/engine/local/mod.rs @@ -59,6 +59,9 @@ use surrealdb_core::{ use crate::api::err::Error; #[cfg(not(target_arch = "wasm32"))] use std::path::PathBuf; +use surrealdb_core::sql::Function; +#[cfg(feature = "ml")] +use surrealdb_core::sql::Model; #[cfg(not(target_arch = "wasm32"))] use tokio::{ fs::OpenOptions, @@ -981,5 +984,34 @@ async fn router( let value = kill_live_query(kvs, uuid.into(), session, vars.clone()).await?; Ok(DbResponse::Other(value)) } + + Command::Run { + name, + version: _version, + args, + } => { + let func: Value = match &name[0..4] { + "fn::" => Function::Custom(name.chars().skip(4).collect(), args.0).into(), + // should return error, but can't on wasm + #[cfg(feature = "ml")] + "ml::" => { + let mut tmp = Model::default(); + + tmp.name = name.chars().skip(4).collect(); + tmp.args = args.0; + tmp.version = _version + .ok_or(Error::Query("ML functions must have a version".to_string()))?; + tmp + } + .into(), + _ => Function::Normal(name, args.0).into(), + }; + let stmt = Statement::Value(func); + + let response = kvs.process(stmt.into(), &*session, Some(vars.clone())).await?; + let value = take(true, response).await?; + + Ok(DbResponse::Other(value)) + } } } diff --git a/lib/src/api/method/mod.rs b/lib/src/api/method/mod.rs index 3e8ce1ea..9c9a6395 100644 --- a/lib/src/api/method/mod.rs +++ b/lib/src/api/method/mod.rs @@ -12,6 +12,7 @@ use crate::api::Surreal; use crate::opt::IntoExportDestination; use crate::opt::WaitFor; use crate::sql::to_value; +use run::IntoArgs; use serde::Serialize; use std::borrow::Cow; use std::marker::PhantomData; @@ -38,6 +39,7 @@ mod insert; mod invalidate; mod merge; mod patch; +mod run; mod select; mod set; mod signin; @@ -75,6 +77,8 @@ pub use merge::Merge; pub use patch::Patch; pub use query::Query; pub use query::QueryStream; +use run::IntoFn; +pub use run::Run; pub use select::Select; use serde_content::Serializer; pub use set::Set; @@ -1222,6 +1226,39 @@ where } } + /// Runs a function + /// + /// # Examples + /// + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> surrealdb::Result<()> { + /// # let db = surrealdb::engine::any::connect("mem://").await?; + /// // specify no args with an empty tuple, vec, or slice + /// let foo = db.run("fn::foo", ()).await?; // fn::foo() + /// // a single value will be turned into one arguement unless it is a tuple or vec + /// let bar = db.run("fn::bar", 42).await?; // fn::bar(42) + /// // to specify a single arguement, which is an array turn it into a value, or wrap in a singleton tuple + /// let count = db.run("count", Value::from(vec![1,2,3])).await?; + /// let count = db.run("count", (vec![1,2,3],)).await?; + /// // specify multiple args with either a tuple or vec + /// let two = db.run("math::log", (100, 10)).await?; // math::log(100, 10) + /// let two = db.run("math::log", [100, 10]).await?; // math::log(100, 10) + /// + /// # Ok(()) + /// # } + /// ``` + /// + pub fn run(&self, name: impl IntoFn, args: impl IntoArgs) -> Run { + let (name, version) = name.into_fn(); + Run { + client: Cow::Borrowed(self), + name, + version, + args: args.into_args(), + } + } + /// Checks whether the server is healthy or not /// /// # Examples diff --git a/lib/src/api/method/run.rs b/lib/src/api/method/run.rs new file mode 100644 index 00000000..f897b091 --- /dev/null +++ b/lib/src/api/method/run.rs @@ -0,0 +1,207 @@ +use crate::api::conn::Command; +use crate::api::Connection; +use crate::api::Result; +use crate::method::OnceLockExt; +use crate::sql::Array; +use crate::sql::Value; +use crate::Surreal; +use std::borrow::Cow; +use std::future::Future; +use std::future::IntoFuture; +use std::pin::Pin; + +/// A run future +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Run<'r, C: Connection> { + pub(super) client: Cow<'r, Surreal>, + pub(super) name: String, + pub(super) version: Option, + pub(super) args: Array, +} +impl Run<'_, C> +where + C: Connection, +{ + /// Converts to an owned type which can easily be moved to a different thread + pub fn into_owned(self) -> Run<'static, C> { + Run { + client: Cow::Owned(self.client.into_owned()), + ..self + } + } +} + +impl<'r, Client> IntoFuture for Run<'r, Client> +where + Client: Connection, +{ + type Output = Result; + type IntoFuture = Pin + Send + Sync + 'r>>; + + fn into_future(self) -> Self::IntoFuture { + let Run { + client, + name, + version, + args, + } = self; + Box::pin(async move { + let router = client.router.extract()?; + router + .execute_value(Command::Run { + name, + version, + args, + }) + .await + }) + } +} + +pub trait IntoArgs { + fn into_args(self) -> Array; +} + +impl IntoArgs for Array { + fn into_args(self) -> Array { + self + } +} + +impl IntoArgs for Value { + fn into_args(self) -> Array { + let arr: Vec = vec![self]; + Array::from(arr) + } +} + +impl IntoArgs for Vec +where + T: Into, +{ + fn into_args(self) -> Array { + let arr: Vec = self.into_iter().map(Into::into).collect(); + Array::from(arr) + } +} + +impl IntoArgs for [T; N] +where + T: Into, +{ + fn into_args(self) -> Array { + let arr: Vec = self.into_iter().map(Into::into).collect(); + Array::from(arr) + } +} + +impl IntoArgs for &[T; N] +where + T: Into + Clone, +{ + fn into_args(self) -> Array { + let arr: Vec = self.iter().cloned().map(Into::into).collect(); + Array::from(arr) + } +} + +impl IntoArgs for &[T] +where + T: Into + Clone, +{ + fn into_args(self) -> Array { + let arr: Vec = self.iter().cloned().map(Into::into).collect(); + Array::from(arr) + } +} +impl IntoArgs for () { + fn into_args(self) -> Array { + Vec::::new().into() + } +} + +impl IntoArgs for (T0,) +where + T0: Into, +{ + fn into_args(self) -> Array { + let arr: Vec = vec![self.0.into()]; + Array::from(arr) + } +} + +impl IntoArgs for (T0, T1) +where + T0: Into, + T1: Into, +{ + fn into_args(self) -> Array { + let arr: Vec = vec![self.0.into(), self.1.into()]; + Array::from(arr) + } +} + +impl IntoArgs for (T0, T1, T2) +where + T0: Into, + T1: Into, + T2: Into, +{ + fn into_args(self) -> Array { + let arr: Vec = vec![self.0.into(), self.1.into(), self.2.into()]; + Array::from(arr) + } +} + +macro_rules! into_impl { + ($type:ty) => { + impl IntoArgs for $type { + fn into_args(self) -> Array { + let val: Value = self.into(); + Array::from(val) + } + } + }; +} +into_impl!(i8); +into_impl!(i16); +into_impl!(i32); +into_impl!(i64); +into_impl!(i128); +into_impl!(u8); +into_impl!(u16); +into_impl!(u32); +into_impl!(u64); +into_impl!(u128); +into_impl!(usize); +into_impl!(isize); +into_impl!(f32); +into_impl!(f64); +into_impl!(String); +into_impl!(&str); + +pub trait IntoFn { + fn into_fn(self) -> (String, Option); +} + +impl IntoFn for String { + fn into_fn(self) -> (String, Option) { + (self, None) + } +} +impl IntoFn for &str { + fn into_fn(self) -> (String, Option) { + (self.to_owned(), None) + } +} + +impl IntoFn for (S0, S1) +where + S0: Into, + S1: Into, +{ + fn into_fn(self) -> (String, Option) { + (self.0.into(), Some(self.1.into())) + } +} diff --git a/lib/src/api/method/tests/mod.rs b/lib/src/api/method/tests/mod.rs index cc8e0bd5..08568979 100644 --- a/lib/src/api/method/tests/mod.rs +++ b/lib/src/api/method/tests/mod.rs @@ -21,6 +21,7 @@ use protocol::Client; use protocol::Test; use semver::Version; use std::ops::Bound; +use surrealdb_core::sql::Value; use types::User; use types::USER; @@ -168,6 +169,9 @@ async fn api() { // version let _: Version = DB.version().await.unwrap(); + + // run + let _: Value = DB.run("foo", ()).await.unwrap(); } fn assert_send_sync(_: impl Send + Sync) {} diff --git a/lib/src/api/method/tests/server.rs b/lib/src/api/method/tests/server.rs index eeaef5ea..e9eb4210 100644 --- a/lib/src/api/method/tests/server.rs +++ b/lib/src/api/method/tests/server.rs @@ -101,6 +101,9 @@ pub(super) fn mock(route_rx: Receiver) { } _ => unreachable!(), }, + Command::Run { + .. + } => Ok(DbResponse::Other(Value::None)), Command::ExportMl { .. } diff --git a/lib/tests/api/mod.rs b/lib/tests/api/mod.rs index bac6588b..671a4cee 100644 --- a/lib/tests/api/mod.rs +++ b/lib/tests/api/mod.rs @@ -1356,3 +1356,42 @@ async fn return_bool() { let value: Value = response.take(0).unwrap(); assert_eq!(value, Value::Bool(false)); } + +#[test_log::test(tokio::test)] +async fn run() { + let (permit, db) = new_db().await; + db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap(); + drop(permit); + let sql = " + DEFINE FUNCTION fn::foo() { + RETURN 42; + }; + DEFINE FUNCTION fn::bar($val: any) { + CREATE foo:1 set val = $val; + }; + DEFINE FUNCTION fn::baz() { + RETURN SELECT VALUE val FROM ONLY foo:1; + }; + "; + let _ = db.query(sql).await; + + let tmp = db.run("fn::foo", ()).await.unwrap(); + assert_eq!(tmp, Value::from(42)); + + let tmp = db.run("fn::foo", 7).await.unwrap_err(); + println!("fn::foo res: {tmp}"); + assert!(tmp.to_string().contains("The function expects 0 arguments.")); + + let tmp = db.run("fn::idnotexist", ()).await.unwrap_err(); + println!("fn::idontexist res: {tmp}"); + assert!(tmp.to_string().contains("The function 'fn::idnotexist' does not exist")); + + let tmp = db.run("count", Value::from(vec![1, 2, 3])).await.unwrap(); + assert_eq!(tmp, Value::from(3)); + + let tmp = db.run("fn::bar", 7).await.unwrap(); + assert_eq!(tmp, Value::None); + + let tmp = db.run("fn::baz", ()).await.unwrap(); + assert_eq!(tmp, Value::from(7)); +}