Add typed LET statement (#4476)

This commit is contained in:
Raphael Darley 2024-08-13 20:47:17 +01:00 committed by GitHub
parent 5d92c7c02c
commit bb1eba4aab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 115 additions and 24 deletions

View file

@ -596,6 +596,14 @@ pub enum Error {
check: String,
},
/// The specified value did not conform to the LET type check
#[error("Found {value} for param ${name}, but expected a {check}")]
SetCheck {
value: String,
name: String,
check: String,
},
/// The specified field did not conform to the field ASSERT clause
#[error(
"Found changed value for field `{field}`, with record `{thing}`, but field is readonly"
@ -1200,8 +1208,21 @@ impl Serialize for Error {
serializer.serialize_str(self.to_string().as_str())
}
}
impl Error {
pub fn set_check_from_coerce(self, name: String) -> Error {
match self {
Error::CoerceTo {
from,
into,
} => Error::SetCheck {
name,
value: from.to_string(),
check: into,
},
e => e,
}
}
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
match self {
Error::CoerceTo {
@ -1212,12 +1233,7 @@ impl Error {
value: from.to_string(),
check: into,
},
fc @ Error::FunctionCheck {
..
} => fc,
_ => Error::Internal(
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
),
e => e,
}
}
}

View file

@ -6,7 +6,7 @@ use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::fmt::{self, Display, Formatter};
use std::ops::Deref;
use std::ops::{Deref, DerefMut};
use std::str;
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query";
@ -43,6 +43,12 @@ impl Deref for Query {
}
}
impl DerefMut for Query {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0 .0
}
}
impl IntoIterator for Query {
type Item = Statement;
type IntoIter = std::vec::IntoIter<Self::Item>;

View file

@ -1,22 +1,24 @@
use crate::cnf::PROTECTED_PARAM_NAMES;
use crate::ctx::Context;
use crate::dbs::Options;
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::sql::Value;
use crate::{cnf::PROTECTED_PARAM_NAMES, sql::Kind};
use derive::Store;
use reblessive::tree::Stk;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
#[revisioned(revision = 1)]
#[revisioned(revision = 2)]
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
pub struct SetStatement {
pub name: String,
pub what: Value,
#[revision(start = 2)]
pub kind: Option<Kind>,
}
impl SetStatement {
@ -35,7 +37,15 @@ impl SetStatement {
// Check if the variable is a protected variable
match PROTECTED_PARAM_NAMES.contains(&self.name.as_str()) {
// The variable isn't protected and can be stored
false => self.what.compute(stk, ctx, opt, doc).await,
false => {
let result = self.what.compute(stk, ctx, opt, doc).await?;
match self.kind {
Some(ref kind) => result
.coerce_to(kind)
.map_err(|e| e.set_check_from_coerce(self.name.to_string())),
None => Ok(result),
}
}
// The user tried to set a protected variable
true => Err(Error::InvalidParam {
// Move the parameter name, as we no longer need it

View file

@ -351,6 +351,7 @@ impl Parser<'_> {
return Statement::Set(crate::sql::statements::SetStatement {
name: x.0 .0,
what: r,
kind: None,
});
}
Statement::Value(Value::Expression(x))
@ -371,6 +372,7 @@ impl Parser<'_> {
return Entry::Set(crate::sql::statements::SetStatement {
name: x.0 .0,
what: r,
kind: None,
});
}
Entry::Value(Value::Expression(x))
@ -652,11 +654,17 @@ impl Parser<'_> {
/// Expects `LET` to already be consumed.
pub(crate) async fn parse_let_stmt(&mut self, ctx: &mut Stk) -> ParseResult<SetStatement> {
let name = self.next_token_value::<Param>()?.0 .0;
let kind = if self.eat(t!(":")) {
Some(self.parse_inner_kind(ctx).await?)
} else {
None
};
expected!(self, t!("="));
let what = self.parse_value(ctx).await?;
Ok(SetStatement {
name,
what,
kind,
})
}

View file

@ -1888,7 +1888,8 @@ fn parse_let() {
res,
Statement::Set(SetStatement {
name: "param".to_owned(),
what: Value::Number(Number::Int(1))
what: Value::Number(Number::Int(1)),
kind: None,
})
);
@ -1897,7 +1898,8 @@ fn parse_let() {
res,
Statement::Set(SetStatement {
name: "param".to_owned(),
what: Value::Number(Number::Int(1))
what: Value::Number(Number::Int(1)),
kind: None,
})
);
}

View file

@ -534,6 +534,7 @@ fn statements() -> Vec<Statement> {
Statement::Set(SetStatement {
name: "param".to_owned(),
what: Value::Number(Number::Int(1)),
kind: None,
}),
Statement::Show(ShowStatement {
table: Some(Table("foo".to_owned())),

View file

@ -107,7 +107,7 @@ pub struct Live;
/// Responses returned with statistics
#[derive(Debug)]
pub struct WithStats<T>(T);
pub struct WithStats<T>(pub T);
impl<C> Surreal<C>
where

23
lib/tests/set.rs Normal file
View file

@ -0,0 +1,23 @@
mod helpers;
mod parse;
use helpers::Test;
use surrealdb::err::Error;
#[tokio::test]
async fn typed_set() -> Result<(), Error> {
let sql = "
LET $foo: int = 42;
RETURN $foo;
LET $bar: int = 'hello';
RETURN $bar;
";
let error = "Found 'hello' for param $bar, but expected a int";
Test::new(sql)
.await?
.expect_val("None")?
.expect_val("42")?
.expect_error(error)?
.expect_val("None")?;
Ok(())
}

View file

@ -14,7 +14,7 @@ use serde_json::ser::PrettyFormatter;
use surrealdb::engine::any::{connect, IntoEndpoint};
use surrealdb::method::{Stats, WithStats};
use surrealdb::opt::{capabilities::Capabilities, Config};
use surrealdb::sql::{self, Statement, Value};
use surrealdb::sql::{self, Param, Statement, Value};
use surrealdb::{Notification, Response};
#[derive(Args, Debug)]
@ -185,10 +185,11 @@ pub async fn init(
}
// Complete the request
match sql::parse(&line) {
Ok(query) => {
Ok(mut query) => {
let mut namespace = None;
let mut database = None;
let mut vars = Vec::new();
let init_length = query.len();
// Capture `use` and `set/let` statements from the query
for statement in query.iter() {
match statement {
@ -200,12 +201,15 @@ pub async fn init(
database = Some(db.clone());
}
}
Statement::Set(stmt) => {
vars.push((stmt.name.clone(), stmt.what.clone()));
}
Statement::Set(stmt) => vars.push(stmt.name.clone()),
_ => {}
}
}
for var in &vars {
query.push(Statement::Value(Value::Param(Param::from(var.as_str()))))
}
// Extract the namespace and database from the current prompt
let (prompt_ns, prompt_db) = split_prompt(&prompt);
// The namespace should be set before the database can be set
@ -216,17 +220,23 @@ pub async fn init(
continue;
}
// Run the query provided
let result = client.query(query).with_stats().await;
let mut result = client.query(query).with_stats().await;
if let Ok(WithStats(res)) = &mut result {
for (i, n) in vars.into_iter().enumerate() {
if let Result::<Value, _>::Ok(v) = res.take(init_length + i) {
let _ = client.set(n, v).await;
}
}
}
let result = process(pretty, json, result);
let result_is_error = result.is_err();
print(result);
if result_is_error {
continue;
}
// Persist the variables extracted from the query
for (key, value) in vars {
let _ = client.set(key, value).await;
}
// Process the last `use` statements, if any
if namespace.is_some() || database.is_some() {
// Use the namespace provided in the query if any, otherwise use the one in the prompt

View file

@ -1450,6 +1450,21 @@ mod cli_integration {
server.finish().unwrap();
}
}
#[test(tokio::test)]
async fn double_create() {
info!("* check only one output created");
{
let args = "sql --conn memory --ns test --db test --pretty --hide-welcome".to_string();
let output = common::run(&args)
.input("let $a = create foo;\n")
.input("select * from foo;\n")
.output()
.unwrap();
let output = remove_debug_info(output);
assert_eq!(output.matches("foo:").count(), 1);
}
}
}
fn remove_debug_info(output: String) -> String {