Feature: 1903 Basic Vector Functions (#1907)
Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com> Co-authored-by: Emmanuel Keller <emmanuel.keller@surrealdb.com>
This commit is contained in:
parent
5c08be973d
commit
b83cd86f9d
12 changed files with 377 additions and 22 deletions
|
@ -29,6 +29,7 @@ pub mod string;
|
|||
pub mod time;
|
||||
pub mod r#type;
|
||||
pub mod util;
|
||||
pub mod vector;
|
||||
|
||||
/// Attempts to run any function
|
||||
pub async fn run(
|
||||
|
@ -282,6 +283,19 @@ pub fn synchronous(ctx: &Context<'_>, name: &str, args: Vec<Value>) -> Result<Va
|
|||
"type::string" => r#type::string,
|
||||
"type::table" => r#type::table,
|
||||
"type::thing" => r#type::thing,
|
||||
//
|
||||
"vector::dotproduct" => vector::dotproduct,
|
||||
"vector::magnitude" => vector::magnitude,
|
||||
"vector::distance::chebyshev" => vector::distance::chebyshev,
|
||||
"vector::distance::euclidean" => vector::distance::euclidean,
|
||||
"vector::distance::hamming" => vector::distance::hamming,
|
||||
"vector::distance::mahalanobis" => vector::distance::mahalanobis,
|
||||
"vector::distance::manhattan" => vector::distance::manhattan,
|
||||
"vector::distance::minkowski" => vector::distance::minkowski,
|
||||
"vector::similarity::cosine" => vector::similarity::cosine,
|
||||
"vector::similarity::jaccard" => vector::similarity::jaccard,
|
||||
"vector::similarity::pearson" => vector::similarity::pearson,
|
||||
"vector::similarity::spearman" => vector::similarity::spearman,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ mod session;
|
|||
mod string;
|
||||
mod time;
|
||||
mod r#type;
|
||||
mod vector;
|
||||
|
||||
pub struct Package;
|
||||
|
||||
|
@ -48,7 +49,8 @@ impl_module_def!(
|
|||
"sleep" => fut Async,
|
||||
"string" => (string::Package),
|
||||
"time" => (time::Package),
|
||||
"type" => (r#type::Package)
|
||||
"type" => (r#type::Package),
|
||||
"vector" => (vector::Package)
|
||||
);
|
||||
|
||||
fn run(js_ctx: js::Ctx<'_>, name: &str, args: Vec<Value>) -> Result<Value> {
|
||||
|
|
15
lib/src/fnc/script/modules/surrealdb/functions/vector.rs
Normal file
15
lib/src/fnc/script/modules/surrealdb/functions/vector.rs
Normal file
|
@ -0,0 +1,15 @@
|
|||
use super::run;
|
||||
use crate::fnc::script::modules::impl_module_def;
|
||||
|
||||
mod distance;
|
||||
mod similarity;
|
||||
pub struct Package;
|
||||
|
||||
impl_module_def!(
|
||||
Package,
|
||||
"vector",
|
||||
"distance" => (distance::Package),
|
||||
"dotproduct" => run,
|
||||
"magnitude" => run,
|
||||
"similarity" => (similarity::Package)
|
||||
);
|
|
@ -0,0 +1,15 @@
|
|||
use super::run;
|
||||
use crate::fnc::script::modules::impl_module_def;
|
||||
|
||||
pub struct Package;
|
||||
|
||||
impl_module_def!(
|
||||
Package,
|
||||
"vector::distance",
|
||||
"chebyshev" => run,
|
||||
"euclidean" => run,
|
||||
"hamming" => run,
|
||||
"mahalanobis" => run,
|
||||
"manhattan" => run,
|
||||
"minkowski" => run
|
||||
);
|
|
@ -0,0 +1,13 @@
|
|||
use super::run;
|
||||
use crate::fnc::script::modules::impl_module_def;
|
||||
|
||||
pub struct Package;
|
||||
|
||||
impl_module_def!(
|
||||
Package,
|
||||
"vector::similarity",
|
||||
"cosine" => run,
|
||||
"jaccard" => run,
|
||||
"pearson" => run,
|
||||
"spearman" => run
|
||||
);
|
15
lib/src/fnc/util/math/dotproduct.rs
Normal file
15
lib/src/fnc/util/math/dotproduct.rs
Normal file
|
@ -0,0 +1,15 @@
|
|||
use crate::sql::Number;
|
||||
|
||||
pub trait DotProduct {
|
||||
/// Dot Product of two vectors
|
||||
fn dotproduct(&self, other: &Self) -> Option<Number>;
|
||||
}
|
||||
|
||||
impl DotProduct for Vec<Number> {
|
||||
fn dotproduct(&self, other: &Self) -> Option<Number> {
|
||||
if self.len() != other.len() {
|
||||
return None;
|
||||
}
|
||||
Some(self.iter().zip(other.iter()).map(|(a, b)| a * b).sum())
|
||||
}
|
||||
}
|
21
lib/src/fnc/util/math/euclideandistance.rs
Normal file
21
lib/src/fnc/util/math/euclideandistance.rs
Normal file
|
@ -0,0 +1,21 @@
|
|||
use crate::sql::Number;
|
||||
|
||||
pub trait EuclideanDistance {
|
||||
/// Euclidean Distance between two vectors (L2 Norm)
|
||||
fn euclidean_distance(&self, other: &Self) -> Option<Number>;
|
||||
}
|
||||
|
||||
impl EuclideanDistance for Vec<Number> {
|
||||
fn euclidean_distance(&self, other: &Self) -> Option<Number> {
|
||||
if self.len() != other.len() {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
self.iter()
|
||||
.zip(other.iter())
|
||||
.map(|(a, b)| (a - b).pow(Number::Int(2)))
|
||||
.sum::<Number>()
|
||||
.sqrt(),
|
||||
)
|
||||
}
|
||||
}
|
12
lib/src/fnc/util/math/magnitude.rs
Normal file
12
lib/src/fnc/util/math/magnitude.rs
Normal file
|
@ -0,0 +1,12 @@
|
|||
use crate::sql::Number;
|
||||
|
||||
pub trait Magnitude {
|
||||
/// Calculate the magnitude of a vector
|
||||
fn magnitude(&self) -> Number;
|
||||
}
|
||||
|
||||
impl Magnitude for Vec<Number> {
|
||||
fn magnitude(&self) -> Number {
|
||||
self.iter().map(|a| a.clone().pow(Number::Int(2))).sum::<Number>().sqrt()
|
||||
}
|
||||
}
|
|
@ -4,7 +4,10 @@
|
|||
|
||||
pub mod bottom;
|
||||
pub mod deviation;
|
||||
pub mod dotproduct;
|
||||
pub mod euclideandistance;
|
||||
pub mod interquartile;
|
||||
pub mod magnitude;
|
||||
pub mod mean;
|
||||
pub mod median;
|
||||
pub mod midhinge;
|
||||
|
|
101
lib/src/fnc/vector.rs
Normal file
101
lib/src/fnc/vector.rs
Normal file
|
@ -0,0 +1,101 @@
|
|||
use crate::err::Error;
|
||||
use crate::fnc::util::math::dotproduct::DotProduct;
|
||||
use crate::fnc::util::math::magnitude::Magnitude;
|
||||
use crate::sql::{Number, Value};
|
||||
|
||||
pub fn dotproduct((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
match a.dotproduct(&b) {
|
||||
None => Err(Error::InvalidArguments {
|
||||
name: String::from("vector::dotproduct"),
|
||||
message: String::from("The two vectors must be of the same length."),
|
||||
}),
|
||||
Some(dot) => Ok(dot.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn magnitude((a,): (Vec<Number>,)) -> Result<Value, Error> {
|
||||
Ok(a.magnitude().into())
|
||||
}
|
||||
|
||||
pub mod distance {
|
||||
|
||||
use crate::err::Error;
|
||||
use crate::fnc::util::math::euclideandistance::EuclideanDistance;
|
||||
use crate::sql::{Number, Value};
|
||||
|
||||
pub fn chebyshev((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::distance::chebyshev() function",
|
||||
})
|
||||
}
|
||||
|
||||
pub fn euclidean((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
match a.euclidean_distance(&b) {
|
||||
None => Err(Error::InvalidArguments {
|
||||
name: String::from("vector::distance::euclidean"),
|
||||
message: String::from("The two vectors must be of the same length."),
|
||||
}),
|
||||
Some(distance) => Ok(distance.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hamming((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::distance::hamming() function",
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mahalanobis((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::distance::mahalanobis() function",
|
||||
})
|
||||
}
|
||||
|
||||
pub fn manhattan((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::distance::manhattan() function",
|
||||
})
|
||||
}
|
||||
|
||||
pub fn minkowski((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::distance::minkowski() function",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub mod similarity {
|
||||
|
||||
use crate::err::Error;
|
||||
use crate::fnc::util::math::dotproduct::DotProduct;
|
||||
use crate::fnc::util::math::magnitude::Magnitude;
|
||||
use crate::sql::{Number, Value};
|
||||
|
||||
pub fn cosine((a, b): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
match a.dotproduct(&b) {
|
||||
None => Err(Error::InvalidArguments {
|
||||
name: String::from("vector::similarity::cosine"),
|
||||
message: String::from("The two vectors must be of the same length."),
|
||||
}),
|
||||
Some(dot) => Ok((dot / (a.magnitude() * b.magnitude())).into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn jaccard((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::similarity::jaccard() function",
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pearson((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::similarity::pearson() function",
|
||||
})
|
||||
}
|
||||
|
||||
pub fn spearman((_, _): (Vec<Number>, Vec<Number>)) -> Result<Value, Error> {
|
||||
Err(Error::FeatureNotYetImplemented {
|
||||
feature: "vector::similarity::spearman() function",
|
||||
})
|
||||
}
|
||||
}
|
|
@ -248,6 +248,7 @@ fn script(i: &str) -> IResult<&str, Function> {
|
|||
|
||||
pub(crate) fn function_names(i: &str) -> IResult<&str, &str> {
|
||||
recognize(alt((
|
||||
alt((
|
||||
preceded(tag("array::"), function_array),
|
||||
preceded(tag("bytes::"), function_bytes),
|
||||
preceded(tag("crypto::"), function_crypto),
|
||||
|
@ -265,10 +266,9 @@ pub(crate) fn function_names(i: &str) -> IResult<&str, &str> {
|
|||
preceded(tag("string::"), function_string),
|
||||
preceded(tag("time::"), function_time),
|
||||
preceded(tag("type::"), function_type),
|
||||
tag("count"),
|
||||
tag("not"),
|
||||
tag("rand"),
|
||||
tag("sleep"),
|
||||
preceded(tag("vector::"), function_vector),
|
||||
)),
|
||||
alt((tag("count"), tag("not"), tag("rand"), tag("sleep"))),
|
||||
)))(i)
|
||||
}
|
||||
|
||||
|
@ -551,6 +551,28 @@ fn function_type(i: &str) -> IResult<&str, &str> {
|
|||
))(i)
|
||||
}
|
||||
|
||||
fn function_vector(i: &str) -> IResult<&str, &str> {
|
||||
alt((
|
||||
tag("dotproduct"),
|
||||
tag("magnitude"),
|
||||
preceded(
|
||||
tag("distance::"),
|
||||
alt((
|
||||
tag("chebyshev"),
|
||||
tag("euclidean"),
|
||||
tag("hamming"),
|
||||
tag("mahalanobis"),
|
||||
tag("manhattan"),
|
||||
tag("minkowski"),
|
||||
)),
|
||||
),
|
||||
preceded(
|
||||
tag("similarity::"),
|
||||
alt((tag("cosine"), tag("jaccard"), tag("pearson"), tag("spearman"))),
|
||||
),
|
||||
))(i)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
|
|
@ -4574,3 +4574,125 @@ async fn function_type_thing() -> Result<(), Error> {
|
|||
//
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_vector_distance_euclidean() -> Result<(), Error> {
|
||||
let sql = r#"
|
||||
RETURN vector::distance::euclidean([1, 2, 3], [1, 2, 3]);
|
||||
RETURN vector::distance::euclidean([1, 2, 3], [-1, -2, -3]);
|
||||
RETURN vector::distance::euclidean([1, 2, 3], [4, 5]);
|
||||
RETURN vector::distance::euclidean([1, 2], [4, 5, 5]);
|
||||
"#;
|
||||
|
||||
let dbs = Datastore::new("memory").await?;
|
||||
let ses = Session::for_kv().with_ns("test").with_db("test");
|
||||
let res = &mut dbs.execute(&sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 4);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(0);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(7.483314773547883);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result;
|
||||
assert!(tmp.is_err());
|
||||
//
|
||||
let tmp = res.remove(0).result;
|
||||
assert!(tmp.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_vector_dotproduct() -> Result<(), Error> {
|
||||
let sql = r#"
|
||||
RETURN vector::dotproduct([1, 2, 3], [1, 2, 3]);
|
||||
RETURN vector::dotproduct([1, 2, 3], [-1, -2, -3]);
|
||||
RETURN vector::dotproduct([1, 2, 3], [4, 5]);
|
||||
RETURN vector::dotproduct([1, 2], [4, 5, 5]);
|
||||
"#;
|
||||
|
||||
let dbs = Datastore::new("memory").await?;
|
||||
let ses = Session::for_kv().with_ns("test").with_db("test");
|
||||
let res = &mut dbs.execute(&sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 4);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(14);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(-14);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result;
|
||||
assert!(tmp.is_err());
|
||||
//
|
||||
let tmp = res.remove(0).result;
|
||||
assert!(tmp.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_vector_magnitude() -> Result<(), Error> {
|
||||
let sql = r#"
|
||||
RETURN vector::magnitude([]);
|
||||
RETURN vector::magnitude([1]);
|
||||
RETURN vector::magnitude([5]);
|
||||
RETURN vector::magnitude([1,2,3,3,3,4,5]);
|
||||
"#;
|
||||
|
||||
let dbs = Datastore::new("memory").await?;
|
||||
let ses = Session::for_kv().with_ns("test").with_db("test");
|
||||
let res = &mut dbs.execute(&sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 4);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(0);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(1);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(5);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(8.54400374531753);
|
||||
assert_eq!(tmp, val);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn function_vector_similarity_cosine() -> Result<(), Error> {
|
||||
let sql = r#"
|
||||
RETURN vector::similarity::cosine([1, 2, 3], [1, 2, 3]);
|
||||
RETURN vector::similarity::cosine([1, 2, 3], [-1, -2, -3]);
|
||||
RETURN vector::similarity::cosine([1, 2, 3], [4, 5]);
|
||||
RETURN vector::similarity::cosine([1, 2], [4, 5, 5]);
|
||||
"#;
|
||||
|
||||
let dbs = Datastore::new("memory").await?;
|
||||
let ses = Session::for_kv().with_ns("test").with_db("test");
|
||||
let res = &mut dbs.execute(&sql, &ses, None).await?;
|
||||
assert_eq!(res.len(), 4);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(1.0);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result?;
|
||||
let val = Value::from(-1.0);
|
||||
assert_eq!(tmp, val);
|
||||
//
|
||||
let tmp = res.remove(0).result;
|
||||
assert!(tmp.is_err());
|
||||
//
|
||||
let tmp = res.remove(0).result;
|
||||
assert!(tmp.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue