From c0be139e590e7d969cf4e2db1a628e5cb7162dbc Mon Sep 17 00:00:00 2001 From: Micha de Vries Date: Sat, 10 Aug 2024 13:44:12 +0100 Subject: [PATCH] Implement anonymous functions (#4474) Co-authored-by: Tobie Morgan Hitchcock --- core/src/ctx/context.rs | 34 +++- core/src/idx/planner/rewriter.rs | 6 +- core/src/sql/closure.rs | 88 +++++++++++ core/src/sql/function.rs | 37 +++++ core/src/sql/kind.rs | 5 +- core/src/sql/mod.rs | 2 + core/src/sql/value/get.rs | 25 ++- core/src/sql/value/serde/de/mod.rs | 1 + core/src/sql/value/serde/ser/closure/mod.rs | 98 ++++++++++++ core/src/sql/value/serde/ser/function/mod.rs | 17 +- core/src/sql/value/serde/ser/kind/mod.rs | 80 ++++++---- .../serde/ser/kind/{vec.rs => vec/mod.rs} | 2 + core/src/sql/value/serde/ser/kind/vec/opt.rs | 56 +++++++ core/src/sql/value/serde/ser/mod.rs | 1 + .../serde/ser/statement/define/function.rs | 4 +- .../value/serde/ser/statement/define/mod.rs | 2 +- core/src/sql/value/serde/ser/value/mod.rs | 6 + core/src/sql/value/value.rs | 38 +++++ core/src/syn/parser/kind.rs | 33 ++-- core/src/syn/parser/prime.rs | 148 ++++++++++++++++-- lib/tests/closure.rs | 101 ++++++++++++ 21 files changed, 713 insertions(+), 71 deletions(-) create mode 100644 core/src/sql/closure.rs create mode 100644 core/src/sql/value/serde/ser/closure/mod.rs rename core/src/sql/value/serde/ser/kind/{vec.rs => vec/mod.rs} (99%) create mode 100644 core/src/sql/value/serde/ser/kind/vec/opt.rs create mode 100644 lib/tests/closure.rs diff --git a/core/src/ctx/context.rs b/core/src/ctx/context.rs index bda3d6b9..ee5b37f1 100644 --- a/core/src/ctx/context.rs +++ b/core/src/ctx/context.rs @@ -72,6 +72,8 @@ pub struct Context<'a> { temporary_directory: Option>, // An optional transaction transaction: Option>, + // Does not read from parent `values`. + isolated: bool, } impl<'a> Default for Context<'a> { @@ -131,6 +133,7 @@ impl<'a> Context<'a> { ))] temporary_directory, transaction: None, + isolated: false, }; if let Some(timeout) = time_out { ctx.add_timeout(timeout)?; @@ -159,6 +162,7 @@ impl<'a> Context<'a> { ))] temporary_directory: None, transaction: None, + isolated: false, } } @@ -184,6 +188,33 @@ impl<'a> Context<'a> { ))] temporary_directory: parent.temporary_directory.clone(), transaction: parent.transaction.clone(), + isolated: false, + } + } + + /// Create a new child from a frozen context. + pub fn new_isolated(parent: &'a Context) -> Self { + Context { + values: HashMap::default(), + parent: Some(parent), + deadline: parent.deadline, + cancelled: Arc::new(AtomicBool::new(false)), + notifications: parent.notifications.clone(), + query_planner: parent.query_planner, + query_executor: parent.query_executor.clone(), + iteration_stage: parent.iteration_stage.clone(), + capabilities: parent.capabilities.clone(), + index_stores: parent.index_stores.clone(), + #[cfg(any( + feature = "kv-mem", + feature = "kv-surrealkv", + feature = "kv-rocksdb", + feature = "kv-fdb", + feature = "kv-tikv", + ))] + temporary_directory: parent.temporary_directory.clone(), + transaction: parent.transaction.clone(), + isolated: true, } } @@ -334,10 +365,11 @@ impl<'a> Context<'a> { Cow::Borrowed(v) => Some(*v), Cow::Owned(v) => Some(v), }, - None => match self.parent { + None if !self.isolated => match self.parent { Some(p) => p.value(key), _ => None, }, + None => None, } } diff --git a/core/src/idx/planner/rewriter.rs b/core/src/idx/planner/rewriter.rs index b72a46aa..56ec19db 100644 --- a/core/src/idx/planner/rewriter.rs +++ b/core/src/idx/planner/rewriter.rs @@ -47,7 +47,8 @@ impl<'a> KnnConditionRewriter<'a> { | Value::Table(_) | Value::Mock(_) | Value::Regex(_) - | Value::Constant(_) => Some(v.clone()), + | Value::Constant(_) + | Value::Closure(_) => Some(v.clone()), } } @@ -208,6 +209,9 @@ impl<'a> KnnConditionRewriter<'a> { Function::Script(s, args) => { self.eval_values(args).map(|args| Function::Script(s.clone(), args)) } + Function::Anonymous(p, args) => { + self.eval_values(args).map(|args| Function::Anonymous(p.clone(), args)) + } } } diff --git a/core/src/sql/closure.rs b/core/src/sql/closure.rs new file mode 100644 index 00000000..14ee9b69 --- /dev/null +++ b/core/src/sql/closure.rs @@ -0,0 +1,88 @@ +use crate::{ctx::Context, dbs::Options, doc::CursorDoc, err::Error, sql::value::Value}; +use reblessive::tree::Stk; +use revision::revisioned; +use serde::{Deserialize, Serialize}; +use std::fmt; + +use super::{Ident, Kind}; + +pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Closure"; + +#[revisioned(revision = 1)] +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] +#[serde(rename = "$surrealdb::private::sql::Closure")] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[non_exhaustive] +pub struct Closure { + pub args: Vec<(Ident, Kind)>, + pub returns: Option, + pub body: Value, +} + +impl Closure { + pub(crate) async fn compute( + &self, + stk: &mut Stk, + ctx: &Context<'_>, + opt: &Options, + doc: Option<&CursorDoc<'_>>, + args: Vec, + ) -> Result { + let mut ctx = Context::new_isolated(ctx); + for (i, (name, kind)) in self.args.iter().enumerate() { + match (kind, args.get(i)) { + (Kind::Option(_), None) => continue, + (_, None) => { + return Err(Error::InvalidArguments { + name: "ANONYMOUS".to_string(), + message: format!("Expected a value for ${}", name), + }) + } + (kind, Some(val)) => { + if let Ok(val) = val.to_owned().coerce_to(kind) { + ctx.add_value(name.to_string(), val); + } else { + return Err(Error::InvalidArguments { + name: "ANONYMOUS".to_string(), + message: format!( + "Expected a value of type '{kind}' for argument ${}", + name + ), + }); + } + } + } + } + + let result = self.body.compute(stk, &ctx, opt, doc).await?; + if let Some(returns) = &self.returns { + if let Ok(result) = result.clone().coerce_to(returns) { + Ok(result) + } else { + Err(Error::InvalidFunction { + name: "ANONYMOUS".to_string(), + message: format!("Expected this closure to return a value of type '{returns}', but found '{}'", result.kindof()), + }) + } + } else { + Ok(result) + } + } +} + +impl fmt::Display for Closure { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("|")?; + for (i, (name, kind)) in self.args.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + write!(f, "${name}: {kind}")?; + } + f.write_str("|")?; + if let Some(returns) = &self.returns { + write!(f, " -> {returns}")?; + } + write!(f, " {}", self.body) + } +} diff --git a/core/src/sql/function.rs b/core/src/sql/function.rs index c39fa187..efb2e848 100644 --- a/core/src/sql/function.rs +++ b/core/src/sql/function.rs @@ -29,6 +29,7 @@ pub enum Function { Normal(String, Vec), Custom(String, Vec), Script(Script, Vec), + Anonymous(Value, Vec), // Add new variants here } @@ -71,6 +72,7 @@ impl Function { /// Convert function call to a field name pub fn to_idiom(&self) -> Idiom { match self { + Self::Anonymous(_, _) => "function".to_string().into(), Self::Script(_, _) => "function".to_string().into(), Self::Normal(f, _) => f.to_owned().into(), Self::Custom(f, _) => format!("fn::{f}").into(), @@ -111,6 +113,11 @@ impl Function { } } + /// Check if this function is a closure function + pub fn is_inline(&self) -> bool { + matches!(self, Self::Anonymous(_, _)) + } + /// Check if this function is a rolling function pub fn is_rolling(&self) -> bool { match self { @@ -204,6 +211,35 @@ impl Function { // Run the normal function fnc::run(stk, ctx, opt, doc, s, a).await } + Self::Anonymous(v, x) => { + let val = match v { + Value::Closure(p) => &Value::Closure(p.to_owned()), + Value::Param(p) => ctx.value(p).unwrap_or(&Value::None), + Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => { + &stk.run(|stk| v.compute(stk, ctx, opt, doc)).await? + } + _ => &Value::None, + }; + + match val { + Value::Closure(closure) => { + // Compute the function arguments + let a = stk + .scope(|scope| { + try_join_all( + x.iter() + .map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))), + ) + }) + .await?; + stk.run(|stk| closure.compute(stk, ctx, opt, doc, a)).await + } + v => Err(Error::InvalidFunction { + name: "ANONYMOUS".to_string(), + message: format!("'{}' is not a function", v.kindof()), + }), + } + } Self::Custom(s, x) => { // Get the full name of this function let name = format!("fn::{s}"); @@ -308,6 +344,7 @@ impl fmt::Display for Function { Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)), Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)), Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)), + Self::Anonymous(p, e) => write!(f, "{p}({})", Fmt::comma_separated(e)), } } } diff --git a/core/src/sql/kind.rs b/core/src/sql/kind.rs index 17434359..b57e1ab2 100644 --- a/core/src/sql/kind.rs +++ b/core/src/sql/kind.rs @@ -29,6 +29,7 @@ pub enum Kind { Either(Vec), Set(Box, Option), Array(Box, Option), + Function(Option>, Option>), } impl Default for Kind { @@ -71,7 +72,8 @@ impl Kind { | Kind::String | Kind::Uuid | Kind::Record(_) - | Kind::Geometry(_) => return None, + | Kind::Geometry(_) + | Kind::Function(_, _) => return None, Kind::Option(x) => { this = x; } @@ -114,6 +116,7 @@ impl Display for Kind { Kind::Point => f.write_str("point"), Kind::String => f.write_str("string"), Kind::Uuid => f.write_str("uuid"), + Kind::Function(_, _) => f.write_str("function"), Kind::Option(k) => write!(f, "option<{}>", k), Kind::Record(k) => match k { k if k.is_empty() => write!(f, "record"), diff --git a/core/src/sql/mod.rs b/core/src/sql/mod.rs index 698111d8..87f867b1 100644 --- a/core/src/sql/mod.rs +++ b/core/src/sql/mod.rs @@ -12,6 +12,7 @@ pub(crate) mod bytes; pub(crate) mod cast; pub(crate) mod change_feed_include; pub(crate) mod changefeed; +pub(crate) mod closure; pub(crate) mod cond; pub(crate) mod constant; pub(crate) mod data; @@ -88,6 +89,7 @@ pub use self::block::Entry; pub use self::bytes::Bytes; pub use self::cast::Cast; pub use self::changefeed::ChangeFeed; +pub use self::closure::Closure; pub use self::cond::Cond; pub use self::constant::Constant; pub use self::data::Data; diff --git a/core/src/sql/value/get.rs b/core/src/sql/value/get.rs index 85e99a0e..c1f69fc0 100644 --- a/core/src/sql/value/get.rs +++ b/core/src/sql/value/get.rs @@ -16,6 +16,7 @@ use crate::sql::paths::ID; use crate::sql::statements::select::SelectStatement; use crate::sql::thing::Thing; use crate::sql::value::{Value, Values}; +use crate::sql::Function; use reblessive::tree::Stk; impl Value { @@ -140,8 +141,28 @@ impl Value { stk.run(|stk| obj.get(stk, ctx, opt, doc, path.next())).await } Part::Method(name, args) => { - let v = idiom(ctx, doc, v.clone().into(), name, args.clone())?; - stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await + let res = idiom(ctx, doc, v.clone().into(), name, args.clone()); + let res = match &res { + Ok(_) => res, + Err(Error::InvalidFunction { + .. + }) => match v.get(name) { + Some(v) => { + let fnc = Function::Anonymous(v.clone(), args.clone()); + match stk.run(|stk| fnc.compute(stk, ctx, opt, doc)).await { + Ok(v) => Ok(v), + Err(Error::InvalidFunction { + .. + }) => res, + e => e, + } + } + None => res, + }, + _ => res, + }?; + + stk.run(|stk| res.get(stk, ctx, opt, doc, path.next())).await } _ => Ok(Value::None), }, diff --git a/core/src/sql/value/serde/de/mod.rs b/core/src/sql/value/serde/de/mod.rs index 8e713dab..e89f4f4c 100644 --- a/core/src/sql/value/serde/de/mod.rs +++ b/core/src/sql/value/serde/de/mod.rs @@ -277,6 +277,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue { Value::Query(query) => json!(query), Value::Subquery(subquery) => json!(subquery), Value::Expression(expression) => json!(expression), + Value::Closure(closure) => json!(closure), } } diff --git a/core/src/sql/value/serde/ser/closure/mod.rs b/core/src/sql/value/serde/ser/closure/mod.rs new file mode 100644 index 00000000..759c92f8 --- /dev/null +++ b/core/src/sql/value/serde/ser/closure/mod.rs @@ -0,0 +1,98 @@ +use crate::err::Error; +use crate::sql::value::serde::ser; +use crate::sql::Closure; +use crate::sql::Ident; +use crate::sql::Kind; +use crate::sql::Value; +use ser::statement::define::function::IdentKindVecSerializer; +use ser::Serializer as _; +use serde::ser::Error as _; +use serde::ser::Serialize; + +#[derive(Default)] +pub(super) struct SerializeClosure { + args: Option>, + returns: Option>, + body: Option, +} + +impl serde::ser::SerializeStruct for SerializeClosure { + type Ok = Closure; + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Error> + where + T: ?Sized + Serialize, + { + match key { + "args" => { + self.args = Some(value.serialize(IdentKindVecSerializer.wrap())?); + } + "returns" => { + self.returns = Some(value.serialize(ser::kind::opt::Serializer.wrap())?); + } + "body" => { + self.body = Some(value.serialize(ser::value::Serializer.wrap())?); + } + key => { + return Err(Error::custom(format!("unexpected field `Closure::{key}`"))); + } + } + Ok(()) + } + + fn end(self) -> Result { + match (self.args, self.returns, self.body) { + (Some(args), Some(returns), Some(body)) => Ok(Closure { + args, + returns, + body, + }), + _ => Err(Error::custom("`Closure` missing required field(s)")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::ser::Impossible; + use serde::Serialize; + + pub(super) struct Serializer; + + impl ser::Serializer for Serializer { + type Ok = Closure; + type Error = Error; + + type SerializeSeq = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + type SerializeMap = Impossible; + type SerializeStruct = SerializeClosure; + type SerializeStructVariant = Impossible; + + const EXPECTED: &'static str = "a struct `Closure`"; + + #[inline] + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(SerializeClosure::default()) + } + } + + #[test] + fn closure() { + let closure = Closure { + args: Vec::new(), + returns: None, + body: Value::default(), + }; + let serialized = closure.serialize(Serializer.wrap()).unwrap(); + assert_eq!(closure, serialized); + } +} diff --git a/core/src/sql/value/serde/ser/function/mod.rs b/core/src/sql/value/serde/ser/function/mod.rs index 55a478ab..8a9ddf80 100644 --- a/core/src/sql/value/serde/ser/function/mod.rs +++ b/core/src/sql/value/serde/ser/function/mod.rs @@ -35,6 +35,7 @@ impl ser::Serializer for Serializer { "Normal" => Inner::Normal(None, None), "Custom" => Inner::Custom(None, None), "Script" => Inner::Script(None, None), + "Anonymous" => Inner::Anonymous(None, None), variant => { return Err(Error::custom(format!("unexpected tuple variant `{name}::{variant}`"))); } @@ -55,6 +56,7 @@ enum Inner { Normal(Option, Option>), Custom(Option, Option>), Script(Option