From 0a4b810fbddc0ca485764972093ab367ea747677 Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Mon, 17 Oct 2022 02:57:53 +0100 Subject: [PATCH] Allow parameters in LIMIT and START clauses in SQL SELECT statements Closes #1332 Closes #116 --- lib/src/dbs/iterator.rs | 56 ++++++++++++++++++++++++++------ lib/src/err/mod.rs | 12 +++++++ lib/src/sql/common.rs | 8 ----- lib/src/sql/limit.rs | 32 +++++++++++++++--- lib/src/sql/number.rs | 16 +++++++++ lib/src/sql/start.rs | 34 +++++++++++++++---- lib/src/sql/statements/select.rs | 23 ++++++------- lib/src/sql/subquery.rs | 2 +- lib/src/sql/value/value.rs | 27 +++++++++++++++ 9 files changed, 167 insertions(+), 43 deletions(-) diff --git a/lib/src/dbs/iterator.rs b/lib/src/dbs/iterator.rs index c884f971..f9c62d4f 100644 --- a/lib/src/dbs/iterator.rs +++ b/lib/src/dbs/iterator.rs @@ -44,6 +44,10 @@ pub enum Workable { pub struct Iterator { // Iterator status run: Canceller, + // Iterator limit value + limit: Option, + // Iterator start value + start: Option, // Iterator runtime error error: Option, // Iterator output results @@ -76,6 +80,10 @@ impl Iterator { // Enable context override let mut ctx = Context::new(ctx); self.run = ctx.add_cancel(); + // Process the query LIMIT clause + self.setup_limit(&ctx, opt, txn, stm).await?; + // Process the query START clause + self.setup_start(&ctx, opt, txn, stm).await?; // Process prepared values self.iterate(&ctx, opt, txn, stm).await?; // Return any document errors @@ -98,6 +106,34 @@ impl Iterator { Ok(mem::take(&mut self.results).into()) } + #[inline] + async fn setup_limit( + &mut self, + ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + stm: &Statement<'_>, + ) -> Result<(), Error> { + if let Some(v) = stm.limit() { + self.limit = Some(v.process(ctx, opt, txn, None).await?); + } + Ok(()) + } + + #[inline] + async fn setup_start( + &mut self, + ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + stm: &Statement<'_>, + ) -> Result<(), Error> { + if let Some(v) = stm.start() { + self.start = Some(v.process(ctx, opt, txn, None).await?); + } + Ok(()) + } + #[inline] async fn output_split( &mut self, @@ -273,10 +309,10 @@ impl Iterator { _ctx: &Context<'_>, _opt: &Options, _txn: &Transaction, - stm: &Statement<'_>, + _stm: &Statement<'_>, ) -> Result<(), Error> { - if let Some(v) = stm.start() { - self.results = mem::take(&mut self.results).into_iter().skip(v.0).collect(); + if let Some(v) = self.start { + self.results = mem::take(&mut self.results).into_iter().skip(v).collect(); } Ok(()) } @@ -287,10 +323,10 @@ impl Iterator { _ctx: &Context<'_>, _opt: &Options, _txn: &Transaction, - stm: &Statement<'_>, + _stm: &Statement<'_>, ) -> Result<(), Error> { - if let Some(v) = stm.limit() { - self.results = mem::take(&mut self.results).into_iter().take(v.0).collect(); + if let Some(v) = self.limit { + self.results = mem::take(&mut self.results).into_iter().take(v).collect(); } Ok(()) } @@ -464,12 +500,12 @@ impl Iterator { } // Check if we can exit if stm.group().is_none() && stm.order().is_none() { - if let Some(l) = stm.limit() { - if let Some(s) = stm.start() { - if self.results.len() == l.0 + s.0 { + if let Some(l) = self.limit { + if let Some(s) = self.start { + if self.results.len() == l + s { self.run.cancel() } - } else if self.results.len() == l.0 { + } else if self.results.len() == l { self.run.cancel() } } diff --git a/lib/src/err/mod.rs b/lib/src/err/mod.rs index 7e1faa7e..621868d0 100644 --- a/lib/src/err/mod.rs +++ b/lib/src/err/mod.rs @@ -75,6 +75,18 @@ pub enum Error { #[error("Remote HTTP request functions are not enabled")] HttpDisabled, + /// The LIMIT clause must evaluate to a positive integer + #[error("Found {value} but the LIMIT clause must evaluate to a positive integer")] + InvalidLimit { + value: String, + }, + + /// The START clause must evaluate to a positive integer + #[error("Found {value} but the START clause must evaluate to a positive integer")] + InvalidStart { + value: String, + }, + /// There was an error with the provided JavaScript code #[error("Problem with embedded script function. {message}")] InvalidScript { diff --git a/lib/src/sql/common.rs b/lib/src/sql/common.rs index fb496216..72dfcbaf 100644 --- a/lib/src/sql/common.rs +++ b/lib/src/sql/common.rs @@ -57,14 +57,6 @@ pub fn take_u64(i: &str) -> IResult<&str, u64> { } } -pub fn take_usize(i: &str) -> IResult<&str, usize> { - let (i, v) = take_while(is_digit)(i)?; - match v.parse::() { - Ok(v) => Ok((i, v)), - _ => Err(Error(ParserError(i))), - } -} - pub fn take_u32_len(i: &str) -> IResult<&str, (u32, usize)> { let (i, v) = take_while(is_digit)(i)?; match v.parse::() { diff --git a/lib/src/sql/limit.rs b/lib/src/sql/limit.rs index 39089faa..50d3e52c 100644 --- a/lib/src/sql/limit.rs +++ b/lib/src/sql/limit.rs @@ -1,6 +1,10 @@ +use crate::ctx::Context; +use crate::dbs::Options; +use crate::dbs::Transaction; +use crate::err::Error; use crate::sql::comment::shouldbespace; -use crate::sql::common::take_usize; use crate::sql::error::IResult; +use crate::sql::value::{value, Value}; use nom::bytes::complete::tag_no_case; use nom::combinator::opt; use nom::sequence::tuple; @@ -8,7 +12,25 @@ use serde::{Deserialize, Serialize}; use std::fmt; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize)] -pub struct Limit(pub usize); +pub struct Limit(pub Value); + +impl Limit { + pub(crate) async fn process( + &self, + ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + doc: Option<&Value>, + ) -> Result { + match self.0.compute(ctx, opt, txn, doc).await { + Ok(v) if v.is_integer() && v.is_positive() => Ok(v.as_usize()), + Ok(v) => Err(Error::InvalidLimit { + value: v.as_string(), + }), + Err(e) => Err(e), + } + } +} impl fmt::Display for Limit { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -20,7 +42,7 @@ pub fn limit(i: &str) -> IResult<&str, Limit> { let (i, _) = tag_no_case("LIMIT")(i)?; let (i, _) = opt(tuple((shouldbespace, tag_no_case("BY"))))(i)?; let (i, _) = shouldbespace(i)?; - let (i, v) = take_usize(i)?; + let (i, v) = value(i)?; Ok((i, Limit(v))) } @@ -35,7 +57,7 @@ mod tests { let res = limit(sql); assert!(res.is_ok()); let out = res.unwrap().1; - assert_eq!(out, Limit(100)); + assert_eq!(out, Limit(Value::from(100))); assert_eq!("LIMIT 100", format!("{}", out)); } @@ -45,7 +67,7 @@ mod tests { let res = limit(sql); assert!(res.is_ok()); let out = res.unwrap().1; - assert_eq!(out, Limit(100)); + assert_eq!(out, Limit(Value::from(100))); assert_eq!("LIMIT 100", format!("{}", out)); } } diff --git a/lib/src/sql/number.rs b/lib/src/sql/number.rs index 6607f9fa..736a2c7d 100644 --- a/lib/src/sql/number.rs +++ b/lib/src/sql/number.rs @@ -173,6 +173,14 @@ impl Number { matches!(self, Number::Decimal(_)) } + pub fn is_integer(&self) -> bool { + match self { + Number::Int(_) => true, + Number::Float(_) => false, + Number::Decimal(v) => v.is_integer(), + } + } + pub fn is_truthy(&self) -> bool { match self { Number::Int(v) => v != &0, @@ -181,6 +189,14 @@ impl Number { } } + pub fn is_positive(&self) -> bool { + match self { + Number::Int(v) => v > &0, + Number::Float(v) => v > &0.0, + Number::Decimal(v) => v > &BigDecimal::from(0), + } + } + // ----------------------------------- // Simple conversion of number // ----------------------------------- diff --git a/lib/src/sql/start.rs b/lib/src/sql/start.rs index b0dbb9d5..7735f202 100644 --- a/lib/src/sql/start.rs +++ b/lib/src/sql/start.rs @@ -1,14 +1,36 @@ +use crate::ctx::Context; +use crate::dbs::Options; +use crate::dbs::Transaction; +use crate::err::Error; use crate::sql::comment::shouldbespace; -use crate::sql::common::take_usize; use crate::sql::error::IResult; +use crate::sql::value::{value, Value}; use nom::bytes::complete::tag_no_case; use nom::combinator::opt; use nom::sequence::tuple; use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)] -pub struct Start(pub usize); +#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize)] +pub struct Start(pub Value); + +impl Start { + pub(crate) async fn process( + &self, + ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + doc: Option<&Value>, + ) -> Result { + match self.0.compute(ctx, opt, txn, doc).await { + Ok(v) if v.is_integer() && v.is_positive() => Ok(v.as_usize()), + Ok(v) => Err(Error::InvalidStart { + value: v.as_string(), + }), + Err(e) => Err(e), + } + } +} impl fmt::Display for Start { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -20,7 +42,7 @@ pub fn start(i: &str) -> IResult<&str, Start> { let (i, _) = tag_no_case("START")(i)?; let (i, _) = opt(tuple((shouldbespace, tag_no_case("AT"))))(i)?; let (i, _) = shouldbespace(i)?; - let (i, v) = take_usize(i)?; + let (i, v) = value(i)?; Ok((i, Start(v))) } @@ -35,7 +57,7 @@ mod tests { let res = start(sql); assert!(res.is_ok()); let out = res.unwrap().1; - assert_eq!(out, Start(100)); + assert_eq!(out, Start(Value::from(100))); assert_eq!("START 100", format!("{}", out)); } @@ -45,7 +67,7 @@ mod tests { let res = start(sql); assert!(res.is_ok()); let out = res.unwrap().1; - assert_eq!(out, Start(100)); + assert_eq!(out, Start(Value::from(100))); assert_eq!("START 100", format!("{}", out)); } } diff --git a/lib/src/sql/statements/select.rs b/lib/src/sql/statements/select.rs index 02be4b27..1ae713d3 100644 --- a/lib/src/sql/statements/select.rs +++ b/lib/src/sql/statements/select.rs @@ -43,19 +43,16 @@ pub struct SelectStatement { } impl SelectStatement { - /// Return the statement limit number or 0 if not set - pub fn limit(&self) -> usize { - match self.limit { - Some(Limit(v)) => v, - None => 0, - } - } - - /// Return the statement start number or 0 if not set - pub fn start(&self) -> usize { - match self.start { - Some(Start(v)) => v, - None => 0, + pub(crate) async fn limit( + &self, + ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + doc: Option<&Value>, + ) -> Result { + match &self.limit { + Some(v) => v.process(ctx, opt, txn, doc).await, + None => Ok(0), } } diff --git a/lib/src/sql/subquery.rs b/lib/src/sql/subquery.rs index f5fa2ec6..3572c867 100644 --- a/lib/src/sql/subquery.rs +++ b/lib/src/sql/subquery.rs @@ -76,7 +76,7 @@ impl Subquery { // Process subquery let res = v.compute(&ctx, opt, txn, doc).await?; // Process result - match v.limit() { + match v.limit(&ctx, opt, txn, doc).await? { 1 => match v.expr.single() { Some(v) => res.first().get(&ctx, opt, txn, &v).await, None => res.first().ok(), diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 09937488..134f7131 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -571,6 +571,26 @@ impl Value { matches!(self, Value::Object(_)) } + pub fn is_int(&self) -> bool { + matches!(self, Value::Number(Number::Int(_))) + } + + pub fn is_float(&self) -> bool { + matches!(self, Value::Number(Number::Float(_))) + } + + pub fn is_decimal(&self) -> bool { + matches!(self, Value::Number(Number::Decimal(_))) + } + + pub fn is_integer(&self) -> bool { + matches!(self, Value::Number(v) if v.is_integer()) + } + + pub fn is_positive(&self) -> bool { + matches!(self, Value::Number(v) if v.is_positive()) + } + pub fn is_type_record(&self, types: &[Table]) -> bool { match self { Value::Thing(v) => types.iter().any(|tb| tb.0 == v.tb), @@ -685,6 +705,13 @@ impl Value { } } + pub fn as_usize(self) -> usize { + match self { + Value::Number(v) => v.as_usize(), + _ => 0, + } + } + // ----------------------------------- // Expensive conversion of value // -----------------------------------