From 23be3353be0974dfaaf52c4a77a5e38817c0aa62 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Sun, 8 Jan 2023 17:11:35 +0000 Subject: [PATCH] Check expressions for SPLIT ON, GROUP BY, and ORDER BY clauses Closes #1229 Closes #1230 Closes #1457 Closes #1233 --- lib/src/err/mod.rs | 24 +++++ lib/src/sql/common.rs | 10 +- lib/src/sql/error.rs | 8 +- lib/src/sql/group.rs | 12 +++ lib/src/sql/mod.rs | 1 + lib/src/sql/parser.rs | 32 +++++-- lib/src/sql/special.rs | 156 +++++++++++++++++++++++++++++++ lib/src/sql/statements/select.rs | 6 ++ lib/src/sql/strand.rs | 6 +- lib/src/sql/value/value.rs | 27 +++++- lib/tests/define.rs | 6 +- lib/tests/group.rs | 9 +- lib/tests/model.rs | 4 +- 13 files changed, 277 insertions(+), 24 deletions(-) create mode 100644 lib/src/sql/special.rs diff --git a/lib/src/err/mod.rs b/lib/src/err/mod.rs index 197f6fed..78cc44a4 100644 --- a/lib/src/err/mod.rs +++ b/lib/src/err/mod.rs @@ -86,6 +86,30 @@ pub enum Error { name: String, }, + #[error("Found '{field}' in SELECT clause on line {line}, but field is not an aggregate function, and is not present in GROUP BY expression")] + InvalidField { + line: usize, + field: String, + }, + + #[error("Found '{field}' in SPLIT ON clause on line {line}, but field is not present in SELECT expression")] + InvalidSplit { + line: usize, + field: String, + }, + + #[error("Found '{field}' in ORDER BY clause on line {line}, but field is not present in SELECT expression")] + InvalidOrder { + line: usize, + field: String, + }, + + #[error("Found '{field}' in GROUP BY clause on line {line}, but field is not present in SELECT expression")] + InvalidGroup { + line: usize, + field: String, + }, + /// The LIMIT clause must evaluate to a positive integer #[error("Found {value} but the LIMIT clause must evaluate to a positive integer")] InvalidLimit { diff --git a/lib/src/sql/common.rs b/lib/src/sql/common.rs index 72dfcbaf..c700d2b2 100644 --- a/lib/src/sql/common.rs +++ b/lib/src/sql/common.rs @@ -1,6 +1,6 @@ use crate::sql::comment::mightbespace; use crate::sql::comment::shouldbespace; -use crate::sql::error::Error::ParserError; +use crate::sql::error::Error::Parser; use crate::sql::error::IResult; use nom::branch::alt; use nom::bytes::complete::take_while; @@ -53,7 +53,7 @@ pub fn take_u64(i: &str) -> IResult<&str, u64> { let (i, v) = take_while(is_digit)(i)?; match v.parse::() { Ok(v) => Ok((i, v)), - _ => Err(Error(ParserError(i))), + _ => Err(Error(Parser(i))), } } @@ -61,7 +61,7 @@ pub fn take_u32_len(i: &str) -> IResult<&str, (u32, usize)> { let (i, v) = take_while(is_digit)(i)?; match v.parse::() { Ok(n) => Ok((i, (n, v.len()))), - _ => Err(Error(ParserError(i))), + _ => Err(Error(Parser(i))), } } @@ -69,7 +69,7 @@ pub fn take_digits(i: &str, n: usize) -> IResult<&str, u32> { let (i, v) = take_while_m_n(n, n, is_digit)(i)?; match v.parse::() { Ok(v) => Ok((i, v)), - _ => Err(Error(ParserError(i))), + _ => Err(Error(Parser(i))), } } @@ -77,6 +77,6 @@ pub fn take_digits_range(i: &str, n: usize, range: impl RangeBounds) -> IRe let (i, v) = take_while_m_n(n, n, is_digit)(i)?; match v.parse::() { Ok(v) if range.contains(&v) => Ok((i, v)), - _ => Err(Error(ParserError(i))), + _ => Err(Error(Parser(i))), } } diff --git a/lib/src/sql/error.rs b/lib/src/sql/error.rs index 69bb902e..220449e9 100644 --- a/lib/src/sql/error.rs +++ b/lib/src/sql/error.rs @@ -5,14 +5,18 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum Error { - ParserError(I), + Parser(I), + Field(I, String), + Split(I, String), + Order(I, String), + Group(I, String), } pub type IResult> = Result<(I, O), Err>; impl ParseError for Error { fn from_error_kind(input: I, _: ErrorKind) -> Self { - Self::ParserError(input) + Self::Parser(input) } fn append(_: I, _: ErrorKind, other: Self) -> Self { other diff --git a/lib/src/sql/group.rs b/lib/src/sql/group.rs index 2c596b3b..2f7a07e2 100644 --- a/lib/src/sql/group.rs +++ b/lib/src/sql/group.rs @@ -3,6 +3,7 @@ use crate::sql::common::commas; use crate::sql::error::IResult; use crate::sql::fmt::Fmt; use crate::sql::idiom::{basic, Idiom}; +use nom::branch::alt; use nom::bytes::complete::tag_no_case; use nom::combinator::opt; use nom::multi::separated_list1; @@ -52,6 +53,17 @@ impl Display for Group { } pub fn group(i: &str) -> IResult<&str, Groups> { + alt((group_all, group_any))(i) +} + +fn group_all(i: &str) -> IResult<&str, Groups> { + let (i, _) = tag_no_case("GROUP")(i)?; + let (i, _) = shouldbespace(i)?; + let (i, _) = tag_no_case("ALL")(i)?; + Ok((i, Groups(vec![]))) +} + +fn group_any(i: &str) -> IResult<&str, Groups> { let (i, _) = tag_no_case("GROUP")(i)?; let (i, _) = opt(tuple((shouldbespace, tag_no_case("BY"))))(i)?; let (i, _) = shouldbespace(i)?; diff --git a/lib/src/sql/mod.rs b/lib/src/sql/mod.rs index f2a6a076..1ec6d61c 100644 --- a/lib/src/sql/mod.rs +++ b/lib/src/sql/mod.rs @@ -45,6 +45,7 @@ pub(crate) mod query; pub(crate) mod range; pub(crate) mod regex; pub(crate) mod script; +pub(crate) mod special; pub(crate) mod split; pub(crate) mod start; pub(crate) mod statement; diff --git a/lib/src/sql/parser.rs b/lib/src/sql/parser.rs index 3bac10a2..e8020eaf 100644 --- a/lib/src/sql/parser.rs +++ b/lib/src/sql/parser.rs @@ -1,5 +1,5 @@ use crate::err::Error; -use crate::sql::error::Error::ParserError; +use crate::sql::error::Error::{Field, Group, Order, Parser, Split}; use crate::sql::error::IResult; use crate::sql::query::{query, Query}; use crate::sql::thing::Thing; @@ -34,19 +34,39 @@ fn parse_impl(input: &str, parser: impl Fn(&str) -> IResult<&str, O>) -> Resu // There was unparsed SQL remaining Ok((_, _)) => Err(Error::QueryRemaining), // There was an error when parsing the query - Err(Err::Error(e)) | Err(Err::Failure(e)) => match e { + Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(match e { // There was a parsing error - ParserError(e) => { + Parser(e) => { // Locate the parser position let (s, l, c) = locate(input, e); // Return the parser error - Err(Error::InvalidQuery { + Error::InvalidQuery { line: l, char: c, sql: s.to_string(), - }) + } } - }, + // There was a SPLIT ON error + Field(e, f) => Error::InvalidField { + line: locate(input, e).1, + field: f, + }, + // There was a SPLIT ON error + Split(e, f) => Error::InvalidSplit { + line: locate(input, e).1, + field: f, + }, + // There was a ORDER BY error + Order(e, f) => Error::InvalidOrder { + line: locate(input, e).1, + field: f, + }, + // There was a GROUP BY error + Group(e, f) => Error::InvalidGroup { + line: locate(input, e).1, + field: f, + }, + }), _ => unreachable!(), }, } diff --git a/lib/src/sql/special.rs b/lib/src/sql/special.rs new file mode 100644 index 00000000..5e57752b --- /dev/null +++ b/lib/src/sql/special.rs @@ -0,0 +1,156 @@ +use crate::sql::error::Error; +use crate::sql::field::{Field, Fields}; +use crate::sql::group::Groups; +use crate::sql::order::Orders; +use crate::sql::split::Splits; +use crate::sql::value::Value; +use nom::Err; +use nom::Err::Failure; + +pub fn check_split_on_fields<'a>( + i: &'a str, + fields: &Fields, + splits: &Option, +) -> Result<(), Err>> { + // Check to see if a ORDER BY clause has been defined + if let Some(splits) = splits { + // Loop over each of the expressions in the SPLIT ON clause + 'outer: for split in splits.iter() { + // Loop over each of the expressions in the SELECT clause + for field in fields.iter() { + // Check to see whether the expression is in the SELECT clause + match field { + // There is a SELECT * expression, so presume everything is ok + Field::All => break 'outer, + // This field is aliased, so check the alias name + Field::Alias(_, i) if i.as_ref() == split.as_ref() => continue 'outer, + // This field is not aliased, so check the field value + Field::Alone(v) => { + match v { + // If the expression in the SELECT clause is a field, check if it exists in the SPLIT ON clause + Value::Idiom(i) if i.as_ref() == split.as_ref() => continue 'outer, + // Otherwise check if the expression itself exists in the SPLIT ON clause + v if v.to_idiom().as_ref() == split.as_ref() => continue 'outer, + // If not, then this query should fail + _ => (), + } + } + // If not, then this query should fail + _ => (), + } + } + // If the expression isn't specified in the SELECT clause, then error + return Err(Failure(Error::Split(i, split.to_string()))); + } + } + // This query is ok to run + Ok(()) +} + +pub fn check_order_by_fields<'a>( + i: &'a str, + fields: &Fields, + orders: &Option, +) -> Result<(), Err>> { + // Check to see if a ORDER BY clause has been defined + if let Some(orders) = orders { + // Loop over each of the expressions in the ORDER BY clause + 'outer: for order in orders.iter() { + // Loop over each of the expressions in the SELECT clause + for field in fields.iter() { + // Check to see whether the expression is in the SELECT clause + match field { + // There is a SELECT * expression, so presume everything is ok + Field::All => break 'outer, + // This field is aliased, so check the alias name + Field::Alias(_, i) if i.as_ref() == order.as_ref() => continue 'outer, + // This field is not aliased, so check the field value + Field::Alone(v) => { + match v { + // If the expression in the SELECT clause is a field, check if it exists in the ORDER BY clause + Value::Idiom(i) if i.as_ref() == order.as_ref() => continue 'outer, + // Otherwise check if the expression itself exists in the ORDER BY clause + v if v.to_idiom().as_ref() == order.as_ref() => continue 'outer, + // If not, then this query should fail + _ => (), + } + } + // If not, then this query should fail + _ => (), + } + } + // If the expression isn't specified in the SELECT clause, then error + return Err(Failure(Error::Order(i, order.to_string()))); + } + } + // This query is ok to run + Ok(()) +} + +pub fn check_group_by_fields<'a>( + i: &'a str, + fields: &Fields, + groups: &Option, +) -> Result<(), Err>> { + // Check to see if a GROUP BY clause has been defined + if let Some(groups) = groups { + // Loop over each of the expressions in the GROUP BY clause + 'outer: for group in groups.iter() { + // Loop over each of the expressions in the SELECT clause + for field in fields.iter() { + // Check to see whether the expression is in the SELECT clause + match field { + // This field is aliased, so check the alias name + Field::Alias(_, i) if i.as_ref() == group.as_ref() => continue 'outer, + // This field is not aliased, so check the field value + Field::Alone(v) => { + match v { + // If the expression in the SELECT clause is a field, check if it exists in the GROUP BY clause + Value::Idiom(i) if i.as_ref() == group.as_ref() => continue 'outer, + // Otherwise check if the expression itself exists in the GROUP BY clause + v if v.to_idiom().as_ref() == group.as_ref() => continue 'outer, + // If not, then this query should fail + _ => (), + } + } + // If not, then this query should fail + _ => (), + } + } + // If the expression isn't specified in the SELECT clause, then error + return Err(Failure(Error::Group(i, group.to_string()))); + } + // Check if this is a GROUP ALL clause or a GROUP BY clause + if groups.len() > 0 { + // Loop over each of the expressions in the SELECT clause + 'outer: for field in fields.iter() { + // Loop over each of the expressions in the GROUP BY clause + for group in groups.iter() { + // Check to see whether the expression is in the SELECT clause + match field { + // This field is aliased, so check the alias name + Field::Alias(_, i) if i.as_ref() == group.as_ref() => continue 'outer, + // Otherwise, check the type of the field value + Field::Alias(v, _) | Field::Alone(v) => match v { + // If the expression in the SELECT clause is a field, check to see if it exists in the GROUP BY + Value::Idiom(i) if i == &group.0 => continue 'outer, + // If the expression in the SELECT clause is a function, check to see if it is an aggregate function + Value::Function(f) if f.is_aggregate() => continue 'outer, + // Otherwise check if the expression itself exists in the GROUP BY clause + v if v.to_idiom() == group.0 => continue 'outer, + // Check if this is a static value which can be used in the GROUP BY clause + v if v.is_static() => continue 'outer, + // If not, then this query should fail + _ => (), + }, + _ => (), + } + } + // If the expression isn't an aggregate function and isn't specified in the GROUP BY clause, then error + return Err(Failure(Error::Field(i, field.to_string()))); + } + } + } + // This query is ok to run + Ok(()) +} diff --git a/lib/src/sql/statements/select.rs b/lib/src/sql/statements/select.rs index d86e9e96..3c34d5b0 100644 --- a/lib/src/sql/statements/select.rs +++ b/lib/src/sql/statements/select.rs @@ -14,6 +14,9 @@ use crate::sql::field::{fields, Field, Fields}; use crate::sql::group::{group, Groups}; use crate::sql::limit::{limit, Limit}; use crate::sql::order::{order, Orders}; +use crate::sql::special::check_group_by_fields; +use crate::sql::special::check_order_by_fields; +use crate::sql::special::check_split_on_fields; use crate::sql::split::{split, Splits}; use crate::sql::start::{start, Start}; use crate::sql::timeout::{timeout, Timeout}; @@ -170,8 +173,11 @@ pub fn select(i: &str) -> IResult<&str, SelectStatement> { let (i, what) = selects(i)?; let (i, cond) = opt(preceded(shouldbespace, cond))(i)?; let (i, split) = opt(preceded(shouldbespace, split))(i)?; + check_split_on_fields(i, &expr, &split)?; let (i, group) = opt(preceded(shouldbespace, group))(i)?; + check_group_by_fields(i, &expr, &group)?; let (i, order) = opt(preceded(shouldbespace, order))(i)?; + check_order_by_fields(i, &expr, &order)?; let (i, limit) = opt(preceded(shouldbespace, limit))(i)?; let (i, start) = opt(preceded(shouldbespace, start))(i)?; let (i, fetch) = opt(preceded(shouldbespace, fetch))(i)?; diff --git a/lib/src/sql/strand.rs b/lib/src/sql/strand.rs index 7a905987..f5ca1bc4 100644 --- a/lib/src/sql/strand.rs +++ b/lib/src/sql/strand.rs @@ -1,4 +1,4 @@ -use crate::sql::error::Error::ParserError; +use crate::sql::error::Error::Parser; use crate::sql::error::IResult; use crate::sql::escape::escape_str; use crate::sql::serde::is_internal_serialization; @@ -164,14 +164,14 @@ fn strand_unicode(i: &str) -> IResult<&str, char> { // We can convert this to u32 as we only have 6 chars let v = match u32::from_str_radix(v, 16) { // We found an invalid unicode sequence - Err(_) => return Err(Error(ParserError(i))), + Err(_) => return Err(Error(Parser(i))), // The unicode sequence was valid Ok(v) => v, }; // We can convert this to char as we know it is valid let v = match std::char::from_u32(v) { // We found an invalid unicode sequence - None => return Err(Error(ParserError(i))), + None => return Err(Error(Parser(i))), // The unicode sequence was valid Some(v) => v, }; diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index e516b82d..14d60e42 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -1029,7 +1029,8 @@ impl Value { // JSON Path conversion // ----------------------------------- - pub fn jsonpath(&self) -> Idiom { + /// Converts this value to a JSONPatch path + pub(crate) fn jsonpath(&self) -> Idiom { self.to_strand() .as_str() .trim_start_matches('/') @@ -1039,6 +1040,30 @@ impl Value { .into() } + // ----------------------------------- + // JSON Path conversion + // ----------------------------------- + + /// Checkes whether this value is a static value + pub(crate) fn is_static(&self) -> bool { + match self { + Value::None => true, + Value::Null => true, + Value::False => true, + Value::True => true, + Value::Uuid(_) => true, + Value::Number(_) => true, + Value::Strand(_) => true, + Value::Duration(_) => true, + Value::Datetime(_) => true, + Value::Geometry(_) => true, + Value::Array(v) => v.iter().all(Value::is_static), + Value::Object(v) => v.values().all(Value::is_static), + Value::Constant(_) => true, + _ => false, + } + } + // ----------------------------------- // Value operations // ----------------------------------- diff --git a/lib/tests/define.rs b/lib/tests/define.rs index e9df051f..7b7f2565 100644 --- a/lib/tests/define.rs +++ b/lib/tests/define.rs @@ -186,7 +186,7 @@ async fn define_statement_event() -> Result<(), Error> { UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now(); UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now(); UPDATE user:test SET email = 'test@surrealdb.com', updated_at = time::now(); - SELECT count() FROM activity GROUP BY ALL; + SELECT count() FROM activity GROUP ALL; "; let dbs = Datastore::new("memory").await?; let ses = Session::for_kv().with_ns("test").with_db("test"); @@ -243,7 +243,7 @@ async fn define_statement_event_when_event() -> Result<(), Error> { UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now(); UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now(); UPDATE user:test SET email = 'test@surrealdb.com', updated_at = time::now(); - SELECT count() FROM activity GROUP BY ALL; + SELECT count() FROM activity GROUP ALL; "; let dbs = Datastore::new("memory").await?; let ses = Session::for_kv().with_ns("test").with_db("test"); @@ -300,7 +300,7 @@ async fn define_statement_event_when_logic() -> Result<(), Error> { UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now(); UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now(); UPDATE user:test SET email = 'test@surrealdb.com', updated_at = time::now(); - SELECT count() FROM activity GROUP BY ALL; + SELECT count() FROM activity GROUP ALL; "; let dbs = Datastore::new("memory").await?; let ses = Session::for_kv().with_ns("test").with_db("test"); diff --git a/lib/tests/group.rs b/lib/tests/group.rs index 8835bb37..07c56bb6 100644 --- a/lib/tests/group.rs +++ b/lib/tests/group.rs @@ -18,7 +18,7 @@ async fn select_limit_fetch() -> Result<(), Error> { CREATE temperature:8 SET country = 'AUD', time = '2021-01-01T08:00:00Z'; CREATE temperature:9 SET country = 'CHF', time = '2023-01-01T08:00:00Z'; SELECT *, time::year(time) AS year FROM temperature; - SELECT count(), time::year(time) AS year, country FROM temperature GROUP BY country; + SELECT count(), time::year(time) AS year, country FROM temperature GROUP BY country, year; "; let dbs = Datastore::new("memory").await?; let ses = Session::for_kv().with_ns("test").with_db("test"); @@ -213,10 +213,15 @@ async fn select_limit_fetch() -> Result<(), Error> { year: 2021 }, { - count: 5, + count: 3, country: 'GBP', year: 2020 }, + { + count: 2, + country: 'GBP', + year: 2021 + }, { count: 1, country: 'USD', diff --git a/lib/tests/model.rs b/lib/tests/model.rs index 64ae7da6..6d8356e4 100644 --- a/lib/tests/model.rs +++ b/lib/tests/model.rs @@ -9,7 +9,7 @@ use surrealdb::sql::Value; async fn model_count() -> Result<(), Error> { let sql = " CREATE |test:1000| SET time = time::now(); - SELECT count() FROM test GROUP BY ALL; + SELECT count() FROM test GROUP ALL; "; let dbs = Datastore::new("memory").await?; let ses = Session::for_kv().with_ns("test").with_db("test"); @@ -34,7 +34,7 @@ async fn model_count() -> Result<(), Error> { async fn model_range() -> Result<(), Error> { let sql = " CREATE |test:1..1000| SET time = time::now(); - SELECT count() FROM test GROUP BY ALL; + SELECT count() FROM test GROUP ALL; "; let dbs = Datastore::new("memory").await?; let ses = Session::for_kv().with_ns("test").with_db("test");