Add typed LET
statement (#4476)
This commit is contained in:
parent
5d92c7c02c
commit
bb1eba4aab
10 changed files with 115 additions and 24 deletions
|
@ -596,6 +596,14 @@ pub enum Error {
|
||||||
check: String,
|
check: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// The specified value did not conform to the LET type check
|
||||||
|
#[error("Found {value} for param ${name}, but expected a {check}")]
|
||||||
|
SetCheck {
|
||||||
|
value: String,
|
||||||
|
name: String,
|
||||||
|
check: String,
|
||||||
|
},
|
||||||
|
|
||||||
/// The specified field did not conform to the field ASSERT clause
|
/// The specified field did not conform to the field ASSERT clause
|
||||||
#[error(
|
#[error(
|
||||||
"Found changed value for field `{field}`, with record `{thing}`, but field is readonly"
|
"Found changed value for field `{field}`, with record `{thing}`, but field is readonly"
|
||||||
|
@ -1200,8 +1208,21 @@ impl Serialize for Error {
|
||||||
serializer.serialize_str(self.to_string().as_str())
|
serializer.serialize_str(self.to_string().as_str())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error {
|
impl Error {
|
||||||
|
pub fn set_check_from_coerce(self, name: String) -> Error {
|
||||||
|
match self {
|
||||||
|
Error::CoerceTo {
|
||||||
|
from,
|
||||||
|
into,
|
||||||
|
} => Error::SetCheck {
|
||||||
|
name,
|
||||||
|
value: from.to_string(),
|
||||||
|
check: into,
|
||||||
|
},
|
||||||
|
e => e,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
|
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
|
||||||
match self {
|
match self {
|
||||||
Error::CoerceTo {
|
Error::CoerceTo {
|
||||||
|
@ -1212,12 +1233,7 @@ impl Error {
|
||||||
value: from.to_string(),
|
value: from.to_string(),
|
||||||
check: into,
|
check: into,
|
||||||
},
|
},
|
||||||
fc @ Error::FunctionCheck {
|
e => e,
|
||||||
..
|
|
||||||
} => fc,
|
|
||||||
_ => Error::Internal(
|
|
||||||
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ use revision::revisioned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::fmt::{self, Display, Formatter};
|
use std::fmt::{self, Display, Formatter};
|
||||||
use std::ops::Deref;
|
use std::ops::{Deref, DerefMut};
|
||||||
use std::str;
|
use std::str;
|
||||||
|
|
||||||
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query";
|
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query";
|
||||||
|
@ -43,6 +43,12 @@ impl Deref for Query {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl DerefMut for Query {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
&mut self.0 .0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl IntoIterator for Query {
|
impl IntoIterator for Query {
|
||||||
type Item = Statement;
|
type Item = Statement;
|
||||||
type IntoIter = std::vec::IntoIter<Self::Item>;
|
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||||
|
|
|
@ -1,22 +1,24 @@
|
||||||
use crate::cnf::PROTECTED_PARAM_NAMES;
|
|
||||||
use crate::ctx::Context;
|
use crate::ctx::Context;
|
||||||
use crate::dbs::Options;
|
use crate::dbs::Options;
|
||||||
use crate::doc::CursorDoc;
|
use crate::doc::CursorDoc;
|
||||||
use crate::err::Error;
|
use crate::err::Error;
|
||||||
use crate::sql::Value;
|
use crate::sql::Value;
|
||||||
|
use crate::{cnf::PROTECTED_PARAM_NAMES, sql::Kind};
|
||||||
use derive::Store;
|
use derive::Store;
|
||||||
use reblessive::tree::Stk;
|
use reblessive::tree::Stk;
|
||||||
use revision::revisioned;
|
use revision::revisioned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
#[revisioned(revision = 1)]
|
#[revisioned(revision = 2)]
|
||||||
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
||||||
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub struct SetStatement {
|
pub struct SetStatement {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub what: Value,
|
pub what: Value,
|
||||||
|
#[revision(start = 2)]
|
||||||
|
pub kind: Option<Kind>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SetStatement {
|
impl SetStatement {
|
||||||
|
@ -35,7 +37,15 @@ impl SetStatement {
|
||||||
// Check if the variable is a protected variable
|
// Check if the variable is a protected variable
|
||||||
match PROTECTED_PARAM_NAMES.contains(&self.name.as_str()) {
|
match PROTECTED_PARAM_NAMES.contains(&self.name.as_str()) {
|
||||||
// The variable isn't protected and can be stored
|
// The variable isn't protected and can be stored
|
||||||
false => self.what.compute(stk, ctx, opt, doc).await,
|
false => {
|
||||||
|
let result = self.what.compute(stk, ctx, opt, doc).await?;
|
||||||
|
match self.kind {
|
||||||
|
Some(ref kind) => result
|
||||||
|
.coerce_to(kind)
|
||||||
|
.map_err(|e| e.set_check_from_coerce(self.name.to_string())),
|
||||||
|
None => Ok(result),
|
||||||
|
}
|
||||||
|
}
|
||||||
// The user tried to set a protected variable
|
// The user tried to set a protected variable
|
||||||
true => Err(Error::InvalidParam {
|
true => Err(Error::InvalidParam {
|
||||||
// Move the parameter name, as we no longer need it
|
// Move the parameter name, as we no longer need it
|
||||||
|
|
|
@ -351,6 +351,7 @@ impl Parser<'_> {
|
||||||
return Statement::Set(crate::sql::statements::SetStatement {
|
return Statement::Set(crate::sql::statements::SetStatement {
|
||||||
name: x.0 .0,
|
name: x.0 .0,
|
||||||
what: r,
|
what: r,
|
||||||
|
kind: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Statement::Value(Value::Expression(x))
|
Statement::Value(Value::Expression(x))
|
||||||
|
@ -371,6 +372,7 @@ impl Parser<'_> {
|
||||||
return Entry::Set(crate::sql::statements::SetStatement {
|
return Entry::Set(crate::sql::statements::SetStatement {
|
||||||
name: x.0 .0,
|
name: x.0 .0,
|
||||||
what: r,
|
what: r,
|
||||||
|
kind: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Entry::Value(Value::Expression(x))
|
Entry::Value(Value::Expression(x))
|
||||||
|
@ -652,11 +654,17 @@ impl Parser<'_> {
|
||||||
/// Expects `LET` to already be consumed.
|
/// Expects `LET` to already be consumed.
|
||||||
pub(crate) async fn parse_let_stmt(&mut self, ctx: &mut Stk) -> ParseResult<SetStatement> {
|
pub(crate) async fn parse_let_stmt(&mut self, ctx: &mut Stk) -> ParseResult<SetStatement> {
|
||||||
let name = self.next_token_value::<Param>()?.0 .0;
|
let name = self.next_token_value::<Param>()?.0 .0;
|
||||||
|
let kind = if self.eat(t!(":")) {
|
||||||
|
Some(self.parse_inner_kind(ctx).await?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
expected!(self, t!("="));
|
expected!(self, t!("="));
|
||||||
let what = self.parse_value(ctx).await?;
|
let what = self.parse_value(ctx).await?;
|
||||||
Ok(SetStatement {
|
Ok(SetStatement {
|
||||||
name,
|
name,
|
||||||
what,
|
what,
|
||||||
|
kind,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1888,7 +1888,8 @@ fn parse_let() {
|
||||||
res,
|
res,
|
||||||
Statement::Set(SetStatement {
|
Statement::Set(SetStatement {
|
||||||
name: "param".to_owned(),
|
name: "param".to_owned(),
|
||||||
what: Value::Number(Number::Int(1))
|
what: Value::Number(Number::Int(1)),
|
||||||
|
kind: None,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -1897,7 +1898,8 @@ fn parse_let() {
|
||||||
res,
|
res,
|
||||||
Statement::Set(SetStatement {
|
Statement::Set(SetStatement {
|
||||||
name: "param".to_owned(),
|
name: "param".to_owned(),
|
||||||
what: Value::Number(Number::Int(1))
|
what: Value::Number(Number::Int(1)),
|
||||||
|
kind: None,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -534,6 +534,7 @@ fn statements() -> Vec<Statement> {
|
||||||
Statement::Set(SetStatement {
|
Statement::Set(SetStatement {
|
||||||
name: "param".to_owned(),
|
name: "param".to_owned(),
|
||||||
what: Value::Number(Number::Int(1)),
|
what: Value::Number(Number::Int(1)),
|
||||||
|
kind: None,
|
||||||
}),
|
}),
|
||||||
Statement::Show(ShowStatement {
|
Statement::Show(ShowStatement {
|
||||||
table: Some(Table("foo".to_owned())),
|
table: Some(Table("foo".to_owned())),
|
||||||
|
|
|
@ -107,7 +107,7 @@ pub struct Live;
|
||||||
|
|
||||||
/// Responses returned with statistics
|
/// Responses returned with statistics
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WithStats<T>(T);
|
pub struct WithStats<T>(pub T);
|
||||||
|
|
||||||
impl<C> Surreal<C>
|
impl<C> Surreal<C>
|
||||||
where
|
where
|
||||||
|
|
23
lib/tests/set.rs
Normal file
23
lib/tests/set.rs
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
mod helpers;
|
||||||
|
mod parse;
|
||||||
|
use helpers::Test;
|
||||||
|
use surrealdb::err::Error;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn typed_set() -> Result<(), Error> {
|
||||||
|
let sql = "
|
||||||
|
LET $foo: int = 42;
|
||||||
|
RETURN $foo;
|
||||||
|
LET $bar: int = 'hello';
|
||||||
|
RETURN $bar;
|
||||||
|
";
|
||||||
|
let error = "Found 'hello' for param $bar, but expected a int";
|
||||||
|
Test::new(sql)
|
||||||
|
.await?
|
||||||
|
.expect_val("None")?
|
||||||
|
.expect_val("42")?
|
||||||
|
.expect_error(error)?
|
||||||
|
.expect_val("None")?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -14,7 +14,7 @@ use serde_json::ser::PrettyFormatter;
|
||||||
use surrealdb::engine::any::{connect, IntoEndpoint};
|
use surrealdb::engine::any::{connect, IntoEndpoint};
|
||||||
use surrealdb::method::{Stats, WithStats};
|
use surrealdb::method::{Stats, WithStats};
|
||||||
use surrealdb::opt::{capabilities::Capabilities, Config};
|
use surrealdb::opt::{capabilities::Capabilities, Config};
|
||||||
use surrealdb::sql::{self, Statement, Value};
|
use surrealdb::sql::{self, Param, Statement, Value};
|
||||||
use surrealdb::{Notification, Response};
|
use surrealdb::{Notification, Response};
|
||||||
|
|
||||||
#[derive(Args, Debug)]
|
#[derive(Args, Debug)]
|
||||||
|
@ -185,10 +185,11 @@ pub async fn init(
|
||||||
}
|
}
|
||||||
// Complete the request
|
// Complete the request
|
||||||
match sql::parse(&line) {
|
match sql::parse(&line) {
|
||||||
Ok(query) => {
|
Ok(mut query) => {
|
||||||
let mut namespace = None;
|
let mut namespace = None;
|
||||||
let mut database = None;
|
let mut database = None;
|
||||||
let mut vars = Vec::new();
|
let mut vars = Vec::new();
|
||||||
|
let init_length = query.len();
|
||||||
// Capture `use` and `set/let` statements from the query
|
// Capture `use` and `set/let` statements from the query
|
||||||
for statement in query.iter() {
|
for statement in query.iter() {
|
||||||
match statement {
|
match statement {
|
||||||
|
@ -200,12 +201,15 @@ pub async fn init(
|
||||||
database = Some(db.clone());
|
database = Some(db.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Statement::Set(stmt) => {
|
Statement::Set(stmt) => vars.push(stmt.name.clone()),
|
||||||
vars.push((stmt.name.clone(), stmt.what.clone()));
|
|
||||||
}
|
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for var in &vars {
|
||||||
|
query.push(Statement::Value(Value::Param(Param::from(var.as_str()))))
|
||||||
|
}
|
||||||
|
|
||||||
// Extract the namespace and database from the current prompt
|
// Extract the namespace and database from the current prompt
|
||||||
let (prompt_ns, prompt_db) = split_prompt(&prompt);
|
let (prompt_ns, prompt_db) = split_prompt(&prompt);
|
||||||
// The namespace should be set before the database can be set
|
// The namespace should be set before the database can be set
|
||||||
|
@ -216,17 +220,23 @@ pub async fn init(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Run the query provided
|
// Run the query provided
|
||||||
let result = client.query(query).with_stats().await;
|
let mut result = client.query(query).with_stats().await;
|
||||||
|
|
||||||
|
if let Ok(WithStats(res)) = &mut result {
|
||||||
|
for (i, n) in vars.into_iter().enumerate() {
|
||||||
|
if let Result::<Value, _>::Ok(v) = res.take(init_length + i) {
|
||||||
|
let _ = client.set(n, v).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let result = process(pretty, json, result);
|
let result = process(pretty, json, result);
|
||||||
let result_is_error = result.is_err();
|
let result_is_error = result.is_err();
|
||||||
print(result);
|
print(result);
|
||||||
if result_is_error {
|
if result_is_error {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Persist the variables extracted from the query
|
|
||||||
for (key, value) in vars {
|
|
||||||
let _ = client.set(key, value).await;
|
|
||||||
}
|
|
||||||
// Process the last `use` statements, if any
|
// Process the last `use` statements, if any
|
||||||
if namespace.is_some() || database.is_some() {
|
if namespace.is_some() || database.is_some() {
|
||||||
// Use the namespace provided in the query if any, otherwise use the one in the prompt
|
// Use the namespace provided in the query if any, otherwise use the one in the prompt
|
||||||
|
|
|
@ -1450,6 +1450,21 @@ mod cli_integration {
|
||||||
server.finish().unwrap();
|
server.finish().unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test(tokio::test)]
|
||||||
|
async fn double_create() {
|
||||||
|
info!("* check only one output created");
|
||||||
|
{
|
||||||
|
let args = "sql --conn memory --ns test --db test --pretty --hide-welcome".to_string();
|
||||||
|
let output = common::run(&args)
|
||||||
|
.input("let $a = create foo;\n")
|
||||||
|
.input("select * from foo;\n")
|
||||||
|
.output()
|
||||||
|
.unwrap();
|
||||||
|
let output = remove_debug_info(output);
|
||||||
|
assert_eq!(output.matches("foo:").count(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_debug_info(output: String) -> String {
|
fn remove_debug_info(output: String) -> String {
|
||||||
|
|
Loading…
Reference in a new issue