implement typed function returns (#4475)
Co-authored-by: Micha de Vries <micha@devrie.sh> Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
parent
c3d788ff4a
commit
5d92c7c02c
9 changed files with 77 additions and 14 deletions
|
@ -228,6 +228,14 @@ pub enum Error {
|
|||
message: String,
|
||||
},
|
||||
|
||||
/// The wrong quantity or magnitude of arguments was given for the specified function
|
||||
#[error("There was a problem running the {name} function. Expected this function to return a value of type {check}, but found {value}")]
|
||||
FunctionCheck {
|
||||
name: String,
|
||||
value: String,
|
||||
check: String,
|
||||
},
|
||||
|
||||
/// The URL is invalid
|
||||
#[error("The URL `{0}` is invalid")]
|
||||
InvalidUrl(String),
|
||||
|
@ -1192,3 +1200,24 @@ impl Serialize for Error {
|
|||
serializer.serialize_str(self.to_string().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
|
||||
match self {
|
||||
Error::CoerceTo {
|
||||
from,
|
||||
into,
|
||||
} => Error::FunctionCheck {
|
||||
name: name.into(),
|
||||
value: from.to_string(),
|
||||
check: into,
|
||||
},
|
||||
fc @ Error::FunctionCheck {
|
||||
..
|
||||
} => fc,
|
||||
_ => Error::Internal(
|
||||
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,14 +56,7 @@ impl Closure {
|
|||
|
||||
let result = self.body.compute(stk, &ctx, opt, doc).await?;
|
||||
if let Some(returns) = &self.returns {
|
||||
if let Ok(result) = result.clone().coerce_to(returns) {
|
||||
Ok(result)
|
||||
} else {
|
||||
Err(Error::InvalidFunction {
|
||||
name: "ANONYMOUS".to_string(),
|
||||
message: format!("Expected this closure to return a value of type '{returns}', but found '{}'", result.kindof()),
|
||||
})
|
||||
}
|
||||
result.coerce_to(returns).map_err(|e| e.function_check_from_coerce("ANONYMOUS"))
|
||||
} else {
|
||||
Ok(result)
|
||||
}
|
||||
|
|
|
@ -213,7 +213,7 @@ impl Function {
|
|||
}
|
||||
Self::Anonymous(v, x) => {
|
||||
let val = match v {
|
||||
Value::Closure(p) => &Value::Closure(p.to_owned()),
|
||||
c @ Value::Closure(_) => c,
|
||||
Value::Param(p) => ctx.value(p).unwrap_or(&Value::None),
|
||||
Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
|
||||
&stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
|
||||
|
@ -297,17 +297,25 @@ impl Function {
|
|||
})
|
||||
.await?;
|
||||
// Duplicate context
|
||||
let mut ctx = Context::new(ctx);
|
||||
let mut ctx = Context::new_isolated(ctx);
|
||||
// Process the function arguments
|
||||
for (val, (name, kind)) in a.into_iter().zip(&val.args) {
|
||||
ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
|
||||
}
|
||||
// Run the custom function
|
||||
match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
|
||||
let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
|
||||
Err(Error::Return {
|
||||
value,
|
||||
}) => Ok(value),
|
||||
res => res,
|
||||
}?;
|
||||
|
||||
if let Some(ref returns) = val.returns {
|
||||
result
|
||||
.coerce_to(returns)
|
||||
.map_err(|e| e.function_check_from_coerce(val.name.to_string()))
|
||||
} else {
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
#[allow(unused_variables)]
|
||||
|
|
|
@ -11,7 +11,7 @@ use revision::revisioned;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{self, Display, Write};
|
||||
|
||||
#[revisioned(revision = 3)]
|
||||
#[revisioned(revision = 4)]
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
||||
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
||||
#[non_exhaustive]
|
||||
|
@ -25,6 +25,8 @@ pub struct DefineFunctionStatement {
|
|||
pub if_not_exists: bool,
|
||||
#[revision(start = 3)]
|
||||
pub overwrite: bool,
|
||||
#[revision(start = 4)]
|
||||
pub returns: Option<Kind>,
|
||||
}
|
||||
|
||||
impl DefineFunctionStatement {
|
||||
|
|
|
@ -150,6 +150,11 @@ impl Parser<'_> {
|
|||
break;
|
||||
}
|
||||
}
|
||||
let returns = if self.eat(t!("->")) {
|
||||
Some(ctx.run(|ctx| self.parse_inner_kind(ctx)).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let next = expected!(self, t!("{")).span;
|
||||
let block = self.parse_block(ctx, next).await?;
|
||||
|
@ -160,6 +165,7 @@ impl Parser<'_> {
|
|||
block,
|
||||
if_not_exists,
|
||||
overwrite,
|
||||
returns,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
|
|
@ -211,6 +211,7 @@ fn parse_define_function() {
|
|||
permissions: Permission::Full,
|
||||
if_not_exists: false,
|
||||
overwrite: false,
|
||||
returns: None,
|
||||
}))
|
||||
)
|
||||
}
|
||||
|
|
|
@ -198,6 +198,7 @@ fn statements() -> Vec<Statement> {
|
|||
permissions: Permission::Full,
|
||||
if_not_exists: false,
|
||||
overwrite: false,
|
||||
returns: None,
|
||||
})),
|
||||
Statement::Define(DefineStatement::Access(DefineAccessStatement {
|
||||
name: Ident("a".to_string()),
|
||||
|
|
|
@ -53,8 +53,12 @@ async fn closures() -> Result<(), Error> {
|
|||
assert_eq!(tmp, val);
|
||||
//
|
||||
match res.remove(0).result {
|
||||
Err(Error::InvalidFunction { name, message }) if name == "ANONYMOUS" && message == "Expected this closure to return a value of type 'string', but found 'int'" => (),
|
||||
_ => panic!("Invocation should have failed with error: There was a problem running the ANONYMOUS() function. Expected this closure to return a value of type 'string', but found 'int'")
|
||||
Err(Error::FunctionCheck {
|
||||
name,
|
||||
value,
|
||||
check,
|
||||
}) if name == "ANONYMOUS" && value == "123" && check == "string" => (),
|
||||
_ => panic!("Invocation should have failed with error"),
|
||||
}
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
|
|
|
@ -6252,3 +6252,22 @@ async fn function_idiom_chaining() -> Result<(), Error> {
|
|||
.expect_val("false")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// tests for custom functions with return types
|
||||
#[tokio::test]
|
||||
async fn function_custom_typed_returns() -> Result<(), Error> {
|
||||
let sql = r#"
|
||||
DEFINE FUNCTION fn::two() -> int {2};
|
||||
DEFINE FUNCTION fn::two_bad_type() -> string {2};
|
||||
RETURN fn::two();
|
||||
RETURN fn::two_bad_type();
|
||||
"#;
|
||||
let error = "There was a problem running the two_bad_type function. Expected this function to return a value of type string, but found 2";
|
||||
Test::new(sql)
|
||||
.await?
|
||||
.expect_val("None")?
|
||||
.expect_val("None")?
|
||||
.expect_val("2")?
|
||||
.expect_error(error)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue