Add more array functions (#4184)

Co-authored-by: Micha de Vries <micha@devrie.sh>
Co-authored-by: Micha de Vries <mt.dev@hotmail.com>
Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
David Bottiau 2024-08-14 17:11:22 +02:00 committed by GitHub
parent ee8e6f00d7
commit 0069cba8a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 390 additions and 26 deletions

View file

@ -59,3 +59,11 @@ pub static EXPERIMENTAL_BEARER_ACCESS: Lazy<bool> =
// Run tests with bearer access enabled as it introduces new functionality that needs to be tested. // Run tests with bearer access enabled as it introduces new functionality that needs to be tested.
#[cfg(test)] #[cfg(test)]
pub static EXPERIMENTAL_BEARER_ACCESS: Lazy<bool> = Lazy::new(|| true); pub static EXPERIMENTAL_BEARER_ACCESS: Lazy<bool> = Lazy::new(|| true);
/// Used to limit allocation for builtin functions
pub static FUNCTION_ALLOCATION_LIMIT: Lazy<usize> = once_cell::sync::Lazy::new(|| {
let n = std::env::var("SURREAL_FUNCTION_ALLOCATION_LIMIT")
.map(|s| s.parse::<u32>().unwrap_or(20))
.unwrap_or(20);
2usize.pow(n)
});

View file

@ -1,6 +1,8 @@
use crate::err::Error; use crate::err::Error;
use crate::sql::value::Value; use crate::sql::value::Value;
use crate::sql::{Array, Bytes, Datetime, Duration, Kind, Number, Object, Regex, Strand, Thing}; use crate::sql::{
Array, Bytes, Closure, Datetime, Duration, Kind, Number, Object, Regex, Strand, Thing,
};
use std::vec::IntoIter; use std::vec::IntoIter;
/// Implemented by types that are commonly used, in a certain way, as arguments. /// Implemented by types that are commonly used, in a certain way, as arguments.
@ -14,6 +16,12 @@ impl FromArg for Value {
} }
} }
impl FromArg for Closure {
fn from_arg(arg: Value) -> Result<Self, Error> {
arg.coerce_to_function()
}
}
impl FromArg for Regex { impl FromArg for Regex {
fn from_arg(arg: Value) -> Result<Self, Error> { fn from_arg(arg: Value) -> Result<Self, Error> {
arg.coerce_to_regex() arg.coerce_to_regex()
@ -242,6 +250,41 @@ impl<A: FromArg, B: FromArg> FromArgs for (A, Option<B>) {
} }
} }
// Some functions take 4 arguments, with the 3rd and 4th being optional.
impl<A: FromArg, B: FromArg, C: FromArg, D: FromArg> FromArgs for (A, B, Option<C>, Option<D>) {
fn from_args(name: &str, args: Vec<Value>) -> Result<Self, Error> {
let err = || Error::InvalidArguments {
name: name.to_owned(),
message: String::from("Expected 2, 3 or 4 arguments."),
};
// Process the function arguments
let mut args = args.into_iter();
// Process the first argument
let a = A::from_arg(args.next().ok_or_else(err)?).map_err(|e| Error::InvalidArguments {
name: name.to_owned(),
message: format!("Argument 1 was the wrong type. {e}"),
})?;
let b = B::from_arg(args.next().ok_or_else(err)?).map_err(|e| Error::InvalidArguments {
name: name.to_owned(),
message: format!("Argument 2 was the wrong type. {e}"),
})?;
let c = match args.next() {
Some(c) => Some(C::from_arg(c)?),
None => None,
};
let d = match args.next() {
Some(d) => Some(D::from_arg(d)?),
None => None,
};
// Process additional function arguments
if args.next().is_some() {
// Too many arguments
return Err(err());
}
Ok((a, b, c, d))
}
}
#[inline] #[inline]
fn get_arg<T: FromArg, E: Fn() -> Error>( fn get_arg<T: FromArg, E: Fn() -> Error>(
name: &str, name: &str,

View file

@ -1,3 +1,7 @@
use crate::cnf::FUNCTION_ALLOCATION_LIMIT;
use crate::ctx::Context;
use crate::dbs::Options;
use crate::doc::CursorDoc;
use crate::err::Error; use crate::err::Error;
use crate::sql::array::Array; use crate::sql::array::Array;
use crate::sql::array::Clump; use crate::sql::array::Clump;
@ -12,8 +16,24 @@ use crate::sql::array::Union;
use crate::sql::array::Uniq; use crate::sql::array::Uniq;
use crate::sql::array::Windows; use crate::sql::array::Windows;
use crate::sql::value::Value; use crate::sql::value::Value;
use crate::sql::Closure;
use crate::sql::Function;
use rand::prelude::SliceRandom; use rand::prelude::SliceRandom;
use reblessive::tree::Stk;
use std::mem::size_of_val;
/// Returns an error if an array of this length is too much to allocate.
fn limit(name: &str, n: usize) -> Result<(), Error> {
if n > *FUNCTION_ALLOCATION_LIMIT {
Err(Error::InvalidArguments {
name: name.to_owned(),
message: format!("Output must not exceed {} bytes.", *FUNCTION_ALLOCATION_LIMIT),
})
} else {
Ok(())
}
}
pub fn add((mut array, value): (Array, Value)) -> Result<Value, Error> { pub fn add((mut array, value): (Array, Value)) -> Result<Value, Error> {
match value { match value {
@ -135,6 +155,45 @@ pub fn distinct((array,): (Array,)) -> Result<Value, Error> {
Ok(array.uniq().into()) Ok(array.uniq().into())
} }
pub fn fill(
(mut array, value, start, end): (Array, Value, Option<isize>, Option<isize>),
) -> Result<Value, Error> {
let min = 0;
let max = array.len();
let negative_max = -(max as isize);
let start = match start {
Some(start) if negative_max <= start && start < 0 => (start + max as isize) as usize,
Some(start) if start < negative_max => 0,
Some(start) => start as usize,
None => min,
};
let end = match end {
Some(end) if negative_max <= end && end < 0 => (end + max as isize) as usize,
Some(end) if end < negative_max => 0,
Some(end) => end as usize,
None => max,
};
if start == min && end >= max {
array.fill(value);
} else if end > start {
let end_minus_one = end - 1;
for i in start..end_minus_one {
if let Some(elem) = array.get_mut(i) {
*elem = value.clone();
}
}
if let Some(last_elem) = array.get_mut(end_minus_one) {
*last_elem = value;
}
}
Ok(array.into())
}
pub fn filter_index((array, value): (Array, Value)) -> Result<Value, Error> { pub fn filter_index((array, value): (Array, Value)) -> Result<Value, Error> {
Ok(array Ok(array
.iter() .iter()
@ -201,6 +260,10 @@ pub fn intersect((array, other): (Array, Array)) -> Result<Value, Error> {
Ok(array.intersect(other).into()) Ok(array.intersect(other).into())
} }
pub fn is_empty((array,): (Array,)) -> Result<Value, Error> {
Ok(array.is_empty().into())
}
pub fn join((arr, sep): (Array, String)) -> Result<Value, Error> { pub fn join((arr, sep): (Array, String)) -> Result<Value, Error> {
Ok(arr.into_iter().map(Value::as_raw_string).collect::<Vec<_>>().join(&sep).into()) Ok(arr.into_iter().map(Value::as_raw_string).collect::<Vec<_>>().join(&sep).into())
} }
@ -289,6 +352,20 @@ pub fn logical_xor((lh, rh): (Array, Array)) -> Result<Value, Error> {
Ok(result_arr.into()) Ok(result_arr.into())
} }
pub async fn map(
(stk, ctx, opt, doc): (&mut Stk, &Context<'_>, &Options, Option<&CursorDoc<'_>>),
(array, mapper): (Array, Closure),
) -> Result<Value, Error> {
let mut array = array;
for i in 0..array.len() {
let v = array.get(i).unwrap();
let fnc = Function::Anonymous(mapper.clone().into(), vec![v.to_owned(), i.into()]);
array[i] = fnc.compute(stk, ctx, opt, doc).await?;
}
Ok(array.into())
}
pub fn matches((array, compare_val): (Array, Value)) -> Result<Value, Error> { pub fn matches((array, compare_val): (Array, Value)) -> Result<Value, Error> {
Ok(array.matches(compare_val).into()) Ok(array.matches(compare_val).into())
} }
@ -315,6 +392,26 @@ pub fn push((mut array, value): (Array, Value)) -> Result<Value, Error> {
Ok(array.into()) Ok(array.into())
} }
pub fn range((start, count): (i64, i64)) -> Result<Value, Error> {
if count < 0 {
return Err(Error::InvalidArguments {
name: String::from("array::range"),
message: format!(
"Argument 1 was the wrong type. Expected a positive number but found {count}"
),
});
}
if let Some(end) = start.checked_add(count - 1) {
Ok(Array((start..=end).map(Value::from).collect::<Vec<_>>()).into())
} else {
Err(Error::InvalidArguments {
name: String::from("array::range"),
message: String::from("The range overflowed the maximum value for an integer"),
})
}
}
pub fn remove((mut array, mut index): (Array, i64)) -> Result<Value, Error> { pub fn remove((mut array, mut index): (Array, i64)) -> Result<Value, Error> {
// Negative index means start from the back // Negative index means start from the back
if index < 0 { if index < 0 {
@ -330,6 +427,11 @@ pub fn remove((mut array, mut index): (Array, i64)) -> Result<Value, Error> {
Ok(array.into()) Ok(array.into())
} }
pub fn repeat((value, count): (Value, usize)) -> Result<Value, Error> {
limit("array::repeat", size_of_val(&value).saturating_mul(count))?;
Ok(Array(std::iter::repeat(value).take(count).collect()).into())
}
pub fn reverse((mut array,): (Array,)) -> Result<Value, Error> { pub fn reverse((mut array,): (Array,)) -> Result<Value, Error> {
array.reverse(); array.reverse();
Ok(array.into()) Ok(array.into())
@ -337,7 +439,7 @@ pub fn reverse((mut array,): (Array,)) -> Result<Value, Error> {
pub fn shuffle((mut array,): (Array,)) -> Result<Value, Error> { pub fn shuffle((mut array,): (Array,)) -> Result<Value, Error> {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
array.0.shuffle(&mut rng); array.shuffle(&mut rng);
Ok(array.into()) Ok(array.into())
} }
@ -392,6 +494,37 @@ pub fn sort((mut array, order): (Array, Option<Value>)) -> Result<Value, Error>
} }
} }
pub fn swap((mut array, from, to): (Array, isize, isize)) -> Result<Value, Error> {
let min = 0;
let max = array.len();
let negative_max = -(max as isize);
let from = match from {
from if from < negative_max || from >= max as isize => Err(Error::InvalidArguments {
name: String::from("array::swap"),
message: format!(
"Argument 1 is out of range. Expected a number between {negative_max} and {max}"
),
}),
from if negative_max <= from && from < min => Ok((from + max as isize) as usize),
from => Ok(from as usize),
}?;
let to = match to {
to if to < negative_max || to >= max as isize => Err(Error::InvalidArguments {
name: String::from("array::swap"),
message: format!(
"Argument 2 is out of range. Expected a number between {negative_max} and {max}"
),
}),
to if negative_max <= to && to < min => Ok((to + max as isize) as usize),
to => Ok(to as usize),
}?;
array.swap(from, to);
Ok(array.into())
}
pub fn transpose((array,): (Array,)) -> Result<Value, Error> { pub fn transpose((array,): (Array,)) -> Result<Value, Error> {
Ok(array.transpose().into()) Ok(array.transpose().into())
} }

View file

@ -52,8 +52,9 @@ pub async fn run(
|| name.starts_with("crypto::bcrypt") || name.starts_with("crypto::bcrypt")
|| name.starts_with("crypto::pbkdf2") || name.starts_with("crypto::pbkdf2")
|| name.starts_with("crypto::scrypt") || name.starts_with("crypto::scrypt")
|| name.starts_with("array::map")
{ {
stk.run(|stk| asynchronous(stk, ctx, Some(opt), doc, name, args)).await stk.run(|stk| asynchronous(stk, ctx, opt, doc, name, args)).await
} else { } else {
synchronous(ctx, doc, name, args) synchronous(ctx, doc, name, args)
} }
@ -110,6 +111,7 @@ pub fn synchronous(
"array::concat" => array::concat, "array::concat" => array::concat,
"array::difference" => array::difference, "array::difference" => array::difference,
"array::distinct" => array::distinct, "array::distinct" => array::distinct,
"array::fill" => array::fill,
"array::filter_index" => array::filter_index, "array::filter_index" => array::filter_index,
"array::find_index" => array::find_index, "array::find_index" => array::find_index,
"array::first" => array::first, "array::first" => array::first,
@ -117,6 +119,7 @@ pub fn synchronous(
"array::group" => array::group, "array::group" => array::group,
"array::insert" => array::insert, "array::insert" => array::insert,
"array::intersect" => array::intersect, "array::intersect" => array::intersect,
"array::is_empty" => array::is_empty,
"array::join" => array::join, "array::join" => array::join,
"array::last" => array::last, "array::last" => array::last,
"array::len" => array::len, "array::len" => array::len,
@ -129,11 +132,14 @@ pub fn synchronous(
"array::pop" => array::pop, "array::pop" => array::pop,
"array::prepend" => array::prepend, "array::prepend" => array::prepend,
"array::push" => array::push, "array::push" => array::push,
"array::range" => array::range,
"array::remove" => array::remove, "array::remove" => array::remove,
"array::repeat" => array::repeat,
"array::reverse" => array::reverse, "array::reverse" => array::reverse,
"array::shuffle" => array::shuffle, "array::shuffle" => array::shuffle,
"array::slice" => array::slice, "array::slice" => array::slice,
"array::sort" => array::sort, "array::sort" => array::sort,
"array::swap" => array::swap,
"array::transpose" => array::transpose, "array::transpose" => array::transpose,
"array::union" => array::union, "array::union" => array::union,
"array::sort::asc" => array::sort::asc, "array::sort::asc" => array::sort::asc,
@ -409,8 +415,10 @@ pub fn synchronous(
} }
/// Attempts to run any synchronous function. /// Attempts to run any synchronous function.
pub fn idiom( pub async fn idiom(
stk: &mut Stk,
ctx: &Context<'_>, ctx: &Context<'_>,
opt: &Options,
doc: Option<&CursorDoc<'_>>, doc: Option<&CursorDoc<'_>>,
value: Value, value: Value,
name: &str, name: &str,
@ -438,6 +446,7 @@ pub fn idiom(
"concat" => array::concat, "concat" => array::concat,
"difference" => array::difference, "difference" => array::difference,
"distinct" => array::distinct, "distinct" => array::distinct,
"fill" => array::fill,
"filter_index" => array::filter_index, "filter_index" => array::filter_index,
"find_index" => array::find_index, "find_index" => array::find_index,
"first" => array::first, "first" => array::first,
@ -445,6 +454,7 @@ pub fn idiom(
"group" => array::group, "group" => array::group,
"insert" => array::insert, "insert" => array::insert,
"intersect" => array::intersect, "intersect" => array::intersect,
"is_empty" => array::is_empty,
"join" => array::join, "join" => array::join,
"last" => array::last, "last" => array::last,
"len" => array::len, "len" => array::len,
@ -452,6 +462,7 @@ pub fn idiom(
"logical_or" => array::logical_or, "logical_or" => array::logical_or,
"logical_xor" => array::logical_xor, "logical_xor" => array::logical_xor,
"matches" => array::matches, "matches" => array::matches,
"map" => array::map((stk, ctx, opt, doc)).await,
"max" => array::max, "max" => array::max,
"min" => array::min, "min" => array::min,
"pop" => array::pop, "pop" => array::pop,
@ -462,6 +473,7 @@ pub fn idiom(
"shuffle" => array::shuffle, "shuffle" => array::shuffle,
"slice" => array::slice, "slice" => array::slice,
"sort" => array::sort, "sort" => array::sort,
"swap" => array::swap,
"transpose" => array::transpose, "transpose" => array::transpose,
"union" => array::union, "union" => array::union,
"sort_asc" => array::sort::asc, "sort_asc" => array::sort::asc,
@ -684,6 +696,8 @@ pub fn idiom(
"to_record" => r#type::record, "to_record" => r#type::record,
"to_string" => r#type::string, "to_string" => r#type::string,
"to_uuid" => r#type::uuid, "to_uuid" => r#type::uuid,
//
"repeat" => array::repeat,
) )
} }
v => v, v => v,
@ -694,7 +708,7 @@ pub fn idiom(
pub async fn asynchronous( pub async fn asynchronous(
stk: &mut Stk, stk: &mut Stk,
ctx: &Context<'_>, ctx: &Context<'_>,
opt: Option<&Options>, opt: &Options,
doc: Option<&CursorDoc<'_>>, doc: Option<&CursorDoc<'_>>,
name: &str, name: &str,
args: Vec<Value>, args: Vec<Value>,
@ -719,6 +733,8 @@ pub async fn asynchronous(
name, name,
args, args,
"no such builtin function found", "no such builtin function found",
"array::map" => array::map((stk, ctx, opt, doc)).await,
//
"crypto::argon2::compare" => (cpu_intensive) crypto::argon2::cmp.await, "crypto::argon2::compare" => (cpu_intensive) crypto::argon2::cmp.await,
"crypto::argon2::generate" => (cpu_intensive) crypto::argon2::gen.await, "crypto::argon2::generate" => (cpu_intensive) crypto::argon2::gen.await,
"crypto::bcrypt::compare" => (cpu_intensive) crypto::bcrypt::cmp.await, "crypto::bcrypt::compare" => (cpu_intensive) crypto::bcrypt::cmp.await,
@ -735,15 +751,15 @@ pub async fn asynchronous(
"http::patch" => http::patch(ctx).await, "http::patch" => http::patch(ctx).await,
"http::delete" => http::delete(ctx).await, "http::delete" => http::delete(ctx).await,
// //
"search::analyze" => search::analyze((stk,ctx, opt)).await, "search::analyze" => search::analyze((stk,ctx, Some(opt))).await,
"search::score" => search::score((ctx, doc)).await, "search::score" => search::score((ctx, doc)).await,
"search::highlight" => search::highlight((ctx, doc)).await, "search::highlight" => search::highlight((ctx, doc)).await,
"search::offsets" => search::offsets((ctx, doc)).await, "search::offsets" => search::offsets((ctx, doc)).await,
// //
"sleep" => sleep::sleep(ctx).await, "sleep" => sleep::sleep(ctx).await,
// //
"type::field" => r#type::field((stk,ctx, opt, doc)).await, "type::field" => r#type::field((stk,ctx, Some(opt), doc)).await,
"type::fields" => r#type::fields((stk,ctx, opt, doc)).await, "type::fields" => r#type::fields((stk,ctx, Some(opt), doc)).await,
) )
} }
@ -773,6 +789,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn implementations_are_present() { async fn implementations_are_present() {
#[cfg(all(feature = "scripting", feature = "kv-mem"))]
let excluded_from_scripting = &["array::map"];
// Accumulate and display all problems at once to avoid a test -> fix -> test -> fix cycle. // Accumulate and display all problems at once to avoid a test -> fix -> test -> fix cycle.
let mut problems = Vec::new(); let mut problems = Vec::new();
@ -780,7 +799,7 @@ mod tests {
let fnc_mod = include_str!("mod.rs"); let fnc_mod = include_str!("mod.rs");
// Patch out idiom methods // Patch out idiom methods
let re = Regex::new(r"(?ms)pub fn idiom\(.*}\n+///").unwrap(); let re = Regex::new(r"(?ms)pub async fn idiom\(.*}\n+///").unwrap();
let fnc_no_idiom = re.replace(fnc_mod, ""); let fnc_no_idiom = re.replace(fnc_mod, "");
for line in fnc_no_idiom.lines() { for line in fnc_no_idiom.lines() {
@ -823,6 +842,10 @@ mod tests {
{ {
use crate::sql::Value; use crate::sql::Value;
if excluded_from_scripting.contains(&name) {
continue;
}
let name = name.replace("::", "."); let name = name.replace("::", ".");
let sql = let sql =
format!("RETURN function() {{ return typeof surrealdb.functions.{name}; }}"); format!("RETURN function() {{ return typeof surrealdb.functions.{name}; }}");

View file

@ -23,6 +23,7 @@ impl_module_def!(
"concat" => run, "concat" => run,
"difference" => run, "difference" => run,
"distinct" => run, "distinct" => run,
"fill" => run,
"filter_index" => run, "filter_index" => run,
"find_index" => run, "find_index" => run,
"first" => run, "first" => run,
@ -30,6 +31,7 @@ impl_module_def!(
"group" => run, "group" => run,
"insert" => run, "insert" => run,
"intersect" => run, "intersect" => run,
"is_empty" => run,
"join" => run, "join" => run,
"knn" => run, "knn" => run,
"last" => run, "last" => run,
@ -43,11 +45,14 @@ impl_module_def!(
"pop" => run, "pop" => run,
"push" => run, "push" => run,
"prepend" => run, "prepend" => run,
"range" => run,
"remove" => run, "remove" => run,
"repeat" => run,
"reverse" => run, "reverse" => run,
"shuffle" => run, "shuffle" => run,
"slice" => run, "slice" => run,
"sort" => (sort::Package), "sort" => (sort::Package),
"swap" => run,
"transpose" => run, "transpose" => run,
"union" => run, "union" => run,
"windows" => run "windows" => run

View file

@ -72,9 +72,8 @@ fn run(js_ctx: js::Ctx<'_>, name: &str, args: Vec<Value>) -> Result<Value> {
async fn fut(js_ctx: js::Ctx<'_>, name: &str, args: Vec<Value>) -> Result<Value> { async fn fut(js_ctx: js::Ctx<'_>, name: &str, args: Vec<Value>) -> Result<Value> {
let this = js_ctx.globals().get::<_, OwnedBorrow<QueryContext>>(QUERY_DATA_PROP_NAME)?; let this = js_ctx.globals().get::<_, OwnedBorrow<QueryContext>>(QUERY_DATA_PROP_NAME)?;
// Process the called function // Process the called function
let res = Stk::enter_run(|stk| { let res =
fnc::asynchronous(stk, this.context, Some(this.opt), this.doc, name, args) Stk::enter_run(|stk| fnc::asynchronous(stk, this.context, this.opt, this.doc, name, args))
})
.await; .await;
// Convert any response error // Convert any response error
res.map_err(|err| { res.map_err(|err| {

View file

@ -1,3 +1,4 @@
use crate::cnf::FUNCTION_ALLOCATION_LIMIT;
use crate::err::Error; use crate::err::Error;
use crate::fnc::util::string; use crate::fnc::util::string;
use crate::sql::value::Value; use crate::sql::value::Value;
@ -5,11 +6,10 @@ use crate::sql::Regex;
/// Returns `true` if a string of this length is too much to allocate. /// Returns `true` if a string of this length is too much to allocate.
fn limit(name: &str, n: usize) -> Result<(), Error> { fn limit(name: &str, n: usize) -> Result<(), Error> {
const LIMIT: usize = 2usize.pow(20); if n > *FUNCTION_ALLOCATION_LIMIT {
if n > LIMIT {
Err(Error::InvalidArguments { Err(Error::InvalidArguments {
name: name.to_owned(), name: name.to_owned(),
message: format!("Output must not exceed {LIMIT} bytes."), message: format!("Output must not exceed {} bytes.", *FUNCTION_ALLOCATION_LIMIT),
}) })
} else { } else {
Ok(()) Ok(())

View file

@ -60,7 +60,11 @@ impl Value {
stk.run(|stk| obj.get(stk, ctx, opt, doc, path)).await stk.run(|stk| obj.get(stk, ctx, opt, doc, path)).await
} }
Part::Method(name, args) => { Part::Method(name, args) => {
let v = idiom(ctx, doc, v.clone().into(), name, args.clone())?; let v = stk
.run(|stk| {
idiom(stk, ctx, opt, doc, v.clone().into(), name, args.clone())
})
.await?;
stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await
} }
// Otherwise return none // Otherwise return none
@ -141,7 +145,11 @@ impl Value {
stk.run(|stk| obj.get(stk, ctx, opt, doc, path.next())).await stk.run(|stk| obj.get(stk, ctx, opt, doc, path.next())).await
} }
Part::Method(name, args) => { Part::Method(name, args) => {
let res = idiom(ctx, doc, v.clone().into(), name, args.clone()); let res = stk
.run(|stk| {
idiom(stk, ctx, opt, doc, v.clone().into(), name, args.clone())
})
.await;
let res = match &res { let res = match &res {
Ok(_) => res, Ok(_) => res,
Err(Error::InvalidFunction { Err(Error::InvalidFunction {
@ -214,7 +222,11 @@ impl Value {
_ => Ok(Value::None), _ => Ok(Value::None),
}, },
Part::Method(name, args) => { Part::Method(name, args) => {
let v = idiom(ctx, doc, v.clone().into(), name, args.clone())?; let v = stk
.run(|stk| {
idiom(stk, ctx, opt, doc, v.clone().into(), name, args.clone())
})
.await?;
stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await
} }
_ => stk _ => stk
@ -302,7 +314,19 @@ impl Value {
} }
} }
Part::Method(name, args) => { Part::Method(name, args) => {
let v = idiom(ctx, doc, v.clone().into(), name, args.clone())?; let v = stk
.run(|stk| {
idiom(
stk,
ctx,
opt,
doc,
v.clone().into(),
name,
args.clone(),
)
})
.await?;
stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await
} }
// This is a remote field expression // This is a remote field expression
@ -325,7 +349,9 @@ impl Value {
stk.run(|stk| v.get(stk, ctx, opt, None, path.next())).await stk.run(|stk| v.get(stk, ctx, opt, None, path.next())).await
} }
Part::Method(name, args) => { Part::Method(name, args) => {
let v = idiom(ctx, doc, v.clone(), name, args.clone())?; let v = stk
.run(|stk| idiom(stk, ctx, opt, doc, v.clone(), name, args.clone()))
.await?;
stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await
} }
// Only continue processing the path from the point that it contains a method // Only continue processing the path from the point that it contains a method

View file

@ -97,6 +97,7 @@ pub(crate) static PATHS: phf::Map<UniCase<&'static str>, PathKind> = phf_map! {
UniCase::ascii("array::concat") => PathKind::Function, UniCase::ascii("array::concat") => PathKind::Function,
UniCase::ascii("array::difference") => PathKind::Function, UniCase::ascii("array::difference") => PathKind::Function,
UniCase::ascii("array::distinct") => PathKind::Function, UniCase::ascii("array::distinct") => PathKind::Function,
UniCase::ascii("array::fill") => PathKind::Function,
UniCase::ascii("array::filter_index") => PathKind::Function, UniCase::ascii("array::filter_index") => PathKind::Function,
UniCase::ascii("array::find_index") => PathKind::Function, UniCase::ascii("array::find_index") => PathKind::Function,
UniCase::ascii("array::first") => PathKind::Function, UniCase::ascii("array::first") => PathKind::Function,
@ -104,6 +105,7 @@ pub(crate) static PATHS: phf::Map<UniCase<&'static str>, PathKind> = phf_map! {
UniCase::ascii("array::group") => PathKind::Function, UniCase::ascii("array::group") => PathKind::Function,
UniCase::ascii("array::insert") => PathKind::Function, UniCase::ascii("array::insert") => PathKind::Function,
UniCase::ascii("array::intersect") => PathKind::Function, UniCase::ascii("array::intersect") => PathKind::Function,
UniCase::ascii("array::is_empty") => PathKind::Function,
UniCase::ascii("array::join") => PathKind::Function, UniCase::ascii("array::join") => PathKind::Function,
UniCase::ascii("array::last") => PathKind::Function, UniCase::ascii("array::last") => PathKind::Function,
UniCase::ascii("array::len") => PathKind::Function, UniCase::ascii("array::len") => PathKind::Function,
@ -111,16 +113,20 @@ pub(crate) static PATHS: phf::Map<UniCase<&'static str>, PathKind> = phf_map! {
UniCase::ascii("array::logical_or") => PathKind::Function, UniCase::ascii("array::logical_or") => PathKind::Function,
UniCase::ascii("array::logical_xor") => PathKind::Function, UniCase::ascii("array::logical_xor") => PathKind::Function,
UniCase::ascii("array::matches") => PathKind::Function, UniCase::ascii("array::matches") => PathKind::Function,
UniCase::ascii("array::map") => PathKind::Function,
UniCase::ascii("array::max") => PathKind::Function, UniCase::ascii("array::max") => PathKind::Function,
UniCase::ascii("array::min") => PathKind::Function, UniCase::ascii("array::min") => PathKind::Function,
UniCase::ascii("array::pop") => PathKind::Function, UniCase::ascii("array::pop") => PathKind::Function,
UniCase::ascii("array::prepend") => PathKind::Function, UniCase::ascii("array::prepend") => PathKind::Function,
UniCase::ascii("array::push") => PathKind::Function, UniCase::ascii("array::push") => PathKind::Function,
UniCase::ascii("array::remove") => PathKind::Function, UniCase::ascii("array::remove") => PathKind::Function,
UniCase::ascii("array::repeat") => PathKind::Function,
UniCase::ascii("array::range") => PathKind::Function,
UniCase::ascii("array::reverse") => PathKind::Function, UniCase::ascii("array::reverse") => PathKind::Function,
UniCase::ascii("array::shuffle") => PathKind::Function, UniCase::ascii("array::shuffle") => PathKind::Function,
UniCase::ascii("array::slice") => PathKind::Function, UniCase::ascii("array::slice") => PathKind::Function,
UniCase::ascii("array::sort") => PathKind::Function, UniCase::ascii("array::sort") => PathKind::Function,
UniCase::ascii("array::swap") => PathKind::Function,
UniCase::ascii("array::transpose") => PathKind::Function, UniCase::ascii("array::transpose") => PathKind::Function,
UniCase::ascii("array::union") => PathKind::Function, UniCase::ascii("array::union") => PathKind::Function,
UniCase::ascii("array::sort::asc") => PathKind::Function, UniCase::ascii("array::sort::asc") => PathKind::Function,

View file

@ -166,6 +166,7 @@
"array::concat(" "array::concat("
"array::difference(" "array::difference("
"array::distinct(" "array::distinct("
"array::fill("
"array::filter_index(" "array::filter_index("
"array::find_index(" "array::find_index("
"array::first(" "array::first("
@ -173,6 +174,7 @@
"array::group(" "array::group("
"array::insert(" "array::insert("
"array::intersect(" "array::intersect("
"array::is_empty("
"array::join(" "array::join("
"array::last(" "array::last("
"array::len(" "array::len("
@ -180,18 +182,22 @@
"array::logical_or(" "array::logical_or("
"array::logical_xor(" "array::logical_xor("
"array::matches(" "array::matches("
"array::map("
"array::max(" "array::max("
"array::min(" "array::min("
"array::pop(" "array::pop("
"array::prepend(" "array::prepend("
"array::push(" "array::push("
"array::range("
"array::remove(" "array::remove("
"array::repeat("
"array::reverse(" "array::reverse("
"array::shuffle(" "array::shuffle("
"array::slice(" "array::slice("
"array::sort(" "array::sort("
"array::sort::asc(" "array::sort::asc("
"array::sort::desc(" "array::sort::desc("
"array::swap("
"array::transpose(" "array::transpose("
"array::union(" "array::union("
"count(" "count("

View file

@ -165,6 +165,7 @@
"array::concat(" "array::concat("
"array::difference(" "array::difference("
"array::distinct(" "array::distinct("
"array::fill("
"array::filter_index(" "array::filter_index("
"array::find_index(" "array::find_index("
"array::first(" "array::first("
@ -172,6 +173,7 @@
"array::group(" "array::group("
"array::insert(" "array::insert("
"array::intersect(" "array::intersect("
"array::is_empty("
"array::join(" "array::join("
"array::last(" "array::last("
"array::len(" "array::len("
@ -179,18 +181,22 @@
"array::logical_or(" "array::logical_or("
"array::logical_xor(" "array::logical_xor("
"array::matches(" "array::matches("
"array::map("
"array::max(" "array::max("
"array::min(" "array::min("
"array::pop(" "array::pop("
"array::prepend(" "array::prepend("
"array::push(" "array::push("
"array::range("
"array::remove(" "array::remove("
"array::repeat("
"array::reverse(" "array::reverse("
"array::shuffle(" "array::shuffle("
"array::slice(" "array::slice("
"array::sort(" "array::sort("
"array::sort::asc(" "array::sort::asc("
"array::sort::desc(" "array::sort::desc("
"array::swap("
"array::transpose(" "array::transpose("
"array::union(" "array::union("
"count(" "count("

View file

@ -279,6 +279,32 @@ async fn function_array_distinct() -> Result<(), Error> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn function_array_fill() -> Result<(), Error> {
let sql = r#"
RETURN array::fill([1,2,3,4,5], 10);
RETURN array::fill([1,2,3,4,5], 10, 0, 7);
RETURN array::fill([1,NONE,NONE,NONE,NONE], 10, 1);
RETURN array::fill([1,NONE,3,4,5], 10, 1, 2);
RETURN array::fill([1,2,3,4,5], 10, 1, 1);
RETURN array::fill([1,2,3,4,5], 10, 7, 7);
RETURN array::fill([1,2,3,4,5], 10, 7, 9);
RETURN array::fill([1,2,NONE,4,5], 10, -3, -2);
"#;
//
Test::new(sql)
.await?
.expect_val("[10,10,10,10,10]")?
.expect_val("[10,10,10,10,10]")?
.expect_val("[1,10,10,10,10]")?
.expect_val("[1,10,3,4,5]")?
.expect_val("[1,2,3,4,5]")?
.expect_val("[1,2,3,4,5]")?
.expect_val("[1,2,3,4,5]")?
.expect_val("[1,2,10,4,5]")?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn function_array_filter_index() -> Result<(), Error> { async fn function_array_filter_index() -> Result<(), Error> {
let sql = r#"RETURN array::filter_index([0, 1, 2], 1); let sql = r#"RETURN array::filter_index([0, 1, 2], 1);
@ -448,6 +474,17 @@ async fn function_array_intersect() -> Result<(), Error> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn function_array_is_empty() -> Result<(), Error> {
let sql = r#"
RETURN array::is_empty([]);
RETURN array::is_empty([1,2,3,4,5]);
"#;
//
Test::new(sql).await?.expect_val("true")?.expect_val("false")?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn function_string_join_arr() -> Result<(), Error> { async fn function_string_join_arr() -> Result<(), Error> {
let sql = r#" let sql = r#"
@ -567,6 +604,16 @@ RETURN array::logical_xor([0, 1], []);"#,
Ok(()) Ok(())
} }
#[tokio::test]
async fn function_array_map() -> Result<(), Error> {
let sql = r#"
RETURN array::map([1,2,3], |$n, $i| $n + $i);
"#;
//
Test::new(sql).await?.expect_val("[1, 3, 5]")?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn function_array_matches() -> Result<(), Error> { async fn function_array_matches() -> Result<(), Error> {
test_queries( test_queries(
@ -729,6 +776,27 @@ async fn function_array_push() -> Result<(), Error> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn function_array_range() -> Result<(), Error> {
let sql = r#"
RETURN array::range(1, 10);
RETURN array::range(3, 1);
RETURN array::range(44, 0);
RETURN array::range(0, -1);
RETURN array::range(0, -256);
RETURN array::range(9223372036854775800, 100);
"#;
//
Test::new(sql).await?
.expect_val("[1,2,3,4,5,6,7,8,9,10]")?
.expect_val("[3]")?
.expect_val("[]")?
.expect_error("Incorrect arguments for function array::range(). Argument 1 was the wrong type. Expected a positive number but found -1")?
.expect_error("Incorrect arguments for function array::range(). Argument 1 was the wrong type. Expected a positive number but found -256")?
.expect_error("Incorrect arguments for function array::range(). The range overflowed the maximum value for an integer")?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn function_array_remove() -> Result<(), Error> { async fn function_array_remove() -> Result<(), Error> {
let sql = r#" let sql = r#"
@ -758,6 +826,27 @@ async fn function_array_remove() -> Result<(), Error> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn function_array_repeat() -> Result<(), Error> {
let sql = r#"
RETURN array::repeat(1, 10);
RETURN array::repeat("hello", 2);
RETURN array::repeat(NONE, 3);
RETURN array::repeat(44, 0);
RETURN array::repeat(0, -1);
RETURN array::repeat(0, -256);
"#;
//
Test::new(sql).await?
.expect_val("[1,1,1,1,1,1,1,1,1,1]")?
.expect_val(r#"["hello","hello"]"#)?
.expect_val("[NONE,NONE,NONE]")?
.expect_val("[]")?
.expect_error("Incorrect arguments for function array::repeat(). Output must not exceed 1048576 bytes.")?
.expect_error("Incorrect arguments for function array::repeat(). Output must not exceed 1048576 bytes.")?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn function_array_reverse() -> Result<(), Error> { async fn function_array_reverse() -> Result<(), Error> {
let sql = r#" let sql = r#"
@ -986,6 +1075,27 @@ async fn function_array_sort_desc() -> Result<(), Error> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn function_array_swap() -> Result<(), Error> {
let sql = r#"
RETURN array::swap([1,2,3,4,5], 1, 2);
RETURN array::swap([1,2,3,4,5], 1, 1);
RETURN array::swap([1,2,3,4,5], -1, -2);
RETURN array::swap([1,2,3,4,5], -5, -4);
RETURN array::swap([1,2,3,4,5], 8, 1);
RETURN array::swap([1,2,3,4,5], 1, -8);
"#;
//
Test::new(sql).await?
.expect_val("[1,3,2,4,5]")?
.expect_val("[1,2,3,4,5]")?
.expect_val("[1,2,3,5,4]")?
.expect_val("[2,1,3,4,5]")?
.expect_error("Incorrect arguments for function array::swap(). Argument 1 is out of range. Expected a number between -5 and 5")?
.expect_error("Incorrect arguments for function array::swap(). Argument 2 is out of range. Expected a number between -5 and 5")?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn function_array_transpose() -> Result<(), Error> { async fn function_array_transpose() -> Result<(), Error> {
let sql = r#" let sql = r#"
@ -6086,7 +6196,7 @@ pub async fn function_http_get_from_script() -> Result<(), Error> {
#[cfg(not(feature = "http"))] #[cfg(not(feature = "http"))]
#[tokio::test] #[tokio::test]
pub async fn function_http_disabled() { pub async fn function_http_disabled() -> Result<(), Error> {
Test::new( Test::new(
r#" r#"
RETURN http::get({}); RETURN http::get({});
@ -6097,8 +6207,7 @@ pub async fn function_http_disabled() {
RETURN http::delete({}); RETURN http::delete({});
"#, "#,
) )
.await .await?
.unwrap()
.expect_errors(&[ .expect_errors(&[
"Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled",
"Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled",
@ -6106,8 +6215,8 @@ pub async fn function_http_disabled() {
"Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled",
"Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled",
"Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled",
]) ])?;
.unwrap(); Ok(())
} }
// Tests for custom defined functions // Tests for custom defined functions