Add SQL math::pow() function and ** operator (#1239)

This commit is contained in:
Etienne Bruines 2022-12-18 15:56:07 +01:00 committed by GitHub
parent df954a9554
commit cef01ad790
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 106 additions and 0 deletions

View file

@ -120,6 +120,10 @@ pub fn percentile((array, n): (Value, Number)) -> Result<Value, Error> {
}) })
} }
pub fn pow((arg, pow): (Number, Number)) -> Result<Value, Error> {
Ok(arg.pow(pow).into())
}
pub fn product((array,): (Value,)) -> Result<Value, Error> { pub fn product((array,): (Value,)) -> Result<Value, Error> {
Ok(match array { Ok(match array {
Value::Array(v) => v.as_numbers().into_iter().product::<Number>().into(), Value::Array(v) => v.as_numbers().into_iter().product::<Number>().into(),

View file

@ -121,6 +121,7 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec<Value>) -> Result<Va
"math::mode" => math::mode, "math::mode" => math::mode,
"math::nearestrank" => math::nearestrank, "math::nearestrank" => math::nearestrank,
"math::percentile" => math::percentile, "math::percentile" => math::percentile,
"math::pow" => math::pow,
"math::product" => math::product, "math::product" => math::product,
"math::round" => math::round, "math::round" => math::round,
"math::spread" => math::spread, "math::spread" => math::spread,

View file

@ -49,6 +49,10 @@ pub fn div(a: Value, b: Value) -> Result<Value, Error> {
Ok(a.div(b)) Ok(a.div(b))
} }
pub fn pow(a: Value, b: Value) -> Result<Value, Error> {
Ok(a.pow(b))
}
pub fn exact(a: &Value, b: &Value) -> Result<Value, Error> { pub fn exact(a: &Value, b: &Value) -> Result<Value, Error> {
Ok(Value::from(a == b)) Ok(Value::from(a == b))
} }

View file

@ -98,6 +98,7 @@ impl Expression {
Operator::Sub => fnc::operate::sub(l, r), Operator::Sub => fnc::operate::sub(l, r),
Operator::Mul => fnc::operate::mul(l, r), Operator::Mul => fnc::operate::mul(l, r),
Operator::Div => fnc::operate::div(l, r), Operator::Div => fnc::operate::div(l, r),
Operator::Pow => fnc::operate::pow(l, r),
Operator::Equal => fnc::operate::equal(&l, &r), Operator::Equal => fnc::operate::equal(&l, &r),
Operator::Exact => fnc::operate::exact(&l, &r), Operator::Exact => fnc::operate::exact(&l, &r),
Operator::NotEqual => fnc::operate::not_equal(&l, &r), Operator::NotEqual => fnc::operate::not_equal(&l, &r),

View file

@ -340,6 +340,7 @@ fn function_math(i: &str) -> IResult<&str, &str> {
alt(( alt((
tag("math::nearestrank"), tag("math::nearestrank"),
tag("math::percentile"), tag("math::percentile"),
tag("math::pow"),
tag("math::product"), tag("math::product"),
tag("math::round"), tag("math::round"),
tag("math::spread"), tag("math::spread"),

View file

@ -2,6 +2,7 @@ use crate::err::Error;
use crate::sql::ending::number as ending; use crate::sql::ending::number as ending;
use crate::sql::error::IResult; use crate::sql::error::IResult;
use crate::sql::serde::is_internal_serialization; use crate::sql::serde::is_internal_serialization;
use bigdecimal::num_traits::Pow;
use bigdecimal::BigDecimal; use bigdecimal::BigDecimal;
use bigdecimal::FromPrimitive; use bigdecimal::FromPrimitive;
use bigdecimal::ToPrimitive; use bigdecimal::ToPrimitive;
@ -396,6 +397,21 @@ impl Number {
} }
} }
pub fn pow(self, power: Number) -> Number {
match (self, power) {
(Number::Int(v), Number::Int(p)) if p >= 0 && p < u32::MAX as i64 => {
Number::Int(v.pow(p as u32))
}
(Number::Decimal(v), Number::Int(p)) if p >= 0 && p < u32::MAX as i64 => {
let (as_int, scale) = v.as_bigint_and_exponent();
Number::Decimal(BigDecimal::new(as_int.pow(p as u32), scale * p))
}
// TODO: (Number::Decimal(v), Number::Float(p)) => todo!(),
// TODO: (Number::Decimal(v), Number::Decimal(p)) => todo!(),
(v, p) => Number::Float(v.as_float().pow(p.as_float())),
}
}
// ----------------------------------- // -----------------------------------
// //
// ----------------------------------- // -----------------------------------
@ -736,4 +752,72 @@ mod tests {
assert_eq!("-123.45", format!("{}", out)); assert_eq!("-123.45", format!("{}", out));
assert_eq!(out, Number::from(-123.45)); assert_eq!(out, Number::from(-123.45));
} }
#[test]
fn number_pow_int() {
let res = number("3");
assert!(res.is_ok());
let res = res.unwrap().1;
let power = number("4");
assert!(power.is_ok());
let power = power.unwrap().1;
assert_eq!(res.pow(power), Number::from(81));
}
#[test]
fn number_pow_negatives() {
let res = number("4");
assert!(res.is_ok());
let res = res.unwrap().1;
let power = number("-0.5");
assert!(power.is_ok());
let power = power.unwrap().1;
assert_eq!(res.pow(power), Number::from(0.5));
}
#[test]
fn number_pow_float() {
let res = number("2.5");
assert!(res.is_ok());
let res = res.unwrap().1;
let power = number("2");
assert!(power.is_ok());
let power = power.unwrap().1;
assert_eq!(res.pow(power), Number::from(6.25));
}
#[test]
fn number_pow_bigdecimal_one() {
let res = number("13.5719384719384719385639856394139476937756394756");
assert!(res.is_ok());
let res = res.unwrap().1;
let power = number("1");
assert!(power.is_ok());
let power = power.unwrap().1;
assert_eq!(
res.pow(power),
Number::from("13.5719384719384719385639856394139476937756394756")
);
}
#[test]
fn number_pow_bigdecimal_int() {
let res = number("13.5719384719384719385639856394139476937756394756");
assert!(res.is_ok());
let res = res.unwrap().1;
let power = number("2");
assert!(power.is_ok());
let power = power.unwrap().1;
assert_eq!(res.pow(power), Number::from("184.19751388608358465578173996877942643463869043732548087725588482334195240945031617770904299536"));
}
} }

View file

@ -21,6 +21,7 @@ pub enum Operator {
Sub, // - Sub, // -
Mul, // * Mul, // *
Div, // / Div, // /
Pow, // **
Inc, // += Inc, // +=
Dec, // -= Dec, // -=
// //
@ -89,6 +90,7 @@ impl fmt::Display for Operator {
Self::Sub => "-", Self::Sub => "-",
Self::Mul => "*", Self::Mul => "*",
Self::Div => "/", Self::Div => "/",
Self::Pow => "**",
Self::Inc => "+=", Self::Inc => "+=",
Self::Dec => "-=", Self::Dec => "-=",
Self::Equal => "=", Self::Equal => "=",
@ -168,6 +170,7 @@ pub fn symbols(i: &str) -> IResult<&str, Operator> {
map(char('∙'), |_| Operator::Mul), map(char('∙'), |_| Operator::Mul),
map(char('/'), |_| Operator::Div), map(char('/'), |_| Operator::Div),
map(char('÷'), |_| Operator::Div), map(char('÷'), |_| Operator::Div),
map(tag("**"), |_| Operator::Pow),
)), )),
alt(( alt((
map(char('∋'), |_| Operator::Contain), map(char('∋'), |_| Operator::Contain),

View file

@ -1229,6 +1229,14 @@ impl Value {
_ => self.partial_cmp(other), _ => self.partial_cmp(other),
} }
} }
// -----------------------------------
// Mathematical operations
// -----------------------------------
pub fn pow(self, other: Value) -> Value {
self.as_number().pow(other.as_number()).into()
}
} }
impl fmt::Display for Value { impl fmt::Display for Value {