diff --git a/lib/src/dbs/iterator.rs b/lib/src/dbs/iterator.rs index ee82c51e..b8334832 100644 --- a/lib/src/dbs/iterator.rs +++ b/lib/src/dbs/iterator.rs @@ -7,6 +7,8 @@ use crate::dbs::Statement; use crate::dbs::Transaction; use crate::doc::Document; use crate::err::Error; +use crate::sql::array::Array; +use crate::sql::field::Field; use crate::sql::id::Id; use crate::sql::part::Part; use crate::sql::statements::create::CreateStatement; @@ -19,6 +21,7 @@ use crate::sql::table::Table; use crate::sql::thing::Thing; use crate::sql::value::Value; use std::cmp::Ordering; +use std::collections::BTreeMap; use std::mem; use std::sync::Arc; @@ -155,7 +158,8 @@ impl Iterator { txn: &Transaction, ) -> Result<(), Error> { if let Some(splits) = self.stm.split() { - for split in &splits.0 { + // Loop over each split clause + for split in splits.iter() { // Get the query result let res = mem::take(&mut self.results); // Loop over each value @@ -192,12 +196,86 @@ impl Iterator { #[inline] async fn output_group( &mut self, - _ctx: &Runtime, - _opt: &Options, - _txn: &Transaction, + ctx: &Runtime, + opt: &Options, + txn: &Transaction, ) -> Result<(), Error> { - if self.stm.group().is_some() { - // Ignore + if let Some(fields) = self.stm.expr() { + if let Some(groups) = self.stm.group() { + // Create the new grouped collection + let mut grp: BTreeMap = BTreeMap::new(); + // Get the query result + let res = mem::take(&mut self.results); + // Loop over each value + for obj in res { + // Create a new column set + let mut arr = Array::with_capacity(groups.len()); + // Loop over each group clause + for group in groups.iter() { + // Get the value at the path + let val = obj.pick(&group.group); + // Set the value at the path + arr.value.push(val); + } + // Add to grouped collection + match grp.get_mut(&arr) { + Some(v) => v.value.push(obj), + None => { + grp.insert(arr, Array::from(obj)); + } + } + } + // Loop over each grouped collection + for (_, vals) in grp { + // Create a new value + let mut obj = Value::base(); + // Save the collected values + let vals = Value::from(vals); + // Loop over each group clause + for field in fields.other() { + // Process it if it is a normal field + if let Field::Alone(v) = field { + match v { + Value::Function(f) if f.is_aggregate() => { + let x = vals + .all(ctx, opt, txn) + .await? + .get(ctx, opt, txn, v.to_idiom().as_ref()) + .await?; + let x = f.aggregate(x).compute(ctx, opt, txn, None).await?; + obj.set(ctx, opt, txn, v.to_idiom().as_ref(), x).await?; + } + _ => { + let x = vals.first(ctx, opt, txn).await?; + let x = v.compute(ctx, opt, txn, Some(&x)).await?; + obj.set(ctx, opt, txn, v.to_idiom().as_ref(), x).await?; + } + } + } + // Process it if it is a aliased field + if let Field::Alias(v, i) = field { + match v { + Value::Function(f) if f.is_aggregate() => { + let x = vals + .all(ctx, opt, txn) + .await? + .get(ctx, opt, txn, i) + .await?; + let x = f.aggregate(x).compute(ctx, opt, txn, None).await?; + obj.set(ctx, opt, txn, i, x).await?; + } + _ => { + let x = vals.first(ctx, opt, txn).await?; + let x = v.compute(ctx, opt, txn, Some(&x)).await?; + obj.set(ctx, opt, txn, i, x).await?; + } + } + } + } + // Add the object to the results + self.results.push(obj); + } + } } Ok(()) } @@ -210,8 +288,11 @@ impl Iterator { _txn: &Transaction, ) -> Result<(), Error> { if let Some(orders) = self.stm.order() { + // Sort the full result set self.results.sort_by(|a, b| { - for order in &orders.0 { + // Loop over each order clause + for order in orders.iter() { + // Reverse the ordering if DESC let o = match order.direction { true => a.compare(b, &order.order), false => b.compare(a, &order.order), diff --git a/lib/src/dbs/statement.rs b/lib/src/dbs/statement.rs index d5d587bd..c95be73f 100644 --- a/lib/src/dbs/statement.rs +++ b/lib/src/dbs/statement.rs @@ -1,5 +1,6 @@ use crate::sql::cond::Cond; use crate::sql::fetch::Fetchs; +use crate::sql::field::Fields; use crate::sql::group::Groups; use crate::sql::limit::Limit; use crate::sql::order::Orders; @@ -83,6 +84,13 @@ impl fmt::Display for Statement { } impl Statement { + // Returns any query fields if specified + pub fn expr(self: &Statement) -> Option<&Fields> { + match self { + Statement::Select(v) => Some(&v.expr), + _ => None, + } + } // Returns any SPLIT clause if specified pub fn conds(self: &Statement) -> Option<&Cond> { match self { diff --git a/lib/src/doc/pluck.rs b/lib/src/doc/pluck.rs index 126b7eec..1372c516 100644 --- a/lib/src/doc/pluck.rs +++ b/lib/src/doc/pluck.rs @@ -67,14 +67,40 @@ impl<'a> Document<'a> { for v in stm.expr.other() { match v { Field::All => (), - Field::Alone(v) => { - let x = v.compute(ctx, opt, txn, Some(&self.current)).await?; - out.set(ctx, opt, txn, v.to_idiom().as_ref(), x).await?; - } - Field::Alias(v, i) => { - let x = v.compute(ctx, opt, txn, Some(&self.current)).await?; - out.set(ctx, opt, txn, i, x).await?; - } + Field::Alone(v) => match v { + Value::Function(f) if stm.group.is_some() && f.is_aggregate() => { + let x = match f.args().len() { + 0 => f.compute(ctx, opt, txn, Some(&self.current)).await?, + _ => { + f.args()[0] + .compute(ctx, opt, txn, Some(&self.current)) + .await? + } + }; + out.set(ctx, opt, txn, v.to_idiom().as_ref(), x).await?; + } + _ => { + let x = v.compute(ctx, opt, txn, Some(&self.current)).await?; + out.set(ctx, opt, txn, v.to_idiom().as_ref(), x).await?; + } + }, + Field::Alias(v, i) => match v { + Value::Function(f) if stm.group.is_some() && f.is_aggregate() => { + let x = match f.args().len() { + 0 => f.compute(ctx, opt, txn, Some(&self.current)).await?, + _ => { + f.args()[0] + .compute(ctx, opt, txn, Some(&self.current)) + .await? + } + }; + out.set(ctx, opt, txn, i, x).await?; + } + _ => { + let x = v.compute(ctx, opt, txn, Some(&self.current)).await?; + out.set(ctx, opt, txn, i, x).await?; + } + }, } } Ok(out) diff --git a/lib/src/sql/array.rs b/lib/src/sql/array.rs index c8d4f955..0117e302 100644 --- a/lib/src/sql/array.rs +++ b/lib/src/sql/array.rs @@ -17,11 +17,19 @@ use serde::{Deserialize, Serialize}; use std::fmt; use std::ops; -#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize)] +#[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Deserialize)] pub struct Array { pub value: Vec, } +impl From for Array { + fn from(v: Value) -> Self { + Array { + value: vec![v], + } + } +} + impl From> for Array { fn from(v: Vec) -> Self { Array { @@ -71,6 +79,18 @@ impl From> for Array { } impl Array { + pub fn new() -> Self { + Array { + value: Vec::default(), + } + } + + pub fn with_capacity(len: usize) -> Self { + Array { + value: Vec::with_capacity(len), + } + } + pub fn len(&self) -> usize { self.value.len() } diff --git a/lib/src/sql/field.rs b/lib/src/sql/field.rs index c3ff20b8..ffc9d9ac 100644 --- a/lib/src/sql/field.rs +++ b/lib/src/sql/field.rs @@ -8,6 +8,7 @@ use nom::bytes::complete::tag_no_case; use nom::multi::separated_list1; use serde::{Deserialize, Serialize}; use std::fmt; +use std::ops::Deref; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize)] pub struct Fields(pub Vec); @@ -32,6 +33,21 @@ impl Fields { } } +impl Deref for Fields { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl IntoIterator for Fields { + type Item = Field; + type IntoIter = std::vec::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + impl fmt::Display for Fields { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.0.iter().map(|ref v| format!("{}", v)).collect::>().join(", ")) diff --git a/lib/src/sql/function.rs b/lib/src/sql/function.rs index cb80a473..5c70accb 100644 --- a/lib/src/sql/function.rs +++ b/lib/src/sql/function.rs @@ -31,6 +31,76 @@ impl PartialOrd for Function { } } +impl Function { + // Get function arguments if applicable + pub fn args(&self) -> &[Value] { + match self { + Function::Normal(_, a) => a, + _ => &[], + } + } + // Convert this function to an aggregate + pub fn aggregate(&self, val: Value) -> Function { + match self { + Function::Normal(n, a) => { + let mut a = a.to_owned(); + match a.len() { + 0 => a.insert(0, val), + _ => { + a.remove(0); + a.insert(0, val); + } + } + Function::Normal(n.to_owned(), a) + } + _ => unreachable!(), + } + } + // Check if this function is a rolling function + pub fn is_rolling(&self) -> bool { + match self { + Function::Normal(f, _) if f == "array::concat" => true, + Function::Normal(f, _) if f == "array::distinct" => true, + Function::Normal(f, _) if f == "array::union" => true, + Function::Normal(f, _) if f == "count" => true, + Function::Normal(f, _) if f == "math::max" => true, + Function::Normal(f, _) if f == "math::mean" => true, + Function::Normal(f, _) if f == "math::min" => true, + Function::Normal(f, _) if f == "math::stddev" => true, + Function::Normal(f, _) if f == "math::sum" => true, + Function::Normal(f, _) if f == "math::variance" => true, + _ => false, + } + } + // Check if this function is a grouping function + pub fn is_aggregate(&self) -> bool { + match self { + Function::Normal(f, _) if f == "array::concat" => true, + Function::Normal(f, _) if f == "array::distinct" => true, + Function::Normal(f, _) if f == "array::union" => true, + Function::Normal(f, _) if f == "count" => true, + Function::Normal(f, _) if f == "math::bottom" => true, + Function::Normal(f, _) if f == "math::interquartile" => true, + Function::Normal(f, _) if f == "math::max" => true, + Function::Normal(f, _) if f == "math::mean" => true, + Function::Normal(f, _) if f == "math::median" => true, + Function::Normal(f, _) if f == "math::midhinge" => true, + Function::Normal(f, _) if f == "math::min" => true, + Function::Normal(f, _) if f == "math::mode" => true, + Function::Normal(f, _) if f == "math::nearestrank" => true, + Function::Normal(f, _) if f == "math::percentile" => true, + Function::Normal(f, _) if f == "math::sample" => true, + Function::Normal(f, _) if f == "math::spread" => true, + Function::Normal(f, _) if f == "math::stddev" => true, + Function::Normal(f, _) if f == "math::sum" => true, + Function::Normal(f, _) if f == "math::top" => true, + Function::Normal(f, _) if f == "math::trimean" => true, + Function::Normal(f, _) if f == "math::variance" => true, + _ => false, + } + } +} + impl Function { pub async fn compute( &self, diff --git a/lib/src/sql/group.rs b/lib/src/sql/group.rs index 6a797cdb..7b6b0c47 100644 --- a/lib/src/sql/group.rs +++ b/lib/src/sql/group.rs @@ -8,10 +8,32 @@ use nom::multi::separated_list1; use nom::sequence::tuple; use serde::{Deserialize, Serialize}; use std::fmt; +use std::ops::Deref; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize)] pub struct Groups(pub Vec); +impl Groups { + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl Deref for Groups { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl IntoIterator for Groups { + type Item = Group; + type IntoIter = std::vec::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + impl fmt::Display for Groups { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( diff --git a/lib/src/sql/object.rs b/lib/src/sql/object.rs index c4334b04..d1153cb6 100644 --- a/lib/src/sql/object.rs +++ b/lib/src/sql/object.rs @@ -21,7 +21,7 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::fmt; -#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Deserialize)] +#[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Deserialize)] pub struct Object { pub value: BTreeMap, } diff --git a/lib/src/sql/order.rs b/lib/src/sql/order.rs index a5d6e6cb..84dbb90c 100644 --- a/lib/src/sql/order.rs +++ b/lib/src/sql/order.rs @@ -9,10 +9,32 @@ use nom::multi::separated_list1; use nom::sequence::tuple; use serde::{Deserialize, Serialize}; use std::fmt; +use std::ops::Deref; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize)] pub struct Orders(pub Vec); +impl Orders { + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl Deref for Orders { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl IntoIterator for Orders { + type Item = Order; + type IntoIter = std::vec::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + impl fmt::Display for Orders { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( diff --git a/lib/src/sql/split.rs b/lib/src/sql/split.rs index 69486ade..72e36481 100644 --- a/lib/src/sql/split.rs +++ b/lib/src/sql/split.rs @@ -8,10 +8,32 @@ use nom::multi::separated_list1; use nom::sequence::tuple; use serde::{Deserialize, Serialize}; use std::fmt; +use std::ops::Deref; #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize)] pub struct Splits(pub Vec); +impl Splits { + pub fn len(&self) -> usize { + self.0.len() + } +} + +impl Deref for Splits { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl IntoIterator for Splits { + type Item = Split; + type IntoIter = std::vec::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + impl fmt::Display for Splits { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( diff --git a/lib/src/sql/value/all.rs b/lib/src/sql/value/all.rs new file mode 100644 index 00000000..54145442 --- /dev/null +++ b/lib/src/sql/value/all.rs @@ -0,0 +1,17 @@ +use crate::dbs::Options; +use crate::dbs::Runtime; +use crate::dbs::Transaction; +use crate::err::Error; +use crate::sql::part::Part; +use crate::sql::value::Value; + +impl Value { + pub async fn all( + &self, + ctx: &Runtime, + opt: &Options, + txn: &Transaction, + ) -> Result { + self.get(ctx, opt, txn, &[Part::All]).await + } +} diff --git a/lib/src/sql/value/mod.rs b/lib/src/sql/value/mod.rs index bfd6f56d..dffd7f7f 100644 --- a/lib/src/sql/value/mod.rs +++ b/lib/src/sql/value/mod.rs @@ -1,5 +1,6 @@ pub use self::value::*; +mod all; mod array; mod clear; mod compare; diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 157ebf2b..0f4d4aff 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -36,10 +36,12 @@ use nom::combinator::map; use nom::multi::separated_list1; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; use std::collections::BTreeMap; use std::collections::HashMap; use std::fmt; use std::ops; +use std::ops::Deref; use std::str::FromStr; static MATCHER: Lazy = Lazy::new(|| SkimMatcherV2::default().ignore_case()); @@ -47,9 +49,10 @@ static MATCHER: Lazy = Lazy::new(|| SkimMatcherV2::default().igno #[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize)] pub struct Values(pub Vec); -impl fmt::Display for Values { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0.iter().map(|ref v| format!("{}", v)).collect::>().join(", ")) +impl Deref for Values { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 } } @@ -61,6 +64,12 @@ impl IntoIterator for Values { } } +impl fmt::Display for Values { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0.iter().map(|ref v| format!("{}", v)).collect::>().join(", ")) + } +} + pub fn values(i: &str) -> IResult<&str, Values> { let (i, v) = separated_list1(commas, value)(i)?; Ok((i, Values(v))) @@ -104,6 +113,12 @@ pub enum Value { impl Eq for Value {} +impl Ord for Value { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap_or(Ordering::Equal) + } +} + impl Default for Value { fn default() -> Value { Value::None