implement typed function returns (#4475)
Co-authored-by: Micha de Vries <micha@devrie.sh> Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
parent
c3d788ff4a
commit
5d92c7c02c
9 changed files with 77 additions and 14 deletions
|
@ -228,6 +228,14 @@ pub enum Error {
|
||||||
message: String,
|
message: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// The wrong quantity or magnitude of arguments was given for the specified function
|
||||||
|
#[error("There was a problem running the {name} function. Expected this function to return a value of type {check}, but found {value}")]
|
||||||
|
FunctionCheck {
|
||||||
|
name: String,
|
||||||
|
value: String,
|
||||||
|
check: String,
|
||||||
|
},
|
||||||
|
|
||||||
/// The URL is invalid
|
/// The URL is invalid
|
||||||
#[error("The URL `{0}` is invalid")]
|
#[error("The URL `{0}` is invalid")]
|
||||||
InvalidUrl(String),
|
InvalidUrl(String),
|
||||||
|
@ -1192,3 +1200,24 @@ impl Serialize for Error {
|
||||||
serializer.serialize_str(self.to_string().as_str())
|
serializer.serialize_str(self.to_string().as_str())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
|
||||||
|
match self {
|
||||||
|
Error::CoerceTo {
|
||||||
|
from,
|
||||||
|
into,
|
||||||
|
} => Error::FunctionCheck {
|
||||||
|
name: name.into(),
|
||||||
|
value: from.to_string(),
|
||||||
|
check: into,
|
||||||
|
},
|
||||||
|
fc @ Error::FunctionCheck {
|
||||||
|
..
|
||||||
|
} => fc,
|
||||||
|
_ => Error::Internal(
|
||||||
|
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -56,14 +56,7 @@ impl Closure {
|
||||||
|
|
||||||
let result = self.body.compute(stk, &ctx, opt, doc).await?;
|
let result = self.body.compute(stk, &ctx, opt, doc).await?;
|
||||||
if let Some(returns) = &self.returns {
|
if let Some(returns) = &self.returns {
|
||||||
if let Ok(result) = result.clone().coerce_to(returns) {
|
result.coerce_to(returns).map_err(|e| e.function_check_from_coerce("ANONYMOUS"))
|
||||||
Ok(result)
|
|
||||||
} else {
|
|
||||||
Err(Error::InvalidFunction {
|
|
||||||
name: "ANONYMOUS".to_string(),
|
|
||||||
message: format!("Expected this closure to return a value of type '{returns}', but found '{}'", result.kindof()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
|
@ -213,7 +213,7 @@ impl Function {
|
||||||
}
|
}
|
||||||
Self::Anonymous(v, x) => {
|
Self::Anonymous(v, x) => {
|
||||||
let val = match v {
|
let val = match v {
|
||||||
Value::Closure(p) => &Value::Closure(p.to_owned()),
|
c @ Value::Closure(_) => c,
|
||||||
Value::Param(p) => ctx.value(p).unwrap_or(&Value::None),
|
Value::Param(p) => ctx.value(p).unwrap_or(&Value::None),
|
||||||
Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
|
Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
|
||||||
&stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
|
&stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
|
||||||
|
@ -297,17 +297,25 @@ impl Function {
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
// Duplicate context
|
// Duplicate context
|
||||||
let mut ctx = Context::new(ctx);
|
let mut ctx = Context::new_isolated(ctx);
|
||||||
// Process the function arguments
|
// Process the function arguments
|
||||||
for (val, (name, kind)) in a.into_iter().zip(&val.args) {
|
for (val, (name, kind)) in a.into_iter().zip(&val.args) {
|
||||||
ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
|
ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
|
||||||
}
|
}
|
||||||
// Run the custom function
|
// Run the custom function
|
||||||
match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
|
let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
|
||||||
Err(Error::Return {
|
Err(Error::Return {
|
||||||
value,
|
value,
|
||||||
}) => Ok(value),
|
}) => Ok(value),
|
||||||
res => res,
|
res => res,
|
||||||
|
}?;
|
||||||
|
|
||||||
|
if let Some(ref returns) = val.returns {
|
||||||
|
result
|
||||||
|
.coerce_to(returns)
|
||||||
|
.map_err(|e| e.function_check_from_coerce(val.name.to_string()))
|
||||||
|
} else {
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
|
|
|
@ -11,7 +11,7 @@ use revision::revisioned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt::{self, Display, Write};
|
use std::fmt::{self, Display, Write};
|
||||||
|
|
||||||
#[revisioned(revision = 3)]
|
#[revisioned(revision = 4)]
|
||||||
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
||||||
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
|
@ -25,6 +25,8 @@ pub struct DefineFunctionStatement {
|
||||||
pub if_not_exists: bool,
|
pub if_not_exists: bool,
|
||||||
#[revision(start = 3)]
|
#[revision(start = 3)]
|
||||||
pub overwrite: bool,
|
pub overwrite: bool,
|
||||||
|
#[revision(start = 4)]
|
||||||
|
pub returns: Option<Kind>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DefineFunctionStatement {
|
impl DefineFunctionStatement {
|
||||||
|
|
|
@ -150,6 +150,11 @@ impl Parser<'_> {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let returns = if self.eat(t!("->")) {
|
||||||
|
Some(ctx.run(|ctx| self.parse_inner_kind(ctx)).await?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let next = expected!(self, t!("{")).span;
|
let next = expected!(self, t!("{")).span;
|
||||||
let block = self.parse_block(ctx, next).await?;
|
let block = self.parse_block(ctx, next).await?;
|
||||||
|
@ -160,6 +165,7 @@ impl Parser<'_> {
|
||||||
block,
|
block,
|
||||||
if_not_exists,
|
if_not_exists,
|
||||||
overwrite,
|
overwrite,
|
||||||
|
returns,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -211,6 +211,7 @@ fn parse_define_function() {
|
||||||
permissions: Permission::Full,
|
permissions: Permission::Full,
|
||||||
if_not_exists: false,
|
if_not_exists: false,
|
||||||
overwrite: false,
|
overwrite: false,
|
||||||
|
returns: None,
|
||||||
}))
|
}))
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,6 +198,7 @@ fn statements() -> Vec<Statement> {
|
||||||
permissions: Permission::Full,
|
permissions: Permission::Full,
|
||||||
if_not_exists: false,
|
if_not_exists: false,
|
||||||
overwrite: false,
|
overwrite: false,
|
||||||
|
returns: None,
|
||||||
})),
|
})),
|
||||||
Statement::Define(DefineStatement::Access(DefineAccessStatement {
|
Statement::Define(DefineStatement::Access(DefineAccessStatement {
|
||||||
name: Ident("a".to_string()),
|
name: Ident("a".to_string()),
|
||||||
|
|
|
@ -53,8 +53,12 @@ async fn closures() -> Result<(), Error> {
|
||||||
assert_eq!(tmp, val);
|
assert_eq!(tmp, val);
|
||||||
//
|
//
|
||||||
match res.remove(0).result {
|
match res.remove(0).result {
|
||||||
Err(Error::InvalidFunction { name, message }) if name == "ANONYMOUS" && message == "Expected this closure to return a value of type 'string', but found 'int'" => (),
|
Err(Error::FunctionCheck {
|
||||||
_ => panic!("Invocation should have failed with error: There was a problem running the ANONYMOUS() function. Expected this closure to return a value of type 'string', but found 'int'")
|
name,
|
||||||
|
value,
|
||||||
|
check,
|
||||||
|
}) if name == "ANONYMOUS" && value == "123" && check == "string" => (),
|
||||||
|
_ => panic!("Invocation should have failed with error"),
|
||||||
}
|
}
|
||||||
//
|
//
|
||||||
let tmp = res.remove(0).result?;
|
let tmp = res.remove(0).result?;
|
||||||
|
|
|
@ -6252,3 +6252,22 @@ async fn function_idiom_chaining() -> Result<(), Error> {
|
||||||
.expect_val("false")?;
|
.expect_val("false")?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tests for custom functions with return types
|
||||||
|
#[tokio::test]
|
||||||
|
async fn function_custom_typed_returns() -> Result<(), Error> {
|
||||||
|
let sql = r#"
|
||||||
|
DEFINE FUNCTION fn::two() -> int {2};
|
||||||
|
DEFINE FUNCTION fn::two_bad_type() -> string {2};
|
||||||
|
RETURN fn::two();
|
||||||
|
RETURN fn::two_bad_type();
|
||||||
|
"#;
|
||||||
|
let error = "There was a problem running the two_bad_type function. Expected this function to return a value of type string, but found 2";
|
||||||
|
Test::new(sql)
|
||||||
|
.await?
|
||||||
|
.expect_val("None")?
|
||||||
|
.expect_val("None")?
|
||||||
|
.expect_val("2")?
|
||||||
|
.expect_error(error)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue