Bug: Wrong count when using COUNT with a subquery (#3855)

This commit is contained in:
Emmanuel Keller 2024-04-11 12:38:42 +01:00 committed by GitHub
parent 2f19afec56
commit 56b4f7d71e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 91 additions and 7 deletions

View file

@ -5,7 +5,7 @@ use crate::dbs::{Options, Statement, Transaction};
use crate::err::Error;
use crate::sql::function::OptimisedAggregate;
use crate::sql::value::{TryAdd, TryDiv, Value};
use crate::sql::{Array, Field, Idiom};
use crate::sql::{Array, Field, Function, Idiom};
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap};
@ -20,6 +20,7 @@ struct Aggregator {
array: Option<Array>,
first_val: Option<Value>,
count: Option<usize>,
count_function: Option<(Box<Function>, usize)>,
math_max: Option<Value>,
math_min: Option<Value>,
math_sum: Option<Value>,
@ -94,7 +95,7 @@ impl GroupsCollector {
) -> Result<(), Error> {
for (agr, idiom) in agrs.iter_mut().zip(idioms) {
let val = obj.get(ctx, opt, txn, None, idiom).await?;
agr.push(val)?;
agr.push(ctx, opt, txn, val).await?;
}
Ok(())
}
@ -174,15 +175,15 @@ impl GroupsCollector {
impl Aggregator {
fn prepare(&mut self, expr: &Value) {
let a = match expr {
Value::Function(f) => f.get_optimised_aggregate(),
let (a, f) = match expr {
Value::Function(f) => (f.get_optimised_aggregate(), Some(f)),
_ => {
// We set it only if we don't already have an array
if self.array.is_none() && self.first_val.is_none() {
self.first_val = Some(Value::None);
return;
}
OptimisedAggregate::None
(OptimisedAggregate::None, None)
}
};
match a {
@ -198,6 +199,11 @@ impl Aggregator {
self.count = Some(0);
}
}
OptimisedAggregate::CountFunction => {
if self.count_function.is_none() {
self.count_function = Some((f.unwrap().clone(), 0));
}
}
OptimisedAggregate::MathMax => {
if self.math_max.is_none() {
self.math_max = Some(Value::None);
@ -236,6 +242,7 @@ impl Aggregator {
array: self.array.as_ref().map(|_| Array::new()),
first_val: self.first_val.as_ref().map(|_| Value::None),
count: self.count.as_ref().map(|_| 0),
count_function: self.count_function.as_ref().map(|(f, _)| (f.clone(), 0)),
math_max: self.math_max.as_ref().map(|_| Value::None),
math_min: self.math_min.as_ref().map(|_| Value::None),
math_sum: self.math_sum.as_ref().map(|_| 0.into()),
@ -245,10 +252,21 @@ impl Aggregator {
}
}
fn push(&mut self, val: Value) -> Result<(), Error> {
async fn push(
&mut self,
ctx: &Context<'_>,
opt: &Options,
txn: &Transaction,
val: Value,
) -> Result<(), Error> {
if let Some(ref mut c) = self.count {
*c += 1;
}
if let Some((ref f, ref mut c)) = self.count_function {
if f.aggregate(val.clone()).compute(ctx, opt, txn, None).await?.is_truthy() {
*c += 1;
}
}
if val.is_number() {
if let Some(s) = self.math_sum.take() {
self.math_sum = Some(s.try_add(val.clone())?);
@ -302,6 +320,9 @@ impl Aggregator {
Ok(match a {
OptimisedAggregate::None => Value::None,
OptimisedAggregate::Count => self.count.take().map(|v| v.into()).unwrap_or(Value::None),
OptimisedAggregate::CountFunction => {
self.count_function.take().map(|(_, v)| v.into()).unwrap_or(Value::None)
}
OptimisedAggregate::MathMax => self.math_max.take().unwrap_or(Value::None),
OptimisedAggregate::MathMin => self.math_min.take().unwrap_or(Value::None),
OptimisedAggregate::MathSum => self.math_sum.take().unwrap_or(Value::None),
@ -339,6 +360,9 @@ impl Aggregator {
if self.count.is_some() {
collections.push("count".into());
}
if self.count_function.is_some() {
collections.push("count+func".into());
}
if self.math_mean.is_some() {
collections.push("math::mean".into());
}

View file

@ -35,6 +35,7 @@ pub enum Function {
pub(crate) enum OptimisedAggregate {
None,
Count,
CountFunction,
MathMax,
MathMin,
MathSum,
@ -156,7 +157,13 @@ impl Function {
}
pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
match self {
Self::Normal(f, _) if f == "count" => OptimisedAggregate::Count,
Self::Normal(f, v) if f == "count" => {
if v.is_empty() {
OptimisedAggregate::Count
} else {
OptimisedAggregate::CountFunction
}
}
Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,

View file

@ -621,3 +621,56 @@ async fn select_array_group_group_by() -> Result<(), Error> {
//
Ok(())
}
#[tokio::test]
async fn select_array_count_subquery_group_by() -> Result<(), Error> {
let sql = r#"
CREATE table CONTENT { bar: "hello", foo: "Man"};
CREATE table CONTENT { bar: "hello", foo: "World"};
CREATE table CONTENT { bar: "world"};
SELECT COUNT(foo != none) FROM table GROUP ALL EXPLAIN;
SELECT COUNT(foo != none) FROM table GROUP ALL;
"#;
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(
r#"[
{
detail: {
table: 'table'
},
operation: 'Iterate Table'
},
{
detail: {
idioms: {
count: [
'count+func'
]
},
type: 'Group'
},
operation: 'Collector'
}
]"#,
);
assert_eq!(format!("{tmp:#}"), format!("{val:#}"));
//
let tmp = res.remove(0).result?;
let val = Value::parse(
r#"[
{
count: 2
}
]"#,
);
assert_eq!(format!("{tmp:#}"), format!("{val:#}"));
//
Ok(())
}