Bug: Wrong count when using COUNT with a subquery (#3855)

This commit is contained in:
Emmanuel Keller 2024-04-11 12:38:42 +01:00 committed by GitHub
parent 2f19afec56
commit 56b4f7d71e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 91 additions and 7 deletions

View file

@ -5,7 +5,7 @@ use crate::dbs::{Options, Statement, Transaction};
use crate::err::Error; use crate::err::Error;
use crate::sql::function::OptimisedAggregate; use crate::sql::function::OptimisedAggregate;
use crate::sql::value::{TryAdd, TryDiv, Value}; use crate::sql::value::{TryAdd, TryDiv, Value};
use crate::sql::{Array, Field, Idiom}; use crate::sql::{Array, Field, Function, Idiom};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeMap, HashMap};
@ -20,6 +20,7 @@ struct Aggregator {
array: Option<Array>, array: Option<Array>,
first_val: Option<Value>, first_val: Option<Value>,
count: Option<usize>, count: Option<usize>,
count_function: Option<(Box<Function>, usize)>,
math_max: Option<Value>, math_max: Option<Value>,
math_min: Option<Value>, math_min: Option<Value>,
math_sum: Option<Value>, math_sum: Option<Value>,
@ -94,7 +95,7 @@ impl GroupsCollector {
) -> Result<(), Error> { ) -> Result<(), Error> {
for (agr, idiom) in agrs.iter_mut().zip(idioms) { for (agr, idiom) in agrs.iter_mut().zip(idioms) {
let val = obj.get(ctx, opt, txn, None, idiom).await?; let val = obj.get(ctx, opt, txn, None, idiom).await?;
agr.push(val)?; agr.push(ctx, opt, txn, val).await?;
} }
Ok(()) Ok(())
} }
@ -174,15 +175,15 @@ impl GroupsCollector {
impl Aggregator { impl Aggregator {
fn prepare(&mut self, expr: &Value) { fn prepare(&mut self, expr: &Value) {
let a = match expr { let (a, f) = match expr {
Value::Function(f) => f.get_optimised_aggregate(), Value::Function(f) => (f.get_optimised_aggregate(), Some(f)),
_ => { _ => {
// We set it only if we don't already have an array // We set it only if we don't already have an array
if self.array.is_none() && self.first_val.is_none() { if self.array.is_none() && self.first_val.is_none() {
self.first_val = Some(Value::None); self.first_val = Some(Value::None);
return; return;
} }
OptimisedAggregate::None (OptimisedAggregate::None, None)
} }
}; };
match a { match a {
@ -198,6 +199,11 @@ impl Aggregator {
self.count = Some(0); self.count = Some(0);
} }
} }
OptimisedAggregate::CountFunction => {
if self.count_function.is_none() {
self.count_function = Some((f.unwrap().clone(), 0));
}
}
OptimisedAggregate::MathMax => { OptimisedAggregate::MathMax => {
if self.math_max.is_none() { if self.math_max.is_none() {
self.math_max = Some(Value::None); self.math_max = Some(Value::None);
@ -236,6 +242,7 @@ impl Aggregator {
array: self.array.as_ref().map(|_| Array::new()), array: self.array.as_ref().map(|_| Array::new()),
first_val: self.first_val.as_ref().map(|_| Value::None), first_val: self.first_val.as_ref().map(|_| Value::None),
count: self.count.as_ref().map(|_| 0), count: self.count.as_ref().map(|_| 0),
count_function: self.count_function.as_ref().map(|(f, _)| (f.clone(), 0)),
math_max: self.math_max.as_ref().map(|_| Value::None), math_max: self.math_max.as_ref().map(|_| Value::None),
math_min: self.math_min.as_ref().map(|_| Value::None), math_min: self.math_min.as_ref().map(|_| Value::None),
math_sum: self.math_sum.as_ref().map(|_| 0.into()), math_sum: self.math_sum.as_ref().map(|_| 0.into()),
@ -245,10 +252,21 @@ impl Aggregator {
} }
} }
fn push(&mut self, val: Value) -> Result<(), Error> { async fn push(
&mut self,
ctx: &Context<'_>,
opt: &Options,
txn: &Transaction,
val: Value,
) -> Result<(), Error> {
if let Some(ref mut c) = self.count { if let Some(ref mut c) = self.count {
*c += 1; *c += 1;
} }
if let Some((ref f, ref mut c)) = self.count_function {
if f.aggregate(val.clone()).compute(ctx, opt, txn, None).await?.is_truthy() {
*c += 1;
}
}
if val.is_number() { if val.is_number() {
if let Some(s) = self.math_sum.take() { if let Some(s) = self.math_sum.take() {
self.math_sum = Some(s.try_add(val.clone())?); self.math_sum = Some(s.try_add(val.clone())?);
@ -302,6 +320,9 @@ impl Aggregator {
Ok(match a { Ok(match a {
OptimisedAggregate::None => Value::None, OptimisedAggregate::None => Value::None,
OptimisedAggregate::Count => self.count.take().map(|v| v.into()).unwrap_or(Value::None), OptimisedAggregate::Count => self.count.take().map(|v| v.into()).unwrap_or(Value::None),
OptimisedAggregate::CountFunction => {
self.count_function.take().map(|(_, v)| v.into()).unwrap_or(Value::None)
}
OptimisedAggregate::MathMax => self.math_max.take().unwrap_or(Value::None), OptimisedAggregate::MathMax => self.math_max.take().unwrap_or(Value::None),
OptimisedAggregate::MathMin => self.math_min.take().unwrap_or(Value::None), OptimisedAggregate::MathMin => self.math_min.take().unwrap_or(Value::None),
OptimisedAggregate::MathSum => self.math_sum.take().unwrap_or(Value::None), OptimisedAggregate::MathSum => self.math_sum.take().unwrap_or(Value::None),
@ -339,6 +360,9 @@ impl Aggregator {
if self.count.is_some() { if self.count.is_some() {
collections.push("count".into()); collections.push("count".into());
} }
if self.count_function.is_some() {
collections.push("count+func".into());
}
if self.math_mean.is_some() { if self.math_mean.is_some() {
collections.push("math::mean".into()); collections.push("math::mean".into());
} }

View file

@ -35,6 +35,7 @@ pub enum Function {
pub(crate) enum OptimisedAggregate { pub(crate) enum OptimisedAggregate {
None, None,
Count, Count,
CountFunction,
MathMax, MathMax,
MathMin, MathMin,
MathSum, MathSum,
@ -156,7 +157,13 @@ impl Function {
} }
pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate { pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
match self { match self {
Self::Normal(f, _) if f == "count" => OptimisedAggregate::Count, Self::Normal(f, v) if f == "count" => {
if v.is_empty() {
OptimisedAggregate::Count
} else {
OptimisedAggregate::CountFunction
}
}
Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax, Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean, Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin, Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,

View file

@ -621,3 +621,56 @@ async fn select_array_group_group_by() -> Result<(), Error> {
// //
Ok(()) Ok(())
} }
#[tokio::test]
async fn select_array_count_subquery_group_by() -> Result<(), Error> {
let sql = r#"
CREATE table CONTENT { bar: "hello", foo: "Man"};
CREATE table CONTENT { bar: "hello", foo: "World"};
CREATE table CONTENT { bar: "world"};
SELECT COUNT(foo != none) FROM table GROUP ALL EXPLAIN;
SELECT COUNT(foo != none) FROM table GROUP ALL;
"#;
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(
r#"[
{
detail: {
table: 'table'
},
operation: 'Iterate Table'
},
{
detail: {
idioms: {
count: [
'count+func'
]
},
type: 'Group'
},
operation: 'Collector'
}
]"#,
);
assert_eq!(format!("{tmp:#}"), format!("{val:#}"));
//
let tmp = res.remove(0).result?;
let val = Value::parse(
r#"[
{
count: 2
}
]"#,
);
assert_eq!(format!("{tmp:#}"), format!("{val:#}"));
//
Ok(())
}