diff --git a/core/src/fnc/mod.rs b/core/src/fnc/mod.rs index 2bc46e4a..15146cf0 100644 --- a/core/src/fnc/mod.rs +++ b/core/src/fnc/mod.rs @@ -321,6 +321,7 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec) -> Result r#type::string, "type::table" => r#type::table, "type::thing" => r#type::thing, + "type::range" => r#type::range, "type::is::array" => r#type::is::array, "type::is::bool" => r#type::is::bool, "type::is::bytes" => r#type::is::bytes, diff --git a/core/src/fnc/script/modules/surrealdb/functions/type.rs b/core/src/fnc/script/modules/surrealdb/functions/type.rs index de5e6fc9..c3730b0f 100644 --- a/core/src/fnc/script/modules/surrealdb/functions/type.rs +++ b/core/src/fnc/script/modules/surrealdb/functions/type.rs @@ -25,5 +25,6 @@ impl_module_def!( "regex" => run, "string" => run, "table" => run, - "thing" => run + "thing" => run, + "range" => run ); diff --git a/core/src/fnc/type.rs b/core/src/fnc/type.rs index 9b8ffc32..21a39164 100644 --- a/core/src/fnc/type.rs +++ b/core/src/fnc/type.rs @@ -1,3 +1,5 @@ +use std::ops::Bound; + use crate::ctx::Context; use crate::dbs::{Options, Transaction}; use crate::doc::CursorDoc; @@ -5,6 +7,7 @@ use crate::err::Error; use crate::sql::table::Table; use crate::sql::thing::Thing; use crate::sql::value::Value; +use crate::sql::{Id, Range, Strand}; use crate::syn; pub fn bool((val,): (Value,)) -> Result { @@ -128,6 +131,93 @@ pub fn thing((arg1, arg2): (Value, Option)) -> Result { }) } +pub fn range(args: Vec) -> Result { + if args.len() > 4 || args.is_empty() { + return Err(Error::InvalidArguments { + name: "type::range".to_owned(), + message: "Expected atleast 1 and at most 4 arguments".to_owned(), + }); + } + let mut args = args.into_iter(); + + // Unwrap will never trigger since length is checked above. + let id = args.next().unwrap().as_string(); + let start = args.next().and_then(|x| match x { + Value::Thing(v) => Some(v.id), + Value::Array(v) => Some(v.into()), + Value::Object(v) => Some(v.into()), + Value::Number(v) => Some(v.into()), + Value::Null | Value::None => None, + v => Some(Id::from(v.as_string())), + }); + let end = args.next().and_then(|x| match x { + Value::Thing(v) => Some(v.id), + Value::Array(v) => Some(v.into()), + Value::Object(v) => Some(v.into()), + Value::Number(v) => Some(v.into()), + Value::Null | Value::None => None, + v => Some(Id::from(v.as_string())), + }); + let (begin, end) = if let Some(x) = args.next() { + let Value::Object(x) = x else { + return Err(Error::ConvertTo { + from: x, + into: "object".to_owned(), + }); + }; + let begin = if let Some(x) = x.get("begin") { + let start = start.ok_or_else(|| Error::InvalidArguments { + name: "type::range".to_string(), + message: "Can't define an inclusion for begin if there is no begin bound" + .to_string(), + })?; + match x { + Value::Strand(Strand(x)) if x == "included" => Bound::Included(start), + Value::Strand(Strand(x)) if x == "excluded" => Bound::Excluded(start), + x => { + return Err(Error::ConvertTo { + from: x.clone(), + into: r#""included" | "excluded""#.to_owned(), + }) + } + } + } else { + start.map(Bound::Included).unwrap_or(Bound::Unbounded) + }; + let end = if let Some(x) = x.get("end") { + let end = end.ok_or_else(|| Error::InvalidArguments { + name: "type::range".to_string(), + message: "Can't define an inclusion for end if there is no end bound".to_string(), + })?; + match x { + Value::Strand(Strand(x)) if x == "included" => Bound::Included(end), + Value::Strand(Strand(x)) if x == "excluded" => Bound::Excluded(end), + x => { + return Err(Error::ConvertTo { + from: x.clone(), + into: r#""included" | "excluded""#.to_owned(), + }) + } + } + } else { + end.map(Bound::Excluded).unwrap_or(Bound::Unbounded) + }; + (begin, end) + } else { + ( + start.map(Bound::Included).unwrap_or(Bound::Unbounded), + end.map(Bound::Excluded).unwrap_or(Bound::Unbounded), + ) + }; + + Ok(Range { + tb: id, + beg: begin, + end, + } + .into()) +} + pub mod is { use crate::err::Error; use crate::sql::table::Table; diff --git a/core/src/syn/v1/builtin.rs b/core/src/syn/v1/builtin.rs index 20a8ea97..8d7b5a7a 100644 --- a/core/src/syn/v1/builtin.rs +++ b/core/src/syn/v1/builtin.rs @@ -458,6 +458,7 @@ pub(crate) fn builtin_name(i: &str) -> IResult<&str, BuiltinName<&str>, ParseErr string => { fn }, table => { fn }, thing => { fn }, + range => { fn }, is => { array => { fn }, r#bool = "bool" => { fn }, diff --git a/core/src/syn/v2/parser/builtin.rs b/core/src/syn/v2/parser/builtin.rs index 40cf925f..46f69d7a 100644 --- a/core/src/syn/v2/parser/builtin.rs +++ b/core/src/syn/v2/parser/builtin.rs @@ -314,6 +314,7 @@ pub(crate) static PATHS: phf::Map, PathKind> = phf_map! { UniCase::ascii("type::string") => PathKind::Function, UniCase::ascii("type::table") => PathKind::Function, UniCase::ascii("type::thing") => PathKind::Function, + UniCase::ascii("type::range") => PathKind::Function, UniCase::ascii("type::is::array") => PathKind::Function, UniCase::ascii("type::is::bool") => PathKind::Function, UniCase::ascii("type::is::bytes") => PathKind::Function, diff --git a/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict b/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict index 09e028ae..f3aee565 100644 --- a/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict +++ b/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict @@ -393,6 +393,7 @@ "type::string(" "type::table(" "type::thing(" +"type::range(" "vector::add(" "vector::angle(" "vector::cross(" diff --git a/lib/tests/function.rs b/lib/tests/function.rs index 61697ca1..12af72e8 100644 --- a/lib/tests/function.rs +++ b/lib/tests/function.rs @@ -5702,6 +5702,42 @@ async fn function_type_thing() -> Result<(), Error> { Ok(()) } +#[tokio::test] +async fn function_type_range() -> Result<(), Error> { + let sql = r#" + RETURN type::range('person'); + RETURN type::range('person',1); + RETURN type::range('person',null,10); + RETURN type::range('person',1,10); + RETURN type::range('person',1,10, { begin: "excluded", end: "included"}); + "#; + let dbs = new_ds().await?; + let ses = Session::owner().with_ns("test").with_db("test"); + let res = &mut dbs.execute(sql, &ses, None).await?; + assert_eq!(res.len(), 5); + // + let tmp = res.remove(0).result?; + let val = Value::parse("person:.."); + assert_eq!(tmp, val); + + let tmp = res.remove(0).result?; + let val = Value::parse("person:1.."); + assert_eq!(tmp, val); + + let tmp = res.remove(0).result?; + let val = Value::parse("person:..10"); + assert_eq!(tmp, val); + + let tmp = res.remove(0).result?; + let val = Value::parse("person:1..10"); + assert_eq!(tmp, val); + + let tmp = res.remove(0).result?; + let val = Value::parse("person:1>..=10"); + assert_eq!(tmp, val); + Ok(()) +} + #[tokio::test] async fn function_vector_add() -> Result<(), Error> { test_queries(