implement typed function returns (#4475)

Co-authored-by: Micha de Vries <micha@devrie.sh>
Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
Raphael Darley 2024-08-13 18:10:13 +01:00 committed by GitHub
parent c3d788ff4a
commit 5d92c7c02c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 77 additions and 14 deletions

View file

@ -228,6 +228,14 @@ pub enum Error {
message: String,
},
/// The wrong quantity or magnitude of arguments was given for the specified function
#[error("There was a problem running the {name} function. Expected this function to return a value of type {check}, but found {value}")]
FunctionCheck {
name: String,
value: String,
check: String,
},
/// The URL is invalid
#[error("The URL `{0}` is invalid")]
InvalidUrl(String),
@ -1192,3 +1200,24 @@ impl Serialize for Error {
serializer.serialize_str(self.to_string().as_str())
}
}
impl Error {
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
match self {
Error::CoerceTo {
from,
into,
} => Error::FunctionCheck {
name: name.into(),
value: from.to_string(),
check: into,
},
fc @ Error::FunctionCheck {
..
} => fc,
_ => Error::Internal(
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
),
}
}
}

View file

@ -56,14 +56,7 @@ impl Closure {
let result = self.body.compute(stk, &ctx, opt, doc).await?;
if let Some(returns) = &self.returns {
if let Ok(result) = result.clone().coerce_to(returns) {
Ok(result)
} else {
Err(Error::InvalidFunction {
name: "ANONYMOUS".to_string(),
message: format!("Expected this closure to return a value of type '{returns}', but found '{}'", result.kindof()),
})
}
result.coerce_to(returns).map_err(|e| e.function_check_from_coerce("ANONYMOUS"))
} else {
Ok(result)
}

View file

@ -213,7 +213,7 @@ impl Function {
}
Self::Anonymous(v, x) => {
let val = match v {
Value::Closure(p) => &Value::Closure(p.to_owned()),
c @ Value::Closure(_) => c,
Value::Param(p) => ctx.value(p).unwrap_or(&Value::None),
Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
&stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
@ -297,17 +297,25 @@ impl Function {
})
.await?;
// Duplicate context
let mut ctx = Context::new(ctx);
let mut ctx = Context::new_isolated(ctx);
// Process the function arguments
for (val, (name, kind)) in a.into_iter().zip(&val.args) {
ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
}
// Run the custom function
match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
Err(Error::Return {
value,
}) => Ok(value),
res => res,
}?;
if let Some(ref returns) = val.returns {
result
.coerce_to(returns)
.map_err(|e| e.function_check_from_coerce(val.name.to_string()))
} else {
Ok(result)
}
}
#[allow(unused_variables)]

View file

@ -11,7 +11,7 @@ use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Display, Write};
#[revisioned(revision = 3)]
#[revisioned(revision = 4)]
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
@ -25,6 +25,8 @@ pub struct DefineFunctionStatement {
pub if_not_exists: bool,
#[revision(start = 3)]
pub overwrite: bool,
#[revision(start = 4)]
pub returns: Option<Kind>,
}
impl DefineFunctionStatement {

View file

@ -150,6 +150,11 @@ impl Parser<'_> {
break;
}
}
let returns = if self.eat(t!("->")) {
Some(ctx.run(|ctx| self.parse_inner_kind(ctx)).await?)
} else {
None
};
let next = expected!(self, t!("{")).span;
let block = self.parse_block(ctx, next).await?;
@ -160,6 +165,7 @@ impl Parser<'_> {
block,
if_not_exists,
overwrite,
returns,
..Default::default()
};

View file

@ -211,6 +211,7 @@ fn parse_define_function() {
permissions: Permission::Full,
if_not_exists: false,
overwrite: false,
returns: None,
}))
)
}

View file

@ -198,6 +198,7 @@ fn statements() -> Vec<Statement> {
permissions: Permission::Full,
if_not_exists: false,
overwrite: false,
returns: None,
})),
Statement::Define(DefineStatement::Access(DefineAccessStatement {
name: Ident("a".to_string()),

View file

@ -53,8 +53,12 @@ async fn closures() -> Result<(), Error> {
assert_eq!(tmp, val);
//
match res.remove(0).result {
Err(Error::InvalidFunction { name, message }) if name == "ANONYMOUS" && message == "Expected this closure to return a value of type 'string', but found 'int'" => (),
_ => panic!("Invocation should have failed with error: There was a problem running the ANONYMOUS() function. Expected this closure to return a value of type 'string', but found 'int'")
Err(Error::FunctionCheck {
name,
value,
check,
}) if name == "ANONYMOUS" && value == "123" && check == "string" => (),
_ => panic!("Invocation should have failed with error"),
}
//
let tmp = res.remove(0).result?;

View file

@ -6252,3 +6252,22 @@ async fn function_idiom_chaining() -> Result<(), Error> {
.expect_val("false")?;
Ok(())
}
// tests for custom functions with return types
#[tokio::test]
async fn function_custom_typed_returns() -> Result<(), Error> {
let sql = r#"
DEFINE FUNCTION fn::two() -> int {2};
DEFINE FUNCTION fn::two_bad_type() -> string {2};
RETURN fn::two();
RETURN fn::two_bad_type();
"#;
let error = "There was a problem running the two_bad_type function. Expected this function to return a value of type string, but found 2";
Test::new(sql)
.await?
.expect_val("None")?
.expect_val("None")?
.expect_val("2")?
.expect_error(error)?;
Ok(())
}