Implement anonymous functions (#4474)

Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
Micha de Vries 2024-08-10 13:44:12 +01:00 committed by GitHub
parent 3f5ef43248
commit c0be139e59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 713 additions and 71 deletions

View file

@ -72,6 +72,8 @@ pub struct Context<'a> {
temporary_directory: Option<Arc<PathBuf>>,
// An optional transaction
transaction: Option<Arc<Transaction>>,
// Does not read from parent `values`.
isolated: bool,
}
impl<'a> Default for Context<'a> {
@ -131,6 +133,7 @@ impl<'a> Context<'a> {
))]
temporary_directory,
transaction: None,
isolated: false,
};
if let Some(timeout) = time_out {
ctx.add_timeout(timeout)?;
@ -159,6 +162,7 @@ impl<'a> Context<'a> {
))]
temporary_directory: None,
transaction: None,
isolated: false,
}
}
@ -184,6 +188,33 @@ impl<'a> Context<'a> {
))]
temporary_directory: parent.temporary_directory.clone(),
transaction: parent.transaction.clone(),
isolated: false,
}
}
/// Create a new child from a frozen context.
pub fn new_isolated(parent: &'a Context) -> Self {
Context {
values: HashMap::default(),
parent: Some(parent),
deadline: parent.deadline,
cancelled: Arc::new(AtomicBool::new(false)),
notifications: parent.notifications.clone(),
query_planner: parent.query_planner,
query_executor: parent.query_executor.clone(),
iteration_stage: parent.iteration_stage.clone(),
capabilities: parent.capabilities.clone(),
index_stores: parent.index_stores.clone(),
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
))]
temporary_directory: parent.temporary_directory.clone(),
transaction: parent.transaction.clone(),
isolated: true,
}
}
@ -334,10 +365,11 @@ impl<'a> Context<'a> {
Cow::Borrowed(v) => Some(*v),
Cow::Owned(v) => Some(v),
},
None => match self.parent {
None if !self.isolated => match self.parent {
Some(p) => p.value(key),
_ => None,
},
None => None,
}
}

View file

@ -47,7 +47,8 @@ impl<'a> KnnConditionRewriter<'a> {
| Value::Table(_)
| Value::Mock(_)
| Value::Regex(_)
| Value::Constant(_) => Some(v.clone()),
| Value::Constant(_)
| Value::Closure(_) => Some(v.clone()),
}
}
@ -208,6 +209,9 @@ impl<'a> KnnConditionRewriter<'a> {
Function::Script(s, args) => {
self.eval_values(args).map(|args| Function::Script(s.clone(), args))
}
Function::Anonymous(p, args) => {
self.eval_values(args).map(|args| Function::Anonymous(p.clone(), args))
}
}
}

88
core/src/sql/closure.rs Normal file
View file

@ -0,0 +1,88 @@
use crate::{ctx::Context, dbs::Options, doc::CursorDoc, err::Error, sql::value::Value};
use reblessive::tree::Stk;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
use super::{Ident, Kind};
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Closure";
#[revisioned(revision = 1)]
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Closure")]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
pub struct Closure {
pub args: Vec<(Ident, Kind)>,
pub returns: Option<Kind>,
pub body: Value,
}
impl Closure {
pub(crate) async fn compute(
&self,
stk: &mut Stk,
ctx: &Context<'_>,
opt: &Options,
doc: Option<&CursorDoc<'_>>,
args: Vec<Value>,
) -> Result<Value, Error> {
let mut ctx = Context::new_isolated(ctx);
for (i, (name, kind)) in self.args.iter().enumerate() {
match (kind, args.get(i)) {
(Kind::Option(_), None) => continue,
(_, None) => {
return Err(Error::InvalidArguments {
name: "ANONYMOUS".to_string(),
message: format!("Expected a value for ${}", name),
})
}
(kind, Some(val)) => {
if let Ok(val) = val.to_owned().coerce_to(kind) {
ctx.add_value(name.to_string(), val);
} else {
return Err(Error::InvalidArguments {
name: "ANONYMOUS".to_string(),
message: format!(
"Expected a value of type '{kind}' for argument ${}",
name
),
});
}
}
}
}
let result = self.body.compute(stk, &ctx, opt, doc).await?;
if let Some(returns) = &self.returns {
if let Ok(result) = result.clone().coerce_to(returns) {
Ok(result)
} else {
Err(Error::InvalidFunction {
name: "ANONYMOUS".to_string(),
message: format!("Expected this closure to return a value of type '{returns}', but found '{}'", result.kindof()),
})
}
} else {
Ok(result)
}
}
}
impl fmt::Display for Closure {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("|")?;
for (i, (name, kind)) in self.args.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "${name}: {kind}")?;
}
f.write_str("|")?;
if let Some(returns) = &self.returns {
write!(f, " -> {returns}")?;
}
write!(f, " {}", self.body)
}
}

View file

@ -29,6 +29,7 @@ pub enum Function {
Normal(String, Vec<Value>),
Custom(String, Vec<Value>),
Script(Script, Vec<Value>),
Anonymous(Value, Vec<Value>),
// Add new variants here
}
@ -71,6 +72,7 @@ impl Function {
/// Convert function call to a field name
pub fn to_idiom(&self) -> Idiom {
match self {
Self::Anonymous(_, _) => "function".to_string().into(),
Self::Script(_, _) => "function".to_string().into(),
Self::Normal(f, _) => f.to_owned().into(),
Self::Custom(f, _) => format!("fn::{f}").into(),
@ -111,6 +113,11 @@ impl Function {
}
}
/// Check if this function is a closure function
pub fn is_inline(&self) -> bool {
matches!(self, Self::Anonymous(_, _))
}
/// Check if this function is a rolling function
pub fn is_rolling(&self) -> bool {
match self {
@ -204,6 +211,35 @@ impl Function {
// Run the normal function
fnc::run(stk, ctx, opt, doc, s, a).await
}
Self::Anonymous(v, x) => {
let val = match v {
Value::Closure(p) => &Value::Closure(p.to_owned()),
Value::Param(p) => ctx.value(p).unwrap_or(&Value::None),
Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
&stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
}
_ => &Value::None,
};
match val {
Value::Closure(closure) => {
// Compute the function arguments
let a = stk
.scope(|scope| {
try_join_all(
x.iter()
.map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
)
})
.await?;
stk.run(|stk| closure.compute(stk, ctx, opt, doc, a)).await
}
v => Err(Error::InvalidFunction {
name: "ANONYMOUS".to_string(),
message: format!("'{}' is not a function", v.kindof()),
}),
}
}
Self::Custom(s, x) => {
// Get the full name of this function
let name = format!("fn::{s}");
@ -308,6 +344,7 @@ impl fmt::Display for Function {
Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
Self::Anonymous(p, e) => write!(f, "{p}({})", Fmt::comma_separated(e)),
}
}
}

View file

@ -29,6 +29,7 @@ pub enum Kind {
Either(Vec<Kind>),
Set(Box<Kind>, Option<u64>),
Array(Box<Kind>, Option<u64>),
Function(Option<Vec<Kind>>, Option<Box<Kind>>),
}
impl Default for Kind {
@ -71,7 +72,8 @@ impl Kind {
| Kind::String
| Kind::Uuid
| Kind::Record(_)
| Kind::Geometry(_) => return None,
| Kind::Geometry(_)
| Kind::Function(_, _) => return None,
Kind::Option(x) => {
this = x;
}
@ -114,6 +116,7 @@ impl Display for Kind {
Kind::Point => f.write_str("point"),
Kind::String => f.write_str("string"),
Kind::Uuid => f.write_str("uuid"),
Kind::Function(_, _) => f.write_str("function"),
Kind::Option(k) => write!(f, "option<{}>", k),
Kind::Record(k) => match k {
k if k.is_empty() => write!(f, "record"),

View file

@ -12,6 +12,7 @@ pub(crate) mod bytes;
pub(crate) mod cast;
pub(crate) mod change_feed_include;
pub(crate) mod changefeed;
pub(crate) mod closure;
pub(crate) mod cond;
pub(crate) mod constant;
pub(crate) mod data;
@ -88,6 +89,7 @@ pub use self::block::Entry;
pub use self::bytes::Bytes;
pub use self::cast::Cast;
pub use self::changefeed::ChangeFeed;
pub use self::closure::Closure;
pub use self::cond::Cond;
pub use self::constant::Constant;
pub use self::data::Data;

View file

@ -16,6 +16,7 @@ use crate::sql::paths::ID;
use crate::sql::statements::select::SelectStatement;
use crate::sql::thing::Thing;
use crate::sql::value::{Value, Values};
use crate::sql::Function;
use reblessive::tree::Stk;
impl Value {
@ -140,8 +141,28 @@ impl Value {
stk.run(|stk| obj.get(stk, ctx, opt, doc, path.next())).await
}
Part::Method(name, args) => {
let v = idiom(ctx, doc, v.clone().into(), name, args.clone())?;
stk.run(|stk| v.get(stk, ctx, opt, doc, path.next())).await
let res = idiom(ctx, doc, v.clone().into(), name, args.clone());
let res = match &res {
Ok(_) => res,
Err(Error::InvalidFunction {
..
}) => match v.get(name) {
Some(v) => {
let fnc = Function::Anonymous(v.clone(), args.clone());
match stk.run(|stk| fnc.compute(stk, ctx, opt, doc)).await {
Ok(v) => Ok(v),
Err(Error::InvalidFunction {
..
}) => res,
e => e,
}
}
None => res,
},
_ => res,
}?;
stk.run(|stk| res.get(stk, ctx, opt, doc, path.next())).await
}
_ => Ok(Value::None),
},

View file

@ -277,6 +277,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
Value::Query(query) => json!(query),
Value::Subquery(subquery) => json!(subquery),
Value::Expression(expression) => json!(expression),
Value::Closure(closure) => json!(closure),
}
}

View file

@ -0,0 +1,98 @@
use crate::err::Error;
use crate::sql::value::serde::ser;
use crate::sql::Closure;
use crate::sql::Ident;
use crate::sql::Kind;
use crate::sql::Value;
use ser::statement::define::function::IdentKindVecSerializer;
use ser::Serializer as _;
use serde::ser::Error as _;
use serde::ser::Serialize;
#[derive(Default)]
pub(super) struct SerializeClosure {
args: Option<Vec<(Ident, Kind)>>,
returns: Option<Option<Kind>>,
body: Option<Value>,
}
impl serde::ser::SerializeStruct for SerializeClosure {
type Ok = Closure;
type Error = Error;
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Error>
where
T: ?Sized + Serialize,
{
match key {
"args" => {
self.args = Some(value.serialize(IdentKindVecSerializer.wrap())?);
}
"returns" => {
self.returns = Some(value.serialize(ser::kind::opt::Serializer.wrap())?);
}
"body" => {
self.body = Some(value.serialize(ser::value::Serializer.wrap())?);
}
key => {
return Err(Error::custom(format!("unexpected field `Closure::{key}`")));
}
}
Ok(())
}
fn end(self) -> Result<Self::Ok, Error> {
match (self.args, self.returns, self.body) {
(Some(args), Some(returns), Some(body)) => Ok(Closure {
args,
returns,
body,
}),
_ => Err(Error::custom("`Closure` missing required field(s)")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::ser::Impossible;
use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Closure;
type Error = Error;
type SerializeSeq = Impossible<Closure, Error>;
type SerializeTuple = Impossible<Closure, Error>;
type SerializeTupleStruct = Impossible<Closure, Error>;
type SerializeTupleVariant = Impossible<Closure, Error>;
type SerializeMap = Impossible<Closure, Error>;
type SerializeStruct = SerializeClosure;
type SerializeStructVariant = Impossible<Closure, Error>;
const EXPECTED: &'static str = "a struct `Closure`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeClosure::default())
}
}
#[test]
fn closure() {
let closure = Closure {
args: Vec::new(),
returns: None,
body: Value::default(),
};
let serialized = closure.serialize(Serializer.wrap()).unwrap();
assert_eq!(closure, serialized);
}
}

View file

@ -35,6 +35,7 @@ impl ser::Serializer for Serializer {
"Normal" => Inner::Normal(None, None),
"Custom" => Inner::Custom(None, None),
"Script" => Inner::Script(None, None),
"Anonymous" => Inner::Anonymous(None, None),
variant => {
return Err(Error::custom(format!("unexpected tuple variant `{name}::{variant}`")));
}
@ -55,6 +56,7 @@ enum Inner {
Normal(Option<String>, Option<Vec<Value>>),
Custom(Option<String>, Option<Vec<Value>>),
Script(Option<Script>, Option<Vec<Value>>),
Anonymous(Option<Value>, Option<Vec<Value>>),
}
impl serde::ser::SerializeTupleVariant for SerializeFunction {
@ -72,11 +74,15 @@ impl serde::ser::SerializeTupleVariant for SerializeFunction {
(0, Inner::Script(ref mut var, _)) => {
*var = Some(Script(value.serialize(ser::string::Serializer.wrap())?));
}
(0, Inner::Anonymous(ref mut var, _)) => {
*var = Some(value.serialize(ser::value::Serializer.wrap())?);
}
(
1,
Inner::Normal(_, ref mut var)
| Inner::Custom(_, ref mut var)
| Inner::Script(_, ref mut var),
| Inner::Script(_, ref mut var)
| Inner::Anonymous(_, ref mut var),
) => {
*var = Some(value.serialize(ser::value::vec::Serializer.wrap())?);
}
@ -85,6 +91,7 @@ impl serde::ser::SerializeTupleVariant for SerializeFunction {
Inner::Normal(..) => "Normal",
Inner::Custom(..) => "Custom",
Inner::Script(..) => "Script",
Inner::Anonymous(..) => "Anonymous",
};
return Err(Error::custom(format!(
"unexpected `Function::{variant}` index `{index}`"
@ -100,6 +107,7 @@ impl serde::ser::SerializeTupleVariant for SerializeFunction {
Inner::Normal(Some(one), Some(two)) => Ok(Function::Normal(one, two)),
Inner::Custom(Some(one), Some(two)) => Ok(Function::Custom(one, two)),
Inner::Script(Some(one), Some(two)) => Ok(Function::Script(one, two)),
Inner::Anonymous(Some(one), Some(two)) => Ok(Function::Anonymous(one, two)),
_ => Err(Error::custom("`Function` missing required value(s)")),
}
}
@ -130,4 +138,11 @@ mod tests {
let serialized = function.serialize(Serializer.wrap()).unwrap();
assert_eq!(function, serialized);
}
#[test]
fn anonymous() {
let function = Function::Anonymous(Default::default(), vec![Default::default()]);
let serialized = function.serialize(Serializer.wrap()).unwrap();
assert_eq!(function, serialized);
}
}

View file

@ -81,26 +81,30 @@ impl ser::Serializer for Serializer {
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant, Self::Error> {
match variant {
"Set" => Ok(SerializeKindTuple {
variant,
..Default::default()
}),
"Array" => Ok(SerializeKindTuple {
variant,
..Default::default()
}),
variant => Err(Error::custom(format!("unexpected tuple variant `{name}::{variant}`"))),
}
let inner = match variant {
"Set" => Inner::Set(Default::default(), Default::default()),
"Array" => Inner::Array(Default::default(), Default::default()),
"Function" => Inner::Function(Default::default(), Default::default()),
variant => {
return Err(Error::custom(format!("unexpected tuple variant `{name}::{variant}`")));
}
};
Ok(SerializeKindTuple {
inner,
index: 0,
})
}
}
#[derive(Default)]
pub(super) struct SerializeKindTuple {
index: usize,
variant: &'static str,
kind: Option<Kind>,
num: Option<u64>,
inner: Inner,
}
enum Inner {
Set(Box<Kind>, Option<u64>),
Array(Box<Kind>, Option<u64>),
Function(Option<Vec<Kind>>, Option<Box<Kind>>),
}
impl serde::ser::SerializeTupleVariant for SerializeKindTuple {
@ -111,15 +115,25 @@ impl serde::ser::SerializeTupleVariant for SerializeKindTuple {
where
T: Serialize + ?Sized,
{
match self.index {
0 => {
self.kind = Some(value.serialize(Serializer.wrap())?);
match (self.index, &mut self.inner) {
(0, Inner::Set(ref mut var, _) | Inner::Array(ref mut var, _)) => {
*var = Box::new(value.serialize(Serializer.wrap())?);
}
1 => {
self.num = value.serialize(ser::primitive::u64::opt::Serializer.wrap())?;
(1, Inner::Set(_, ref mut var) | Inner::Array(_, ref mut var)) => {
*var = value.serialize(ser::primitive::u64::opt::Serializer.wrap())?;
}
index => {
let variant = self.variant;
(0, Inner::Function(ref mut var, _)) => {
*var = value.serialize(ser::kind::vec::opt::Serializer.wrap())?;
}
(1, Inner::Function(_, ref mut var)) => {
*var = value.serialize(ser::kind::opt::Serializer.wrap())?.map(Box::new);
}
(index, inner) => {
let variant = match inner {
Inner::Set(..) => "Set",
Inner::Array(..) => "Array",
Inner::Function(..) => "Function",
};
return Err(Error::custom(format!("unexpected `Kind::{variant}` index `{index}`")));
}
}
@ -128,17 +142,10 @@ impl serde::ser::SerializeTupleVariant for SerializeKindTuple {
}
fn end(self) -> Result<Self::Ok, Self::Error> {
let variant = self.variant;
let kind = match self.kind {
Some(kind) => kind,
_ => {
return Err(Error::custom("`Kind::{variant}` missing required value(s)"));
}
};
match variant {
"Set" => Ok(Kind::Set(Box::new(kind), self.num)),
"Array" => Ok(Kind::Array(Box::new(kind), self.num)),
_ => Err(Error::custom("unknown tuple variant `Kind::{variant}`")),
match self.inner {
Inner::Set(one, two) => Ok(Kind::Set(one, two)),
Inner::Array(one, two) => Ok(Kind::Array(one, two)),
Inner::Function(one, two) => Ok(Kind::Function(one, two)),
}
}
}
@ -238,6 +245,13 @@ mod tests {
assert_eq!(kind, serialized);
}
#[test]
fn function() {
let kind = Kind::Function(Default::default(), Default::default());
let serialized = kind.serialize(Serializer.wrap()).unwrap();
assert_eq!(kind, serialized);
}
#[test]
fn record() {
let kind = Kind::Record(Default::default());

View file

@ -5,6 +5,8 @@ use ser::Serializer as _;
use serde::ser::Impossible;
use serde::ser::Serialize;
pub mod opt;
#[non_exhaustive]
pub struct Serializer;

View file

@ -0,0 +1,56 @@
use crate::err::Error;
use crate::sql::value::serde::ser;
use crate::sql::Kind;
use serde::ser::Impossible;
use serde::ser::Serialize;
#[non_exhaustive]
pub struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Option<Vec<Kind>>;
type Error = Error;
type SerializeSeq = Impossible<Option<Vec<Kind>>, Error>;
type SerializeTuple = Impossible<Option<Vec<Kind>>, Error>;
type SerializeTupleStruct = Impossible<Option<Vec<Kind>>, Error>;
type SerializeTupleVariant = Impossible<Option<Vec<Kind>>, Error>;
type SerializeMap = Impossible<Option<Vec<Kind>>, Error>;
type SerializeStruct = Impossible<Option<Vec<Kind>>, Error>;
type SerializeStructVariant = Impossible<Option<Vec<Kind>>, Error>;
const EXPECTED: &'static str = "an `Option<Vec<Kind>>`";
#[inline]
fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
Ok(None)
}
#[inline]
fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>
where
T: ?Sized + Serialize,
{
Ok(Some(value.serialize(super::Serializer.wrap())?))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ser::Serializer as _;
#[test]
fn none() {
let option: Option<Vec<Kind>> = None;
let serialized = option.serialize(Serializer.wrap()).unwrap();
assert_eq!(option, serialized);
}
#[test]
fn some() {
let option = Some(vec![Kind::default()]);
let serialized = option.serialize(Serializer.wrap()).unwrap();
assert_eq!(option, serialized);
}
}

View file

@ -4,6 +4,7 @@ mod base;
mod block;
mod cast;
mod changefeed;
mod closure;
mod cond;
mod constant;
mod data;

View file

@ -104,7 +104,7 @@ impl serde::ser::SerializeStruct for SerializeDefineFunctionStatement {
type IdentKindTuple = (Ident, Kind);
struct IdentKindVecSerializer;
pub struct IdentKindVecSerializer;
impl ser::Serializer for IdentKindVecSerializer {
type Ok = Vec<IdentKindTuple>;
@ -125,7 +125,7 @@ impl ser::Serializer for IdentKindVecSerializer {
}
}
struct SerializeIdentKindVec(Vec<IdentKindTuple>);
pub struct SerializeIdentKindVec(Vec<IdentKindTuple>);
impl serde::ser::SerializeSeq for SerializeIdentKindVec {
type Ok = Vec<IdentKindTuple>;

View file

@ -3,7 +3,7 @@ mod analyzer;
mod database;
mod event;
mod field;
mod function;
pub mod function;
mod index;
mod namespace;
mod param;

View file

@ -37,6 +37,8 @@ use serde::ser::SerializeSeq as _;
use std::fmt::Display;
use vec::SerializeValueVec;
use super::closure::SerializeClosure;
/// Convert a `T` into `surrealdb::sql::Value` which is an enum that can represent any valid SQL data.
pub fn to_value<T>(value: T) -> Result<Value, Error>
where
@ -344,6 +346,7 @@ impl ser::Serializer for Serializer {
sql::thing::TOKEN => SerializeStruct::Thing(Default::default()),
sql::edges::TOKEN => SerializeStruct::Edges(Default::default()),
sql::range::TOKEN => SerializeStruct::Range(Default::default()),
sql::closure::TOKEN => SerializeStruct::Closure(Default::default()),
_ => SerializeStruct::Unknown(Default::default()),
})
}
@ -519,6 +522,7 @@ pub(super) enum SerializeStruct {
Thing(SerializeThing),
Edges(SerializeEdges),
Range(SerializeRange),
Closure(SerializeClosure),
Unknown(SerializeValueMap),
}
@ -534,6 +538,7 @@ impl serde::ser::SerializeStruct for SerializeStruct {
Self::Thing(thing) => thing.serialize_field(key, value),
Self::Edges(edges) => edges.serialize_field(key, value),
Self::Range(range) => range.serialize_field(key, value),
Self::Closure(closure) => closure.serialize_field(key, value),
Self::Unknown(map) => map.serialize_entry(key, value),
}
}
@ -543,6 +548,7 @@ impl serde::ser::SerializeStruct for SerializeStruct {
Self::Thing(thing) => Ok(Value::Thing(thing.end()?)),
Self::Edges(edges) => Ok(Value::Edges(Box::new(edges.end()?))),
Self::Range(range) => Ok(Value::Range(Box::new(range.end()?))),
Self::Closure(closure) => Ok(Value::Closure(Box::new(closure.end()?))),
Self::Unknown(map) => Ok(Value::Object(Object(map.end()?))),
}
}

View file

@ -6,6 +6,7 @@ use crate::doc::CursorDoc;
use crate::err::Error;
use crate::fnc::util::string::fuzzy::Fuzzy;
use crate::sql::statements::info::InfoStructure;
use crate::sql::Closure;
use crate::sql::{
array::Uniq,
fmt::{Fmt, Pretty},
@ -120,6 +121,7 @@ pub enum Value {
Expression(Box<Expression>),
Query(Query),
Model(Box<Model>),
Closure(Box<Closure>),
// Add new variants here
}
@ -144,6 +146,12 @@ impl From<Uuid> for Value {
}
}
impl From<Closure> for Value {
fn from(v: Closure) -> Self {
Value::Closure(Box::new(v))
}
}
impl From<Param> for Value {
fn from(v: Param) -> Self {
Value::Param(v)
@ -1179,6 +1187,7 @@ impl Value {
Self::Strand(_) => "string",
Self::Duration(_) => "duration",
Self::Datetime(_) => "datetime",
Self::Closure(_) => "function",
Self::Number(Number::Int(_)) => "int",
Self::Number(Number::Float(_)) => "float",
Self::Number(Number::Decimal(_)) => "decimal",
@ -1216,6 +1225,7 @@ impl Value {
Kind::Point => self.coerce_to_point().map(Value::from),
Kind::Bytes => self.coerce_to_bytes().map(Value::from),
Kind::Uuid => self.coerce_to_uuid().map(Value::from),
Kind::Function(_, _) => self.coerce_to_function().map(Value::from),
Kind::Set(t, l) => match l {
Some(l) => self.coerce_to_set_type_len(t, l).map(Value::from),
None => self.coerce_to_set_type(t).map(Value::from),
@ -1531,6 +1541,19 @@ impl Value {
}
}
/// Try to coerce this value to a `Closure`
pub(crate) fn coerce_to_function(self) -> Result<Closure, Error> {
match self {
// Closures are allowed
Value::Closure(v) => Ok(*v),
// Anything else raises an error
_ => Err(Error::CoerceTo {
from: self,
into: "function".into(),
}),
}
}
/// Try to coerce this value to a `Datetime`
pub(crate) fn coerce_to_datetime(self) -> Result<Datetime, Error> {
match self {
@ -1771,6 +1794,7 @@ impl Value {
Kind::Point => self.convert_to_point().map(Value::from),
Kind::Bytes => self.convert_to_bytes().map(Value::from),
Kind::Uuid => self.convert_to_uuid().map(Value::from),
Kind::Function(_, _) => self.convert_to_function().map(Value::from),
Kind::Set(t, l) => match l {
Some(l) => self.convert_to_set_type_len(t, l).map(Value::from),
None => self.convert_to_set_type(t).map(Value::from),
@ -2066,6 +2090,19 @@ impl Value {
}
}
/// Try to convert this value to a `Closure`
pub(crate) fn convert_to_function(self) -> Result<Closure, Error> {
match self {
// Closures are allowed
Value::Closure(v) => Ok(*v),
// Anything else converts to a closure with self as the body
_ => Err(Error::ConvertTo {
from: self,
into: "function".into(),
}),
}
}
/// Try to convert this value to a `Datetime`
pub(crate) fn convert_to_datetime(self) -> Result<Datetime, Error> {
match self {
@ -2612,6 +2649,7 @@ impl fmt::Display for Value {
Value::Table(v) => write!(f, "{v}"),
Value::Thing(v) => write!(f, "{v}"),
Value::Uuid(v) => write!(f, "{v}"),
Value::Closure(v) => write!(f, "{v}"),
}
}
}

View file

@ -23,6 +23,25 @@ impl Parser<'_> {
/// Parse an inner kind, a kind without enclosing `<` `>`.
pub async fn parse_inner_kind(&mut self, ctx: &mut Stk) -> ParseResult<Kind> {
match self.parse_inner_single_kind(ctx).await? {
Kind::Any => Ok(Kind::Any),
Kind::Option(k) => Ok(Kind::Option(k)),
first => {
if self.peek_kind() == t!("|") {
let mut kind = vec![first];
while self.eat(t!("|")) {
kind.push(ctx.run(|ctx| self.parse_concrete_kind(ctx)).await?);
}
Ok(Kind::Either(kind))
} else {
Ok(first)
}
}
}
}
/// Parse a single inner kind, a kind without enclosing `<` `>`.
pub async fn parse_inner_single_kind(&mut self, ctx: &mut Stk) -> ParseResult<Kind> {
match self.peek_kind() {
t!("ANY") => {
self.pop_peek();
@ -43,18 +62,7 @@ impl Parser<'_> {
self.expect_closing_delimiter(t!(">"), delim)?;
Ok(Kind::Option(Box::new(first)))
}
_ => {
let first = ctx.run(|ctx| self.parse_concrete_kind(ctx)).await?;
if self.peek_kind() == t!("|") {
let mut kind = vec![first];
while self.eat(t!("|")) {
kind.push(ctx.run(|ctx| self.parse_concrete_kind(ctx)).await?);
}
Ok(Kind::Either(kind))
} else {
Ok(first)
}
}
_ => ctx.run(|ctx| self.parse_concrete_kind(ctx)).await,
}
}
@ -74,6 +82,7 @@ impl Parser<'_> {
t!("POINT") => Ok(Kind::Point),
t!("STRING") => Ok(Kind::String),
t!("UUID") => Ok(Kind::Uuid),
t!("FUNCTION") => Ok(Kind::Function(Default::default(), Default::default())),
t!("RECORD") => {
let span = self.peek().span;
if self.eat(t!("<")) {

View file

@ -5,8 +5,8 @@ use super::{ParseResult, Parser};
use crate::{
enter_object_recursion, enter_query_recursion,
sql::{
Array, Dir, Function, Geometry, Ident, Idiom, Mock, Number, Part, Script, Strand, Subquery,
Table, Value,
Array, Closure, Dir, Function, Geometry, Ident, Idiom, Kind, Mock, Number, Param, Part,
Script, Strand, Subquery, Table, Value,
},
syn::{
parser::{
@ -42,13 +42,14 @@ impl Parser<'_> {
Ok(Value::Uuid(uuid))
}
t!("$param") => {
let param = self.next_token_value()?;
Ok(Value::Param(param))
let value = Value::Param(self.next_token_value()?);
Ok(self.try_parse_inline(ctx, &value).await?.unwrap_or(value))
}
t!("FUNCTION") => {
self.pop_peek();
let func = self.parse_script(ctx).await?;
Ok(Value::Function(Box::new(func)))
let value = Value::Function(Box::new(func));
Ok(self.try_parse_inline(ctx, &value).await?.unwrap_or(value))
}
t!("IF") => {
let stmt = ctx.run(|ctx| self.parse_if_stmt(ctx)).await?;
@ -56,9 +57,11 @@ impl Parser<'_> {
}
t!("(") => {
let token = self.pop_peek();
self.parse_inner_subquery(ctx, Some(token.span))
let value = self
.parse_inner_subquery(ctx, Some(token.span))
.await
.map(|x| Value::Subquery(Box::new(x)))
.map(|x| Value::Subquery(Box::new(x)))?;
Ok(self.try_parse_inline(ctx, &value).await?.unwrap_or(value))
}
t!("<") => {
self.pop_peek();
@ -70,8 +73,9 @@ impl Parser<'_> {
}
t!("|") => {
let start = self.pop_peek().span;
self.parse_mock(start).map(Value::Mock)
self.parse_closure_or_mock(ctx, start).await
}
t!("||") => self.parse_closure_after_args(ctx, Vec::new()).await,
t!("/") => self.next_token_value().map(Value::Regex),
t!("RETURN")
| t!("SELECT")
@ -85,8 +89,15 @@ impl Parser<'_> {
| t!("REBUILD") => {
self.parse_inner_subquery(ctx, None).await.map(|x| Value::Subquery(Box::new(x)))
}
t!("fn") => self.parse_custom_function(ctx).await.map(|x| Value::Function(Box::new(x))),
t!("ml") => self.parse_model(ctx).await.map(|x| Value::Model(Box::new(x))),
t!("fn") => {
let value =
self.parse_custom_function(ctx).await.map(|x| Value::Function(Box::new(x)))?;
Ok(self.try_parse_inline(ctx, &value).await?.unwrap_or(value))
}
t!("ml") => {
let value = self.parse_model(ctx).await.map(|x| Value::Model(Box::new(x)))?;
Ok(self.try_parse_inline(ctx, &value).await?.unwrap_or(value))
}
x => {
if !self.peek_can_start_ident() {
unexpected!(self, x, "a value")
@ -117,6 +128,36 @@ impl Parser<'_> {
}
}
pub async fn try_parse_inline(
&mut self,
ctx: &mut Stk,
subject: &Value,
) -> ParseResult<Option<Value>> {
if self.eat(t!("(")) {
let start = self.last_span();
let mut args = Vec::new();
loop {
if self.eat(t!(")")) {
break;
}
let arg = ctx.run(|ctx| self.parse_value_field(ctx)).await?;
args.push(arg);
if !self.eat(t!(",")) {
self.expect_closing_delimiter(t!(")"), start)?;
break;
}
}
let value = Value::Function(Box::new(Function::Anonymous(subject.clone(), args)));
let value = ctx.run(|ctx| self.try_parse_inline(ctx, &value)).await?.unwrap_or(value);
Ok(Some(value))
} else {
Ok(None)
}
}
pub fn parse_number_like_prime(&mut self) -> ParseResult<Value> {
let token = self.glue_numeric()?;
match token.kind {
@ -190,8 +231,8 @@ impl Parser<'_> {
Value::Number(Number::Float(f64::NAN))
}
t!("$param") => {
let param = self.next_token_value()?;
Value::Param(param)
let value = Value::Param(self.next_token_value()?);
self.try_parse_inline(ctx, &value).await?.unwrap_or(value)
}
t!("FUNCTION") => {
self.pop_peek();
@ -219,11 +260,16 @@ impl Parser<'_> {
}
t!("{") => {
self.pop_peek();
self.parse_object_like(ctx, token.span).await?
let value = self.parse_object_like(ctx, token.span).await?;
self.try_parse_inline(ctx, &value).await?.unwrap_or(value)
}
t!("|") => {
self.pop_peek();
self.parse_mock(token.span).map(Value::Mock)?
self.parse_closure_or_mock(ctx, token.span).await?
}
t!("||") => {
self.pop_peek();
ctx.run(|ctx| self.parse_closure_after_args(ctx, Vec::new())).await?
}
t!("IF") => {
enter_query_recursion!(this = self => {
@ -234,7 +280,8 @@ impl Parser<'_> {
}
t!("(") => {
self.pop_peek();
self.parse_inner_subquery_or_coordinate(ctx, token.span).await?
let value = self.parse_inner_subquery_or_coordinate(ctx, token.span).await?;
self.try_parse_inline(ctx, &value).await?.unwrap_or(value)
}
t!("/") => self.next_token_value().map(Value::Regex)?,
t!("RETURN")
@ -284,13 +331,14 @@ impl Parser<'_> {
// Parse the rest of the idiom if it is being continued.
if Self::continues_idiom(self.peek_kind()) {
match value {
let value = match value {
Value::Idiom(Idiom(x)) => self.parse_remaining_value_idiom(ctx, x).await,
Value::Table(Table(x)) => {
self.parse_remaining_value_idiom(ctx, vec![Part::Field(Ident(x))]).await
}
x => self.parse_remaining_value_idiom(ctx, vec![Part::Start(x)]).await,
}
}?;
Ok(self.try_parse_inline(ctx, &value).await?.unwrap_or(value))
} else {
Ok(value)
}
@ -338,6 +386,72 @@ impl Parser<'_> {
}
}
pub async fn parse_closure_or_mock(
&mut self,
ctx: &mut Stk,
start: Span,
) -> ParseResult<Value> {
match self.peek_kind() {
t!("$param") => ctx.run(|ctx| self.parse_closure(ctx, start)).await,
v => {
println!("{:?}", v);
self.parse_mock(start).map(Value::Mock)
}
}
}
pub async fn parse_closure(&mut self, ctx: &mut Stk, start: Span) -> ParseResult<Value> {
let mut args = Vec::new();
loop {
if self.eat(t!("|")) {
break;
}
let param = self.next_token_value::<Param>()?.0;
let kind = if self.eat(t!(":")) {
if self.eat(t!("<")) {
let delim = self.last_span();
ctx.run(|ctx| self.parse_kind(ctx, delim)).await?
} else {
ctx.run(|ctx| self.parse_inner_single_kind(ctx)).await?
}
} else {
Kind::Any
};
args.push((param, kind));
if !self.eat(t!(",")) {
self.expect_closing_delimiter(t!("|"), start)?;
break;
}
}
self.parse_closure_after_args(ctx, args).await
}
pub async fn parse_closure_after_args(
&mut self,
ctx: &mut Stk,
args: Vec<(Ident, Kind)>,
) -> ParseResult<Value> {
let (returns, body) = if self.eat(t!("->")) {
let returns = Some(ctx.run(|ctx| self.parse_inner_kind(ctx)).await?);
let start = expected!(self, t!("{")).span;
let body = Value::Block(Box::new(ctx.run(|ctx| self.parse_block(ctx, start)).await?));
(returns, body)
} else {
let body = ctx.run(|ctx| self.parse_value(ctx)).await?;
(None, body)
};
Ok(Value::Closure(Box::new(Closure {
args,
returns,
body,
})))
}
pub async fn parse_full_subquery(&mut self, ctx: &mut Stk) -> ParseResult<Subquery> {
let peek = self.peek();
match peek.kind {

101
lib/tests/closure.rs Normal file
View file

@ -0,0 +1,101 @@
mod parse;
use parse::Parse;
mod helpers;
use helpers::new_ds;
use surrealdb::dbs::Session;
use surrealdb::err::Error;
use surrealdb::sql::Value;
#[tokio::test]
async fn closures() -> Result<(), Error> {
let sql = "
LET $double = |$n: number| $n * 2;
$double(2);
LET $pipe = |$arg| $arg;
$pipe('abc');
LET $rettype = |$arg| -> string { $arg };
$rettype('works');
$rettype(123);
LET $argtype = |$arg: string| $arg;
$argtype('works');
$argtype(123);
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 10);
//
let tmp = res.remove(0).result?;
let val = Value::None;
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("4");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::None;
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("'abc'");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::None;
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("'works'");
assert_eq!(tmp, val);
//
match res.remove(0).result {
Err(Error::InvalidFunction { name, message }) if name == "ANONYMOUS" && message == "Expected this closure to return a value of type 'string', but found 'int'" => (),
_ => panic!("Invocation should have failed with error: There was a problem running the ANONYMOUS() function. Expected this closure to return a value of type 'string', but found 'int'")
}
//
let tmp = res.remove(0).result?;
let val = Value::None;
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("'works'");
assert_eq!(tmp, val);
//
match res.remove(0).result {
Err(Error::InvalidArguments { name, message }) if name == "ANONYMOUS" && message == "Expected a value of type 'string' for argument $arg" => (),
_ => panic!("Invocation should have failed with error: There was a problem running the ANONYMOUS() function. Expected a value of type 'string' for argument $arg")
}
//
Ok(())
}
#[tokio::test]
async fn closures_inline() -> Result<(), Error> {
let sql = "
(||1)();
{||2}();
{ a: ||3 }.a();
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 3);
//
let tmp = res.remove(0).result?;
let val = Value::parse("1");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("2");
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::parse("3");
assert_eq!(tmp, val);
//
Ok(())
}