implement typed function returns (#4475)

Co-authored-by: Micha de Vries <micha@devrie.sh>
Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
Raphael Darley 2024-08-13 18:10:13 +01:00 committed by GitHub
parent c3d788ff4a
commit 5d92c7c02c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 77 additions and 14 deletions

View file

@ -228,6 +228,14 @@ pub enum Error {
message: String, message: String,
}, },
/// The wrong quantity or magnitude of arguments was given for the specified function
#[error("There was a problem running the {name} function. Expected this function to return a value of type {check}, but found {value}")]
FunctionCheck {
name: String,
value: String,
check: String,
},
/// The URL is invalid /// The URL is invalid
#[error("The URL `{0}` is invalid")] #[error("The URL `{0}` is invalid")]
InvalidUrl(String), InvalidUrl(String),
@ -1192,3 +1200,24 @@ impl Serialize for Error {
serializer.serialize_str(self.to_string().as_str()) serializer.serialize_str(self.to_string().as_str())
} }
} }
impl Error {
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
match self {
Error::CoerceTo {
from,
into,
} => Error::FunctionCheck {
name: name.into(),
value: from.to_string(),
check: into,
},
fc @ Error::FunctionCheck {
..
} => fc,
_ => Error::Internal(
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
),
}
}
}

View file

@ -56,14 +56,7 @@ impl Closure {
let result = self.body.compute(stk, &ctx, opt, doc).await?; let result = self.body.compute(stk, &ctx, opt, doc).await?;
if let Some(returns) = &self.returns { if let Some(returns) = &self.returns {
if let Ok(result) = result.clone().coerce_to(returns) { result.coerce_to(returns).map_err(|e| e.function_check_from_coerce("ANONYMOUS"))
Ok(result)
} else {
Err(Error::InvalidFunction {
name: "ANONYMOUS".to_string(),
message: format!("Expected this closure to return a value of type '{returns}', but found '{}'", result.kindof()),
})
}
} else { } else {
Ok(result) Ok(result)
} }

View file

@ -213,7 +213,7 @@ impl Function {
} }
Self::Anonymous(v, x) => { Self::Anonymous(v, x) => {
let val = match v { let val = match v {
Value::Closure(p) => &Value::Closure(p.to_owned()), c @ Value::Closure(_) => c,
Value::Param(p) => ctx.value(p).unwrap_or(&Value::None), Value::Param(p) => ctx.value(p).unwrap_or(&Value::None),
Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => { Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
&stk.run(|stk| v.compute(stk, ctx, opt, doc)).await? &stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
@ -297,17 +297,25 @@ impl Function {
}) })
.await?; .await?;
// Duplicate context // Duplicate context
let mut ctx = Context::new(ctx); let mut ctx = Context::new_isolated(ctx);
// Process the function arguments // Process the function arguments
for (val, (name, kind)) in a.into_iter().zip(&val.args) { for (val, (name, kind)) in a.into_iter().zip(&val.args) {
ctx.add_value(name.to_raw(), val.coerce_to(kind)?); ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
} }
// Run the custom function // Run the custom function
match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await { let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
Err(Error::Return { Err(Error::Return {
value, value,
}) => Ok(value), }) => Ok(value),
res => res, res => res,
}?;
if let Some(ref returns) = val.returns {
result
.coerce_to(returns)
.map_err(|e| e.function_check_from_coerce(val.name.to_string()))
} else {
Ok(result)
} }
} }
#[allow(unused_variables)] #[allow(unused_variables)]

View file

@ -11,7 +11,7 @@ use revision::revisioned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::{self, Display, Write}; use std::fmt::{self, Display, Write};
#[revisioned(revision = 3)] #[revisioned(revision = 4)]
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive] #[non_exhaustive]
@ -25,6 +25,8 @@ pub struct DefineFunctionStatement {
pub if_not_exists: bool, pub if_not_exists: bool,
#[revision(start = 3)] #[revision(start = 3)]
pub overwrite: bool, pub overwrite: bool,
#[revision(start = 4)]
pub returns: Option<Kind>,
} }
impl DefineFunctionStatement { impl DefineFunctionStatement {

View file

@ -150,6 +150,11 @@ impl Parser<'_> {
break; break;
} }
} }
let returns = if self.eat(t!("->")) {
Some(ctx.run(|ctx| self.parse_inner_kind(ctx)).await?)
} else {
None
};
let next = expected!(self, t!("{")).span; let next = expected!(self, t!("{")).span;
let block = self.parse_block(ctx, next).await?; let block = self.parse_block(ctx, next).await?;
@ -160,6 +165,7 @@ impl Parser<'_> {
block, block,
if_not_exists, if_not_exists,
overwrite, overwrite,
returns,
..Default::default() ..Default::default()
}; };

View file

@ -211,6 +211,7 @@ fn parse_define_function() {
permissions: Permission::Full, permissions: Permission::Full,
if_not_exists: false, if_not_exists: false,
overwrite: false, overwrite: false,
returns: None,
})) }))
) )
} }

View file

@ -198,6 +198,7 @@ fn statements() -> Vec<Statement> {
permissions: Permission::Full, permissions: Permission::Full,
if_not_exists: false, if_not_exists: false,
overwrite: false, overwrite: false,
returns: None,
})), })),
Statement::Define(DefineStatement::Access(DefineAccessStatement { Statement::Define(DefineStatement::Access(DefineAccessStatement {
name: Ident("a".to_string()), name: Ident("a".to_string()),

View file

@ -53,8 +53,12 @@ async fn closures() -> Result<(), Error> {
assert_eq!(tmp, val); assert_eq!(tmp, val);
// //
match res.remove(0).result { match res.remove(0).result {
Err(Error::InvalidFunction { name, message }) if name == "ANONYMOUS" && message == "Expected this closure to return a value of type 'string', but found 'int'" => (), Err(Error::FunctionCheck {
_ => panic!("Invocation should have failed with error: There was a problem running the ANONYMOUS() function. Expected this closure to return a value of type 'string', but found 'int'") name,
value,
check,
}) if name == "ANONYMOUS" && value == "123" && check == "string" => (),
_ => panic!("Invocation should have failed with error"),
} }
// //
let tmp = res.remove(0).result?; let tmp = res.remove(0).result?;

View file

@ -6252,3 +6252,22 @@ async fn function_idiom_chaining() -> Result<(), Error> {
.expect_val("false")?; .expect_val("false")?;
Ok(()) Ok(())
} }
// tests for custom functions with return types
#[tokio::test]
async fn function_custom_typed_returns() -> Result<(), Error> {
let sql = r#"
DEFINE FUNCTION fn::two() -> int {2};
DEFINE FUNCTION fn::two_bad_type() -> string {2};
RETURN fn::two();
RETURN fn::two_bad_type();
"#;
let error = "There was a problem running the two_bad_type function. Expected this function to return a value of type string, but found 2";
Test::new(sql)
.await?
.expect_val("None")?
.expect_val("None")?
.expect_val("2")?
.expect_error(error)?;
Ok(())
}