From 3d10df0fcb200d0055dc716b45211fb9f7d88817 Mon Sep 17 00:00:00 2001 From: Micha de Vries Date: Sat, 21 Sep 2024 15:51:39 +0100 Subject: [PATCH] Fixes ULID/UUID gen, programmatically generating ranges, RETURN inside FOR/IF and improves arithmetic operations (#4847) Co-authored-by: Tobie Morgan Hitchcock --- core/src/dbs/executor.rs | 8 +++- core/src/err/mod.rs | 4 ++ core/src/fnc/rand.rs | 42 ++++++++++++++++++-- core/src/fnc/type.rs | 3 ++ core/src/sql/datetime.rs | 12 ++++++ core/src/sql/duration.rs | 79 ++++++++++++++++++++++++++++++++++++- core/src/sql/strand.rs | 15 +++++++ core/src/sql/value/value.rs | 18 ++++----- 8 files changed, 165 insertions(+), 16 deletions(-) diff --git a/core/src/dbs/executor.rs b/core/src/dbs/executor.rs index 3fa1528f..16a4a2ab 100644 --- a/core/src/dbs/executor.rs +++ b/core/src/dbs/executor.rs @@ -387,8 +387,12 @@ impl<'a> Executor<'a> { .await; ctx = MutableContext::unfreeze(c)?; // Check if this is a RETURN statement - let can_return = - matches!(stm, Statement::Output(_) | Statement::Value(_)); + let can_return = matches!( + stm, + Statement::Output(_) + | Statement::Value(_) | Statement::Ifelse(_) + | Statement::Foreach(_) + ); // Catch global timeout let res = match ctx.is_timedout() { true => Err(Error::QueryTimedout), diff --git a/core/src/err/mod.rs b/core/src/err/mod.rs index 16eb9d1a..69302e51 100644 --- a/core/src/err/mod.rs +++ b/core/src/err/mod.rs @@ -1151,6 +1151,10 @@ pub enum Error { #[error("Found a non-computed value where they are not allowed")] NonComputed, + + /// Represents a failure in timestamp arithmetic related to database internals + #[error("Failed to compute: \"{0}\", as the operation results in an overflow.")] + ArithmeticOverflow(String), } impl From for String { diff --git a/core/src/fnc/rand.rs b/core/src/fnc/rand.rs index 5e2d7456..e8cfd6fd 100644 --- a/core/src/fnc/rand.rs +++ b/core/src/fnc/rand.rs @@ -153,7 +153,19 @@ pub fn time((range,): (Option<(i64, i64)>,)) -> Result { pub fn ulid((timestamp,): (Option,)) -> Result { let ulid = match timestamp { - Some(timestamp) => Ulid::from_datetime(timestamp.0.into()), + Some(timestamp) => { + #[cfg(target_arch = "wasm32")] + if timestamp.0 < chrono::DateTime::UNIX_EPOCH { + return Err(Error::InvalidArguments { + name: String::from("rand::ulid"), + message: format!( + "To generate a ULID from a datetime, it must be a time beyond UNIX epoch." + ), + }); + } + + Ulid::from_datetime(timestamp.0.into()) + } None => Ulid::new(), }; @@ -162,7 +174,19 @@ pub fn ulid((timestamp,): (Option,)) -> Result { pub fn uuid((timestamp,): (Option,)) -> Result { let uuid = match timestamp { - Some(timestamp) => Uuid::new_v7_from_datetime(timestamp), + Some(timestamp) => { + #[cfg(target_arch = "wasm32")] + if timestamp.0 < chrono::DateTime::UNIX_EPOCH { + return Err(Error::InvalidArguments { + name: String::from("rand::ulid"), + message: format!( + "To generate a ULID from a datetime, it must be a time beyond UNIX epoch." + ), + }); + } + + Uuid::new_v7_from_datetime(timestamp) + } None => Uuid::new(), }; Ok(uuid.into()) @@ -181,7 +205,19 @@ pub mod uuid { pub fn v7((timestamp,): (Option,)) -> Result { let uuid = match timestamp { - Some(timestamp) => Uuid::new_v7_from_datetime(timestamp), + Some(timestamp) => { + #[cfg(target_arch = "wasm32")] + if timestamp.0 < chrono::DateTime::UNIX_EPOCH { + return Err(Error::InvalidArguments { + name: String::from("rand::ulid"), + message: format!( + "To generate a ULID from a datetime, it must be a time beyond UNIX epoch." + ), + }); + } + + Uuid::new_v7_from_datetime(timestamp) + } None => Uuid::new(), }; Ok(uuid.into()) diff --git a/core/src/fnc/type.rs b/core/src/fnc/type.rs index 815766d3..8cb07bdc 100644 --- a/core/src/fnc/type.rs +++ b/core/src/fnc/type.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + use crate::ctx::Context; use crate::dbs::Options; use crate::doc::CursorDoc; @@ -146,6 +148,7 @@ pub fn thing((arg1, arg2): (Value, Option)) -> Result { Value::Array(v) => v.into(), Value::Object(v) => v.into(), Value::Number(v) => v.into(), + Value::Range(v) => v.deref().to_owned().try_into()?, v => v.as_string().into(), }, })), diff --git a/core/src/sql/datetime.rs b/core/src/sql/datetime.rs index 82120a24..860af303 100644 --- a/core/src/sql/datetime.rs +++ b/core/src/sql/datetime.rs @@ -1,3 +1,4 @@ +use crate::err::Error; use crate::sql::duration::Duration; use crate::sql::strand::Strand; use crate::syn; @@ -11,6 +12,7 @@ use std::str; use std::str::FromStr; use super::escape::quote_str; +use super::value::TrySub; pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Datetime"; @@ -108,3 +110,13 @@ impl ops::Sub for Datetime { } } } + +impl TrySub for Datetime { + type Output = Duration; + fn try_sub(self, other: Self) -> Result { + (self.0 - other.0) + .to_std() + .map_err(|_| Error::ArithmeticOverflow(format!("{self} - {other}"))) + .map(Duration::from) + } +} diff --git a/core/src/sql/duration.rs b/core/src/sql/duration.rs index 51b0ad69..c663adec 100644 --- a/core/src/sql/duration.rs +++ b/core/src/sql/duration.rs @@ -1,3 +1,4 @@ +use crate::err::Error; use crate::sql::datetime::Datetime; use crate::sql::statements::info::InfoStructure; use crate::sql::strand::Strand; @@ -12,6 +13,8 @@ use std::ops::Deref; use std::str::FromStr; use std::time; +use super::value::{TryAdd, TrySub}; + pub(crate) static SECONDS_PER_YEAR: u64 = 365 * SECONDS_PER_DAY; pub(crate) static SECONDS_PER_WEEK: u64 = 7 * SECONDS_PER_DAY; pub(crate) static SECONDS_PER_DAY: u64 = 24 * SECONDS_PER_HOUR; @@ -235,6 +238,16 @@ impl ops::Add for Duration { } } +impl TryAdd for Duration { + type Output = Self; + fn try_add(self, other: Self) -> Result { + self.0 + .checked_add(other.0) + .ok_or_else(|| Error::ArithmeticOverflow(format!("{self} + {other}"))) + .map(Duration::from) + } +} + impl<'a, 'b> ops::Add<&'b Duration> for &'a Duration { type Output = Duration; fn add(self, other: &'b Duration) -> Duration { @@ -245,6 +258,16 @@ impl<'a, 'b> ops::Add<&'b Duration> for &'a Duration { } } +impl<'a, 'b> TryAdd<&'b Duration> for &'a Duration { + type Output = Duration; + fn try_add(self, other: &'b Duration) -> Result { + self.0 + .checked_add(other.0) + .ok_or_else(|| Error::ArithmeticOverflow(format!("{self} + {other}"))) + .map(Duration::from) + } +} + impl ops::Sub for Duration { type Output = Self; fn sub(self, other: Self) -> Self { @@ -255,6 +278,16 @@ impl ops::Sub for Duration { } } +impl TrySub for Duration { + type Output = Self; + fn try_sub(self, other: Self) -> Result { + self.0 + .checked_sub(other.0) + .ok_or_else(|| Error::ArithmeticOverflow(format!("{self} - {other}"))) + .map(Duration::from) + } +} + impl<'a, 'b> ops::Sub<&'b Duration> for &'a Duration { type Output = Duration; fn sub(self, other: &'b Duration) -> Duration { @@ -265,26 +298,68 @@ impl<'a, 'b> ops::Sub<&'b Duration> for &'a Duration { } } +impl<'a, 'b> TrySub<&'b Duration> for &'a Duration { + type Output = Duration; + fn try_sub(self, other: &'b Duration) -> Result { + self.0 + .checked_sub(other.0) + .ok_or_else(|| Error::ArithmeticOverflow(format!("{self} - {other}"))) + .map(Duration::from) + } +} + impl ops::Add for Duration { type Output = Datetime; fn add(self, other: Datetime) -> Datetime { match chrono::Duration::from_std(self.0) { - Ok(d) => Datetime::from(other.0 + d), + Ok(d) => match other.0.checked_add_signed(d) { + Some(v) => Datetime::from(v), + None => Datetime::default(), + }, Err(_) => Datetime::default(), } } } +impl TryAdd for Duration { + type Output = Datetime; + fn try_add(self, other: Datetime) -> Result { + match chrono::Duration::from_std(self.0) { + Ok(d) => match other.0.checked_add_signed(d) { + Some(v) => Ok(Datetime::from(v)), + None => Err(Error::ArithmeticOverflow(format!("{self} + {other}"))), + }, + Err(_) => Err(Error::ArithmeticOverflow(format!("{self} + {other}"))), + } + } +} + impl ops::Sub for Duration { type Output = Datetime; fn sub(self, other: Datetime) -> Datetime { match chrono::Duration::from_std(self.0) { - Ok(d) => Datetime::from(other.0 - d), + Ok(d) => match other.0.checked_sub_signed(d) { + Some(v) => Datetime::from(v), + None => Datetime::default(), + }, Err(_) => Datetime::default(), } } } +impl TrySub for Duration { + type Output = Datetime; + fn try_sub(self, other: Datetime) -> Result { + match chrono::Duration::from_std(self.0) { + Ok(d) => match other.0.checked_sub_signed(d) { + Some(v) => Ok(Datetime::from(v)), + None => Err(Error::ArithmeticOverflow(format!("{self} - {other}"))), + }, + Err(_) => Err(Error::ArithmeticOverflow(format!("{self} - {other}"))), + } + } +} + impl Sum for Duration { fn sum(iter: I) -> Duration where diff --git a/core/src/sql/strand.rs b/core/src/sql/strand.rs index 3499d91d..a54ef636 100644 --- a/core/src/sql/strand.rs +++ b/core/src/sql/strand.rs @@ -1,3 +1,4 @@ +use crate::err::Error; use crate::sql::escape::quote_plain_str; use revision::revisioned; use serde::{Deserialize, Serialize}; @@ -6,6 +7,8 @@ use std::ops::Deref; use std::ops::{self}; use std::str; +use super::value::TryAdd; + pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Strand"; /// A string that doesn't contain NUL bytes. @@ -72,6 +75,18 @@ impl ops::Add for Strand { } } +impl TryAdd for Strand { + type Output = Self; + fn try_add(mut self, other: Self) -> Result { + if self.0.try_reserve(other.len()).is_ok() { + self.0.push_str(other.as_str()); + Ok(self) + } else { + Err(Error::ArithmeticOverflow(format!("{self} + {other}"))) + } + } +} + // serde(with = no_nul_bytes) will (de)serialize with no NUL bytes. pub(crate) mod no_nul_bytes { use serde::{ diff --git a/core/src/sql/value/value.rs b/core/src/sql/value/value.rs index 3caa9290..7ea9952d 100644 --- a/core/src/sql/value/value.rs +++ b/core/src/sql/value/value.rs @@ -2981,10 +2981,10 @@ impl TryAdd for Value { fn try_add(self, other: Self) -> Result { Ok(match (self, other) { (Self::Number(v), Self::Number(w)) => Self::Number(v.try_add(w)?), - (Self::Strand(v), Self::Strand(w)) => Self::Strand(v + w), - (Self::Datetime(v), Self::Duration(w)) => Self::Datetime(w + v), - (Self::Duration(v), Self::Datetime(w)) => Self::Datetime(v + w), - (Self::Duration(v), Self::Duration(w)) => Self::Duration(v + w), + (Self::Strand(v), Self::Strand(w)) => Self::Strand(v.try_add(w)?), + (Self::Datetime(v), Self::Duration(w)) => Self::Datetime(w.try_add(v)?), + (Self::Duration(v), Self::Datetime(w)) => Self::Datetime(v.try_add(w)?), + (Self::Duration(v), Self::Duration(w)) => Self::Duration(v.try_add(w)?), (v, w) => return Err(Error::TryAdd(v.to_raw_string(), w.to_raw_string())), }) } @@ -2994,7 +2994,7 @@ impl TryAdd for Value { pub(crate) trait TrySub { type Output; - fn try_sub(self, v: Self) -> Result; + fn try_sub(self, v: Rhs) -> Result; } impl TrySub for Value { @@ -3002,10 +3002,10 @@ impl TrySub for Value { fn try_sub(self, other: Self) -> Result { Ok(match (self, other) { (Self::Number(v), Self::Number(w)) => Self::Number(v.try_sub(w)?), - (Self::Datetime(v), Self::Datetime(w)) => Self::Duration(v - w), - (Self::Datetime(v), Self::Duration(w)) => Self::Datetime(w - v), - (Self::Duration(v), Self::Datetime(w)) => Self::Datetime(v - w), - (Self::Duration(v), Self::Duration(w)) => Self::Duration(v - w), + (Self::Datetime(v), Self::Datetime(w)) => Self::Duration(v.try_sub(w)?), + (Self::Datetime(v), Self::Duration(w)) => Self::Datetime(w.try_sub(v)?), + (Self::Duration(v), Self::Datetime(w)) => Self::Datetime(v.try_sub(w)?), + (Self::Duration(v), Self::Duration(w)) => Self::Duration(v.try_sub(w)?), (v, w) => return Err(Error::TrySub(v.to_raw_string(), w.to_raw_string())), }) }