diff --git a/core/src/err/mod.rs b/core/src/err/mod.rs index f8d8b186..6f876faa 100644 --- a/core/src/err/mod.rs +++ b/core/src/err/mod.rs @@ -596,6 +596,14 @@ pub enum Error { check: String, }, + /// The specified value did not conform to the LET type check + #[error("Found {value} for param ${name}, but expected a {check}")] + SetCheck { + value: String, + name: String, + check: String, + }, + /// The specified field did not conform to the field ASSERT clause #[error( "Found changed value for field `{field}`, with record `{thing}`, but field is readonly" @@ -1200,8 +1208,21 @@ impl Serialize for Error { serializer.serialize_str(self.to_string().as_str()) } } - impl Error { + pub fn set_check_from_coerce(self, name: String) -> Error { + match self { + Error::CoerceTo { + from, + into, + } => Error::SetCheck { + name, + value: from.to_string(), + check: into, + }, + e => e, + } + } + pub fn function_check_from_coerce(self, name: impl Into) -> Error { match self { Error::CoerceTo { @@ -1212,12 +1233,7 @@ impl Error { value: from.to_string(), check: into, }, - fc @ Error::FunctionCheck { - .. - } => fc, - _ => Error::Internal( - "function_check_from_coerce called on Error that wasn't CoerceTo".to_string(), - ), + e => e, } } } diff --git a/core/src/sql/query.rs b/core/src/sql/query.rs index a7d48e97..9a60d5aa 100644 --- a/core/src/sql/query.rs +++ b/core/src/sql/query.rs @@ -6,7 +6,7 @@ use revision::revisioned; use serde::{Deserialize, Serialize}; use std::fmt::Write; use std::fmt::{self, Display, Formatter}; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::str; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query"; @@ -43,6 +43,12 @@ impl Deref for Query { } } +impl DerefMut for Query { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 .0 + } +} + impl IntoIterator for Query { type Item = Statement; type IntoIter = std::vec::IntoIter; diff --git a/core/src/sql/statements/set.rs b/core/src/sql/statements/set.rs index 5f100c68..9992a665 100644 --- a/core/src/sql/statements/set.rs +++ b/core/src/sql/statements/set.rs @@ -1,22 +1,24 @@ -use crate::cnf::PROTECTED_PARAM_NAMES; use crate::ctx::Context; use crate::dbs::Options; use crate::doc::CursorDoc; use crate::err::Error; use crate::sql::Value; +use crate::{cnf::PROTECTED_PARAM_NAMES, sql::Kind}; use derive::Store; use reblessive::tree::Stk; use revision::revisioned; use serde::{Deserialize, Serialize}; use std::fmt; -#[revisioned(revision = 1)] +#[revisioned(revision = 2)] #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[non_exhaustive] pub struct SetStatement { pub name: String, pub what: Value, + #[revision(start = 2)] + pub kind: Option, } impl SetStatement { @@ -35,7 +37,15 @@ impl SetStatement { // Check if the variable is a protected variable match PROTECTED_PARAM_NAMES.contains(&self.name.as_str()) { // The variable isn't protected and can be stored - false => self.what.compute(stk, ctx, opt, doc).await, + false => { + let result = self.what.compute(stk, ctx, opt, doc).await?; + match self.kind { + Some(ref kind) => result + .coerce_to(kind) + .map_err(|e| e.set_check_from_coerce(self.name.to_string())), + None => Ok(result), + } + } // The user tried to set a protected variable true => Err(Error::InvalidParam { // Move the parameter name, as we no longer need it diff --git a/core/src/syn/parser/stmt/mod.rs b/core/src/syn/parser/stmt/mod.rs index dd2e51cc..dd78157f 100644 --- a/core/src/syn/parser/stmt/mod.rs +++ b/core/src/syn/parser/stmt/mod.rs @@ -351,6 +351,7 @@ impl Parser<'_> { return Statement::Set(crate::sql::statements::SetStatement { name: x.0 .0, what: r, + kind: None, }); } Statement::Value(Value::Expression(x)) @@ -371,6 +372,7 @@ impl Parser<'_> { return Entry::Set(crate::sql::statements::SetStatement { name: x.0 .0, what: r, + kind: None, }); } Entry::Value(Value::Expression(x)) @@ -652,11 +654,17 @@ impl Parser<'_> { /// Expects `LET` to already be consumed. pub(crate) async fn parse_let_stmt(&mut self, ctx: &mut Stk) -> ParseResult { let name = self.next_token_value::()?.0 .0; + let kind = if self.eat(t!(":")) { + Some(self.parse_inner_kind(ctx).await?) + } else { + None + }; expected!(self, t!("=")); let what = self.parse_value(ctx).await?; Ok(SetStatement { name, what, + kind, }) } diff --git a/core/src/syn/parser/test/stmt.rs b/core/src/syn/parser/test/stmt.rs index b86348b1..76685dbc 100644 --- a/core/src/syn/parser/test/stmt.rs +++ b/core/src/syn/parser/test/stmt.rs @@ -1888,7 +1888,8 @@ fn parse_let() { res, Statement::Set(SetStatement { name: "param".to_owned(), - what: Value::Number(Number::Int(1)) + what: Value::Number(Number::Int(1)), + kind: None, }) ); @@ -1897,7 +1898,8 @@ fn parse_let() { res, Statement::Set(SetStatement { name: "param".to_owned(), - what: Value::Number(Number::Int(1)) + what: Value::Number(Number::Int(1)), + kind: None, }) ); } diff --git a/core/src/syn/parser/test/streaming.rs b/core/src/syn/parser/test/streaming.rs index aa855749..0df465e0 100644 --- a/core/src/syn/parser/test/streaming.rs +++ b/core/src/syn/parser/test/streaming.rs @@ -534,6 +534,7 @@ fn statements() -> Vec { Statement::Set(SetStatement { name: "param".to_owned(), what: Value::Number(Number::Int(1)), + kind: None, }), Statement::Show(ShowStatement { table: Some(Table("foo".to_owned())), diff --git a/lib/src/api/method/mod.rs b/lib/src/api/method/mod.rs index 65754737..3e8ce1ea 100644 --- a/lib/src/api/method/mod.rs +++ b/lib/src/api/method/mod.rs @@ -107,7 +107,7 @@ pub struct Live; /// Responses returned with statistics #[derive(Debug)] -pub struct WithStats(T); +pub struct WithStats(pub T); impl Surreal where diff --git a/lib/tests/set.rs b/lib/tests/set.rs new file mode 100644 index 00000000..e3196ee8 --- /dev/null +++ b/lib/tests/set.rs @@ -0,0 +1,23 @@ +mod helpers; +mod parse; +use helpers::Test; +use surrealdb::err::Error; + +#[tokio::test] +async fn typed_set() -> Result<(), Error> { + let sql = " + LET $foo: int = 42; + RETURN $foo; + LET $bar: int = 'hello'; + RETURN $bar; + "; + let error = "Found 'hello' for param $bar, but expected a int"; + Test::new(sql) + .await? + .expect_val("None")? + .expect_val("42")? + .expect_error(error)? + .expect_val("None")?; + + Ok(()) +} diff --git a/src/cli/sql.rs b/src/cli/sql.rs index 9c1bd544..1b062777 100644 --- a/src/cli/sql.rs +++ b/src/cli/sql.rs @@ -14,7 +14,7 @@ use serde_json::ser::PrettyFormatter; use surrealdb::engine::any::{connect, IntoEndpoint}; use surrealdb::method::{Stats, WithStats}; use surrealdb::opt::{capabilities::Capabilities, Config}; -use surrealdb::sql::{self, Statement, Value}; +use surrealdb::sql::{self, Param, Statement, Value}; use surrealdb::{Notification, Response}; #[derive(Args, Debug)] @@ -185,10 +185,11 @@ pub async fn init( } // Complete the request match sql::parse(&line) { - Ok(query) => { + Ok(mut query) => { let mut namespace = None; let mut database = None; let mut vars = Vec::new(); + let init_length = query.len(); // Capture `use` and `set/let` statements from the query for statement in query.iter() { match statement { @@ -200,12 +201,15 @@ pub async fn init( database = Some(db.clone()); } } - Statement::Set(stmt) => { - vars.push((stmt.name.clone(), stmt.what.clone())); - } + Statement::Set(stmt) => vars.push(stmt.name.clone()), _ => {} } } + + for var in &vars { + query.push(Statement::Value(Value::Param(Param::from(var.as_str())))) + } + // Extract the namespace and database from the current prompt let (prompt_ns, prompt_db) = split_prompt(&prompt); // The namespace should be set before the database can be set @@ -216,17 +220,23 @@ pub async fn init( continue; } // Run the query provided - let result = client.query(query).with_stats().await; + let mut result = client.query(query).with_stats().await; + + if let Ok(WithStats(res)) = &mut result { + for (i, n) in vars.into_iter().enumerate() { + if let Result::::Ok(v) = res.take(init_length + i) { + let _ = client.set(n, v).await; + } + } + } + let result = process(pretty, json, result); let result_is_error = result.is_err(); print(result); if result_is_error { continue; } - // Persist the variables extracted from the query - for (key, value) in vars { - let _ = client.set(key, value).await; - } + // Process the last `use` statements, if any if namespace.is_some() || database.is_some() { // Use the namespace provided in the query if any, otherwise use the one in the prompt diff --git a/tests/cli_integration.rs b/tests/cli_integration.rs index e559926e..9d68ec63 100644 --- a/tests/cli_integration.rs +++ b/tests/cli_integration.rs @@ -1450,6 +1450,21 @@ mod cli_integration { server.finish().unwrap(); } } + + #[test(tokio::test)] + async fn double_create() { + info!("* check only one output created"); + { + let args = "sql --conn memory --ns test --db test --pretty --hide-welcome".to_string(); + let output = common::run(&args) + .input("let $a = create foo;\n") + .input("select * from foo;\n") + .output() + .unwrap(); + let output = remove_debug_info(output); + assert_eq!(output.matches("foo:").count(), 1); + } + } } fn remove_debug_info(output: String) -> String {