From 06b4f9519827589ba73d5524939cc4f42685c1ee Mon Sep 17 00:00:00 2001 From: Tobie Morgan Hitchcock Date: Thu, 5 Sep 2024 09:10:48 +0100 Subject: [PATCH] Improve aggregate function argument checks (#4698) --- core/src/doc/table.rs | 105 ++++++++++++++++++++++++++++++++++++------ core/src/err/mod.rs | 8 ++++ sdk/tests/table.rs | 23 ++++++++- 3 files changed, 121 insertions(+), 15 deletions(-) diff --git a/core/src/doc/table.rs b/core/src/doc/table.rs index 31d4eb4b..e94af80f 100644 --- a/core/src/doc/table.rs +++ b/core/src/doc/table.rs @@ -390,21 +390,101 @@ impl Document { let val = f.compute(stk, ctx, opt, Some(fdc.doc)).await?; self.chg(&mut set_ops, &mut del_ops, &fdc.act, idiom, val)?; } - Some("math::sum") => { + Some(name) if name == "time::min" => { let val = f.args()[0].compute(stk, ctx, opt, Some(fdc.doc)).await?; + let val = match val { + val @ Value::Datetime(_) => val, + val => { + return Err(Error::InvalidAggregation { + name: name.to_string(), + table: fdc.ft.name.to_raw(), + message: format!( + "This function expects a datetime but found {val}" + ), + }) + } + }; + self.min(&mut set_ops, &mut del_ops, fdc, field, idiom, val); + } + Some(name) if name == "time::max" => { + let val = f.args()[0].compute(stk, ctx, opt, Some(fdc.doc)).await?; + let val = match val { + val @ Value::Datetime(_) => val, + val => { + return Err(Error::InvalidAggregation { + name: name.to_string(), + table: fdc.ft.name.to_raw(), + message: format!( + "This function expects a datetime but found {val}" + ), + }) + } + }; + self.max(&mut set_ops, &mut del_ops, fdc, field, idiom, val); + } + Some(name) if name == "math::sum" => { + let val = f.args()[0].compute(stk, ctx, opt, Some(fdc.doc)).await?; + let val = match val { + val @ Value::Number(_) => val, + val => { + return Err(Error::InvalidAggregation { + name: name.to_string(), + table: fdc.ft.name.to_raw(), + message: format!( + "This function expects a number but found {val}" + ), + }) + } + }; self.chg(&mut set_ops, &mut del_ops, &fdc.act, idiom, val)?; } - Some("math::min") | Some("time::min") => { + Some(name) if name == "math::min" => { let val = f.args()[0].compute(stk, ctx, opt, Some(fdc.doc)).await?; - self.min(&mut set_ops, &mut del_ops, fdc, field, idiom, val)?; + let val = match val { + val @ Value::Number(_) => val, + val => { + return Err(Error::InvalidAggregation { + name: name.to_string(), + table: fdc.ft.name.to_raw(), + message: format!( + "This function expects a number but found {val}" + ), + }) + } + }; + self.min(&mut set_ops, &mut del_ops, fdc, field, idiom, val); } - Some("math::max") | Some("time::max") => { + Some(name) if name == "math::max" => { let val = f.args()[0].compute(stk, ctx, opt, Some(fdc.doc)).await?; - self.max(&mut set_ops, &mut del_ops, fdc, field, idiom, val)?; + let val = match val { + val @ Value::Number(_) => val, + val => { + return Err(Error::InvalidAggregation { + name: name.to_string(), + table: fdc.ft.name.to_raw(), + message: format!( + "This function expects a number but found {val}" + ), + }) + } + }; + self.max(&mut set_ops, &mut del_ops, fdc, field, idiom, val); } - Some("math::mean") => { + Some(name) if name == "math::mean" => { let val = f.args()[0].compute(stk, ctx, opt, Some(fdc.doc)).await?; - self.mean(&mut set_ops, &mut del_ops, &fdc.act, idiom, val)?; + let val = match val { + val @ Value::Number(_) => val.coerce_to_decimal()?.into(), + val => { + return Err(Error::InvalidAggregation { + name: name.to_string(), + table: fdc.ft.name.to_raw(), + message: format!( + "This function expects a number but found {val}" + ), + }) + } + }; + self.mean(&mut set_ops, &mut del_ops, &fdc.act, idiom, val); } _ => unreachable!(), }, @@ -454,7 +534,7 @@ impl Document { field: &Field, key: Idiom, val: Value, - ) -> Result<(), Error> { + ) { // Key for the value count let mut key_c = Idiom::from(vec![Part::from("__")]); key_c.0.push(Part::from(key.to_hash())); @@ -498,7 +578,6 @@ impl Document { del_ops.push((key_c, Operator::Equal, Value::from(0))); } } - Ok(()) } /// Set the new maximum value for the field in the foreign table fn max( @@ -509,7 +588,7 @@ impl Document { field: &Field, key: Idiom, val: Value, - ) -> Result<(), Error> { + ) { // Key for the value count let mut key_c = Idiom::from(vec![Part::from("__")]); key_c.0.push(Part::from(key.to_hash())); @@ -554,7 +633,6 @@ impl Document { del_ops.push((key_c, Operator::Equal, Value::from(0))); } } - Ok(()) } /// Set the new average value for the field in the foreign table @@ -565,7 +643,7 @@ impl Document { act: &FieldAction, key: Idiom, val: Value, - ) -> Result<(), Error> { + ) { // Key for the value count let mut key_c = Idiom::from(vec![Part::from("__")]); key_c.0.push(Part::from(key.to_hash())); @@ -600,7 +678,7 @@ impl Document { FieldAction::Sub => Operator::Sub, FieldAction::Add => Operator::Add, }, - r: val.convert_to_decimal()?.into(), + r: val, }, ))))), o: Operator::Div, @@ -632,7 +710,6 @@ impl Document { del_ops.push((key_c, Operator::Equal, Value::from(0))); } } - Ok(()) } /// Recomputes the value for one group diff --git a/core/src/err/mod.rs b/core/src/err/mod.rs index 8fdd90d1..596a42fa 100644 --- a/core/src/err/mod.rs +++ b/core/src/err/mod.rs @@ -228,6 +228,14 @@ pub enum Error { message: String, }, + /// The wrong quantity or magnitude of arguments was given for the specified function + #[error("Incorrect arguments for aggregate function {name}() on table '{table}'. {message}")] + InvalidAggregation { + name: String, + table: String, + message: String, + }, + /// The wrong quantity or magnitude of arguments was given for the specified function #[error("There was a problem running the {name} function. Expected this function to return a value of type {check}, but found {value}")] FunctionCheck { diff --git a/sdk/tests/table.rs b/sdk/tests/table.rs index 56205d27..c834f0bd 100644 --- a/sdk/tests/table.rs +++ b/sdk/tests/table.rs @@ -29,11 +29,13 @@ async fn define_foreign_table() -> Result<(), Error> { SELECT * FROM person_by_age; UPSERT person:two SET age = 39, score = 91; SELECT * FROM person_by_age; + UPSERT person:two SET age = 39, score = 'test'; + SELECT * FROM person_by_age; "; let dbs = new_ds().await?; let ses = Session::owner().with_ns("test").with_db("test"); let res = &mut dbs.execute(sql, &ses, None).await?; - assert_eq!(res.len(), 9); + assert_eq!(res.len(), 11); // let tmp = res.remove(0).result; assert!(tmp.is_ok()); @@ -137,6 +139,25 @@ async fn define_foreign_table() -> Result<(), Error> { ); assert_eq!(tmp, val); // + let tmp = res.remove(0).result.unwrap_err(); + assert!(matches!(tmp, Error::InvalidAggregation { .. })); + // + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + age: 39, + average: 81.5, + count: 2, + id: person_by_age:[39], + max: 91, + min: 72, + total: 78 + } + ]", + ); + assert_eq!(tmp, val); + // Ok(()) }