Add typed LET statement (#4476)

This commit is contained in:
Raphael Darley 2024-08-13 20:47:17 +01:00 committed by GitHub
parent 5d92c7c02c
commit bb1eba4aab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 115 additions and 24 deletions

View file

@ -596,6 +596,14 @@ pub enum Error {
check: String, check: String,
}, },
/// The specified value did not conform to the LET type check
#[error("Found {value} for param ${name}, but expected a {check}")]
SetCheck {
value: String,
name: String,
check: String,
},
/// The specified field did not conform to the field ASSERT clause /// The specified field did not conform to the field ASSERT clause
#[error( #[error(
"Found changed value for field `{field}`, with record `{thing}`, but field is readonly" "Found changed value for field `{field}`, with record `{thing}`, but field is readonly"
@ -1200,8 +1208,21 @@ impl Serialize for Error {
serializer.serialize_str(self.to_string().as_str()) serializer.serialize_str(self.to_string().as_str())
} }
} }
impl Error { impl Error {
pub fn set_check_from_coerce(self, name: String) -> Error {
match self {
Error::CoerceTo {
from,
into,
} => Error::SetCheck {
name,
value: from.to_string(),
check: into,
},
e => e,
}
}
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error { pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
match self { match self {
Error::CoerceTo { Error::CoerceTo {
@ -1212,12 +1233,7 @@ impl Error {
value: from.to_string(), value: from.to_string(),
check: into, check: into,
}, },
fc @ Error::FunctionCheck { e => e,
..
} => fc,
_ => Error::Internal(
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
),
} }
} }
} }

View file

@ -6,7 +6,7 @@ use revision::revisioned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Write; use std::fmt::Write;
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
use std::ops::Deref; use std::ops::{Deref, DerefMut};
use std::str; use std::str;
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query"; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query";
@ -43,6 +43,12 @@ impl Deref for Query {
} }
} }
impl DerefMut for Query {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0 .0
}
}
impl IntoIterator for Query { impl IntoIterator for Query {
type Item = Statement; type Item = Statement;
type IntoIter = std::vec::IntoIter<Self::Item>; type IntoIter = std::vec::IntoIter<Self::Item>;

View file

@ -1,22 +1,24 @@
use crate::cnf::PROTECTED_PARAM_NAMES;
use crate::ctx::Context; use crate::ctx::Context;
use crate::dbs::Options; use crate::dbs::Options;
use crate::doc::CursorDoc; use crate::doc::CursorDoc;
use crate::err::Error; use crate::err::Error;
use crate::sql::Value; use crate::sql::Value;
use crate::{cnf::PROTECTED_PARAM_NAMES, sql::Kind};
use derive::Store; use derive::Store;
use reblessive::tree::Stk; use reblessive::tree::Stk;
use revision::revisioned; use revision::revisioned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
#[revisioned(revision = 1)] #[revisioned(revision = 2)]
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive] #[non_exhaustive]
pub struct SetStatement { pub struct SetStatement {
pub name: String, pub name: String,
pub what: Value, pub what: Value,
#[revision(start = 2)]
pub kind: Option<Kind>,
} }
impl SetStatement { impl SetStatement {
@ -35,7 +37,15 @@ impl SetStatement {
// Check if the variable is a protected variable // Check if the variable is a protected variable
match PROTECTED_PARAM_NAMES.contains(&self.name.as_str()) { match PROTECTED_PARAM_NAMES.contains(&self.name.as_str()) {
// The variable isn't protected and can be stored // The variable isn't protected and can be stored
false => self.what.compute(stk, ctx, opt, doc).await, false => {
let result = self.what.compute(stk, ctx, opt, doc).await?;
match self.kind {
Some(ref kind) => result
.coerce_to(kind)
.map_err(|e| e.set_check_from_coerce(self.name.to_string())),
None => Ok(result),
}
}
// The user tried to set a protected variable // The user tried to set a protected variable
true => Err(Error::InvalidParam { true => Err(Error::InvalidParam {
// Move the parameter name, as we no longer need it // Move the parameter name, as we no longer need it

View file

@ -351,6 +351,7 @@ impl Parser<'_> {
return Statement::Set(crate::sql::statements::SetStatement { return Statement::Set(crate::sql::statements::SetStatement {
name: x.0 .0, name: x.0 .0,
what: r, what: r,
kind: None,
}); });
} }
Statement::Value(Value::Expression(x)) Statement::Value(Value::Expression(x))
@ -371,6 +372,7 @@ impl Parser<'_> {
return Entry::Set(crate::sql::statements::SetStatement { return Entry::Set(crate::sql::statements::SetStatement {
name: x.0 .0, name: x.0 .0,
what: r, what: r,
kind: None,
}); });
} }
Entry::Value(Value::Expression(x)) Entry::Value(Value::Expression(x))
@ -652,11 +654,17 @@ impl Parser<'_> {
/// Expects `LET` to already be consumed. /// Expects `LET` to already be consumed.
pub(crate) async fn parse_let_stmt(&mut self, ctx: &mut Stk) -> ParseResult<SetStatement> { pub(crate) async fn parse_let_stmt(&mut self, ctx: &mut Stk) -> ParseResult<SetStatement> {
let name = self.next_token_value::<Param>()?.0 .0; let name = self.next_token_value::<Param>()?.0 .0;
let kind = if self.eat(t!(":")) {
Some(self.parse_inner_kind(ctx).await?)
} else {
None
};
expected!(self, t!("=")); expected!(self, t!("="));
let what = self.parse_value(ctx).await?; let what = self.parse_value(ctx).await?;
Ok(SetStatement { Ok(SetStatement {
name, name,
what, what,
kind,
}) })
} }

View file

@ -1888,7 +1888,8 @@ fn parse_let() {
res, res,
Statement::Set(SetStatement { Statement::Set(SetStatement {
name: "param".to_owned(), name: "param".to_owned(),
what: Value::Number(Number::Int(1)) what: Value::Number(Number::Int(1)),
kind: None,
}) })
); );
@ -1897,7 +1898,8 @@ fn parse_let() {
res, res,
Statement::Set(SetStatement { Statement::Set(SetStatement {
name: "param".to_owned(), name: "param".to_owned(),
what: Value::Number(Number::Int(1)) what: Value::Number(Number::Int(1)),
kind: None,
}) })
); );
} }

View file

@ -534,6 +534,7 @@ fn statements() -> Vec<Statement> {
Statement::Set(SetStatement { Statement::Set(SetStatement {
name: "param".to_owned(), name: "param".to_owned(),
what: Value::Number(Number::Int(1)), what: Value::Number(Number::Int(1)),
kind: None,
}), }),
Statement::Show(ShowStatement { Statement::Show(ShowStatement {
table: Some(Table("foo".to_owned())), table: Some(Table("foo".to_owned())),

View file

@ -107,7 +107,7 @@ pub struct Live;
/// Responses returned with statistics /// Responses returned with statistics
#[derive(Debug)] #[derive(Debug)]
pub struct WithStats<T>(T); pub struct WithStats<T>(pub T);
impl<C> Surreal<C> impl<C> Surreal<C>
where where

23
lib/tests/set.rs Normal file
View file

@ -0,0 +1,23 @@
mod helpers;
mod parse;
use helpers::Test;
use surrealdb::err::Error;
#[tokio::test]
async fn typed_set() -> Result<(), Error> {
let sql = "
LET $foo: int = 42;
RETURN $foo;
LET $bar: int = 'hello';
RETURN $bar;
";
let error = "Found 'hello' for param $bar, but expected a int";
Test::new(sql)
.await?
.expect_val("None")?
.expect_val("42")?
.expect_error(error)?
.expect_val("None")?;
Ok(())
}

View file

@ -14,7 +14,7 @@ use serde_json::ser::PrettyFormatter;
use surrealdb::engine::any::{connect, IntoEndpoint}; use surrealdb::engine::any::{connect, IntoEndpoint};
use surrealdb::method::{Stats, WithStats}; use surrealdb::method::{Stats, WithStats};
use surrealdb::opt::{capabilities::Capabilities, Config}; use surrealdb::opt::{capabilities::Capabilities, Config};
use surrealdb::sql::{self, Statement, Value}; use surrealdb::sql::{self, Param, Statement, Value};
use surrealdb::{Notification, Response}; use surrealdb::{Notification, Response};
#[derive(Args, Debug)] #[derive(Args, Debug)]
@ -185,10 +185,11 @@ pub async fn init(
} }
// Complete the request // Complete the request
match sql::parse(&line) { match sql::parse(&line) {
Ok(query) => { Ok(mut query) => {
let mut namespace = None; let mut namespace = None;
let mut database = None; let mut database = None;
let mut vars = Vec::new(); let mut vars = Vec::new();
let init_length = query.len();
// Capture `use` and `set/let` statements from the query // Capture `use` and `set/let` statements from the query
for statement in query.iter() { for statement in query.iter() {
match statement { match statement {
@ -200,12 +201,15 @@ pub async fn init(
database = Some(db.clone()); database = Some(db.clone());
} }
} }
Statement::Set(stmt) => { Statement::Set(stmt) => vars.push(stmt.name.clone()),
vars.push((stmt.name.clone(), stmt.what.clone()));
}
_ => {} _ => {}
} }
} }
for var in &vars {
query.push(Statement::Value(Value::Param(Param::from(var.as_str()))))
}
// Extract the namespace and database from the current prompt // Extract the namespace and database from the current prompt
let (prompt_ns, prompt_db) = split_prompt(&prompt); let (prompt_ns, prompt_db) = split_prompt(&prompt);
// The namespace should be set before the database can be set // The namespace should be set before the database can be set
@ -216,17 +220,23 @@ pub async fn init(
continue; continue;
} }
// Run the query provided // Run the query provided
let result = client.query(query).with_stats().await; let mut result = client.query(query).with_stats().await;
if let Ok(WithStats(res)) = &mut result {
for (i, n) in vars.into_iter().enumerate() {
if let Result::<Value, _>::Ok(v) = res.take(init_length + i) {
let _ = client.set(n, v).await;
}
}
}
let result = process(pretty, json, result); let result = process(pretty, json, result);
let result_is_error = result.is_err(); let result_is_error = result.is_err();
print(result); print(result);
if result_is_error { if result_is_error {
continue; continue;
} }
// Persist the variables extracted from the query
for (key, value) in vars {
let _ = client.set(key, value).await;
}
// Process the last `use` statements, if any // Process the last `use` statements, if any
if namespace.is_some() || database.is_some() { if namespace.is_some() || database.is_some() {
// Use the namespace provided in the query if any, otherwise use the one in the prompt // Use the namespace provided in the query if any, otherwise use the one in the prompt

View file

@ -1450,6 +1450,21 @@ mod cli_integration {
server.finish().unwrap(); server.finish().unwrap();
} }
} }
#[test(tokio::test)]
async fn double_create() {
info!("* check only one output created");
{
let args = "sql --conn memory --ns test --db test --pretty --hide-welcome".to_string();
let output = common::run(&args)
.input("let $a = create foo;\n")
.input("select * from foo;\n")
.output()
.unwrap();
let output = remove_debug_info(output);
assert_eq!(output.matches("foo:").count(), 1);
}
}
} }
fn remove_debug_info(output: String) -> String { fn remove_debug_info(output: String) -> String {