Feature: 1903 Basic Vector Functions (#1907)

Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
Co-authored-by: Emmanuel Keller <emmanuel.keller@surrealdb.com>
This commit is contained in:
Tim 2023-07-14 20:00:07 +02:00 committed by GitHub
parent 5c08be973d
commit b83cd86f9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 377 additions and 22 deletions

View file

@ -29,6 +29,7 @@ pub mod string;
pub mod time;
pub mod r#type;
pub mod util;
pub mod vector;
/// Attempts to run any function
pub async fn run(
@ -282,6 +283,19 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec<Value>) -> Result<Va
"type::string" => r#type::string,
"type::table" => r#type::table,
"type::thing" => r#type::thing,
//
"vector::dotproduct" => vector::dotproduct,
"vector::magnitude" => vector::magnitude,
"vector::distance::chebyshev" => vector::distance::chebyshev,
"vector::distance::euclidean" => vector::distance::euclidean,
"vector::distance::hamming" => vector::distance::hamming,
"vector::distance::mahalanobis" => vector::distance::mahalanobis,
"vector::distance::manhattan" => vector::distance::manhattan,
"vector::distance::minkowski" => vector::distance::minkowski,
"vector::similarity::cosine" => vector::similarity::cosine,
"vector::similarity::jaccard" => vector::similarity::jaccard,
"vector::similarity::pearson" => vector::similarity::pearson,
"vector::similarity::spearman" => vector::similarity::spearman,
)
}

View file

@ -22,6 +22,7 @@ mod session;
mod string;
mod time;
mod r#type;
mod vector;
pub struct Package;
@ -48,7 +49,8 @@ impl_module_def!(
"sleep" => fut Async,
"string" => (string::Package),
"time" => (time::Package),
"type" => (r#type::Package)
"type" => (r#type::Package),
"vector" => (vector::Package)
);
fn run(js_ctx: js::Ctx<'_>, name: &str, args: Vec<Value>) -> Result<Value> {

View file

@ -0,0 +1,15 @@
use super::run;
use crate::fnc::script::modules::impl_module_def;
mod distance;
mod similarity;
pub struct Package;
impl_module_def!(
Package,
"vector",
"distance" => (distance::Package),
"dotproduct" => run,
"magnitude" => run,
"similarity" => (similarity::Package)
);

View file

@ -0,0 +1,15 @@
use super::run;
use crate::fnc::script::modules::impl_module_def;
pub struct Package;
impl_module_def!(
Package,
"vector::distance",
"chebyshev" => run,
"euclidean" => run,
"hamming" => run,
"mahalanobis" => run,
"manhattan" => run,
"minkowski" => run
);

View file

@ -0,0 +1,13 @@
use super::run;
use crate::fnc::script::modules::impl_module_def;
pub struct Package;
impl_module_def!(
Package,
"vector::similarity",
"cosine" => run,
"jaccard" => run,
"pearson" => run,
"spearman" => run
);

View file

@ -0,0 +1,15 @@
use crate::sql::Number;
pub trait DotProduct {
/// Dot Product of two vectors
fn dotproduct(&self, other: &Self) -> Option<Number>;
}
impl DotProduct for Vec<Number> {
fn dotproduct(&self, other: &Self) -> Option<Number> {
if self.len() != other.len() {
return None;
}
Some(self.iter().zip(other.iter()).map(|(a, b)| a * b).sum())
}
}

View file

@ -0,0 +1,21 @@
use crate::sql::Number;
pub trait EuclideanDistance {
/// Euclidean Distance between two vectors (L2 Norm)
fn euclidean_distance(&self, other: &Self) -> Option<Number>;
}
impl EuclideanDistance for Vec<Number> {
fn euclidean_distance(&self, other: &Self) -> Option<Number> {
if self.len() != other.len() {
return None;
}
Some(
self.iter()
.zip(other.iter())
.map(|(a, b)| (a - b).pow(Number::Int(2)))
.sum::<Number>()
.sqrt(),
)
}
}

View file

@ -0,0 +1,12 @@
use crate::sql::Number;
pub trait Magnitude {
/// Calculate the magnitude of a vector
fn magnitude(&self) -> Number;
}
impl Magnitude for Vec<Number> {
fn magnitude(&self) -> Number {
self.iter().map(|a| a.clone().pow(Number::Int(2))).sum::<Number>().sqrt()
}
}

View file

@ -4,7 +4,10 @@
pub mod bottom;
pub mod deviation;
pub mod dotproduct;
pub mod euclideandistance;
pub mod interquartile;
pub mod magnitude;
pub mod mean;
pub mod median;
pub mod midhinge;

101
lib/src/fnc/vector.rs Normal file
View file

@ -0,0 +1,101 @@
use crate::err::Error;
use crate::fnc::util::math::dotproduct::DotProduct;
use crate::fnc::util::math::magnitude::Magnitude;
use crate::sql::{Number, Value};
pub fn dotproduct((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
match a.dotproduct(&b) {
None => Err(Error::InvalidArguments {
name: String::from("vector::dotproduct"),
message: String::from("The two vectors must be of the same length."),
}),
Some(dot) => Ok(dot.into()),
}
}
pub fn magnitude((a,): (Vec<Number>,)) -> Result<Value, Error> {
Ok(a.magnitude().into())
}
pub mod distance {
use crate::err::Error;
use crate::fnc::util::math::euclideandistance::EuclideanDistance;
use crate::sql::{Number, Value};
pub fn chebyshev((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::distance::chebyshev() function",
})
}
pub fn euclidean((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
match a.euclidean_distance(&b) {
None => Err(Error::InvalidArguments {
name: String::from("vector::distance::euclidean"),
message: String::from("The two vectors must be of the same length."),
}),
Some(distance) => Ok(distance.into()),
}
}
pub fn hamming((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::distance::hamming() function",
})
}
pub fn mahalanobis((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::distance::mahalanobis() function",
})
}
pub fn manhattan((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::distance::manhattan() function",
})
}
pub fn minkowski((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::distance::minkowski() function",
})
}
}
pub mod similarity {
use crate::err::Error;
use crate::fnc::util::math::dotproduct::DotProduct;
use crate::fnc::util::math::magnitude::Magnitude;
use crate::sql::{Number, Value};
pub fn cosine((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
match a.dotproduct(&b) {
None => Err(Error::InvalidArguments {
name: String::from("vector::similarity::cosine"),
message: String::from("The two vectors must be of the same length."),
}),
Some(dot) => Ok((dot / (a.magnitude() * b.magnitude())).into()),
}
}
pub fn jaccard((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::similarity::jaccard() function",
})
}
pub fn pearson((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::similarity::pearson() function",
})
}
pub fn spearman((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
Err(Error::FeatureNotYetImplemented {
feature: "vector::similarity::spearman() function",
})
}
}

View file

@ -248,27 +248,27 @@ fn script(i: &str) -> IResult<&str, Function> {
pub(crate) fn function_names(i: &str) -> IResult<&str, &str> {
recognize(alt((
preceded(tag("array::"), function_array),
preceded(tag("bytes::"), function_bytes),
preceded(tag("crypto::"), function_crypto),
preceded(tag("duration::"), function_duration),
preceded(tag("encoding::"), function_encoding),
preceded(tag("geo::"), function_geo),
preceded(tag("http::"), function_http),
preceded(tag("is::"), function_is),
preceded(tag("math::"), function_math),
preceded(tag("meta::"), function_meta),
preceded(tag("parse::"), function_parse),
preceded(tag("rand::"), function_rand),
preceded(tag("search::"), function_search),
preceded(tag("session::"), function_session),
preceded(tag("string::"), function_string),
preceded(tag("time::"), function_time),
preceded(tag("type::"), function_type),
tag("count"),
tag("not"),
tag("rand"),
tag("sleep"),
alt((
preceded(tag("array::"), function_array),
preceded(tag("bytes::"), function_bytes),
preceded(tag("crypto::"), function_crypto),
preceded(tag("duration::"), function_duration),
preceded(tag("encoding::"), function_encoding),
preceded(tag("geo::"), function_geo),
preceded(tag("http::"), function_http),
preceded(tag("is::"), function_is),
preceded(tag("math::"), function_math),
preceded(tag("meta::"), function_meta),
preceded(tag("parse::"), function_parse),
preceded(tag("rand::"), function_rand),
preceded(tag("search::"), function_search),
preceded(tag("session::"), function_session),
preceded(tag("string::"), function_string),
preceded(tag("time::"), function_time),
preceded(tag("type::"), function_type),
preceded(tag("vector::"), function_vector),
)),
alt((tag("count"), tag("not"), tag("rand"), tag("sleep"))),
)))(i)
}
@ -551,6 +551,28 @@ fn function_type(i: &str) -> IResult<&str, &str> {
))(i)
}
fn function_vector(i: &str) -> IResult<&str, &str> {
alt((
tag("dotproduct"),
tag("magnitude"),
preceded(
tag("distance::"),
alt((
tag("chebyshev"),
tag("euclidean"),
tag("hamming"),
tag("mahalanobis"),
tag("manhattan"),
tag("minkowski"),
)),
),
preceded(
tag("similarity::"),
alt((tag("cosine"), tag("jaccard"), tag("pearson"), tag("spearman"))),
),
))(i)
}
#[cfg(test)]
mod tests {

View file

@ -4574,3 +4574,125 @@ async fn function_type_thing() -> Result<(), Error> {
//
Ok(())
}
#[tokio::test]
async fn function_vector_distance_euclidean() -> Result<(), Error> {
let sql = r#"
RETURN vector::distance::euclidean([1, 2, 3], [1, 2, 3]);
RETURN vector::distance::euclidean([1, 2, 3], [-1, -2, -3]);
RETURN vector::distance::euclidean([1, 2, 3], [4, 5]);
RETURN vector::distance::euclidean([1, 2], [4, 5, 5]);
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 4);
//
let tmp = res.remove(0).result?;
let val = Value::from(0);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::from(7.483314773547883);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result;
assert!(tmp.is_err());
//
let tmp = res.remove(0).result;
assert!(tmp.is_err());
Ok(())
}
#[tokio::test]
async fn function_vector_dotproduct() -> Result<(), Error> {
let sql = r#"
RETURN vector::dotproduct([1, 2, 3], [1, 2, 3]);
RETURN vector::dotproduct([1, 2, 3], [-1, -2, -3]);
RETURN vector::dotproduct([1, 2, 3], [4, 5]);
RETURN vector::dotproduct([1, 2], [4, 5, 5]);
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 4);
//
let tmp = res.remove(0).result?;
let val = Value::from(14);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::from(-14);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result;
assert!(tmp.is_err());
//
let tmp = res.remove(0).result;
assert!(tmp.is_err());
Ok(())
}
#[tokio::test]
async fn function_vector_magnitude() -> Result<(), Error> {
let sql = r#"
RETURN vector::magnitude([]);
RETURN vector::magnitude([1]);
RETURN vector::magnitude([5]);
RETURN vector::magnitude([1,2,3,3,3,4,5]);
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 4);
//
let tmp = res.remove(0).result?;
let val = Value::from(0);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::from(1);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::from(5);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::from(8.54400374531753);
assert_eq!(tmp, val);
Ok(())
}
#[tokio::test]
async fn function_vector_similarity_cosine() -> Result<(), Error> {
let sql = r#"
RETURN vector::similarity::cosine([1, 2, 3], [1, 2, 3]);
RETURN vector::similarity::cosine([1, 2, 3], [-1, -2, -3]);
RETURN vector::similarity::cosine([1, 2, 3], [4, 5]);
RETURN vector::similarity::cosine([1, 2], [4, 5, 5]);
"#;
let dbs = Datastore::new("memory").await?;
let ses = Session::for_kv().with_ns("test").with_db("test");
let res = &mut dbs.execute(&sql, &ses, None).await?;
assert_eq!(res.len(), 4);
//
let tmp = res.remove(0).result?;
let val = Value::from(1.0);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result?;
let val = Value::from(-1.0);
assert_eq!(tmp, val);
//
let tmp = res.remove(0).result;
assert!(tmp.is_err());
//
let tmp = res.remove(0).result;
assert!(tmp.is_err());
Ok(())
}