From 212d5a9e5abbf101cb63639fbf01e26caaf7c54d Mon Sep 17 00:00:00 2001 From: ekgns33 <76658405+ekgns33@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:01:10 +0900 Subject: [PATCH] Add `vector::scale()` method (#4292) Co-authored-by: Micha de Vries --- core/src/fnc/mod.rs | 2 + .../modules/surrealdb/functions/vector.rs | 1 + core/src/fnc/util/math/vector.rs | 10 ++++ core/src/fnc/vector.rs | 55 ++++++++++++++++++- core/src/syn/parser/builtin.rs | 1 + lib/fuzz/fuzz_targets/fuzz_executor.dict | 1 + lib/fuzz/fuzz_targets/fuzz_sql_parser.dict | 1 + 7 files changed, 70 insertions(+), 1 deletion(-) diff --git a/core/src/fnc/mod.rs b/core/src/fnc/mod.rs index 2b1e06fa..03143b95 100644 --- a/core/src/fnc/mod.rs +++ b/core/src/fnc/mod.rs @@ -399,6 +399,7 @@ pub fn synchronous( "vector::multiply" => vector::multiply, "vector::normalize" => vector::normalize, "vector::project" => vector::project, + "vector::scale" => vector::scale, "vector::subtract" => vector::subtract, "vector::distance::chebyshev" => vector::distance::chebyshev, "vector::distance::euclidean" => vector::distance::euclidean, @@ -489,6 +490,7 @@ pub async fn idiom( "vector_multiply" => vector::multiply, "vector_normalize" => vector::normalize, "vector_project" => vector::project, + "vector_scale" => vector::scale, "vector_subtract" => vector::subtract, "vector_distance_chebyshev" => vector::distance::chebyshev, "vector_distance_euclidean" => vector::distance::euclidean, diff --git a/core/src/fnc/script/modules/surrealdb/functions/vector.rs b/core/src/fnc/script/modules/surrealdb/functions/vector.rs index b81035d7..bb6bb1be 100644 --- a/core/src/fnc/script/modules/surrealdb/functions/vector.rs +++ b/core/src/fnc/script/modules/surrealdb/functions/vector.rs @@ -21,5 +21,6 @@ impl_module_def!( "multiply" => run, "normalize" => run, "project" => run, + "scale" => run, "subtract" => run ); diff --git a/core/src/fnc/util/math/vector.rs b/core/src/fnc/util/math/vector.rs index b315e1cb..779b4826 100644 --- a/core/src/fnc/util/math/vector.rs +++ b/core/src/fnc/util/math/vector.rs @@ -170,6 +170,16 @@ impl Multiply for Vec { } } +pub trait Scale { + fn scale(&self, other: &Number) -> Result, Error>; +} + +impl Scale for Vec { + fn scale(&self, other: &Number) -> Result, Error> { + Ok(self.iter().map(|a| a * other).collect()) + } +} + pub trait Project { /// Projection of two vectors fn project(&self, other: &Self) -> Result, Error>; diff --git a/core/src/fnc/vector.rs b/core/src/fnc/vector.rs index 48b9e529..80be98a3 100644 --- a/core/src/fnc/vector.rs +++ b/core/src/fnc/vector.rs @@ -1,6 +1,7 @@ use crate::err::Error; use crate::fnc::util::math::vector::{ - Add, Angle, CrossProduct, Divide, DotProduct, Magnitude, Multiply, Normalize, Project, Subtract, + Add, Angle, CrossProduct, Divide, DotProduct, Magnitude, Multiply, Normalize, Project, Scale, + Subtract, }; use crate::sql::{Number, Value}; @@ -44,6 +45,10 @@ pub fn subtract((a, b): (Vec, Vec)) -> Result { Ok(a.subtract(&b)?.into()) } +pub fn scale((a, b): (Vec, Number)) -> Result { + Ok(a.scale(&b)?.into()) +} + pub mod distance { use crate::ctx::Context; use crate::doc::CursorDoc; @@ -160,3 +165,51 @@ impl TryFrom for Vec { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::sql::Number; + use rust_decimal::Decimal; + + #[test] + fn vector_scale_int() { + let input_vector: Vec = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect(); + let scalar_int = Number::Int(2); + + let result: Result = scale((input_vector.clone(), scalar_int.clone())); + + let expected_output: Vec = vec![2, 4, 6, 8].into_iter().map(Number::Int).collect(); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_output.into()); + } + + #[test] + fn vector_scale_float() { + let input_vector: Vec = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect(); + let scalar_float = Number::Float(1.51); + + let result: Result = scale((input_vector.clone(), scalar_float.clone())); + let expected_output: Vec = + vec![1.51, 3.02, 4.53, 6.04].into_iter().map(Number::Float).collect(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_output.into()); + } + + #[test] + fn vector_scale_decimal() { + let input_vector: Vec = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect(); + let scalar_decimal = Number::Decimal(Decimal::new(3141, 3)); + + let result: Result = scale((input_vector.clone(), scalar_decimal.clone())); + let expected_output: Vec = vec![ + Number::Decimal(Decimal::new(3141, 3)), // 3.141 * 1 + Number::Decimal(Decimal::new(6282, 3)), // 3.141 * 2 + Number::Decimal(Decimal::new(9423, 3)), // 3.141 * 3 + Number::Decimal(Decimal::new(12564, 3)), // 3.141 * 4 + ]; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_output.into()); + } +} diff --git a/core/src/syn/parser/builtin.rs b/core/src/syn/parser/builtin.rs index 6fa4859f..a557810c 100644 --- a/core/src/syn/parser/builtin.rs +++ b/core/src/syn/parser/builtin.rs @@ -387,6 +387,7 @@ pub(crate) static PATHS: phf::Map, PathKind> = phf_map! { UniCase::ascii("vector::multiply") => PathKind::Function, UniCase::ascii("vector::normalize") => PathKind::Function, UniCase::ascii("vector::project") => PathKind::Function, + UniCase::ascii("vector::scale") => PathKind::Function, UniCase::ascii("vector::subtract") => PathKind::Function, UniCase::ascii("vector::distance::chebyshev") => PathKind::Function, UniCase::ascii("vector::distance::euclidean") => PathKind::Function, diff --git a/lib/fuzz/fuzz_targets/fuzz_executor.dict b/lib/fuzz/fuzz_targets/fuzz_executor.dict index 86630bbe..5910a923 100644 --- a/lib/fuzz/fuzz_targets/fuzz_executor.dict +++ b/lib/fuzz/fuzz_targets/fuzz_executor.dict @@ -422,6 +422,7 @@ "vector::multiply(" "vector::normalize(" "vector::project(" +"vector::scale(" "vector::subtract(" "vector::distance::chebyshev(" "vector::distance::euclidean(" diff --git a/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict b/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict index 4a8afc4b..6b28e3e5 100644 --- a/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict +++ b/lib/fuzz/fuzz_targets/fuzz_sql_parser.dict @@ -421,6 +421,7 @@ "vector::multiply(" "vector::normalize(" "vector::project(" +"vector::scale(" "vector::subtract(" "vector::distance::chebyshev(" "vector::distance::euclidean("