From 5d92c7c02cf1acf6b3aaba142ba4c3aa8ccc26f0 Mon Sep 17 00:00:00 2001 From: Raphael Darley Date: Tue, 13 Aug 2024 18:10:13 +0100 Subject: [PATCH] implement typed function returns (#4475) Co-authored-by: Micha de Vries Co-authored-by: Tobie Morgan Hitchcock --- core/src/err/mod.rs | 29 ++++++++++++++++++++++ core/src/sql/closure.rs | 9 +------ core/src/sql/function.rs | 14 ++++++++--- core/src/sql/statements/define/function.rs | 4 ++- core/src/syn/parser/stmt/define.rs | 6 +++++ core/src/syn/parser/test/stmt.rs | 1 + core/src/syn/parser/test/streaming.rs | 1 + lib/tests/closure.rs | 8 ++++-- lib/tests/function.rs | 19 ++++++++++++++ 9 files changed, 77 insertions(+), 14 deletions(-) diff --git a/core/src/err/mod.rs b/core/src/err/mod.rs index d98c4df8..f8d8b186 100644 --- a/core/src/err/mod.rs +++ b/core/src/err/mod.rs @@ -228,6 +228,14 @@ pub enum Error { message: String, }, + /// The wrong quantity or magnitude of arguments was given for the specified function + #[error("There was a problem running the {name} function. Expected this function to return a value of type {check}, but found {value}")] + FunctionCheck { + name: String, + value: String, + check: String, + }, + /// The URL is invalid #[error("The URL `{0}` is invalid")] InvalidUrl(String), @@ -1192,3 +1200,24 @@ impl Serialize for Error { serializer.serialize_str(self.to_string().as_str()) } } + +impl Error { + pub fn function_check_from_coerce(self, name: impl Into) -> Error { + match self { + Error::CoerceTo { + from, + into, + } => Error::FunctionCheck { + name: name.into(), + value: from.to_string(), + check: into, + }, + fc @ Error::FunctionCheck { + .. + } => fc, + _ => Error::Internal( + "function_check_from_coerce called on Error that wasn't CoerceTo".to_string(), + ), + } + } +} diff --git a/core/src/sql/closure.rs b/core/src/sql/closure.rs index 14ee9b69..a9ae55ba 100644 --- a/core/src/sql/closure.rs +++ b/core/src/sql/closure.rs @@ -56,14 +56,7 @@ impl Closure { let result = self.body.compute(stk, &ctx, opt, doc).await?; if let Some(returns) = &self.returns { - if let Ok(result) = result.clone().coerce_to(returns) { - Ok(result) - } else { - Err(Error::InvalidFunction { - name: "ANONYMOUS".to_string(), - message: format!("Expected this closure to return a value of type '{returns}', but found '{}'", result.kindof()), - }) - } + result.coerce_to(returns).map_err(|e| e.function_check_from_coerce("ANONYMOUS")) } else { Ok(result) } diff --git a/core/src/sql/function.rs b/core/src/sql/function.rs index efb2e848..a483ab43 100644 --- a/core/src/sql/function.rs +++ b/core/src/sql/function.rs @@ -213,7 +213,7 @@ impl Function { } Self::Anonymous(v, x) => { let val = match v { - Value::Closure(p) => &Value::Closure(p.to_owned()), + c @ Value::Closure(_) => c, Value::Param(p) => ctx.value(p).unwrap_or(&Value::None), Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => { &stk.run(|stk| v.compute(stk, ctx, opt, doc)).await? @@ -297,17 +297,25 @@ impl Function { }) .await?; // Duplicate context - let mut ctx = Context::new(ctx); + let mut ctx = Context::new_isolated(ctx); // Process the function arguments for (val, (name, kind)) in a.into_iter().zip(&val.args) { ctx.add_value(name.to_raw(), val.coerce_to(kind)?); } // Run the custom function - match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await { + let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await { Err(Error::Return { value, }) => Ok(value), res => res, + }?; + + if let Some(ref returns) = val.returns { + result + .coerce_to(returns) + .map_err(|e| e.function_check_from_coerce(val.name.to_string())) + } else { + Ok(result) } } #[allow(unused_variables)] diff --git a/core/src/sql/statements/define/function.rs b/core/src/sql/statements/define/function.rs index c87f1029..686d8438 100644 --- a/core/src/sql/statements/define/function.rs +++ b/core/src/sql/statements/define/function.rs @@ -11,7 +11,7 @@ use revision::revisioned; use serde::{Deserialize, Serialize}; use std::fmt::{self, Display, Write}; -#[revisioned(revision = 3)] +#[revisioned(revision = 4)] #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[non_exhaustive] @@ -25,6 +25,8 @@ pub struct DefineFunctionStatement { pub if_not_exists: bool, #[revision(start = 3)] pub overwrite: bool, + #[revision(start = 4)] + pub returns: Option, } impl DefineFunctionStatement { diff --git a/core/src/syn/parser/stmt/define.rs b/core/src/syn/parser/stmt/define.rs index 7a2c3228..5df637f8 100644 --- a/core/src/syn/parser/stmt/define.rs +++ b/core/src/syn/parser/stmt/define.rs @@ -150,6 +150,11 @@ impl Parser<'_> { break; } } + let returns = if self.eat(t!("->")) { + Some(ctx.run(|ctx| self.parse_inner_kind(ctx)).await?) + } else { + None + }; let next = expected!(self, t!("{")).span; let block = self.parse_block(ctx, next).await?; @@ -160,6 +165,7 @@ impl Parser<'_> { block, if_not_exists, overwrite, + returns, ..Default::default() }; diff --git a/core/src/syn/parser/test/stmt.rs b/core/src/syn/parser/test/stmt.rs index 0e3d8320..b86348b1 100644 --- a/core/src/syn/parser/test/stmt.rs +++ b/core/src/syn/parser/test/stmt.rs @@ -211,6 +211,7 @@ fn parse_define_function() { permissions: Permission::Full, if_not_exists: false, overwrite: false, + returns: None, })) ) } diff --git a/core/src/syn/parser/test/streaming.rs b/core/src/syn/parser/test/streaming.rs index de9a8291..aa855749 100644 --- a/core/src/syn/parser/test/streaming.rs +++ b/core/src/syn/parser/test/streaming.rs @@ -198,6 +198,7 @@ fn statements() -> Vec { permissions: Permission::Full, if_not_exists: false, overwrite: false, + returns: None, })), Statement::Define(DefineStatement::Access(DefineAccessStatement { name: Ident("a".to_string()), diff --git a/lib/tests/closure.rs b/lib/tests/closure.rs index 78de1bba..114ba3fe 100644 --- a/lib/tests/closure.rs +++ b/lib/tests/closure.rs @@ -53,8 +53,12 @@ async fn closures() -> Result<(), Error> { assert_eq!(tmp, val); // match res.remove(0).result { - Err(Error::InvalidFunction { name, message }) if name == "ANONYMOUS" && message == "Expected this closure to return a value of type 'string', but found 'int'" => (), - _ => panic!("Invocation should have failed with error: There was a problem running the ANONYMOUS() function. Expected this closure to return a value of type 'string', but found 'int'") + Err(Error::FunctionCheck { + name, + value, + check, + }) if name == "ANONYMOUS" && value == "123" && check == "string" => (), + _ => panic!("Invocation should have failed with error"), } // let tmp = res.remove(0).result?; diff --git a/lib/tests/function.rs b/lib/tests/function.rs index aaf5fa45..41590d89 100644 --- a/lib/tests/function.rs +++ b/lib/tests/function.rs @@ -6252,3 +6252,22 @@ async fn function_idiom_chaining() -> Result<(), Error> { .expect_val("false")?; Ok(()) } + +// tests for custom functions with return types +#[tokio::test] +async fn function_custom_typed_returns() -> Result<(), Error> { + let sql = r#" + DEFINE FUNCTION fn::two() -> int {2}; + DEFINE FUNCTION fn::two_bad_type() -> string {2}; + RETURN fn::two(); + RETURN fn::two_bad_type(); + "#; + let error = "There was a problem running the two_bad_type function. Expected this function to return a value of type string, but found 2"; + Test::new(sql) + .await? + .expect_val("None")? + .expect_val("None")? + .expect_val("2")? + .expect_error(error)?; + Ok(()) +}