diff --git a/lib/src/sql/number.rs b/lib/src/sql/number.rs index 975feec5..5f0f331e 100644 --- a/lib/src/sql/number.rs +++ b/lib/src/sql/number.rs @@ -1,3 +1,4 @@ +use super::value::{TryAdd, TryDiv, TryMul, TryNeg, TryPow, TrySub}; use crate::err::Error; use crate::sql::ending::number as ending; use crate::sql::error::Error::Parser; @@ -17,7 +18,7 @@ use std::fmt::{self, Display, Formatter}; use std::hash; use std::iter::Product; use std::iter::Sum; -use std::ops::{self, Neg}; +use std::ops::{self, Add, Div, Mul, Neg, Sub}; use std::str::FromStr; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Number"; @@ -474,6 +475,96 @@ impl PartialOrd for Number { } } +macro_rules! impl_simple_try_op { + ($trt:ident, $fn:ident, $unchecked:ident, $checked:ident) => { + impl $trt for Number { + type Output = Self; + fn $fn(self, other: Self) -> Result { + Ok(match (self, other) { + (Number::Int(v), Number::Int(w)) => Number::Int( + v.$checked(w).ok_or_else(|| Error::$trt(v.to_string(), w.to_string()))?, + ), + (Number::Float(v), Number::Float(w)) => Number::Float(v.$unchecked(w)), + (Number::Decimal(v), Number::Decimal(w)) => Number::Decimal( + v.$checked(w).ok_or_else(|| Error::$trt(v.to_string(), w.to_string()))?, + ), + (Number::Int(v), Number::Float(w)) => Number::Float((v as f64).$unchecked(w)), + (Number::Float(v), Number::Int(w)) => Number::Float(v.$unchecked(w as f64)), + (v, w) => Number::Decimal( + v.to_decimal() + .$checked(w.to_decimal()) + .ok_or_else(|| Error::$trt(v.to_string(), w.to_string()))?, + ), + }) + } + } + }; +} + +impl_simple_try_op!(TryAdd, try_add, add, checked_add); +impl_simple_try_op!(TrySub, try_sub, sub, checked_sub); +impl_simple_try_op!(TryMul, try_mul, mul, checked_mul); +impl_simple_try_op!(TryDiv, try_div, div, checked_div); + +impl TryPow for Number { + type Output = Self; + fn try_pow(self, power: Self) -> Result { + Ok(match (self, power) { + (Self::Int(v), Self::Int(p)) => Self::Int(match v { + 0 => match p.cmp(&0) { + // 0^(-x) + Ordering::Less => return Err(Error::TryPow(v.to_string(), p.to_string())), + // 0^0 + Ordering::Equal => 1, + // 0^x + Ordering::Greater => 0, + }, + // 1^p + 1 => 1, + -1 => { + if p % 2 == 0 { + // (-1)^even + 1 + } else { + // (-1)^odd + -1 + } + } + // try_into may cause an error, which would be wrong for the above cases. + _ => p + .try_into() + .ok() + .and_then(|p| v.checked_pow(p)) + .ok_or_else(|| Error::TryPow(v.to_string(), p.to_string()))?, + }), + (Self::Decimal(v), Self::Int(p)) => Self::Decimal( + v.checked_powi(p).ok_or_else(|| Error::TryPow(v.to_string(), p.to_string()))?, + ), + (Self::Decimal(v), Self::Float(p)) => Self::Decimal( + v.checked_powf(p).ok_or_else(|| Error::TryPow(v.to_string(), p.to_string()))?, + ), + (Self::Decimal(v), Self::Decimal(p)) => Self::Decimal( + v.checked_powd(p).ok_or_else(|| Error::TryPow(v.to_string(), p.to_string()))?, + ), + (v, p) => v.as_float().powf(p.as_float()).into(), + }) + } +} + +impl TryNeg for Number { + type Output = Self; + + fn try_neg(self) -> Result { + Ok(match self { + Self::Int(n) => { + Number::Int(n.checked_neg().ok_or_else(|| Error::TryNeg(n.to_string()))?) + } + Self::Float(n) => Number::Float(-n), + Self::Decimal(n) => Number::Decimal(-n), + }) + } +} + impl ops::Add for Number { type Output = Self; fn add(self, other: Self) -> Self { diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index f79c9396..55501e55 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -60,7 +60,6 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::fmt::{self, Display, Formatter, Write}; use std::ops::Deref; -use std::ops::Neg; use std::str::FromStr; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Value"; @@ -2600,28 +2599,14 @@ pub(crate) trait TryAdd { impl TryAdd for Value { type Output = Self; fn try_add(self, other: Self) -> Result { - match (self, other) { - (Value::Number(v), Value::Number(w)) => match (v, w) { - (Number::Int(v), Number::Int(w)) if v.checked_add(w).is_none() => { - Err(Error::TryAdd(v.to_string(), w.to_string())) - } - (Number::Decimal(v), Number::Decimal(w)) if v.checked_add(w).is_none() => { - Err(Error::TryAdd(v.to_string(), w.to_string())) - } - (Number::Decimal(v), w) if v.checked_add(w.to_decimal()).is_none() => { - Err(Error::TryAdd(v.to_string(), w.to_string())) - } - (v, Number::Decimal(w)) if v.to_decimal().checked_add(w).is_none() => { - Err(Error::TryAdd(v.to_string(), w.to_string())) - } - (v, w) => Ok(Value::Number(v + w)), - }, - (Value::Strand(v), Value::Strand(w)) => Ok(Value::Strand(v + w)), - (Value::Datetime(v), Value::Duration(w)) => Ok(Value::Datetime(w + v)), - (Value::Duration(v), Value::Datetime(w)) => Ok(Value::Datetime(v + w)), - (Value::Duration(v), Value::Duration(w)) => Ok(Value::Duration(v + w)), - (v, w) => Err(Error::TryAdd(v.to_raw_string(), w.to_raw_string())), - } + Ok(match (self, other) { + (Self::Number(v), Self::Number(w)) => Self::Number(v.try_add(w)?), + (Self::Strand(v), Self::Strand(w)) => Self::Strand(v + w), + (Self::Datetime(v), Self::Duration(w)) => Self::Datetime(w + v), + (Self::Duration(v), Self::Datetime(w)) => Self::Datetime(v + w), + (Self::Duration(v), Self::Duration(w)) => Self::Duration(v + w), + (v, w) => return Err(Error::TryAdd(v.to_raw_string(), w.to_raw_string())), + }) } } @@ -2635,28 +2620,14 @@ pub(crate) trait TrySub { impl TrySub for Value { type Output = Self; fn try_sub(self, other: Self) -> Result { - match (self, other) { - (Value::Number(v), Value::Number(w)) => match (v, w) { - (Number::Int(v), Number::Int(w)) if v.checked_sub(w).is_none() => { - Err(Error::TrySub(v.to_string(), w.to_string())) - } - (Number::Decimal(v), Number::Decimal(w)) if v.checked_sub(w).is_none() => { - Err(Error::TrySub(v.to_string(), w.to_string())) - } - (Number::Decimal(v), w) if v.checked_sub(w.to_decimal()).is_none() => { - Err(Error::TrySub(v.to_string(), w.to_string())) - } - (v, Number::Decimal(w)) if v.to_decimal().checked_sub(w).is_none() => { - Err(Error::TrySub(v.to_string(), w.to_string())) - } - (v, w) => Ok(Value::Number(v - w)), - }, - (Value::Datetime(v), Value::Datetime(w)) => Ok(Value::Duration(v - w)), - (Value::Datetime(v), Value::Duration(w)) => Ok(Value::Datetime(w - v)), - (Value::Duration(v), Value::Datetime(w)) => Ok(Value::Datetime(v - w)), - (Value::Duration(v), Value::Duration(w)) => Ok(Value::Duration(v - w)), - (v, w) => Err(Error::TrySub(v.to_raw_string(), w.to_raw_string())), - } + Ok(match (self, other) { + (Self::Number(v), Self::Number(w)) => Self::Number(v.try_sub(w)?), + (Self::Datetime(v), Self::Datetime(w)) => Self::Duration(v - w), + (Self::Datetime(v), Self::Duration(w)) => Self::Datetime(w - v), + (Self::Duration(v), Self::Datetime(w)) => Self::Datetime(v - w), + (Self::Duration(v), Self::Duration(w)) => Self::Duration(v - w), + (v, w) => return Err(Error::TrySub(v.to_raw_string(), w.to_raw_string())), + }) } } @@ -2670,24 +2641,10 @@ pub(crate) trait TryMul { impl TryMul for Value { type Output = Self; fn try_mul(self, other: Self) -> Result { - match (self, other) { - (Value::Number(v), Value::Number(w)) => match (v, w) { - (Number::Int(v), Number::Int(w)) if v.checked_mul(w).is_none() => { - Err(Error::TryMul(v.to_string(), w.to_string())) - } - (Number::Decimal(v), Number::Decimal(w)) if v.checked_mul(w).is_none() => { - Err(Error::TryMul(v.to_string(), w.to_string())) - } - (Number::Decimal(v), w) if v.checked_mul(w.to_decimal()).is_none() => { - Err(Error::TryMul(v.to_string(), w.to_string())) - } - (v, Number::Decimal(w)) if v.to_decimal().checked_mul(w).is_none() => { - Err(Error::TryMul(v.to_string(), w.to_string())) - } - (v, w) => Ok(Value::Number(v * w)), - }, - (v, w) => Err(Error::TryMul(v.to_raw_string(), w.to_raw_string())), - } + Ok(match (self, other) { + (Self::Number(v), Self::Number(w)) => Self::Number(v.try_mul(w)?), + (v, w) => return Err(Error::TryMul(v.to_raw_string(), w.to_raw_string())), + }) } } @@ -2701,17 +2658,10 @@ pub(crate) trait TryDiv { impl TryDiv for Value { type Output = Self; fn try_div(self, other: Self) -> Result { - match (self, other) { - (Value::Number(v), Value::Number(w)) => match (v, w) { - (_, Number::Int(0)) => Ok(Value::None), - (Number::Decimal(v), Number::Decimal(w)) if v.checked_div(w).is_none() => { - // Divided a large number by a small number, got an overflowing number - Err(Error::TryDiv(v.to_string(), w.to_string())) - } - (v, w) => Ok(Value::Number(v / w)), - }, - (v, w) => Err(Error::TryDiv(v.to_raw_string(), w.to_raw_string())), - } + Ok(match (self, other) { + (Self::Number(v), Self::Number(w)) => Self::Number(v.try_div(w)?), + (v, w) => return Err(Error::TryDiv(v.to_raw_string(), w.to_raw_string())), + }) } } @@ -2725,20 +2675,10 @@ pub(crate) trait TryPow { impl TryPow for Value { type Output = Self; fn try_pow(self, other: Self) -> Result { - match (self, other) { - (Value::Number(v), Value::Number(w)) => match (v, w) { - (Number::Int(v), Number::Int(w)) - if w.try_into().ok().and_then(|w| v.checked_pow(w)).is_none() => - { - Err(Error::TryPow(v.to_string(), w.to_string())) - } - (Number::Decimal(v), Number::Int(w)) if v.checked_powi(w).is_none() => { - Err(Error::TryPow(v.to_string(), w.to_string())) - } - (v, w) => Ok(Value::Number(v.pow(w))), - }, - (v, w) => Err(Error::TryPow(v.to_raw_string(), w.to_raw_string())), - } + Ok(match (self, other) { + (Value::Number(v), Value::Number(w)) => Self::Number(v.try_pow(w)?), + (v, w) => return Err(Error::TryPow(v.to_raw_string(), w.to_raw_string())), + }) } } @@ -2752,10 +2692,10 @@ pub(crate) trait TryNeg { impl TryNeg for Value { type Output = Self; fn try_neg(self) -> Result { - match self { - Self::Number(n) if !matches!(n, Number::Int(i64::MIN)) => Ok(Self::Number(n.neg())), - v => Err(Error::TryNeg(v.to_string())), - } + Ok(match self { + Self::Number(n) => Self::Number(n.try_neg()?), + v => return Err(Error::TryNeg(v.to_string())), + }) } }