Improve consistency of the run method API with existing methods (#4563)

This commit is contained in:
Rushmore Mushambi 2024-08-21 15:35:26 +01:00 committed by GitHub
parent 13b6788540
commit d002fa3958
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 116 additions and 193 deletions

View file

@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughpu
use std::time::Duration; use std::time::Duration;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore; use surrealdb::kvs::Datastore;
use surrealdb::Value; use surrealdb_core::sql::Value;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
fn bench_processor(c: &mut Criterion) { fn bench_processor(c: &mut Criterion) {

View file

@ -52,7 +52,6 @@ use surrealdb_core::{
}; };
use uuid::Uuid; use uuid::Uuid;
#[cfg(not(target_arch = "wasm32"))]
use crate::api::err::Error; use crate::api::err::Error;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf; use std::path::PathBuf;
@ -1008,28 +1007,24 @@ async fn router(
version: _version, version: _version,
args, args,
} => { } => {
let func: CoreValue = if let Some(name) = name.strip_prefix("fn::") { let func: CoreValue = match name.strip_prefix("fn::") {
Function::Custom(name.to_owned(), args.0).into() Some(name) => Function::Custom(name.to_owned(), args.0).into(),
} else if let Some(_name) = name.strip_prefix("ml::") { None => match name.strip_prefix("ml::") {
#[cfg(feature = "ml")] #[cfg(feature = "ml")]
{ Some(name) => {
let mut model = Model::default(); let mut tmp = Model::default();
tmp.name = name.to_owned();
model.name = _name.to_owned(); tmp.args = args.0;
model.args = args.0; tmp.version = _version
model.version = _version
.ok_or(Error::Query("ML functions must have a version".to_string()))?; .ok_or(Error::Query("ML functions must have a version".to_string()))?;
model.into() tmp.into()
} }
#[cfg(not(feature = "ml"))] #[cfg(not(feature = "ml"))]
{ Some(_) => {
return Err(crate::error::Db::InvalidModel { return Err(Error::Query(format!("tried to call an ML function `{name}` but the `ml` feature is not enabled")).into());
message: "Machine learning computation is not enabled.".to_owned(),
} }
.into()); None => Function::Normal(name, args.0).into(),
} },
} else {
Function::Custom(name, args.0).into()
}; };
let stmt = Statement::Value(func); let stmt = Statement::Value(func);

View file

@ -19,7 +19,7 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::sync::OnceLock; use std::sync::OnceLock;
use std::time::Duration; use std::time::Duration;
use surrealdb_core::sql::{to_value as to_core_value, Array as CoreArray}; use surrealdb_core::sql::to_value as to_core_value;
pub(crate) mod live; pub(crate) mod live;
pub(crate) mod query; pub(crate) mod query;
@ -76,8 +76,8 @@ pub use merge::Merge;
pub use patch::Patch; pub use patch::Patch;
pub use query::Query; pub use query::Query;
pub use query::QueryStream; pub use query::QueryStream;
pub use run::IntoFn;
pub use run::Run; pub use run::Run;
pub use run::{IntoArgs, IntoFn};
pub use select::Select; pub use select::Select;
use serde_content::Serializer; use serde_content::Serializer;
pub use set::Set; pub use set::Set;
@ -1223,42 +1223,33 @@ where
} }
} }
// TODO: Re-enable doc tests
/// Runs a function /// Runs a function
/// ///
/// # Examples /// # Examples
/// ///
/// ```ignore /// ```no_run
/// # #[tokio::main] /// # #[tokio::main]
/// # async fn main() -> surrealdb::Result<()> { /// # async fn main() -> surrealdb::Result<()> {
/// # let db = surrealdb::engine::any::connect("mem://").await?; /// # let db = surrealdb::engine::any::connect("mem://").await?;
/// // Note that the sdk is currently undergoing some changes so the below examples might not /// // Specify no args by not calling `.args()`
/// work until the sdk is somewhat more stable. /// let foo = db.run("fn::foo").await?; // fn::foo()
/// /// // A single value will be turned into one argument
/// // specify no args with an empty tuple, vec, or slice /// let bar = db.run("fn::bar").args(42).await?; // fn::bar(42)
/// let foo = db.run("fn::foo", ()).await?; // fn::foo() /// // Arrays are treated as single arguments
/// // a single value will be turned into one arguement unless it is a tuple or vec /// let count = db.run("count").args(vec![1,2,3]).await?;
/// let bar = db.run("fn::bar", 42).await?; // fn::bar(42) /// // Specify multiple args using a tuple
/// // to specify a single arguement, which is an array turn it into a value, or wrap in a singleton tuple /// let two = db.run("math::log").args((100, 10)).await?; // math::log(100, 10)
/// let count = db.run("count", Value::from(vec![1,2,3])).await?;
/// let count = db.run("count", (vec![1,2,3],)).await?;
/// // specify multiple args with either a tuple or vec
/// let two = db.run("math::log", (100, 10)).await?; // math::log(100, 10)
/// let two = db.run("math::log", [100, 10]).await?; // math::log(100, 10)
/// ///
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
/// ///
pub fn run(&self, name: impl IntoFn, args: impl IntoArgs) -> Run<C> { pub fn run<R>(&self, function: impl IntoFn) -> Run<C, R> {
let (name, version) = name.into_fn();
let mut arguments = CoreArray::default();
arguments.0 = crate::Value::array_to_core(args.into_args());
Run { Run {
client: Cow::Borrowed(self), client: Cow::Borrowed(self),
name, function: function.into_fn(),
version, args: Ok(serde_content::Value::Tuple(vec![])),
args: arguments, response_type: PhantomData,
} }
} }

View file

@ -1,30 +1,35 @@
use crate::api::conn::Command; use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal; use crate::Surreal;
use crate::Value; use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_content::Serializer;
use serde_content::Value as Content;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
use surrealdb_core::sql::Array as CoreArray; use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
use super::BoxFuture; use surrealdb_core::sql::Array;
/// A run future /// A run future
#[derive(Debug)] #[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Run<'r, C: Connection> { pub struct Run<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) name: String, pub(super) function: Result<(String, Option<String>)>,
pub(super) version: Option<String>, pub(super) args: serde_content::Result<serde_content::Value<'static>>,
pub(super) args: CoreArray, pub(super) response_type: PhantomData<R>,
} }
impl<C> Run<'_, C> impl<C, R> Run<'_, C, R>
where where
C: Connection, C: Connection,
{ {
/// Converts to an owned type which can easily be moved to a different thread /// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Run<'static, C> { pub fn into_owned(self) -> Run<'static, C, R> {
Run { Run {
client: Cow::Owned(self.client.into_owned()), client: Cow::Owned(self.client.into_owned()),
..self ..self
@ -32,24 +37,36 @@ where
} }
} }
impl<'r, Client> IntoFuture for Run<'r, Client> impl<'r, Client, R> IntoFuture for Run<'r, Client, R>
where where
Client: Connection, Client: Connection,
R: DeserializeOwned,
{ {
type Output = Result<Value>; type Output = Result<R>;
type IntoFuture = BoxFuture<'r, Self::Output>; type IntoFuture = BoxFuture<'r, Self::Output>;
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
let Run { let Run {
client, client,
name, function,
version,
args, args,
..
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let router = client.router.extract()?; let router = client.router.extract()?;
let (name, version) = function?;
let value = match args.map_err(crate::error::Db::from)? {
// Tuples are treated as multiple function arguments
Content::Tuple(tup) => tup,
// Everything else is treated as a single argument
content => vec![content],
};
let args = match to_value(value)? {
Value::Array(array) => array,
value => Array::from(vec![value]),
};
router router
.execute_value(Command::Run { .execute(Command::Run {
name, name,
version, version,
args, args,
@ -59,132 +76,57 @@ where
} }
} }
pub trait IntoArgs { impl<'r, Client, R> Run<'r, Client, R>
fn into_args(self) -> Vec<Value>;
}
impl IntoArgs for Value {
fn into_args(self) -> Vec<Value> {
vec![self]
}
}
impl<T> IntoArgs for Vec<T>
where where
T: Into<Value>, Client: Connection,
{ {
fn into_args(self) -> Vec<Value> { /// Supply arguments to the function being run.
self.into_iter().map(Into::into).collect() pub fn args(mut self, args: impl Serialize) -> Self {
self.args = Serializer::new().serialize(args);
self
} }
} }
impl<T, const N: usize> IntoArgs for [T; N] /// Converts a function into name and version parts
where
T: Into<Value>,
{
fn into_args(self) -> Vec<Value> {
self.into_iter().map(Into::into).collect()
}
}
impl<T, const N: usize> IntoArgs for &[T; N]
where
T: Into<Value> + Clone,
{
fn into_args(self) -> Vec<Value> {
self.iter().cloned().map(Into::into).collect()
}
}
impl<T> IntoArgs for &[T]
where
T: Into<Value> + Clone,
{
fn into_args(self) -> Vec<Value> {
self.iter().cloned().map(Into::into).collect()
}
}
macro_rules! impl_args_tuple {
($($i:ident), *$(,)?) => {
impl_args_tuple!(@marker $($i,)*);
};
($($cur:ident,)* @marker $head:ident, $($tail:ident,)*) => {
impl<$($cur: Into<Value>,)*> IntoArgs for ($($cur,)*) {
#[allow(non_snake_case)]
fn into_args(self) -> Vec<Value> {
let ($($cur,)*) = self;
vec![$($cur.into(),)*]
}
}
impl_args_tuple!($($cur,)* $head, @marker $($tail,)*);
};
($($cur:ident,)* @marker ) => {
impl<$($cur: Into<Value>,)*> IntoArgs for ($($cur,)*) {
#[allow(non_snake_case)]
fn into_args(self) -> Vec<Value> {
let ($($cur,)*) = self;
vec![$($cur.into(),)*]
}
}
}
}
impl_args_tuple!(A, B, C, D, E, F,);
/* TODO: Removed for now.
* The detach value PR removed a lot of conversion methods with, pending later request which might
* add them back depending on how the sdk turns out.
*
macro_rules! into_impl {
($type:ty) => {
impl IntoArgs for $type {
fn into_args(self) -> Vec<Value> {
vec![Value::from(self)]
}
}
};
}
into_impl!(i8);
into_impl!(i16);
into_impl!(i32);
into_impl!(i64);
into_impl!(i128);
into_impl!(u8);
into_impl!(u16);
into_impl!(u32);
into_impl!(u64);
into_impl!(u128);
into_impl!(usize);
into_impl!(isize);
into_impl!(f32);
into_impl!(f64);
into_impl!(String);
into_impl!(&str);
*/
pub trait IntoFn { pub trait IntoFn {
fn into_fn(self) -> (String, Option<String>); /// Handles the conversion of the function string
fn into_fn(self) -> Result<(String, Option<String>)>;
} }
impl IntoFn for String { impl IntoFn for String {
fn into_fn(self) -> (String, Option<String>) { fn into_fn(self) -> Result<(String, Option<String>)> {
(self, None) match self.split_once('<') {
Some((name, rest)) => match rest.strip_suffix('>') {
Some(version) => Ok((name.to_owned(), Some(version.to_owned()))),
None => Err(crate::error::Db::InvalidFunction {
name: self,
message: "function version is missing a clossing '>'".to_owned(),
}
.into()),
},
None => Ok((self, None)),
} }
}
impl IntoFn for &str {
fn into_fn(self) -> (String, Option<String>) {
(self.to_owned(), None)
} }
} }
impl<S0, S1> IntoFn for (S0, S1) impl IntoFn for &str {
where fn into_fn(self) -> Result<(String, Option<String>)> {
S0: Into<String>, match self.split_once('<') {
S1: Into<String>, Some((name, rest)) => match rest.strip_suffix('>') {
{ Some(version) => Ok((name.to_owned(), Some(version.to_owned()))),
fn into_fn(self) -> (String, Option<String>) { None => Err(crate::error::Db::InvalidFunction {
(self.0.into(), Some(self.1.into())) name: self.to_owned(),
message: "function version is missing a clossing '>'".to_owned(),
}
.into()),
},
None => Ok((self.to_owned(), None)),
}
}
}
impl IntoFn for &String {
fn into_fn(self) -> Result<(String, Option<String>)> {
self.as_str().into_fn()
} }
} }

View file

@ -14,7 +14,6 @@ use crate::api::opt::auth::Root;
use crate::api::opt::PatchOp; use crate::api::opt::PatchOp;
use crate::api::Response as QueryResponse; use crate::api::Response as QueryResponse;
use crate::api::Surreal; use crate::api::Surreal;
use crate::Value;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use protocol::Client; use protocol::Client;
use protocol::Test; use protocol::Test;
@ -170,7 +169,7 @@ async fn api() {
let _: Version = DB.version().await.unwrap(); let _: Version = DB.version().await.unwrap();
// run // run
let _: Value = DB.run("foo", ()).await.unwrap(); let _: Option<User> = DB.run("foo").await.unwrap();
} }
fn assert_send_sync(_: impl Send + Sync) {} fn assert_send_sync(_: impl Send + Sync) {}

View file

@ -1363,9 +1363,6 @@ async fn return_bool() {
assert_eq!(value.into_inner(), CoreValue::Bool(false)); assert_eq!(value.into_inner(), CoreValue::Bool(false));
} }
/*
* TODO: Reenable test.
* Disabling run test for now as it depends on value conversions which are removed
#[test_log::test(tokio::test)] #[test_log::test(tokio::test)]
async fn run() { async fn run() {
let (permit, db) = new_db().await; let (permit, db) = new_db().await;
@ -1384,24 +1381,23 @@ async fn run() {
"; ";
let _ = db.query(sql).await; let _ = db.query(sql).await;
let tmp = db.run("fn::foo", ()).await.unwrap(); let tmp: i32 = db.run("fn::foo").await.unwrap();
assert_eq!(tmp, Value::from(42)); assert_eq!(tmp, 42);
let tmp = db.run("fn::foo", 7).await.unwrap_err(); let tmp = db.run::<i32>("fn::foo").args(7).await.unwrap_err();
println!("fn::foo res: {tmp}"); println!("fn::foo res: {tmp}");
assert!(tmp.to_string().contains("The function expects 0 arguments.")); assert!(tmp.to_string().contains("The function expects 0 arguments."));
let tmp = db.run("fn::idnotexist", ()).await.unwrap_err(); let tmp = db.run::<()>("fn::idnotexist").await.unwrap_err();
println!("fn::idontexist res: {tmp}"); println!("fn::idontexist res: {tmp}");
assert!(tmp.to_string().contains("The function 'fn::idnotexist' does not exist")); assert!(tmp.to_string().contains("The function 'fn::idnotexist' does not exist"));
let tmp = db.run("count", Value::from(vec![1, 2, 3])).await.unwrap(); let tmp: usize = db.run("count").args(vec![1, 2, 3]).await.unwrap();
assert_eq!(tmp, Value::from(3)); assert_eq!(tmp, 3);
let tmp = db.run("fn::bar", 7).await.unwrap(); let tmp: Option<RecordId> = db.run("fn::bar").args(7).await.unwrap();
assert_eq!(tmp, Value::None); assert_eq!(tmp, None);
let tmp = db.run("fn::baz", ()).await.unwrap(); let tmp: i32 = db.run("fn::baz").await.unwrap();
assert_eq!(tmp, Value::from(7)); assert_eq!(tmp, 7);
} }
*/