[Feat] Implements a few math functions ()

This commit is contained in:
Emmanuel Keller 2024-06-05 16:21:49 +01:00 committed by GitHub
parent ddf2b874b2
commit 9c196fa154
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1698 additions and 1908 deletions

View file

@ -19,6 +19,22 @@ pub fn abs((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.abs().into())
}
pub fn acos((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.acos().into())
}
pub fn acot((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.acot().into())
}
pub fn asin((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.asin().into())
}
pub fn atan((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.atan().into())
}
pub fn bottom((array, c): (Vec<Number>, i64)) -> Result<Value, Error> {
if c > 0 {
Ok(array.bottom(c).into())
@ -34,6 +50,21 @@ pub fn ceil((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.ceil().into())
}
pub fn clamp((arg, min, max): (Number, Number, Number)) -> Result<Value, Error> {
Ok(arg.clamp(min, max).into())
}
pub fn cos((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.cos().into())
}
pub fn cot((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.cot().into())
}
pub fn deg2rad((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.deg2rad().into())
}
pub fn fixed((arg, p): (Number, i64)) -> Result<Value, Error> {
if p > 0 {
Ok(arg.fixed(p as usize).into())
@ -53,6 +84,30 @@ pub fn interquartile((mut array,): (Vec<Number>,)) -> Result<Value, Error> {
Ok(array.sorted().interquartile().into())
}
pub fn lerp((from, to, factor): (Number, Number, Number)) -> Result<Value, Error> {
Ok(factor.lerp(from, to).into())
}
pub fn lerpangle((from, to, factor): (Number, Number, Number)) -> Result<Value, Error> {
Ok(factor.lerp_angle(from, to).into())
}
pub fn ln((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.ln().into())
}
pub fn log((arg, base): (Number, Number)) -> Result<Value, Error> {
Ok(arg.log(base).into())
}
pub fn log10((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.log10().into())
}
pub fn log2((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.log2().into())
}
pub fn max((array,): (Vec<Number>,)) -> Result<Value, Error> {
Ok(match array.into_iter().max() {
Some(v) => v.into(),
@ -102,10 +157,22 @@ pub fn product((array,): (Vec<Number>,)) -> Result<Value, Error> {
Ok(array.into_iter().product::<Number>().into())
}
pub fn rad2deg((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.rad2deg().into())
}
pub fn round((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.round().into())
}
pub fn sign((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.sign().into())
}
pub fn sin((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.sin().into())
}
pub fn spread((array,): (Vec<Number>,)) -> Result<Value, Error> {
Ok(array.spread().into())
}
@ -124,6 +191,9 @@ pub fn stddev((array,): (Vec<Number>,)) -> Result<Value, Error> {
pub fn sum((array,): (Vec<Number>,)) -> Result<Value, Error> {
Ok(array.into_iter().sum::<Number>().into())
}
pub fn tan((arg,): (Number,)) -> Result<Value, Error> {
Ok(arg.tan().into())
}
pub fn top((array, c): (Vec<Number>, i64)) -> Result<Value, Error> {
if c > 0 {

View file

@ -176,11 +176,25 @@ pub fn synchronous(
"geo::hash::encode" => geo::hash::encode,
//
"math::abs" => math::abs,
"math::acos" => math::acos,
"math::acot" => math::acot,
"math::asin" => math::asin,
"math::atan" => math::atan,
"math::bottom" => math::bottom,
"math::ceil" => math::ceil,
"math::clamp" => math::clamp,
"math::cos" => math::cos,
"math::cot" => math::cot,
"math::deg2rad" => math::deg2rad,
"math::fixed" => math::fixed,
"math::floor" => math::floor,
"math::interquartile" => math::interquartile,
"math::lerp" => math::lerp,
"math::lerpangle" => math::lerpangle,
"math::ln" => math::ln,
"math::log" => math::log,
"math::log10" => math::log10,
"math::log2" => math::log2,
"math::max" => math::max,
"math::mean" => math::mean,
"math::median" => math::median,
@ -191,11 +205,15 @@ pub fn synchronous(
"math::percentile" => math::percentile,
"math::pow" => math::pow,
"math::product" => math::product,
"math::rad2deg" => math::rad2deg,
"math::round" => math::round,
"math::sign" => math::sign,
"math::sin" => math::sin,
"math::spread" => math::spread,
"math::sqrt" => math::sqrt,
"math::stddev" => math::stddev,
"math::sum" => math::sum,
"math::tan" => math::tan,
"math::top" => math::top,
"math::trimean" => math::trimean,
"math::variance" => math::variance,

View file

@ -8,11 +8,25 @@ impl_module_def!(
Package,
"math",
"abs" => run,
"acos" => run,
"acot" => run,
"asin" => run,
"atan" => run,
"bottom" => run,
"ceil" => run,
"clamp" => run,
"cos" => run,
"cot" => run,
"deg2rad" => run,
"fixed" => run,
"floor" => run,
"interquartile" => run,
"lerp" => run,
"lerpangle" => run,
"ln" => run,
"log" => run,
"log2" => run,
"log10" => run,
"max" => run,
"mean" => run,
"median" => run,
@ -23,11 +37,15 @@ impl_module_def!(
"percentile" => run,
"pow" => run,
"product" => run,
"rad2deg" => run,
"round" => run,
"sign" => run,
"sin" => run,
"spread" => run,
"sqrt" => run,
"stddev" => run,
"sum" => run,
"tan" => run,
"top" => run,
"trimean" => run,
"variance" => run

View file

@ -1,6 +1,3 @@
use crate::ctx::Context;
use crate::dbs::Options;
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::sql::value::Value;
use crate::sql::Datetime;
@ -36,6 +33,7 @@ pub enum Constant {
MathLog10E,
MathLog210,
MathLog2E,
MathNegInf,
MathPi,
MathSqrt2,
MathTau,
@ -70,6 +68,7 @@ impl Constant {
Self::MathLog10E => ConstantValue::Float(f64c::LOG10_E),
Self::MathLog210 => ConstantValue::Float(f64c::LOG2_10),
Self::MathLog2E => ConstantValue::Float(f64c::LOG2_E),
Self::MathNegInf => ConstantValue::Float(f64::NEG_INFINITY),
Self::MathPi => ConstantValue::Float(f64c::PI),
Self::MathSqrt2 => ConstantValue::Float(f64c::SQRT_2),
Self::MathTau => ConstantValue::Float(f64c::TAU),
@ -77,12 +76,7 @@ impl Constant {
}
}
/// Process this type returning a computed simple Value
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
_opt: &Options,
_doc: Option<&CursorDoc<'_>>,
) -> Result<Value, Error> {
pub fn compute(&self) -> Result<Value, Error> {
Ok(match self.value() {
ConstantValue::Datetime(d) => d.into(),
ConstantValue::Float(f) => f.into(),
@ -110,6 +104,7 @@ impl fmt::Display for Constant {
Self::MathLog10E => "math::LOG10_E",
Self::MathLog210 => "math::LOG2_10",
Self::MathLog2E => "math::LOG2_E",
Self::MathNegInf => "math::NEG_INF",
Self::MathPi => "math::PI",
Self::MathSqrt2 => "math::SQRT_2",
Self::MathTau => "math::TAU",

View file

@ -7,6 +7,7 @@ use revision::revisioned;
use rust_decimal::prelude::*;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::f64::consts::PI;
use std::fmt::{self, Display, Formatter};
use std::hash;
use std::iter::Product;
@ -391,6 +392,18 @@ impl Number {
self.to_float().acos().into()
}
pub fn asin(self) -> Self {
self.to_float().asin().into()
}
pub fn atan(self) -> Self {
self.to_float().atan().into()
}
pub fn acot(self) -> Self {
(PI / 2.0 - self.atan().to_float()).into()
}
pub fn ceil(self) -> Self {
match self {
Number::Int(v) => v.into(),
@ -399,6 +412,27 @@ impl Number {
}
}
pub fn clamp(self, min: Self, max: Self) -> Self {
match (self, min, max) {
(Number::Int(n), Number::Int(min), Number::Int(max)) => n.clamp(min, max).into(),
(Number::Decimal(n), min, max) => n.clamp(min.to_decimal(), max.to_decimal()).into(),
(Number::Float(n), min, max) => n.clamp(min.to_float(), max.to_float()).into(),
(Number::Int(n), min, max) => n.to_float().clamp(min.to_float(), max.to_float()).into(),
}
}
pub fn cos(self) -> Self {
self.to_float().cos().into()
}
pub fn cot(self) -> Self {
(1.0 / self.to_float().tan()).into()
}
pub fn deg2rad(self) -> Self {
self.to_float().to_radians().into()
}
pub fn floor(self) -> Self {
match self {
Number::Int(v) => v.into(),
@ -407,6 +441,77 @@ impl Number {
}
}
fn lerp_f64(from: f64, to: f64, factor: f64) -> f64 {
from + factor * (to - from)
}
fn lerp_decimal(from: Decimal, to: Decimal, factor: Decimal) -> Decimal {
from + factor * (to - from)
}
pub fn lerp(self, from: Self, to: Self) -> Self {
match (self, from, to) {
(Number::Decimal(val), from, to) => {
Self::lerp_decimal(from.to_decimal(), to.to_decimal(), val).into()
}
(val, from, to) => {
Self::lerp_f64(from.to_float(), to.to_float(), val.to_float()).into()
}
}
}
fn repeat_f64(t: f64, m: f64) -> f64 {
(t - (t / m).floor() * m).clamp(0.0, m)
}
fn repeat_decimal(t: Decimal, m: Decimal) -> Decimal {
(t - (t / m).floor() * m).clamp(Decimal::ZERO, m)
}
pub fn lerp_angle(self, from: Self, to: Self) -> Self {
match (self, from, to) {
(Number::Decimal(val), from, to) => {
let from = from.to_decimal();
let to = to.to_decimal();
let mut dt = Self::repeat_decimal(to - from, Decimal::from(360));
if dt > Decimal::from(180) {
dt = Decimal::from(360) - dt;
}
Self::lerp_decimal(from, from + dt, val).into()
}
(val, from, to) => {
let val = val.to_float();
let from = from.to_float();
let to = to.to_float();
let mut dt = Self::repeat_f64(to - from, 360.0);
if dt > 180.0 {
dt = 360.0 - dt;
}
Self::lerp_f64(from, from + dt, val).into()
}
}
}
pub fn ln(self) -> Self {
self.to_float().ln().into()
}
pub fn log(self, base: Self) -> Self {
self.to_float().log(base.to_float()).into()
}
pub fn log2(self) -> Self {
self.to_float().log2().into()
}
pub fn log10(self) -> Self {
self.to_float().log10().into()
}
pub fn rad2deg(self) -> Self {
self.to_float().to_degrees().into()
}
pub fn round(self) -> Self {
match self {
Number::Int(v) => v.into(),
@ -423,6 +528,22 @@ impl Number {
}
}
pub fn sign(self) -> Self {
match self {
Number::Int(n) => n.signum().into(),
Number::Float(n) => n.signum().into(),
Number::Decimal(n) => n.signum().into(),
}
}
pub fn sin(self) -> Self {
self.to_float().sin().into()
}
pub fn tan(self) -> Self {
self.to_float().tan().into()
}
pub fn sqrt(self) -> Self {
match self {
Number::Int(v) => (v as f64).sqrt().into(),

View file

@ -38,15 +38,18 @@ impl ser::Serializer for Serializer {
"MathFracPi4" => Ok(Constant::MathFracPi4),
"MathFracPi6" => Ok(Constant::MathFracPi6),
"MathFracPi8" => Ok(Constant::MathFracPi8),
"MathInf" => Ok(Constant::MathInf),
"MathLn10" => Ok(Constant::MathLn10),
"MathLn2" => Ok(Constant::MathLn2),
"MathLog102" => Ok(Constant::MathLog102),
"MathLog10E" => Ok(Constant::MathLog10E),
"MathLog210" => Ok(Constant::MathLog210),
"MathLog2E" => Ok(Constant::MathLog2E),
"MathNegInf" => Ok(Constant::MathNegInf),
"MathPi" => Ok(Constant::MathPi),
"MathSqrt2" => Ok(Constant::MathSqrt2),
"MathTau" => Ok(Constant::MathTau),
"TimeEpoch" => Ok(Constant::TimeEpoch),
variant => Err(Error::custom(format!("unknown variant `{name}::{variant}`"))),
}
}
@ -128,6 +131,13 @@ mod tests {
assert_eq!(constant, serialized);
}
#[test]
fn math_inf() {
let constant = Constant::MathInf;
let serialized = constant.serialize(Serializer.wrap()).unwrap();
assert_eq!(constant, serialized);
}
#[test]
fn math_ln10() {
let constant = Constant::MathLn10;
@ -170,6 +180,13 @@ mod tests {
assert_eq!(constant, serialized);
}
#[test]
fn math_neg_inf() {
let constant = Constant::MathNegInf;
let serialized = constant.serialize(Serializer.wrap()).unwrap();
assert_eq!(constant, serialized);
}
#[test]
fn math_pi() {
let constant = Constant::MathPi;
@ -190,4 +207,11 @@ mod tests {
let serialized = constant.serialize(Serializer.wrap()).unwrap();
assert_eq!(constant, serialized);
}
#[test]
fn time_epoch() {
let constant = Constant::TimeEpoch;
let serialized = constant.serialize(Serializer.wrap()).unwrap();
assert_eq!(constant, serialized);
}
}

View file

@ -2656,7 +2656,7 @@ impl Value {
Value::Array(v) => stk.run(|stk| v.compute(stk, ctx, opt, doc)).await,
Value::Object(v) => stk.run(|stk| v.compute(stk, ctx, opt, doc)).await,
Value::Future(v) => stk.run(|stk| v.compute(stk, ctx, opt, doc)).await,
Value::Constant(v) => v.compute(ctx, opt, doc).await,
Value::Constant(v) => v.compute(),
Value::Function(v) => v.compute(stk, ctx, opt, doc).await,
Value::Model(v) => v.compute(stk, ctx, opt, doc).await,
Value::Subquery(v) => stk.run(|stk| v.compute(stk, ctx, opt, doc)).await,

View file

@ -168,11 +168,25 @@ pub(crate) static PATHS: phf::Map<UniCase<&'static str>, PathKind> = phf_map! {
UniCase::ascii("geo::hash::encode") => PathKind::Function,
//
UniCase::ascii("math::abs") => PathKind::Function,
UniCase::ascii("math::acos") => PathKind::Function,
UniCase::ascii("math::asin") => PathKind::Function,
UniCase::ascii("math::acot") => PathKind::Function,
UniCase::ascii("math::atan") => PathKind::Function,
UniCase::ascii("math::bottom") => PathKind::Function,
UniCase::ascii("math::ceil") => PathKind::Function,
UniCase::ascii("math::clamp") => PathKind::Function,
UniCase::ascii("math::cos") => PathKind::Function,
UniCase::ascii("math::cot") => PathKind::Function,
UniCase::ascii("math::deg2rad") => PathKind::Function,
UniCase::ascii("math::fixed") => PathKind::Function,
UniCase::ascii("math::floor") => PathKind::Function,
UniCase::ascii("math::interquartile") => PathKind::Function,
UniCase::ascii("math::lerp") => PathKind::Function,
UniCase::ascii("math::lerpangle") => PathKind::Function,
UniCase::ascii("math::ln") => PathKind::Function,
UniCase::ascii("math::log") => PathKind::Function,
UniCase::ascii("math::log2") => PathKind::Function,
UniCase::ascii("math::log10") => PathKind::Function,
UniCase::ascii("math::max") => PathKind::Function,
UniCase::ascii("math::mean") => PathKind::Function,
UniCase::ascii("math::median") => PathKind::Function,
@ -184,8 +198,12 @@ pub(crate) static PATHS: phf::Map<UniCase<&'static str>, PathKind> = phf_map! {
UniCase::ascii("math::pow") => PathKind::Function,
UniCase::ascii("math::product") => PathKind::Function,
UniCase::ascii("math::round") => PathKind::Function,
UniCase::ascii("math::rad2deg") => PathKind::Function,
UniCase::ascii("math::sign") => PathKind::Function,
UniCase::ascii("math::sin") => PathKind::Function,
UniCase::ascii("math::spread") => PathKind::Function,
UniCase::ascii("math::sqrt") => PathKind::Function,
UniCase::ascii("math::tan") => PathKind::Function,
UniCase::ascii("math::stddev") => PathKind::Function,
UniCase::ascii("math::sum") => PathKind::Function,
UniCase::ascii("math::top") => PathKind::Function,
@ -409,6 +427,7 @@ pub(crate) static PATHS: phf::Map<UniCase<&'static str>, PathKind> = phf_map! {
UniCase::ascii("math::LOG10_E") => PathKind::Constant(Constant::MathLog10E),
UniCase::ascii("math::LOG2_10") => PathKind::Constant(Constant::MathLog210),
UniCase::ascii("math::LOG2_E") => PathKind::Constant(Constant::MathLog2E),
UniCase::ascii("math::NEG_INF") => PathKind::Constant(Constant::MathNegInf),
UniCase::ascii("math::PI") => PathKind::Constant(Constant::MathPi),
UniCase::ascii("math::SQRT_2") => PathKind::Constant(Constant::MathSqrt2),
UniCase::ascii("math::TAU") => PathKind::Constant(Constant::MathTau),

View file

@ -122,18 +122,36 @@ fn parse_i64() {
fn constant_lowercase() {
let out = test_parse!(parse_value, r#" math::pi "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathPi));
let out = test_parse!(parse_value, r#" math::inf "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathInf));
let out = test_parse!(parse_value, r#" math::neg_inf "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathNegInf));
}
#[test]
fn constant_uppercase() {
let out = test_parse!(parse_value, r#" MATH::PI "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathPi));
let out = test_parse!(parse_value, r#" MATH::INF "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathInf));
let out = test_parse!(parse_value, r#" MATH::NEG_INF "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathNegInf));
}
#[test]
fn constant_mixedcase() {
let out = test_parse!(parse_value, r#" MaTh::Pi "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathPi));
let out = test_parse!(parse_value, r#" MaTh::Inf "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathInf));
let out = test_parse!(parse_value, r#" MaTh::Neg_Inf "#).unwrap();
assert_eq!(out, Value::Constant(Constant::MathNegInf));
}
#[test]

View file

@ -854,7 +854,7 @@ async fn define_statement_index_multiple() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
skip_ok(res, 2)?;
skip_ok(res, 2);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -1243,7 +1243,7 @@ async fn define_statement_index_on_schemafull_without_permission() -> Result<(),
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 2);
//
skip_ok(&mut res, 1)?;
skip_ok(&mut res, 1);
//
let tmp = res.remove(0).result;
let s = format!("{:?}", tmp);

File diff suppressed because it is too large Load diff

View file

@ -598,7 +598,7 @@ async fn select_array_group_group_by() -> Result<(), Error> {
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 4)?;
skip_ok(&mut res, 4);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -636,7 +636,7 @@ async fn select_array_count_subquery_group_by() -> Result<(), Error> {
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
//
let tmp = res.remove(0).result?;
let val = Value::parse(

View file

@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::Arc;
use std::thread::Builder;
@ -9,6 +10,7 @@ use surrealdb::err::Error;
use surrealdb::iam::{Auth, Level, Role};
use surrealdb::kvs::Datastore;
use surrealdb_core::dbs::Response;
use surrealdb_core::sql::{value, Number, Value};
pub async fn new_ds() -> Result<Datastore, Error> {
Ok(Datastore::new("memory").await?.with_capabilities(Capabilities::all()).with_notifications())
@ -195,8 +197,10 @@ pub fn with_enough_stack(
.unwrap()
}
/// Skip the specified number of successful results from a vector of responses.
/// This function will panic if there are not enough results in the vector or if an error occurs.
#[allow(dead_code)]
pub fn skip_ok(res: &mut Vec<Response>, skip: usize) -> Result<(), Error> {
pub fn skip_ok(res: &mut Vec<Response>, skip: usize) {
for i in 0..skip {
if res.is_empty() {
panic!("No more result #{i}");
@ -206,5 +210,208 @@ pub fn skip_ok(res: &mut Vec<Response>, skip: usize) -> Result<(), Error> {
panic!("Statement #{i} fails with: {e}");
});
}
Ok(())
}
/// Struct representing a test scenario.
///
/// # Fields
/// - `ds`: The datastore for the test.
/// - `session`: The session for the test.
/// - `responses`: The list of responses for the test.
/// - `pos`: The current position in the responses list.
#[allow(dead_code)]
pub struct Test {
pub ds: Datastore,
pub session: Session,
pub responses: Vec<Response>,
pos: usize,
}
impl Debug for Test {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Responses left: {:?}.", self.responses)
}
}
impl Test {
/// Creates a new instance of the `Self` struct with the given SQL query.
/// Arguments `sql` - A string slice representing the SQL query.
/// Panics if an error occurs.
#[allow(dead_code)]
pub async fn new(sql: &str) -> Self {
Self::try_new(sql).await.unwrap_or_else(|e| panic!("{e}"))
}
/// Create a new instance of the Test struct and execute the given SQL statement.
///
#[allow(dead_code)]
pub async fn try_new(sql: &str) -> Result<Self, Error> {
let ds = new_ds().await?;
let session = Session::owner().with_ns("test").with_db("test");
let responses = ds.execute(sql, &session, None).await?;
Ok(Self {
ds,
session,
responses,
pos: 0,
})
}
/// Checks if the number of responses matches the expected size.
/// Panics if the number of responses does not match the expected size
#[allow(dead_code)]
pub fn expect_size(&mut self, expected: usize) -> &mut Self {
assert_eq!(
self.responses.len(),
expected,
"Unexpected number of results: {} - Expected: {expected}",
self.responses.len()
);
self
}
/// Retrieves the next response from the responses list.
/// This method will panic if the responses list is empty, indicating that there are no more responses to retrieve.
/// The panic message will include the last position in the responses list before it was emptied.
#[allow(dead_code)]
pub fn next(&mut self) -> Response {
if self.responses.is_empty() {
panic!("No response left - last position: {}", self.pos);
}
self.pos += 1;
self.responses.remove(0)
}
/// Retrieves the next value from the responses list.
/// This method will panic if the responses list is empty, indicating that there are no more responses to retrieve.
/// The panic message will include the last position in the responses list before it was emptied.
pub fn next_value(&mut self) -> Value {
self.next()
.result
.unwrap_or_else(|e| panic!("Unexpected error: {e} - last position: {}", self.pos))
}
/// Skips a specified number of elements from the beginning of the `responses` vector
/// and updates the position.
#[allow(dead_code)]
pub fn skip_ok(&mut self, skip: usize) -> &mut Self {
skip_ok(&mut self.responses, skip);
self.pos += skip;
self
}
/// Expects the next value to be equal to the provided value.
/// Panics if the expected value is not equal to the actual value.
/// Compliant with NaN and Constants.
#[allow(dead_code)]
pub fn expect_value(&mut self, val: Value) -> &mut Self {
let tmp = self.next_value();
// Then check they are indeed the same values
//
// If it is a constant we need to transform it as a number
let val = if let Value::Constant(c) = val {
c.compute().unwrap_or_else(|e| panic!("Can't convert constant {c} - {e}"))
} else {
val
};
if val.is_nan() {
assert!(tmp.is_nan(), "Expected NaN but got: {tmp}");
} else {
assert_eq!(tmp, val, "{tmp:#}");
}
//
self
}
/// Expect values in the given slice to be present in the responses, following the same order.
#[allow(dead_code)]
pub fn expect_values(&mut self, values: &[Value]) -> &mut Self {
for value in values {
self.expect_value(value.clone());
}
self
}
/// Expect the given value to be equals to the next response.
#[allow(dead_code)]
pub fn expect_val(&mut self, val: &str) -> &mut Self {
self.expect_value(value(val).unwrap())
}
#[allow(dead_code)]
/// Expect values in the given slice to be present in the responses, following the same order.
pub fn expect_vals(&mut self, vals: &[&str]) -> &mut Self {
for val in vals {
self.expect_val(val);
}
self
}
/// Expects the next result to be an error with the specified error message.
/// This function will panic if the next result is not an error or if the error
/// message does not match the specified error.
#[allow(dead_code)]
pub fn expect_error(&mut self, error: &str) -> &mut Self {
let tmp = self.next().result;
assert!(
matches!(
&tmp,
Err(e) if e.to_string() == error
),
"{tmp:?} didn't match {error}"
);
self
}
#[allow(dead_code)]
pub fn expect_errors(&mut self, errors: &[&str]) -> &mut Self {
for error in errors {
self.expect_error(error);
}
self
}
/// Expects the next value to be a floating-point number and compares it with the given value.
///
/// # Arguments
///
/// * `val` - The expected floating-point value
/// * `precision` - The allowed difference between the expected and actual value
///
/// # Panics
///
/// Panics if the next value is not a number or if the difference
/// between the expected and actual value exceeds the precision.
#[allow(dead_code)]
pub fn expect_float(&mut self, val: f64, precision: f64) -> &mut Self {
let tmp = self.next_value();
if let Value::Number(Number::Float(n)) = tmp {
let diff = (n - val).abs();
assert!(
diff <= precision,
"{tmp} does not match expected: {val} - diff: {diff} - precision: {precision}"
);
} else {
panic!("At position {}: Value {tmp} is not a number", self.pos);
}
self
}
#[allow(dead_code)]
pub fn expect_floats(&mut self, vals: &[f64], precision: f64) -> &mut Self {
for val in vals {
self.expect_float(*val, precision);
}
self
}
}
impl Drop for Test {
/// Drops the instance of the struct
/// This method will panic if there are remaining responses that have not been checked.
fn drop(&mut self) {
if !self.responses.is_empty() {
panic!("Not every response has been checked");
}
}
}

View file

@ -21,7 +21,7 @@ async fn select_where_matches_using_index() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(res, 3)?;
skip_ok(res, 3);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -74,7 +74,7 @@ async fn select_where_matches_without_using_index_iterator() -> Result<(), Error
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 6);
//
skip_ok(res, 4)?;
skip_ok(res, 4);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -133,7 +133,7 @@ async fn select_where_matches_using_index_and_arrays(parallel: bool) -> Result<(
let res = &mut dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(res, 3)?;
skip_ok(res, 3);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -204,7 +204,7 @@ async fn select_where_matches_partial_highlight() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 9);
//
skip_ok(res, 3)?;
skip_ok(res, 3);
//
for i in 0..2 {
let tmp = res.remove(0).result?;
@ -289,7 +289,7 @@ async fn select_where_matches_partial_highlight_ngram() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 10);
//
skip_ok(res, 3)?;
skip_ok(res, 3);
//
for i in 0..3 {
let tmp = res.remove(0).result?;
@ -375,7 +375,7 @@ async fn select_where_matches_using_index_and_objects(parallel: bool) -> Result<
let res = &mut dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(res, 3)?;
skip_ok(res, 3);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -443,7 +443,7 @@ async fn select_where_matches_using_index_offsets() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(res, 4)?;
skip_ok(res, 4);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -480,7 +480,7 @@ async fn select_where_matches_using_index_and_score() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
skip_ok(res, 6)?;
skip_ok(res, 6);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -517,7 +517,7 @@ async fn select_where_matches_without_using_index_and_score() -> Result<(), Erro
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 9);
//
skip_ok(res, 7)?;
skip_ok(res, 7);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -565,7 +565,7 @@ async fn select_where_matches_without_complex_query() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 10);
//
skip_ok(res, 6)?;
skip_ok(res, 6);
//
let tmp = res.remove(0).result?;
let val_docs = Value::parse(
@ -654,7 +654,7 @@ async fn select_where_matches_mixing_indexes() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
skip_ok(res, 5)?;
skip_ok(res, 5);
//
let tmp = res.remove(0).result?;
let val = Value::parse(

View file

@ -12,7 +12,7 @@ use surrealdb::sql::Value;
async fn select_where_iterate_three_multi_index() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, &three_multi_index_query("", ""), 12).await?;
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Tobie' }, { name: 'Lizzie' }]")?;
// OR results
check_result(&mut res, THREE_MULTI_INDEX_EXPLAIN)?;
@ -26,7 +26,7 @@ async fn select_where_iterate_three_multi_index() -> Result<(), Error> {
async fn select_where_iterate_three_multi_index_parallel() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, &three_multi_index_query("", "PARALLEL"), 12).await?;
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Tobie' }, { name: 'Lizzie' }]")?;
check_result(&mut res, THREE_MULTI_INDEX_EXPLAIN)?;
@ -45,7 +45,7 @@ async fn select_where_iterate_three_multi_index_with_all_index() -> Result<(), E
12,
)
.await?;
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Tobie' }, { name: 'Lizzie' }]")?;
check_result(&mut res, THREE_MULTI_INDEX_EXPLAIN)?;
@ -60,7 +60,7 @@ async fn select_where_iterate_three_multi_index_with_one_ft_index() -> Result<()
let dbs = new_ds().await?;
let mut res =
execute_test(&dbs, &three_multi_index_query("WITH INDEX ft_company", ""), 12).await?;
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Lizzie' }, { name: 'Tobie' } ]")?;
@ -76,7 +76,7 @@ async fn select_where_iterate_three_multi_index_with_one_index() -> Result<(), E
let dbs = new_ds().await?;
let mut res =
execute_test(&dbs, &three_multi_index_query("WITH INDEX uniq_name", ""), 12).await?;
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Lizzie' }, { name: 'Tobie' } ]")?;
@ -91,7 +91,7 @@ async fn select_where_iterate_three_multi_index_with_one_index() -> Result<(), E
async fn select_where_iterate_two_multi_index() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, &two_multi_index_query("", ""), 9).await?;
skip_ok(&mut res, 5)?;
skip_ok(&mut res, 5);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Tobie' }]")?;
check_result(&mut res, TWO_MULTI_INDEX_EXPLAIN)?;
@ -105,7 +105,7 @@ async fn select_where_iterate_two_multi_index() -> Result<(), Error> {
async fn select_where_iterate_two_multi_index_with_one_index() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, &two_multi_index_query("WITH INDEX idx_genre", ""), 9).await?;
skip_ok(&mut res, 5)?;
skip_ok(&mut res, 5);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Tobie' }]")?;
check_result(&mut res, &table_explain(2))?;
@ -120,7 +120,7 @@ async fn select_where_iterate_two_multi_index_with_two_index() -> Result<(), Err
let dbs = new_ds().await?;
let mut res =
execute_test(&dbs, &two_multi_index_query("WITH INDEX idx_genre,uniq_name", ""), 9).await?;
skip_ok(&mut res, 5)?;
skip_ok(&mut res, 5);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Tobie' }]")?;
check_result(&mut res, TWO_MULTI_INDEX_EXPLAIN)?;
@ -134,7 +134,7 @@ async fn select_where_iterate_two_multi_index_with_two_index() -> Result<(), Err
async fn select_where_iterate_two_no_index() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, &two_multi_index_query("WITH NOINDEX", ""), 9).await?;
skip_ok(&mut res, 5)?;
skip_ok(&mut res, 5);
// OR results
check_result(&mut res, "[{ name: 'Jaime' }, { name: 'Tobie' }]")?;
check_result(&mut res, &table_explain_no_index(2))?;
@ -547,7 +547,7 @@ async fn select_range(
) -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, &range_test(unique, from_incl, to_incl), 8).await?;
skip_ok(&mut res, 6)?;
skip_ok(&mut res, 6);
{
let tmp = res.remove(0).result?;
let val = Value::parse(explain);
@ -790,7 +790,7 @@ async fn select_single_range_operator(
) -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, &single_range_operator_test(unique, op), 6).await?;
skip_ok(&mut res, 4)?;
skip_ok(&mut res, 4);
{
let tmp = res.remove(0).result?;
let val = Value::parse(explain);
@ -990,7 +990,7 @@ async fn select_with_idiom_param_value() -> Result<(), Error> {
.to_owned();
let mut res = dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 6);
skip_ok(&mut res, 5)?;
skip_ok(&mut res, 5);
let tmp = res.remove(0).result?;
let val = Value::parse(
r#"[
@ -1079,7 +1079,7 @@ async fn test_contains(
let val = Value::parse(result);
assert_eq!(format!("{:#}", tmp), format!("{:#}", val));
}
skip_ok(&mut res, 1)?;
skip_ok(&mut res, 1);
{
let tmp = res.remove(0).result?;
let val = Value::parse(index_explain);
@ -1097,7 +1097,7 @@ async fn test_contains(
async fn select_contains() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, CONTAINS_CONTENT, 3).await?;
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
const SQL: &str = r#"
SELECT id FROM student WHERE marks.*.subject CONTAINS "english" EXPLAIN;
@ -1142,7 +1142,7 @@ async fn select_contains() -> Result<(), Error> {
async fn select_contains_all() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, CONTAINS_CONTENT, 3).await?;
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
const SQL: &str = r#"
SELECT id FROM student WHERE marks.*.subject CONTAINSALL ["hindi", "maths"] EXPLAIN;
SELECT id FROM student WHERE marks.*.subject CONTAINSALL ["hindi", "maths"];
@ -1185,7 +1185,7 @@ async fn select_contains_all() -> Result<(), Error> {
async fn select_contains_any() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, CONTAINS_CONTENT, 3).await?;
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
const SQL: &str = r#"
SELECT id FROM student WHERE marks.*.subject CONTAINSANY ["tamil", "french"] EXPLAIN;
SELECT id FROM student WHERE marks.*.subject CONTAINSANY ["tamil", "french"];
@ -1233,7 +1233,7 @@ const CONTAINS_UNIQUE_CONTENT: &str = r#"
async fn select_unique_contains() -> Result<(), Error> {
let dbs = new_ds().await?;
let mut res = execute_test(&dbs, CONTAINS_UNIQUE_CONTENT, 3).await?;
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
const SQL: &str = r#"
SELECT id FROM student WHERE subject CONTAINS "english" EXPLAIN;
@ -1291,7 +1291,7 @@ async fn select_with_datetime_value() -> Result<(), Error> {
let mut res = dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 8);
skip_ok(&mut res, 4)?;
skip_ok(&mut res, 4);
for _ in 0..2 {
let tmp = res.remove(0).result?;
@ -1354,7 +1354,7 @@ async fn select_with_uuid_value() -> Result<(), Error> {
let mut res = dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
for _ in 0..2 {
let tmp = res.remove(0).result?;
@ -1415,7 +1415,7 @@ async fn select_with_in_operator() -> Result<(), Error> {
let mut res = dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
for _ in 0..2 {
let tmp = res.remove(0).result?;
@ -1476,7 +1476,7 @@ async fn select_with_in_operator_uniq_index() -> Result<(), Error> {
let mut res = dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 8);
skip_ok(&mut res, 2)?;
skip_ok(&mut res, 2);
let tmp = res.remove(0).result?;
let val = Value::parse(r#"[]"#);
@ -1549,7 +1549,7 @@ async fn select_with_in_operator_multiple_indexes() -> Result<(), Error> {
let mut res = dbs.execute(sql, &ses, None).await?;
//
assert_eq!(res.len(), 17);
skip_ok(&mut res, 9)?;
skip_ok(&mut res, 9);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -1933,7 +1933,7 @@ async fn select_with_record_id_link_no_index() -> Result<(), Error> {
let mut res = dbs.execute(&sql, &ses, None).await?;
//
assert_eq!(res.len(), 8);
skip_ok(&mut res, 6)?;
skip_ok(&mut res, 6);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -1992,7 +1992,7 @@ async fn select_with_record_id_link_index() -> Result<(), Error> {
let mut res = dbs.execute(&sql, &ses, None).await?;
//
assert_eq!(res.len(), 10);
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
//
let expected = Value::parse(
r#"[
@ -2057,7 +2057,7 @@ async fn select_with_record_id_link_unique_index() -> Result<(), Error> {
let mut res = dbs.execute(&sql, &ses, None).await?;
//
assert_eq!(res.len(), 10);
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
//
let expected = Value::parse(
r#"[
@ -2121,7 +2121,7 @@ async fn select_with_record_id_link_unique_remote_index() -> Result<(), Error> {
let mut res = dbs.execute(&sql, &ses, None).await?;
//
assert_eq!(res.len(), 10);
skip_ok(&mut res, 8)?;
skip_ok(&mut res, 8);
//
let expected = Value::parse(
r#"[
@ -2188,7 +2188,7 @@ async fn select_with_record_id_link_full_text_index() -> Result<(), Error> {
let mut res = dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 9);
skip_ok(&mut res, 7)?;
skip_ok(&mut res, 7);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -2245,7 +2245,7 @@ async fn select_with_record_id_link_full_text_no_record_index() -> Result<(), Er
let mut res = dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 8);
skip_ok(&mut res, 6)?;
skip_ok(&mut res, 6);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -2313,7 +2313,7 @@ async fn select_with_record_id_index() -> Result<(), Error> {
);
//
assert_eq!(res.len(), 15);
skip_ok(&mut res, 2)?;
skip_ok(&mut res, 2);
//
for t in ["CONTAINS", "CONTAINSANY", "IN"] {
let tmp = res.remove(0).result?;
@ -2345,7 +2345,7 @@ async fn select_with_record_id_index() -> Result<(), Error> {
assert_eq!(format!("{:#}", tmp), format!("{:#}", val));
}
//
skip_ok(&mut res, 1)?;
skip_ok(&mut res, 1);
// CONTAINS
let tmp = res.remove(0).result?;
assert_eq!(format!("{:#}", tmp), format!("{:#}", expected));
@ -2450,7 +2450,7 @@ async fn select_with_exact_operator() -> Result<(), Error> {
let mut res = dbs.execute(&sql, &ses, None).await?;
//
assert_eq!(res.len(), 8);
skip_ok(&mut res, 4)?;
skip_ok(&mut res, 4);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -2553,7 +2553,7 @@ async fn select_with_non_boolean_expression() -> Result<(), Error> {
let mut res = dbs.execute(&sql, &ses, None).await?;
//
assert_eq!(res.len(), 15);
skip_ok(&mut res, 5)?;
skip_ok(&mut res, 5);
//
for i in 0..5 {
let tmp = res.remove(0).result?;

View file

@ -133,7 +133,7 @@ async fn define_foreign_table_no_doubles() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
skip_ok(res, 5)?;
skip_ok(res, 5);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -218,11 +218,11 @@ async fn define_foreign_table_group(cond: bool, agr: &str) -> Result<(), Error>
let res = &mut dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 29);
//
skip_ok(res, 2)?;
skip_ok(res, 2);
//
for i in 0..9 {
// Skip the UPDATE or DELETE statement
skip_ok(res, 1)?;
skip_ok(res, 1);
// Get the computed result
let comp = res.remove(0).result?;
// Get the projected result

View file

@ -170,7 +170,7 @@ async fn select_where_brute_force_knn() -> Result<(), Error> {
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
skip_ok(res, 4)?;
skip_ok(res, 4);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -301,7 +301,7 @@ async fn select_mtree_knn_with_condition() -> Result<(), Error> {
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -373,7 +373,7 @@ async fn select_hnsw_knn_with_condition() -> Result<(), Error> {
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 5);
//
skip_ok(&mut res, 3)?;
skip_ok(&mut res, 3);
//
let tmp = res.remove(0).result?;
let val = Value::parse(
@ -444,7 +444,7 @@ async fn select_bruteforce_knn_with_condition() -> Result<(), Error> {
let mut res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 4);
//
skip_ok(&mut res, 2)?;
skip_ok(&mut res, 2);
//
let tmp = res.remove(0).result?;
let val = Value::parse(