Check expressions for SPLIT ON, GROUP BY, and ORDER BY clauses

Closes #1229
Closes #1230
Closes #1457
Closes #1233
This commit is contained in:
Tobie Morgan Hitchcock 2023-01-08 17:11:35 +00:00
parent 1162e4a8ce
commit 23be3353be
13 changed files with 277 additions and 24 deletions

View file

@ -86,6 +86,30 @@ pub enum Error {
name: String,
},
#[error("Found '{field}' in SELECT clause on line {line}, but field is not an aggregate function, and is not present in GROUP BY expression")]
InvalidField {
line: usize,
field: String,
},
#[error("Found '{field}' in SPLIT ON clause on line {line}, but field is not present in SELECT expression")]
InvalidSplit {
line: usize,
field: String,
},
#[error("Found '{field}' in ORDER BY clause on line {line}, but field is not present in SELECT expression")]
InvalidOrder {
line: usize,
field: String,
},
#[error("Found '{field}' in GROUP BY clause on line {line}, but field is not present in SELECT expression")]
InvalidGroup {
line: usize,
field: String,
},
/// The LIMIT clause must evaluate to a positive integer
#[error("Found {value} but the LIMIT clause must evaluate to a positive integer")]
InvalidLimit {

View file

@ -1,6 +1,6 @@
use crate::sql::comment::mightbespace;
use crate::sql::comment::shouldbespace;
use crate::sql::error::Error::ParserError;
use crate::sql::error::Error::Parser;
use crate::sql::error::IResult;
use nom::branch::alt;
use nom::bytes::complete::take_while;
@ -53,7 +53,7 @@ pub fn take_u64(i: &str) -> IResult<&str, u64> {
let (i, v) = take_while(is_digit)(i)?;
match v.parse::<u64>() {
Ok(v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
_ => Err(Error(Parser(i))),
}
}
@ -61,7 +61,7 @@ pub fn take_u32_len(i: &str) -> IResult<&str, (u32, usize)> {
let (i, v) = take_while(is_digit)(i)?;
match v.parse::<u32>() {
Ok(n) => Ok((i, (n, v.len()))),
_ => Err(Error(ParserError(i))),
_ => Err(Error(Parser(i))),
}
}
@ -69,7 +69,7 @@ pub fn take_digits(i: &str, n: usize) -> IResult<&str, u32> {
let (i, v) = take_while_m_n(n, n, is_digit)(i)?;
match v.parse::<u32>() {
Ok(v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
_ => Err(Error(Parser(i))),
}
}
@ -77,6 +77,6 @@ pub fn take_digits_range(i: &str, n: usize, range: impl RangeBounds<u32>) -> IRe
let (i, v) = take_while_m_n(n, n, is_digit)(i)?;
match v.parse::<u32>() {
Ok(v) if range.contains(&v) => Ok((i, v)),
_ => Err(Error(ParserError(i))),
_ => Err(Error(Parser(i))),
}
}

View file

@ -5,14 +5,18 @@ use thiserror::Error;
#[derive(Error, Debug)]
pub enum Error<I> {
ParserError(I),
Parser(I),
Field(I, String),
Split(I, String),
Order(I, String),
Group(I, String),
}
pub type IResult<I, O, E = Error<I>> = Result<(I, O), Err<E>>;
impl<I> ParseError<I> for Error<I> {
fn from_error_kind(input: I, _: ErrorKind) -> Self {
Self::ParserError(input)
Self::Parser(input)
}
fn append(_: I, _: ErrorKind, other: Self) -> Self {
other

View file

@ -3,6 +3,7 @@ use crate::sql::common::commas;
use crate::sql::error::IResult;
use crate::sql::fmt::Fmt;
use crate::sql::idiom::{basic, Idiom};
use nom::branch::alt;
use nom::bytes::complete::tag_no_case;
use nom::combinator::opt;
use nom::multi::separated_list1;
@ -52,6 +53,17 @@ impl Display for Group {
}
pub fn group(i: &str) -> IResult<&str, Groups> {
alt((group_all, group_any))(i)
}
fn group_all(i: &str) -> IResult<&str, Groups> {
let (i, _) = tag_no_case("GROUP")(i)?;
let (i, _) = shouldbespace(i)?;
let (i, _) = tag_no_case("ALL")(i)?;
Ok((i, Groups(vec![])))
}
fn group_any(i: &str) -> IResult<&str, Groups> {
let (i, _) = tag_no_case("GROUP")(i)?;
let (i, _) = opt(tuple((shouldbespace, tag_no_case("BY"))))(i)?;
let (i, _) = shouldbespace(i)?;

View file

@ -45,6 +45,7 @@ pub(crate) mod query;
pub(crate) mod range;
pub(crate) mod regex;
pub(crate) mod script;
pub(crate) mod special;
pub(crate) mod split;
pub(crate) mod start;
pub(crate) mod statement;

View file

@ -1,5 +1,5 @@
use crate::err::Error;
use crate::sql::error::Error::ParserError;
use crate::sql::error::Error::{Field, Group, Order, Parser, Split};
use crate::sql::error::IResult;
use crate::sql::query::{query, Query};
use crate::sql::thing::Thing;
@ -34,19 +34,39 @@ fn parse_impl<O>(input: &str, parser: impl Fn(&str) -> IResult<&str, O>) -> Resu
// There was unparsed SQL remaining
Ok((_, _)) => Err(Error::QueryRemaining),
// There was an error when parsing the query
Err(Err::Error(e)) | Err(Err::Failure(e)) => match e {
Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(match e {
// There was a parsing error
ParserError(e) => {
Parser(e) => {
// Locate the parser position
let (s, l, c) = locate(input, e);
// Return the parser error
Err(Error::InvalidQuery {
Error::InvalidQuery {
line: l,
char: c,
sql: s.to_string(),
})
}
}
// There was a SPLIT ON error
Field(e, f) => Error::InvalidField {
line: locate(input, e).1,
field: f,
},
// There was a SPLIT ON error
Split(e, f) => Error::InvalidSplit {
line: locate(input, e).1,
field: f,
},
// There was a ORDER BY error
Order(e, f) => Error::InvalidOrder {
line: locate(input, e).1,
field: f,
},
// There was a GROUP BY error
Group(e, f) => Error::InvalidGroup {
line: locate(input, e).1,
field: f,
},
}),
_ => unreachable!(),
},
}

156
lib/src/sql/special.rs Normal file
View file

@ -0,0 +1,156 @@
use crate::sql::error::Error;
use crate::sql::field::{Field, Fields};
use crate::sql::group::Groups;
use crate::sql::order::Orders;
use crate::sql::split::Splits;
use crate::sql::value::Value;
use nom::Err;
use nom::Err::Failure;
pub fn check_split_on_fields<'a>(
i: &'a str,
fields: &Fields,
splits: &Option<Splits>,
) -> Result<(), Err<Error<&'a str>>> {
// Check to see if a ORDER BY clause has been defined
if let Some(splits) = splits {
// Loop over each of the expressions in the SPLIT ON clause
'outer: for split in splits.iter() {
// Loop over each of the expressions in the SELECT clause
for field in fields.iter() {
// Check to see whether the expression is in the SELECT clause
match field {
// There is a SELECT * expression, so presume everything is ok
Field::All => break 'outer,
// This field is aliased, so check the alias name
Field::Alias(_, i) if i.as_ref() == split.as_ref() => continue 'outer,
// This field is not aliased, so check the field value
Field::Alone(v) => {
match v {
// If the expression in the SELECT clause is a field, check if it exists in the SPLIT ON clause
Value::Idiom(i) if i.as_ref() == split.as_ref() => continue 'outer,
// Otherwise check if the expression itself exists in the SPLIT ON clause
v if v.to_idiom().as_ref() == split.as_ref() => continue 'outer,
// If not, then this query should fail
_ => (),
}
}
// If not, then this query should fail
_ => (),
}
}
// If the expression isn't specified in the SELECT clause, then error
return Err(Failure(Error::Split(i, split.to_string())));
}
}
// This query is ok to run
Ok(())
}
pub fn check_order_by_fields<'a>(
i: &'a str,
fields: &Fields,
orders: &Option<Orders>,
) -> Result<(), Err<Error<&'a str>>> {
// Check to see if a ORDER BY clause has been defined
if let Some(orders) = orders {
// Loop over each of the expressions in the ORDER BY clause
'outer: for order in orders.iter() {
// Loop over each of the expressions in the SELECT clause
for field in fields.iter() {
// Check to see whether the expression is in the SELECT clause
match field {
// There is a SELECT * expression, so presume everything is ok
Field::All => break 'outer,
// This field is aliased, so check the alias name
Field::Alias(_, i) if i.as_ref() == order.as_ref() => continue 'outer,
// This field is not aliased, so check the field value
Field::Alone(v) => {
match v {
// If the expression in the SELECT clause is a field, check if it exists in the ORDER BY clause
Value::Idiom(i) if i.as_ref() == order.as_ref() => continue 'outer,
// Otherwise check if the expression itself exists in the ORDER BY clause
v if v.to_idiom().as_ref() == order.as_ref() => continue 'outer,
// If not, then this query should fail
_ => (),
}
}
// If not, then this query should fail
_ => (),
}
}
// If the expression isn't specified in the SELECT clause, then error
return Err(Failure(Error::Order(i, order.to_string())));
}
}
// This query is ok to run
Ok(())
}
pub fn check_group_by_fields<'a>(
i: &'a str,
fields: &Fields,
groups: &Option<Groups>,
) -> Result<(), Err<Error<&'a str>>> {
// Check to see if a GROUP BY clause has been defined
if let Some(groups) = groups {
// Loop over each of the expressions in the GROUP BY clause
'outer: for group in groups.iter() {
// Loop over each of the expressions in the SELECT clause
for field in fields.iter() {
// Check to see whether the expression is in the SELECT clause
match field {
// This field is aliased, so check the alias name
Field::Alias(_, i) if i.as_ref() == group.as_ref() => continue 'outer,
// This field is not aliased, so check the field value
Field::Alone(v) => {
match v {
// If the expression in the SELECT clause is a field, check if it exists in the GROUP BY clause
Value::Idiom(i) if i.as_ref() == group.as_ref() => continue 'outer,
// Otherwise check if the expression itself exists in the GROUP BY clause
v if v.to_idiom().as_ref() == group.as_ref() => continue 'outer,
// If not, then this query should fail
_ => (),
}
}
// If not, then this query should fail
_ => (),
}
}
// If the expression isn't specified in the SELECT clause, then error
return Err(Failure(Error::Group(i, group.to_string())));
}
// Check if this is a GROUP ALL clause or a GROUP BY clause
if groups.len() > 0 {
// Loop over each of the expressions in the SELECT clause
'outer: for field in fields.iter() {
// Loop over each of the expressions in the GROUP BY clause
for group in groups.iter() {
// Check to see whether the expression is in the SELECT clause
match field {
// This field is aliased, so check the alias name
Field::Alias(_, i) if i.as_ref() == group.as_ref() => continue 'outer,
// Otherwise, check the type of the field value
Field::Alias(v, _) | Field::Alone(v) => match v {
// If the expression in the SELECT clause is a field, check to see if it exists in the GROUP BY
Value::Idiom(i) if i == &group.0 => continue 'outer,
// If the expression in the SELECT clause is a function, check to see if it is an aggregate function
Value::Function(f) if f.is_aggregate() => continue 'outer,
// Otherwise check if the expression itself exists in the GROUP BY clause
v if v.to_idiom() == group.0 => continue 'outer,
// Check if this is a static value which can be used in the GROUP BY clause
v if v.is_static() => continue 'outer,
// If not, then this query should fail
_ => (),
},
_ => (),
}
}
// If the expression isn't an aggregate function and isn't specified in the GROUP BY clause, then error
return Err(Failure(Error::Field(i, field.to_string())));
}
}
}
// This query is ok to run
Ok(())
}

View file

@ -14,6 +14,9 @@ use crate::sql::field::{fields, Field, Fields};
use crate::sql::group::{group, Groups};
use crate::sql::limit::{limit, Limit};
use crate::sql::order::{order, Orders};
use crate::sql::special::check_group_by_fields;
use crate::sql::special::check_order_by_fields;
use crate::sql::special::check_split_on_fields;
use crate::sql::split::{split, Splits};
use crate::sql::start::{start, Start};
use crate::sql::timeout::{timeout, Timeout};
@ -170,8 +173,11 @@ pub fn select(i: &str) -> IResult<&str, SelectStatement> {
let (i, what) = selects(i)?;
let (i, cond) = opt(preceded(shouldbespace, cond))(i)?;
let (i, split) = opt(preceded(shouldbespace, split))(i)?;
check_split_on_fields(i, &expr, &split)?;
let (i, group) = opt(preceded(shouldbespace, group))(i)?;
check_group_by_fields(i, &expr, &group)?;
let (i, order) = opt(preceded(shouldbespace, order))(i)?;
check_order_by_fields(i, &expr, &order)?;
let (i, limit) = opt(preceded(shouldbespace, limit))(i)?;
let (i, start) = opt(preceded(shouldbespace, start))(i)?;
let (i, fetch) = opt(preceded(shouldbespace, fetch))(i)?;

View file

@ -1,4 +1,4 @@
use crate::sql::error::Error::ParserError;
use crate::sql::error::Error::Parser;
use crate::sql::error::IResult;
use crate::sql::escape::escape_str;
use crate::sql::serde::is_internal_serialization;
@ -164,14 +164,14 @@ fn strand_unicode(i: &str) -> IResult<&str, char> {
// We can convert this to u32 as we only have 6 chars
let v = match u32::from_str_radix(v, 16) {
// We found an invalid unicode sequence
Err(_) => return Err(Error(ParserError(i))),
Err(_) => return Err(Error(Parser(i))),
// The unicode sequence was valid
Ok(v) => v,
};
// We can convert this to char as we know it is valid
let v = match std::char::from_u32(v) {
// We found an invalid unicode sequence
None => return Err(Error(ParserError(i))),
None => return Err(Error(Parser(i))),
// The unicode sequence was valid
Some(v) => v,
};

View file

@ -1029,7 +1029,8 @@ impl Value {
// JSON Path conversion
// -----------------------------------
pub fn jsonpath(&self) -> Idiom {
/// Converts this value to a JSONPatch path
pub(crate) fn jsonpath(&self) -> Idiom {
self.to_strand()
.as_str()
.trim_start_matches('/')
@ -1039,6 +1040,30 @@ impl Value {
.into()
}
// -----------------------------------
// JSON Path conversion
// -----------------------------------
/// Checkes whether this value is a static value
pub(crate) fn is_static(&self) -> bool {
match self {
Value::None => true,
Value::Null => true,
Value::False => true,
Value::True => true,
Value::Uuid(_) => true,
Value::Number(_) => true,
Value::Strand(_) => true,
Value::Duration(_) => true,
Value::Datetime(_) => true,
Value::Geometry(_) => true,
Value::Array(v) => v.iter().all(Value::is_static),
Value::Object(v) => v.values().all(Value::is_static),
Value::Constant(_) => true,
_ => false,
}
}
// -----------------------------------
// Value operations
// -----------------------------------

View file

@ -186,7 +186,7 @@ async fn define_statement_event() -> Result<(), Error> {
UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now();
UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now();
UPDATE user:test SET email = 'test@surrealdb.com', updated_at = time::now();
SELECT count() FROM activity GROUP BY ALL;
SELECT count() FROM activity GROUP ALL;
";
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
@ -243,7 +243,7 @@ async fn define_statement_event_when_event() -> Result<(), Error> {
UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now();
UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now();
UPDATE user:test SET email = 'test@surrealdb.com', updated_at = time::now();
SELECT count() FROM activity GROUP BY ALL;
SELECT count() FROM activity GROUP ALL;
";
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
@ -300,7 +300,7 @@ async fn define_statement_event_when_logic() -> Result<(), Error> {
UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now();
UPDATE user:test SET email = 'info@surrealdb.com', updated_at = time::now();
UPDATE user:test SET email = 'test@surrealdb.com', updated_at = time::now();
SELECT count() FROM activity GROUP BY ALL;
SELECT count() FROM activity GROUP ALL;
";
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");

View file

@ -18,7 +18,7 @@ async fn select_limit_fetch() -> Result<(), Error> {
CREATE temperature:8 SET country = 'AUD', time = '2021-01-01T08:00:00Z';
CREATE temperature:9 SET country = 'CHF', time = '2023-01-01T08:00:00Z';
SELECT *, time::year(time) AS year FROM temperature;
SELECT count(), time::year(time) AS year, country FROM temperature GROUP BY country;
SELECT count(), time::year(time) AS year, country FROM temperature GROUP BY country, year;
";
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
@ -213,10 +213,15 @@ async fn select_limit_fetch() -> Result<(), Error> {
year: 2021
},
{
count: 5,
count: 3,
country: 'GBP',
year: 2020
},
{
count: 2,
country: 'GBP',
year: 2021
},
{
count: 1,
country: 'USD',

View file

@ -9,7 +9,7 @@ use surrealdb::sql::Value;
async fn model_count() -> Result<(), Error> {
let sql = "
CREATE |test:1000| SET time = time::now();
SELECT count() FROM test GROUP BY ALL;
SELECT count() FROM test GROUP ALL;
";
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
@ -34,7 +34,7 @@ async fn model_count() -> Result<(), Error> {
async fn model_range() -> Result<(), Error> {
let sql = "
CREATE |test:1..1000| SET time = time::now();
SELECT count() FROM test GROUP BY ALL;
SELECT count() FROM test GROUP ALL;
";
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");