From 56b4f7d71e993d73b87d5cf0e4eaec6806cf025a Mon Sep 17 00:00:00 2001 From: Emmanuel Keller Date: Thu, 11 Apr 2024 12:38:42 +0100 Subject: [PATCH] Bug: Wrong count when using COUNT with a subquery (#3855) --- core/src/dbs/group.rs | 36 ++++++++++++++++++++++----- core/src/sql/function.rs | 9 ++++++- lib/tests/group.rs | 53 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/core/src/dbs/group.rs b/core/src/dbs/group.rs index 69fe15b2..33bf1f0d 100644 --- a/core/src/dbs/group.rs +++ b/core/src/dbs/group.rs @@ -5,7 +5,7 @@ use crate::dbs::{Options, Statement, Transaction}; use crate::err::Error; use crate::sql::function::OptimisedAggregate; use crate::sql::value::{TryAdd, TryDiv, Value}; -use crate::sql::{Array, Field, Idiom}; +use crate::sql::{Array, Field, Function, Idiom}; use std::borrow::Cow; use std::collections::{BTreeMap, HashMap}; @@ -20,6 +20,7 @@ struct Aggregator { array: Option, first_val: Option, count: Option, + count_function: Option<(Box, usize)>, math_max: Option, math_min: Option, math_sum: Option, @@ -94,7 +95,7 @@ impl GroupsCollector { ) -> Result<(), Error> { for (agr, idiom) in agrs.iter_mut().zip(idioms) { let val = obj.get(ctx, opt, txn, None, idiom).await?; - agr.push(val)?; + agr.push(ctx, opt, txn, val).await?; } Ok(()) } @@ -174,15 +175,15 @@ impl GroupsCollector { impl Aggregator { fn prepare(&mut self, expr: &Value) { - let a = match expr { - Value::Function(f) => f.get_optimised_aggregate(), + let (a, f) = match expr { + Value::Function(f) => (f.get_optimised_aggregate(), Some(f)), _ => { // We set it only if we don't already have an array if self.array.is_none() && self.first_val.is_none() { self.first_val = Some(Value::None); return; } - OptimisedAggregate::None + (OptimisedAggregate::None, None) } }; match a { @@ -198,6 +199,11 @@ impl Aggregator { self.count = Some(0); } } + OptimisedAggregate::CountFunction => { + if self.count_function.is_none() { + self.count_function = Some((f.unwrap().clone(), 0)); + } + } OptimisedAggregate::MathMax => { if self.math_max.is_none() { self.math_max = Some(Value::None); @@ -236,6 +242,7 @@ impl Aggregator { array: self.array.as_ref().map(|_| Array::new()), first_val: self.first_val.as_ref().map(|_| Value::None), count: self.count.as_ref().map(|_| 0), + count_function: self.count_function.as_ref().map(|(f, _)| (f.clone(), 0)), math_max: self.math_max.as_ref().map(|_| Value::None), math_min: self.math_min.as_ref().map(|_| Value::None), math_sum: self.math_sum.as_ref().map(|_| 0.into()), @@ -245,10 +252,21 @@ impl Aggregator { } } - fn push(&mut self, val: Value) -> Result<(), Error> { + async fn push( + &mut self, + ctx: &Context<'_>, + opt: &Options, + txn: &Transaction, + val: Value, + ) -> Result<(), Error> { if let Some(ref mut c) = self.count { *c += 1; } + if let Some((ref f, ref mut c)) = self.count_function { + if f.aggregate(val.clone()).compute(ctx, opt, txn, None).await?.is_truthy() { + *c += 1; + } + } if val.is_number() { if let Some(s) = self.math_sum.take() { self.math_sum = Some(s.try_add(val.clone())?); @@ -302,6 +320,9 @@ impl Aggregator { Ok(match a { OptimisedAggregate::None => Value::None, OptimisedAggregate::Count => self.count.take().map(|v| v.into()).unwrap_or(Value::None), + OptimisedAggregate::CountFunction => { + self.count_function.take().map(|(_, v)| v.into()).unwrap_or(Value::None) + } OptimisedAggregate::MathMax => self.math_max.take().unwrap_or(Value::None), OptimisedAggregate::MathMin => self.math_min.take().unwrap_or(Value::None), OptimisedAggregate::MathSum => self.math_sum.take().unwrap_or(Value::None), @@ -339,6 +360,9 @@ impl Aggregator { if self.count.is_some() { collections.push("count".into()); } + if self.count_function.is_some() { + collections.push("count+func".into()); + } if self.math_mean.is_some() { collections.push("math::mean".into()); } diff --git a/core/src/sql/function.rs b/core/src/sql/function.rs index 5434f56c..b642f88b 100644 --- a/core/src/sql/function.rs +++ b/core/src/sql/function.rs @@ -35,6 +35,7 @@ pub enum Function { pub(crate) enum OptimisedAggregate { None, Count, + CountFunction, MathMax, MathMin, MathSum, @@ -156,7 +157,13 @@ impl Function { } pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate { match self { - Self::Normal(f, _) if f == "count" => OptimisedAggregate::Count, + Self::Normal(f, v) if f == "count" => { + if v.is_empty() { + OptimisedAggregate::Count + } else { + OptimisedAggregate::CountFunction + } + } Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax, Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean, Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin, diff --git a/lib/tests/group.rs b/lib/tests/group.rs index 341cf253..da9360ec 100644 --- a/lib/tests/group.rs +++ b/lib/tests/group.rs @@ -621,3 +621,56 @@ async fn select_array_group_group_by() -> Result<(), Error> { // Ok(()) } + +#[tokio::test] +async fn select_array_count_subquery_group_by() -> Result<(), Error> { + let sql = r#" + CREATE table CONTENT { bar: "hello", foo: "Man"}; + CREATE table CONTENT { bar: "hello", foo: "World"}; + CREATE table CONTENT { bar: "world"}; + SELECT COUNT(foo != none) FROM table GROUP ALL EXPLAIN; + SELECT COUNT(foo != none) FROM table GROUP ALL; + "#; + let dbs = new_ds().await?; + let ses = Session::owner().with_ns("test").with_db("test"); + let mut res = &mut dbs.execute(sql, &ses, None).await?; + assert_eq!(res.len(), 5); + // + skip_ok(&mut res, 3)?; + // + let tmp = res.remove(0).result?; + let val = Value::parse( + r#"[ + { + detail: { + table: 'table' + }, + operation: 'Iterate Table' + }, + { + detail: { + idioms: { + count: [ + 'count+func' + ] + }, + type: 'Group' + }, + operation: 'Collector' + } + ]"#, + ); + assert_eq!(format!("{tmp:#}"), format!("{val:#}")); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + r#"[ + { + count: 2 + } + ]"#, + ); + assert_eq!(format!("{tmp:#}"), format!("{val:#}")); + // + Ok(()) +}