Add vector::scale()
method (#4292)
Co-authored-by: Micha de Vries <micha@devrie.sh>
This commit is contained in:
parent
0069cba8a3
commit
212d5a9e5a
7 changed files with 70 additions and 1 deletions
|
@ -399,6 +399,7 @@ pub fn synchronous(
|
|||
"vector::multiply" => vector::multiply,
|
||||
"vector::normalize" => vector::normalize,
|
||||
"vector::project" => vector::project,
|
||||
"vector::scale" => vector::scale,
|
||||
"vector::subtract" => vector::subtract,
|
||||
"vector::distance::chebyshev" => vector::distance::chebyshev,
|
||||
"vector::distance::euclidean" => vector::distance::euclidean,
|
||||
|
@ -489,6 +490,7 @@ pub async fn idiom(
|
|||
"vector_multiply" => vector::multiply,
|
||||
"vector_normalize" => vector::normalize,
|
||||
"vector_project" => vector::project,
|
||||
"vector_scale" => vector::scale,
|
||||
"vector_subtract" => vector::subtract,
|
||||
"vector_distance_chebyshev" => vector::distance::chebyshev,
|
||||
"vector_distance_euclidean" => vector::distance::euclidean,
|
||||
|
|
|
@ -21,5 +21,6 @@ impl_module_def!(
|
|||
"multiply" => run,
|
||||
"normalize" => run,
|
||||
"project" => run,
|
||||
"scale" => run,
|
||||
"subtract" => run
|
||||
);
|
||||
|
|
|
@ -170,6 +170,16 @@ impl Multiply for Vec<Number> {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait Scale {
|
||||
fn scale(&self, other: &Number) -> Result<Vec<Number>, Error>;
|
||||
}
|
||||
|
||||
impl Scale for Vec<Number> {
|
||||
fn scale(&self, other: &Number) -> Result<Vec<Number>, Error> {
|
||||
Ok(self.iter().map(|a| a * other).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Project {
|
||||
/// Projection of two vectors
|
||||
fn project(&self, other: &Self) -> Result<Vec<Number>, Error>;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::err::Error;
|
||||
use crate::fnc::util::math::vector::{
|
||||
Add, Angle, CrossProduct, Divide, DotProduct, Magnitude, Multiply, Normalize, Project, Subtract,
|
||||
Add, Angle, CrossProduct, Divide, DotProduct, Magnitude, Multiply, Normalize, Project, Scale,
|
||||
Subtract,
|
||||
};
|
||||
use crate::sql::{Number, Value};
|
||||
|
||||
|
@ -44,6 +45,10 @@ pub fn subtract((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
|||
Ok(a.subtract(&b)?.into())
|
||||
}
|
||||
|
||||
pub fn scale((a, b): (Vec<Number>, Number)) -> Result<Value, Error> {
|
||||
Ok(a.scale(&b)?.into())
|
||||
}
|
||||
|
||||
pub mod distance {
|
||||
use crate::ctx::Context;
|
||||
use crate::doc::CursorDoc;
|
||||
|
@ -160,3 +165,51 @@ impl TryFrom<Value> for Vec<Number> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::sql::Number;
|
||||
use rust_decimal::Decimal;
|
||||
|
||||
#[test]
|
||||
fn vector_scale_int() {
|
||||
let input_vector: Vec<Number> = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect();
|
||||
let scalar_int = Number::Int(2);
|
||||
|
||||
let result: Result<Value, Error> = scale((input_vector.clone(), scalar_int.clone()));
|
||||
|
||||
let expected_output: Vec<Number> = vec![2, 4, 6, 8].into_iter().map(Number::Int).collect();
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected_output.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vector_scale_float() {
|
||||
let input_vector: Vec<Number> = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect();
|
||||
let scalar_float = Number::Float(1.51);
|
||||
|
||||
let result: Result<Value, Error> = scale((input_vector.clone(), scalar_float.clone()));
|
||||
let expected_output: Vec<Number> =
|
||||
vec![1.51, 3.02, 4.53, 6.04].into_iter().map(Number::Float).collect();
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected_output.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vector_scale_decimal() {
|
||||
let input_vector: Vec<Number> = vec![1, 2, 3, 4].into_iter().map(Number::Int).collect();
|
||||
let scalar_decimal = Number::Decimal(Decimal::new(3141, 3));
|
||||
|
||||
let result: Result<Value, Error> = scale((input_vector.clone(), scalar_decimal.clone()));
|
||||
let expected_output: Vec<Number> = vec![
|
||||
Number::Decimal(Decimal::new(3141, 3)), // 3.141 * 1
|
||||
Number::Decimal(Decimal::new(6282, 3)), // 3.141 * 2
|
||||
Number::Decimal(Decimal::new(9423, 3)), // 3.141 * 3
|
||||
Number::Decimal(Decimal::new(12564, 3)), // 3.141 * 4
|
||||
];
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected_output.into());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -387,6 +387,7 @@ pub(crate) static PATHS: phf::Map<UniCase<&'static str>, PathKind> = phf_map! {
|
|||
UniCase::ascii("vector::multiply") => PathKind::Function,
|
||||
UniCase::ascii("vector::normalize") => PathKind::Function,
|
||||
UniCase::ascii("vector::project") => PathKind::Function,
|
||||
UniCase::ascii("vector::scale") => PathKind::Function,
|
||||
UniCase::ascii("vector::subtract") => PathKind::Function,
|
||||
UniCase::ascii("vector::distance::chebyshev") => PathKind::Function,
|
||||
UniCase::ascii("vector::distance::euclidean") => PathKind::Function,
|
||||
|
|
|
@ -422,6 +422,7 @@
|
|||
"vector::multiply("
|
||||
"vector::normalize("
|
||||
"vector::project("
|
||||
"vector::scale("
|
||||
"vector::subtract("
|
||||
"vector::distance::chebyshev("
|
||||
"vector::distance::euclidean("
|
||||
|
|
|
@ -421,6 +421,7 @@
|
|||
"vector::multiply("
|
||||
"vector::normalize("
|
||||
"vector::project("
|
||||
"vector::scale("
|
||||
"vector::subtract("
|
||||
"vector::distance::chebyshev("
|
||||
"vector::distance::euclidean("
|
||||
|
|
Loading…
Reference in a new issue