Improve consistency of the run method API with existing methods (#4563)

This commit is contained in:
Rushmore Mushambi 2024-08-21 15:35:26 +01:00 committed by GitHub
parent 13b6788540
commit d002fa3958
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 116 additions and 193 deletions

View file

@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughpu
use std::time::Duration;
use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore;
use surrealdb::Value;
use surrealdb_core::sql::Value;
use tokio::runtime::Runtime;
fn bench_processor(c: &mut Criterion) {

View file

@ -52,7 +52,6 @@ use surrealdb_core::{
};
use uuid::Uuid;
#[cfg(not(target_arch = "wasm32"))]
use crate::api::err::Error;
#[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf;
@ -1008,28 +1007,24 @@ async fn router(
version: _version,
args,
} => {
let func: CoreValue = if let Some(name) = name.strip_prefix("fn::") {
Function::Custom(name.to_owned(), args.0).into()
} else if let Some(_name) = name.strip_prefix("ml::") {
let func: CoreValue = match name.strip_prefix("fn::") {
Some(name) => Function::Custom(name.to_owned(), args.0).into(),
None => match name.strip_prefix("ml::") {
#[cfg(feature = "ml")]
{
let mut model = Model::default();
model.name = _name.to_owned();
model.args = args.0;
model.version = _version
Some(name) => {
let mut tmp = Model::default();
tmp.name = name.to_owned();
tmp.args = args.0;
tmp.version = _version
.ok_or(Error::Query("ML functions must have a version".to_string()))?;
model.into()
tmp.into()
}
#[cfg(not(feature = "ml"))]
{
return Err(crate::error::Db::InvalidModel {
message: "Machine learning computation is not enabled.".to_owned(),
Some(_) => {
return Err(Error::Query(format!("tried to call an ML function `{name}` but the `ml` feature is not enabled")).into());
}
.into());
}
} else {
Function::Custom(name, args.0).into()
None => Function::Normal(name, args.0).into(),
},
};
let stmt = Statement::Value(func);

View file

@ -19,7 +19,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use surrealdb_core::sql::{to_value as to_core_value, Array as CoreArray};
use surrealdb_core::sql::to_value as to_core_value;
pub(crate) mod live;
pub(crate) mod query;
@ -76,8 +76,8 @@ pub use merge::Merge;
pub use patch::Patch;
pub use query::Query;
pub use query::QueryStream;
pub use run::IntoFn;
pub use run::Run;
pub use run::{IntoArgs, IntoFn};
pub use select::Select;
use serde_content::Serializer;
pub use set::Set;
@ -1223,42 +1223,33 @@ where
}
}
// TODO: Re-enable doc tests
/// Runs a function
///
/// # Examples
///
/// ```ignore
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> surrealdb::Result<()> {
/// # let db = surrealdb::engine::any::connect("mem://").await?;
/// // Note that the sdk is currently undergoing some changes so the below examples might not
/// work until the sdk is somewhat more stable.
///
/// // specify no args with an empty tuple, vec, or slice
/// let foo = db.run("fn::foo", ()).await?; // fn::foo()
/// // a single value will be turned into one arguement unless it is a tuple or vec
/// let bar = db.run("fn::bar", 42).await?; // fn::bar(42)
/// // to specify a single arguement, which is an array turn it into a value, or wrap in a singleton tuple
/// let count = db.run("count", Value::from(vec![1,2,3])).await?;
/// let count = db.run("count", (vec![1,2,3],)).await?;
/// // specify multiple args with either a tuple or vec
/// let two = db.run("math::log", (100, 10)).await?; // math::log(100, 10)
/// let two = db.run("math::log", [100, 10]).await?; // math::log(100, 10)
/// // Specify no args by not calling `.args()`
/// let foo = db.run("fn::foo").await?; // fn::foo()
/// // A single value will be turned into one argument
/// let bar = db.run("fn::bar").args(42).await?; // fn::bar(42)
/// // Arrays are treated as single arguments
/// let count = db.run("count").args(vec![1,2,3]).await?;
/// // Specify multiple args using a tuple
/// let two = db.run("math::log").args((100, 10)).await?; // math::log(100, 10)
///
/// # Ok(())
/// # }
/// ```
///
pub fn run(&self, name: impl IntoFn, args: impl IntoArgs) -> Run<C> {
let (name, version) = name.into_fn();
let mut arguments = CoreArray::default();
arguments.0 = crate::Value::array_to_core(args.into_args());
pub fn run<R>(&self, function: impl IntoFn) -> Run<C, R> {
Run {
client: Cow::Borrowed(self),
name,
version,
args: arguments,
function: function.into_fn(),
args: Ok(serde_content::Value::Tuple(vec![])),
response_type: PhantomData,
}
}

View file

@ -1,30 +1,35 @@
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal;
use crate::Value;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_content::Serializer;
use serde_content::Value as Content;
use std::borrow::Cow;
use std::future::IntoFuture;
use surrealdb_core::sql::Array as CoreArray;
use super::BoxFuture;
use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
use surrealdb_core::sql::Array;
/// A run future
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Run<'r, C: Connection> {
pub struct Run<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) name: String,
pub(super) version: Option<String>,
pub(super) args: CoreArray,
pub(super) function: Result<(String, Option<String>)>,
pub(super) args: serde_content::Result<serde_content::Value<'static>>,
pub(super) response_type: PhantomData<R>,
}
impl<C> Run<'_, C>
impl<C, R> Run<'_, C, R>
where
C: Connection,
{
/// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Run<'static, C> {
pub fn into_owned(self) -> Run<'static, C, R> {
Run {
client: Cow::Owned(self.client.into_owned()),
..self
@ -32,24 +37,36 @@ where
}
}
impl<'r, Client> IntoFuture for Run<'r, Client>
impl<'r, Client, R> IntoFuture for Run<'r, Client, R>
where
Client: Connection,
R: DeserializeOwned,
{
type Output = Result<Value>;
type Output = Result<R>;
type IntoFuture = BoxFuture<'r, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
let Run {
client,
name,
version,
function,
args,
..
} = self;
Box::pin(async move {
let router = client.router.extract()?;
let (name, version) = function?;
let value = match args.map_err(crate::error::Db::from)? {
// Tuples are treated as multiple function arguments
Content::Tuple(tup) => tup,
// Everything else is treated as a single argument
content => vec![content],
};
let args = match to_value(value)? {
Value::Array(array) => array,
value => Array::from(vec![value]),
};
router
.execute_value(Command::Run {
.execute(Command::Run {
name,
version,
args,
@ -59,132 +76,57 @@ where
}
}
pub trait IntoArgs {
fn into_args(self) -> Vec<Value>;
}
impl IntoArgs for Value {
fn into_args(self) -> Vec<Value> {
vec![self]
}
}
impl<T> IntoArgs for Vec<T>
impl<'r, Client, R> Run<'r, Client, R>
where
T: Into<Value>,
Client: Connection,
{
fn into_args(self) -> Vec<Value> {
self.into_iter().map(Into::into).collect()
/// Supply arguments to the function being run.
pub fn args(mut self, args: impl Serialize) -> Self {
self.args = Serializer::new().serialize(args);
self
}
}
impl<T, const N: usize> IntoArgs for [T; N]
where
T: Into<Value>,
{
fn into_args(self) -> Vec<Value> {
self.into_iter().map(Into::into).collect()
}
}
impl<T, const N: usize> IntoArgs for &[T; N]
where
T: Into<Value> + Clone,
{
fn into_args(self) -> Vec<Value> {
self.iter().cloned().map(Into::into).collect()
}
}
impl<T> IntoArgs for &[T]
where
T: Into<Value> + Clone,
{
fn into_args(self) -> Vec<Value> {
self.iter().cloned().map(Into::into).collect()
}
}
macro_rules! impl_args_tuple {
($($i:ident), *$(,)?) => {
impl_args_tuple!(@marker $($i,)*);
};
($($cur:ident,)* @marker $head:ident, $($tail:ident,)*) => {
impl<$($cur: Into<Value>,)*> IntoArgs for ($($cur,)*) {
#[allow(non_snake_case)]
fn into_args(self) -> Vec<Value> {
let ($($cur,)*) = self;
vec![$($cur.into(),)*]
}
}
impl_args_tuple!($($cur,)* $head, @marker $($tail,)*);
};
($($cur:ident,)* @marker ) => {
impl<$($cur: Into<Value>,)*> IntoArgs for ($($cur,)*) {
#[allow(non_snake_case)]
fn into_args(self) -> Vec<Value> {
let ($($cur,)*) = self;
vec![$($cur.into(),)*]
}
}
}
}
impl_args_tuple!(A, B, C, D, E, F,);
/* TODO: Removed for now.
* The detach value PR removed a lot of conversion methods with, pending later request which might
* add them back depending on how the sdk turns out.
*
macro_rules! into_impl {
($type:ty) => {
impl IntoArgs for $type {
fn into_args(self) -> Vec<Value> {
vec![Value::from(self)]
}
}
};
}
into_impl!(i8);
into_impl!(i16);
into_impl!(i32);
into_impl!(i64);
into_impl!(i128);
into_impl!(u8);
into_impl!(u16);
into_impl!(u32);
into_impl!(u64);
into_impl!(u128);
into_impl!(usize);
into_impl!(isize);
into_impl!(f32);
into_impl!(f64);
into_impl!(String);
into_impl!(&str);
*/
/// Converts a function into name and version parts
pub trait IntoFn {
fn into_fn(self) -> (String, Option<String>);
/// Handles the conversion of the function string
fn into_fn(self) -> Result<(String, Option<String>)>;
}
impl IntoFn for String {
fn into_fn(self) -> (String, Option<String>) {
(self, None)
fn into_fn(self) -> Result<(String, Option<String>)> {
match self.split_once('<') {
Some((name, rest)) => match rest.strip_suffix('>') {
Some(version) => Ok((name.to_owned(), Some(version.to_owned()))),
None => Err(crate::error::Db::InvalidFunction {
name: self,
message: "function version is missing a clossing '>'".to_owned(),
}
.into()),
},
None => Ok((self, None)),
}
impl IntoFn for &str {
fn into_fn(self) -> (String, Option<String>) {
(self.to_owned(), None)
}
}
impl<S0, S1> IntoFn for (S0, S1)
where
S0: Into<String>,
S1: Into<String>,
{
fn into_fn(self) -> (String, Option<String>) {
(self.0.into(), Some(self.1.into()))
impl IntoFn for &str {
fn into_fn(self) -> Result<(String, Option<String>)> {
match self.split_once('<') {
Some((name, rest)) => match rest.strip_suffix('>') {
Some(version) => Ok((name.to_owned(), Some(version.to_owned()))),
None => Err(crate::error::Db::InvalidFunction {
name: self.to_owned(),
message: "function version is missing a clossing '>'".to_owned(),
}
.into()),
},
None => Ok((self.to_owned(), None)),
}
}
}
impl IntoFn for &String {
fn into_fn(self) -> Result<(String, Option<String>)> {
self.as_str().into_fn()
}
}

View file

@ -14,7 +14,6 @@ use crate::api::opt::auth::Root;
use crate::api::opt::PatchOp;
use crate::api::Response as QueryResponse;
use crate::api::Surreal;
use crate::Value;
use once_cell::sync::Lazy;
use protocol::Client;
use protocol::Test;
@ -170,7 +169,7 @@ async fn api() {
let _: Version = DB.version().await.unwrap();
// run
let _: Value = DB.run("foo", ()).await.unwrap();
let _: Option<User> = DB.run("foo").await.unwrap();
}
fn assert_send_sync(_: impl Send + Sync) {}

View file

@ -1363,9 +1363,6 @@ async fn return_bool() {
assert_eq!(value.into_inner(), CoreValue::Bool(false));
}
/*
* TODO: Reenable test.
* Disabling run test for now as it depends on value conversions which are removed
#[test_log::test(tokio::test)]
async fn run() {
let (permit, db) = new_db().await;
@ -1384,24 +1381,23 @@ async fn run() {
";
let _ = db.query(sql).await;
let tmp = db.run("fn::foo", ()).await.unwrap();
assert_eq!(tmp, Value::from(42));
let tmp: i32 = db.run("fn::foo").await.unwrap();
assert_eq!(tmp, 42);
let tmp = db.run("fn::foo", 7).await.unwrap_err();
let tmp = db.run::<i32>("fn::foo").args(7).await.unwrap_err();
println!("fn::foo res: {tmp}");
assert!(tmp.to_string().contains("The function expects 0 arguments."));
let tmp = db.run("fn::idnotexist", ()).await.unwrap_err();
let tmp = db.run::<()>("fn::idnotexist").await.unwrap_err();
println!("fn::idontexist res: {tmp}");
assert!(tmp.to_string().contains("The function 'fn::idnotexist' does not exist"));
let tmp = db.run("count", Value::from(vec![1, 2, 3])).await.unwrap();
assert_eq!(tmp, Value::from(3));
let tmp: usize = db.run("count").args(vec![1, 2, 3]).await.unwrap();
assert_eq!(tmp, 3);
let tmp = db.run("fn::bar", 7).await.unwrap();
assert_eq!(tmp, Value::None);
let tmp: Option<RecordId> = db.run("fn::bar").args(7).await.unwrap();
assert_eq!(tmp, None);
let tmp = db.run("fn::baz", ()).await.unwrap();
assert_eq!(tmp, Value::from(7));
let tmp: i32 = db.run("fn::baz").await.unwrap();
assert_eq!(tmp, 7);
}
*/