From cef01ad790c2bc490fc35109f8786d16b7e36543 Mon Sep 17 00:00:00 2001 From: Etienne Bruines Date: Sun, 18 Dec 2022 15:56:07 +0100 Subject: [PATCH] Add SQL math::pow() function and ** operator (#1239) --- lib/src/fnc/math.rs | 4 ++ lib/src/fnc/mod.rs | 1 + lib/src/fnc/operate.rs | 4 ++ lib/src/sql/expression.rs | 1 + lib/src/sql/function.rs | 1 + lib/src/sql/number.rs | 84 ++++++++++++++++++++++++++++++++++++++ lib/src/sql/operator.rs | 3 ++ lib/src/sql/value/value.rs | 8 ++++ 8 files changed, 106 insertions(+) diff --git a/lib/src/fnc/math.rs b/lib/src/fnc/math.rs index 328b59a1..438e0227 100644 --- a/lib/src/fnc/math.rs +++ b/lib/src/fnc/math.rs @@ -120,6 +120,10 @@ pub fn percentile((array, n): (Value, Number)) -> Result { }) } +pub fn pow((arg, pow): (Number, Number)) -> Result { + Ok(arg.pow(pow).into()) +} + pub fn product((array,): (Value,)) -> Result { Ok(match array { Value::Array(v) => v.as_numbers().into_iter().product::().into(), diff --git a/lib/src/fnc/mod.rs b/lib/src/fnc/mod.rs index 174f9ef6..54054a7b 100644 --- a/lib/src/fnc/mod.rs +++ b/lib/src/fnc/mod.rs @@ -121,6 +121,7 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec) -> Result math::mode, "math::nearestrank" => math::nearestrank, "math::percentile" => math::percentile, + "math::pow" => math::pow, "math::product" => math::product, "math::round" => math::round, "math::spread" => math::spread, diff --git a/lib/src/fnc/operate.rs b/lib/src/fnc/operate.rs index 6ac4365d..72bb72bc 100644 --- a/lib/src/fnc/operate.rs +++ b/lib/src/fnc/operate.rs @@ -49,6 +49,10 @@ pub fn div(a: Value, b: Value) -> Result { Ok(a.div(b)) } +pub fn pow(a: Value, b: Value) -> Result { + Ok(a.pow(b)) +} + pub fn exact(a: &Value, b: &Value) -> Result { Ok(Value::from(a == b)) } diff --git a/lib/src/sql/expression.rs b/lib/src/sql/expression.rs index 41a882d1..028ec829 100644 --- a/lib/src/sql/expression.rs +++ b/lib/src/sql/expression.rs @@ -98,6 +98,7 @@ impl Expression { Operator::Sub => fnc::operate::sub(l, r), Operator::Mul => fnc::operate::mul(l, r), Operator::Div => fnc::operate::div(l, r), + Operator::Pow => fnc::operate::pow(l, r), Operator::Equal => fnc::operate::equal(&l, &r), Operator::Exact => fnc::operate::exact(&l, &r), Operator::NotEqual => fnc::operate::not_equal(&l, &r), diff --git a/lib/src/sql/function.rs b/lib/src/sql/function.rs index cf9cc212..168197af 100644 --- a/lib/src/sql/function.rs +++ b/lib/src/sql/function.rs @@ -340,6 +340,7 @@ fn function_math(i: &str) -> IResult<&str, &str> { alt(( tag("math::nearestrank"), tag("math::percentile"), + tag("math::pow"), tag("math::product"), tag("math::round"), tag("math::spread"), diff --git a/lib/src/sql/number.rs b/lib/src/sql/number.rs index ff070800..40d49085 100644 --- a/lib/src/sql/number.rs +++ b/lib/src/sql/number.rs @@ -2,6 +2,7 @@ use crate::err::Error; use crate::sql::ending::number as ending; use crate::sql::error::IResult; use crate::sql::serde::is_internal_serialization; +use bigdecimal::num_traits::Pow; use bigdecimal::BigDecimal; use bigdecimal::FromPrimitive; use bigdecimal::ToPrimitive; @@ -396,6 +397,21 @@ impl Number { } } + pub fn pow(self, power: Number) -> Number { + match (self, power) { + (Number::Int(v), Number::Int(p)) if p >= 0 && p < u32::MAX as i64 => { + Number::Int(v.pow(p as u32)) + } + (Number::Decimal(v), Number::Int(p)) if p >= 0 && p < u32::MAX as i64 => { + let (as_int, scale) = v.as_bigint_and_exponent(); + Number::Decimal(BigDecimal::new(as_int.pow(p as u32), scale * p)) + } + // TODO: (Number::Decimal(v), Number::Float(p)) => todo!(), + // TODO: (Number::Decimal(v), Number::Decimal(p)) => todo!(), + (v, p) => Number::Float(v.as_float().pow(p.as_float())), + } + } + // ----------------------------------- // // ----------------------------------- @@ -736,4 +752,72 @@ mod tests { assert_eq!("-123.45", format!("{}", out)); assert_eq!(out, Number::from(-123.45)); } + + #[test] + fn number_pow_int() { + let res = number("3"); + assert!(res.is_ok()); + let res = res.unwrap().1; + + let power = number("4"); + assert!(power.is_ok()); + let power = power.unwrap().1; + + assert_eq!(res.pow(power), Number::from(81)); + } + + #[test] + fn number_pow_negatives() { + let res = number("4"); + assert!(res.is_ok()); + let res = res.unwrap().1; + + let power = number("-0.5"); + assert!(power.is_ok()); + let power = power.unwrap().1; + + assert_eq!(res.pow(power), Number::from(0.5)); + } + + #[test] + fn number_pow_float() { + let res = number("2.5"); + assert!(res.is_ok()); + let res = res.unwrap().1; + + let power = number("2"); + assert!(power.is_ok()); + let power = power.unwrap().1; + + assert_eq!(res.pow(power), Number::from(6.25)); + } + + #[test] + fn number_pow_bigdecimal_one() { + let res = number("13.5719384719384719385639856394139476937756394756"); + assert!(res.is_ok()); + let res = res.unwrap().1; + + let power = number("1"); + assert!(power.is_ok()); + let power = power.unwrap().1; + + assert_eq!( + res.pow(power), + Number::from("13.5719384719384719385639856394139476937756394756") + ); + } + + #[test] + fn number_pow_bigdecimal_int() { + let res = number("13.5719384719384719385639856394139476937756394756"); + assert!(res.is_ok()); + let res = res.unwrap().1; + + let power = number("2"); + assert!(power.is_ok()); + let power = power.unwrap().1; + + assert_eq!(res.pow(power), Number::from("184.19751388608358465578173996877942643463869043732548087725588482334195240945031617770904299536")); + } } diff --git a/lib/src/sql/operator.rs b/lib/src/sql/operator.rs index 3e479b68..6d88faff 100644 --- a/lib/src/sql/operator.rs +++ b/lib/src/sql/operator.rs @@ -21,6 +21,7 @@ pub enum Operator { Sub, // - Mul, // * Div, // / + Pow, // ** Inc, // += Dec, // -= // @@ -89,6 +90,7 @@ impl fmt::Display for Operator { Self::Sub => "-", Self::Mul => "*", Self::Div => "/", + Self::Pow => "**", Self::Inc => "+=", Self::Dec => "-=", Self::Equal => "=", @@ -168,6 +170,7 @@ pub fn symbols(i: &str) -> IResult<&str, Operator> { map(char('∙'), |_| Operator::Mul), map(char('/'), |_| Operator::Div), map(char('÷'), |_| Operator::Div), + map(tag("**"), |_| Operator::Pow), )), alt(( map(char('∋'), |_| Operator::Contain), diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 3535d5a4..28547ce8 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -1229,6 +1229,14 @@ impl Value { _ => self.partial_cmp(other), } } + + // ----------------------------------- + // Mathematical operations + // ----------------------------------- + + pub fn pow(self, other: Value) -> Value { + self.as_number().pow(other.as_number()).into() + } } impl fmt::Display for Value {