Add typed LET
statement (#4476)
This commit is contained in:
parent
5d92c7c02c
commit
bb1eba4aab
10 changed files with 115 additions and 24 deletions
|
@ -596,6 +596,14 @@ pub enum Error {
|
|||
check: String,
|
||||
},
|
||||
|
||||
/// The specified value did not conform to the LET type check
|
||||
#[error("Found {value} for param ${name}, but expected a {check}")]
|
||||
SetCheck {
|
||||
value: String,
|
||||
name: String,
|
||||
check: String,
|
||||
},
|
||||
|
||||
/// The specified field did not conform to the field ASSERT clause
|
||||
#[error(
|
||||
"Found changed value for field `{field}`, with record `{thing}`, but field is readonly"
|
||||
|
@ -1200,8 +1208,21 @@ impl Serialize for Error {
|
|||
serializer.serialize_str(self.to_string().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn set_check_from_coerce(self, name: String) -> Error {
|
||||
match self {
|
||||
Error::CoerceTo {
|
||||
from,
|
||||
into,
|
||||
} => Error::SetCheck {
|
||||
name,
|
||||
value: from.to_string(),
|
||||
check: into,
|
||||
},
|
||||
e => e,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn function_check_from_coerce(self, name: impl Into<String>) -> Error {
|
||||
match self {
|
||||
Error::CoerceTo {
|
||||
|
@ -1212,12 +1233,7 @@ impl Error {
|
|||
value: from.to_string(),
|
||||
check: into,
|
||||
},
|
||||
fc @ Error::FunctionCheck {
|
||||
..
|
||||
} => fc,
|
||||
_ => Error::Internal(
|
||||
"function_check_from_coerce called on Error that wasn't CoerceTo".to_string(),
|
||||
),
|
||||
e => e,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use revision::revisioned;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
use std::ops::Deref;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::str;
|
||||
|
||||
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Query";
|
||||
|
@ -43,6 +43,12 @@ impl Deref for Query {
|
|||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Query {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for Query {
|
||||
type Item = Statement;
|
||||
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
use crate::cnf::PROTECTED_PARAM_NAMES;
|
||||
use crate::ctx::Context;
|
||||
use crate::dbs::Options;
|
||||
use crate::doc::CursorDoc;
|
||||
use crate::err::Error;
|
||||
use crate::sql::Value;
|
||||
use crate::{cnf::PROTECTED_PARAM_NAMES, sql::Kind};
|
||||
use derive::Store;
|
||||
use reblessive::tree::Stk;
|
||||
use revision::revisioned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
#[revisioned(revision = 1)]
|
||||
#[revisioned(revision = 2)]
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
|
||||
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
||||
#[non_exhaustive]
|
||||
pub struct SetStatement {
|
||||
pub name: String,
|
||||
pub what: Value,
|
||||
#[revision(start = 2)]
|
||||
pub kind: Option<Kind>,
|
||||
}
|
||||
|
||||
impl SetStatement {
|
||||
|
@ -35,7 +37,15 @@ impl SetStatement {
|
|||
// Check if the variable is a protected variable
|
||||
match PROTECTED_PARAM_NAMES.contains(&self.name.as_str()) {
|
||||
// The variable isn't protected and can be stored
|
||||
false => self.what.compute(stk, ctx, opt, doc).await,
|
||||
false => {
|
||||
let result = self.what.compute(stk, ctx, opt, doc).await?;
|
||||
match self.kind {
|
||||
Some(ref kind) => result
|
||||
.coerce_to(kind)
|
||||
.map_err(|e| e.set_check_from_coerce(self.name.to_string())),
|
||||
None => Ok(result),
|
||||
}
|
||||
}
|
||||
// The user tried to set a protected variable
|
||||
true => Err(Error::InvalidParam {
|
||||
// Move the parameter name, as we no longer need it
|
||||
|
|
|
@ -351,6 +351,7 @@ impl Parser<'_> {
|
|||
return Statement::Set(crate::sql::statements::SetStatement {
|
||||
name: x.0 .0,
|
||||
what: r,
|
||||
kind: None,
|
||||
});
|
||||
}
|
||||
Statement::Value(Value::Expression(x))
|
||||
|
@ -371,6 +372,7 @@ impl Parser<'_> {
|
|||
return Entry::Set(crate::sql::statements::SetStatement {
|
||||
name: x.0 .0,
|
||||
what: r,
|
||||
kind: None,
|
||||
});
|
||||
}
|
||||
Entry::Value(Value::Expression(x))
|
||||
|
@ -652,11 +654,17 @@ impl Parser<'_> {
|
|||
/// Expects `LET` to already be consumed.
|
||||
pub(crate) async fn parse_let_stmt(&mut self, ctx: &mut Stk) -> ParseResult<SetStatement> {
|
||||
let name = self.next_token_value::<Param>()?.0 .0;
|
||||
let kind = if self.eat(t!(":")) {
|
||||
Some(self.parse_inner_kind(ctx).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
expected!(self, t!("="));
|
||||
let what = self.parse_value(ctx).await?;
|
||||
Ok(SetStatement {
|
||||
name,
|
||||
what,
|
||||
kind,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1888,7 +1888,8 @@ fn parse_let() {
|
|||
res,
|
||||
Statement::Set(SetStatement {
|
||||
name: "param".to_owned(),
|
||||
what: Value::Number(Number::Int(1))
|
||||
what: Value::Number(Number::Int(1)),
|
||||
kind: None,
|
||||
})
|
||||
);
|
||||
|
||||
|
@ -1897,7 +1898,8 @@ fn parse_let() {
|
|||
res,
|
||||
Statement::Set(SetStatement {
|
||||
name: "param".to_owned(),
|
||||
what: Value::Number(Number::Int(1))
|
||||
what: Value::Number(Number::Int(1)),
|
||||
kind: None,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
|
@ -534,6 +534,7 @@ fn statements() -> Vec<Statement> {
|
|||
Statement::Set(SetStatement {
|
||||
name: "param".to_owned(),
|
||||
what: Value::Number(Number::Int(1)),
|
||||
kind: None,
|
||||
}),
|
||||
Statement::Show(ShowStatement {
|
||||
table: Some(Table("foo".to_owned())),
|
||||
|
|
|
@ -107,7 +107,7 @@ pub struct Live;
|
|||
|
||||
/// Responses returned with statistics
|
||||
#[derive(Debug)]
|
||||
pub struct WithStats<T>(T);
|
||||
pub struct WithStats<T>(pub T);
|
||||
|
||||
impl<C> Surreal<C>
|
||||
where
|
||||
|
|
23
lib/tests/set.rs
Normal file
23
lib/tests/set.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
mod helpers;
|
||||
mod parse;
|
||||
use helpers::Test;
|
||||
use surrealdb::err::Error;
|
||||
|
||||
#[tokio::test]
|
||||
async fn typed_set() -> Result<(), Error> {
|
||||
let sql = "
|
||||
LET $foo: int = 42;
|
||||
RETURN $foo;
|
||||
LET $bar: int = 'hello';
|
||||
RETURN $bar;
|
||||
";
|
||||
let error = "Found 'hello' for param $bar, but expected a int";
|
||||
Test::new(sql)
|
||||
.await?
|
||||
.expect_val("None")?
|
||||
.expect_val("42")?
|
||||
.expect_error(error)?
|
||||
.expect_val("None")?;
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -14,7 +14,7 @@ use serde_json::ser::PrettyFormatter;
|
|||
use surrealdb::engine::any::{connect, IntoEndpoint};
|
||||
use surrealdb::method::{Stats, WithStats};
|
||||
use surrealdb::opt::{capabilities::Capabilities, Config};
|
||||
use surrealdb::sql::{self, Statement, Value};
|
||||
use surrealdb::sql::{self, Param, Statement, Value};
|
||||
use surrealdb::{Notification, Response};
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
|
@ -185,10 +185,11 @@ pub async fn init(
|
|||
}
|
||||
// Complete the request
|
||||
match sql::parse(&line) {
|
||||
Ok(query) => {
|
||||
Ok(mut query) => {
|
||||
let mut namespace = None;
|
||||
let mut database = None;
|
||||
let mut vars = Vec::new();
|
||||
let init_length = query.len();
|
||||
// Capture `use` and `set/let` statements from the query
|
||||
for statement in query.iter() {
|
||||
match statement {
|
||||
|
@ -200,12 +201,15 @@ pub async fn init(
|
|||
database = Some(db.clone());
|
||||
}
|
||||
}
|
||||
Statement::Set(stmt) => {
|
||||
vars.push((stmt.name.clone(), stmt.what.clone()));
|
||||
}
|
||||
Statement::Set(stmt) => vars.push(stmt.name.clone()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
for var in &vars {
|
||||
query.push(Statement::Value(Value::Param(Param::from(var.as_str()))))
|
||||
}
|
||||
|
||||
// Extract the namespace and database from the current prompt
|
||||
let (prompt_ns, prompt_db) = split_prompt(&prompt);
|
||||
// The namespace should be set before the database can be set
|
||||
|
@ -216,17 +220,23 @@ pub async fn init(
|
|||
continue;
|
||||
}
|
||||
// Run the query provided
|
||||
let result = client.query(query).with_stats().await;
|
||||
let mut result = client.query(query).with_stats().await;
|
||||
|
||||
if let Ok(WithStats(res)) = &mut result {
|
||||
for (i, n) in vars.into_iter().enumerate() {
|
||||
if let Result::<Value, _>::Ok(v) = res.take(init_length + i) {
|
||||
let _ = client.set(n, v).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = process(pretty, json, result);
|
||||
let result_is_error = result.is_err();
|
||||
print(result);
|
||||
if result_is_error {
|
||||
continue;
|
||||
}
|
||||
// Persist the variables extracted from the query
|
||||
for (key, value) in vars {
|
||||
let _ = client.set(key, value).await;
|
||||
}
|
||||
|
||||
// Process the last `use` statements, if any
|
||||
if namespace.is_some() || database.is_some() {
|
||||
// Use the namespace provided in the query if any, otherwise use the one in the prompt
|
||||
|
|
|
@ -1450,6 +1450,21 @@ mod cli_integration {
|
|||
server.finish().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn double_create() {
|
||||
info!("* check only one output created");
|
||||
{
|
||||
let args = "sql --conn memory --ns test --db test --pretty --hide-welcome".to_string();
|
||||
let output = common::run(&args)
|
||||
.input("let $a = create foo;\n")
|
||||
.input("select * from foo;\n")
|
||||
.output()
|
||||
.unwrap();
|
||||
let output = remove_debug_info(output);
|
||||
assert_eq!(output.matches("foo:").count(), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_debug_info(output: String) -> String {
|
||||
|
|
Loading…
Reference in a new issue