From 2f19afec56d01104f8e8180745b2f5391e6f57e5 Mon Sep 17 00:00:00 2001 From: Emmanuel Keller Date: Tue, 9 Apr 2024 00:04:44 +0100 Subject: [PATCH] Bug when using array::group in a group by query (#3826) --- core/src/dbs/group.rs | 43 +++++++++++++++++++------------------------ lib/tests/group.rs | 39 +++++++++++++++++++++++++++++++++++++++ lib/tests/helpers.rs | 9 +++++++++ lib/tests/planner.rs | 9 +-------- 4 files changed, 68 insertions(+), 32 deletions(-) diff --git a/core/src/dbs/group.rs b/core/src/dbs/group.rs index 738be1d8..69fe15b2 100644 --- a/core/src/dbs/group.rs +++ b/core/src/dbs/group.rs @@ -8,7 +8,6 @@ use crate::sql::value::{TryAdd, TryDiv, Value}; use crate::sql::{Array, Field, Idiom}; use std::borrow::Cow; use std::collections::{BTreeMap, HashMap}; -use std::mem; pub(super) struct GroupsCollector { base: Vec, @@ -113,9 +112,8 @@ impl GroupsCollector { ) -> Result { let mut results = MemoryCollector::default(); if let Some(fields) = stm.expr() { - let grp = mem::take(&mut self.grp); // Loop over each grouped collection - for (_, mut aggregator) in grp { + for aggregator in self.grp.values_mut() { // Create a new value let mut obj = Value::base(); // Loop over each group clause @@ -133,27 +131,24 @@ impl GroupsCollector { if let Some(idioms_pos) = self.idioms.iter().position(|i| i.eq(idiom.as_ref())) { - let agr = &mut aggregator[idioms_pos]; - match expr { - Value::Function(f) if f.is_aggregate() => { - let a = f.get_optimised_aggregate(); - let x = if matches!(a, OptimisedAggregate::None) { - // The aggregation is not optimised, let's compute it with the values - let vals = agr.take(); - let x = vals - .all() - .get(ctx, opt, txn, None, idiom.as_ref()) - .await?; - f.aggregate(x).compute(ctx, opt, txn, None).await? - } else { - // The aggregation is optimised, just get the value - agr.compute(a)? - }; - obj.set(ctx, opt, txn, idiom.as_ref(), x).await?; - } - _ => { - let x = agr.take().first(); - obj.set(ctx, opt, txn, idiom.as_ref(), x).await?; + if let Some(agr) = aggregator.get_mut(idioms_pos) { + match expr { + Value::Function(f) if f.is_aggregate() => { + let a = f.get_optimised_aggregate(); + let x = if matches!(a, OptimisedAggregate::None) { + // The aggregation is not optimised, let's compute it with the values + let vals = agr.take(); + f.aggregate(vals).compute(ctx, opt, txn, None).await? + } else { + // The aggregation is optimised, just get the value + agr.compute(a)? + }; + obj.set(ctx, opt, txn, idiom.as_ref(), x).await?; + } + _ => { + let x = agr.take().first(); + obj.set(ctx, opt, txn, idiom.as_ref(), x).await?; + } } } } diff --git a/lib/tests/group.rs b/lib/tests/group.rs index 5d1fb3a9..341cf253 100644 --- a/lib/tests/group.rs +++ b/lib/tests/group.rs @@ -2,6 +2,7 @@ mod parse; use parse::Parse; mod helpers; use helpers::new_ds; +use helpers::skip_ok; use surrealdb::dbs::Session; use surrealdb::err::Error; use surrealdb::sql::Value; @@ -582,3 +583,41 @@ async fn select_multi_aggregate_composed() -> Result<(), Error> { // Ok(()) } + +#[tokio::test] +async fn select_array_group_group_by() -> Result<(), Error> { + let sql = " + CREATE test:1 SET user = 1, role = 1; + CREATE test:2 SET user = 1, role = 2; + CREATE test:3 SET user = 2, role = 1; + CREATE test:4 SET user = 2, role = 2; + SELECT user, array::group(role) FROM test GROUP BY user; + "; + let dbs = new_ds().await?; + let ses = Session::owner().with_ns("test").with_db("test"); + let mut res = &mut dbs.execute(sql, &ses, None).await?; + assert_eq!(res.len(), 5); + // + skip_ok(&mut res, 4)?; + // + let tmp = res.remove(0).result?; + let val = Value::parse( + r#"[ + { + "array::group": [ + 1,2 + ], + user: 1 + }, + { + "array::group": [ + 1,2 + ], + user: 2 + } + ]"#, + ); + assert_eq!(format!("{tmp:#}"), format!("{val:#}")); + // + Ok(()) +} diff --git a/lib/tests/helpers.rs b/lib/tests/helpers.rs index 413cd70d..e6785d91 100644 --- a/lib/tests/helpers.rs +++ b/lib/tests/helpers.rs @@ -8,6 +8,7 @@ use surrealdb::dbs::Session; use surrealdb::err::Error; use surrealdb::iam::{Auth, Level, Role}; use surrealdb::kvs::Datastore; +use surrealdb_core::dbs::Response; pub async fn new_ds() -> Result { Ok(Datastore::new("memory").await?.with_capabilities(Capabilities::all()).with_notifications()) @@ -193,3 +194,11 @@ pub fn with_enough_stack( .join() .unwrap() } + +#[allow(dead_code)] +pub fn skip_ok(res: &mut Vec, skip: usize) -> Result<(), Error> { + for _ in 0..skip { + let _ = res.remove(0).result?; + } + Ok(()) +} diff --git a/lib/tests/planner.rs b/lib/tests/planner.rs index d47405eb..487650fd 100644 --- a/lib/tests/planner.rs +++ b/lib/tests/planner.rs @@ -2,7 +2,7 @@ mod parse; use parse::Parse; mod helpers; -use helpers::new_ds; +use helpers::{new_ds, skip_ok}; use surrealdb::dbs::{Response, Session}; use surrealdb::err::Error; use surrealdb::kvs::Datastore; @@ -155,13 +155,6 @@ async fn execute_test( Ok(res) } -fn skip_ok(res: &mut Vec, skip: usize) -> Result<(), Error> { - for _ in 0..skip { - let _ = res.remove(0).result?; - } - Ok(()) -} - fn check_result(res: &mut Vec, expected: &str) -> Result<(), Error> { let tmp = res.remove(0).result?; let val = Value::parse(expected);