Bug when using array::group in a group by query (#3826)

This commit is contained in:
Emmanuel Keller 2024-04-09 00:04:44 +01:00 committed by GitHub
parent e842515882
commit 2f19afec56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 68 additions and 32 deletions

View file

@ -8,7 +8,6 @@ use crate::sql::value::{TryAdd, TryDiv, Value};
use crate::sql::{Array, Field, Idiom}; use crate::sql::{Array, Field, Idiom};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeMap, HashMap};
use std::mem;
pub(super) struct GroupsCollector { pub(super) struct GroupsCollector {
base: Vec<Aggregator>, base: Vec<Aggregator>,
@ -113,9 +112,8 @@ impl GroupsCollector {
) -> Result<MemoryCollector, Error> { ) -> Result<MemoryCollector, Error> {
let mut results = MemoryCollector::default(); let mut results = MemoryCollector::default();
if let Some(fields) = stm.expr() { if let Some(fields) = stm.expr() {
let grp = mem::take(&mut self.grp);
// Loop over each grouped collection // Loop over each grouped collection
for (_, mut aggregator) in grp { for aggregator in self.grp.values_mut() {
// Create a new value // Create a new value
let mut obj = Value::base(); let mut obj = Value::base();
// Loop over each group clause // Loop over each group clause
@ -133,27 +131,24 @@ impl GroupsCollector {
if let Some(idioms_pos) = if let Some(idioms_pos) =
self.idioms.iter().position(|i| i.eq(idiom.as_ref())) self.idioms.iter().position(|i| i.eq(idiom.as_ref()))
{ {
let agr = &mut aggregator[idioms_pos]; if let Some(agr) = aggregator.get_mut(idioms_pos) {
match expr { match expr {
Value::Function(f) if f.is_aggregate() => { Value::Function(f) if f.is_aggregate() => {
let a = f.get_optimised_aggregate(); let a = f.get_optimised_aggregate();
let x = if matches!(a, OptimisedAggregate::None) { let x = if matches!(a, OptimisedAggregate::None) {
// The aggregation is not optimised, let's compute it with the values // The aggregation is not optimised, let's compute it with the values
let vals = agr.take(); let vals = agr.take();
let x = vals f.aggregate(vals).compute(ctx, opt, txn, None).await?
.all() } else {
.get(ctx, opt, txn, None, idiom.as_ref()) // The aggregation is optimised, just get the value
.await?; agr.compute(a)?
f.aggregate(x).compute(ctx, opt, txn, None).await? };
} else { obj.set(ctx, opt, txn, idiom.as_ref(), x).await?;
// The aggregation is optimised, just get the value }
agr.compute(a)? _ => {
}; let x = agr.take().first();
obj.set(ctx, opt, txn, idiom.as_ref(), x).await?; obj.set(ctx, opt, txn, idiom.as_ref(), x).await?;
} }
_ => {
let x = agr.take().first();
obj.set(ctx, opt, txn, idiom.as_ref(), x).await?;
} }
} }
} }

View file

@ -2,6 +2,7 @@ mod parse;
use parse::Parse; use parse::Parse;
mod helpers; mod helpers;
use helpers::new_ds; use helpers::new_ds;
use helpers::skip_ok;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::err::Error; use surrealdb::err::Error;
use surrealdb::sql::Value; use surrealdb::sql::Value;
@ -582,3 +583,41 @@ async fn select_multi_aggregate_composed() -> Result<(), Error> {
// //
Ok(()) Ok(())
} }
#[tokio::test]
async fn select_array_group_group_by() -> Result<(), Error> {
let sql = "
CREATE test:1 SET user = 1, role = 1;
CREATE test:2 SET user = 1, role = 2;
CREATE test:3 SET user = 2, role = 1;
CREATE test:4 SET user = 2, role = 2;
SELECT user, array::group(role) FROM test GROUP BY user;
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 4)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(
r#"[
{
"array::group": [
1,2
],
user: 1
},
{
"array::group": [
1,2
],
user: 2
}
]"#,
);
assert_eq!(format!("{tmp:#}"), format!("{val:#}"));
//
Ok(())
}

View file

@ -8,6 +8,7 @@ use surrealdb::dbs::Session;
use surrealdb::err::Error; use surrealdb::err::Error;
use surrealdb::iam::{Auth, Level, Role}; use surrealdb::iam::{Auth, Level, Role};
use surrealdb::kvs::Datastore; use surrealdb::kvs::Datastore;
use surrealdb_core::dbs::Response;
pub async fn new_ds() -> Result<Datastore, Error> { pub async fn new_ds() -> Result<Datastore, Error> {
Ok(Datastore::new("memory").await?.with_capabilities(Capabilities::all()).with_notifications()) Ok(Datastore::new("memory").await?.with_capabilities(Capabilities::all()).with_notifications())
@ -193,3 +194,11 @@ pub fn with_enough_stack(
.join() .join()
.unwrap() .unwrap()
} }
#[allow(dead_code)]
pub fn skip_ok(res: &mut Vec<Response>, skip: usize) -> Result<(), Error> {
for _ in 0..skip {
let _ = res.remove(0).result?;
}
Ok(())
}

View file

@ -2,7 +2,7 @@ mod parse;
use parse::Parse; use parse::Parse;
mod helpers; mod helpers;
use helpers::new_ds; use helpers::{new_ds, skip_ok};
use surrealdb::dbs::{Response, Session}; use surrealdb::dbs::{Response, Session};
use surrealdb::err::Error; use surrealdb::err::Error;
use surrealdb::kvs::Datastore; use surrealdb::kvs::Datastore;
@ -155,13 +155,6 @@ async fn execute_test(
Ok(res) Ok(res)
} }
fn skip_ok(res: &mut Vec<Response>, skip: usize) -> Result<(), Error> {
for _ in 0..skip {
let _ = res.remove(0).result?;
}
Ok(())
}
fn check_result(res: &mut Vec<Response>, expected: &str) -> Result<(), Error> { fn check_result(res: &mut Vec<Response>, expected: &str) -> Result<(), Error> {
let tmp = res.remove(0).result?; let tmp = res.remove(0).result?;
let val = Value::parse(expected); let val = Value::parse(expected);