Index & vectors minor fixes (#4009)

This commit is contained in:
Emmanuel Keller 2024-05-10 08:12:07 +01:00 committed by GitHub
parent f607703f7e
commit 1d441e1c21
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 206 additions and 93 deletions

2
Cargo.lock generated
View file

@ -3466,7 +3466,6 @@ dependencies = [
"num-integer",
"num-traits",
"rawpointer",
"serde",
]
[[package]]
@ -5941,6 +5940,7 @@ name = "surrealdb-core"
version = "2.0.0-1.5.0"
dependencies = [
"addr",
"ahash 0.8.11",
"any_ascii",
"arbitrary",
"argon2",

View file

@ -59,6 +59,7 @@ targets = []
[dependencies]
addr = { version = "0.15.6", default-features = false, features = ["std"] }
ahash = "0.8.11"
arbitrary = { version = "1.3.2", features = ["derive"], optional = true }
argon2 = "0.5.2"
ascii = { version = "0.3.2", package = "any_ascii" }
@ -106,7 +107,7 @@ lexicmp = "0.1.0"
linfa-linalg = "=0.1.0"
md-5 = "0.10.6"
nanoid = "0.4.0"
ndarray = { version = "=0.15.6", features = ["serde"] }
ndarray = { version = "=0.15.6" }
ndarray-stats = "=0.5.1"
num-traits = "0.2.18"
nom = { version = "7.1.3", features = ["alloc"] }

View file

@ -512,7 +512,6 @@ mod tests {
use crate::idx::IndexKeyBase;
use crate::kvs::{Datastore, LockType::*, TransactionType};
use crate::sql::index::SearchParams;
use crate::sql::scoring::Scoring;
use crate::sql::statements::{DefineAnalyzerStatement, DefineStatement};
use crate::sql::{Array, Statement, Thing, Value};
use crate::syn;
@ -584,7 +583,7 @@ mod tests {
doc_lengths_order: order,
postings_order: order,
terms_order: order,
sc: Scoring::bm25(),
sc: Default::default(),
hl,
doc_ids_cache: 100,
doc_lengths_cache: 100,

View file

@ -1889,6 +1889,7 @@ mod tests {
}
#[test(tokio::test(flavor = "multi_thread"))]
#[ignore]
async fn test_mtree_random_small() -> Result<(), Error> {
let mut stack = reblessive::tree::TreeStack::new();
stack

View file

@ -2,20 +2,22 @@ use crate::err::Error;
use crate::fnc::util::math::ToFloat;
use crate::sql::index::{Distance, VectorType};
use crate::sql::{Array, Number, Value};
use ahash::AHasher;
use hashbrown::HashSet;
use linfa_linalg::norm::Norm;
use ndarray::{Array1, LinalgScalar, Zip};
use ndarray_stats::DeviationExt;
use num_traits::Zero;
use revision::revisioned;
use rust_decimal::prelude::FromPrimitive;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::cmp::PartialEq;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::hash::{Hash, Hasher};
use std::ops::{Add, Deref, Div, Sub};
use std::sync::Arc;
/// In the context of a Symmetric MTree index, the term object refers to a vector, representing the indexed item.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum Vector {
F64(Array1<f64>),
@ -25,6 +27,41 @@ pub enum Vector {
I16(Array1<i16>),
}
#[revisioned(revision = 1)]
#[derive(Serialize, Deserialize)]
#[non_exhaustive]
enum SerializedVector {
F64(Vec<f64>),
F32(Vec<f32>),
I64(Vec<i64>),
I32(Vec<i32>),
I16(Vec<i16>),
}
impl From<&Vector> for SerializedVector {
fn from(value: &Vector) -> Self {
match value {
Vector::F64(v) => Self::F64(v.to_vec()),
Vector::F32(v) => Self::F32(v.to_vec()),
Vector::I64(v) => Self::I64(v.to_vec()),
Vector::I32(v) => Self::I32(v.to_vec()),
Vector::I16(v) => Self::I16(v.to_vec()),
}
}
}
impl From<SerializedVector> for Vector {
fn from(value: SerializedVector) -> Self {
match value {
SerializedVector::F64(v) => Self::F64(Array1::from_vec(v)),
SerializedVector::F32(v) => Self::F32(Array1::from_vec(v)),
SerializedVector::I64(v) => Self::I64(Array1::from_vec(v)),
SerializedVector::I32(v) => Self::I32(Array1::from_vec(v)),
SerializedVector::I16(v) => Self::I16(Array1::from_vec(v)),
}
}
}
impl Vector {
#[inline]
fn chebyshev<T>(a: &Array1<T>, b: &Array1<T>) -> f64
@ -282,7 +319,7 @@ impl Vector {
pub struct SharedVector(Arc<Vector>, u64);
impl From<Vector> for SharedVector {
fn from(v: Vector) -> Self {
let mut h = DefaultHasher::new();
let mut h = AHasher::default();
v.hash(&mut h);
Self(Arc::new(v), h.finish())
}
@ -315,7 +352,8 @@ impl Serialize for SharedVector {
S: Serializer,
{
// We only serialize the vector part, not the u64
self.0.serialize(serializer)
let ser: SerializedVector = self.0.as_ref().into();
ser.serialize(serializer)
}
}
@ -325,8 +363,7 @@ impl<'de> Deserialize<'de> for SharedVector {
D: Deserializer<'de>,
{
// We deserialize into a vector and construct the struct
// assuming some default or dummy value for the u64, e.g., 0
let v = Vector::deserialize(deserializer)?;
let v: Vector = SerializedVector::deserialize(deserializer)?.into();
Ok(v.into())
}
}

View file

@ -251,7 +251,7 @@ impl Display for Index {
Self::Hnsw(p) => {
write!(
f,
"HNSW DIMENSION {} DIST {} TYPE {} EFC {} M {} M0 {} ML {}",
"HNSW DIMENSION {} DIST {} TYPE {} EFC {} M {} M0 {} LM {}",
p.dimension, p.distance, p.vector_type, p.ef_construction, p.m, p.m0, p.ml
)?;
if p.extend_candidates {

View file

@ -51,8 +51,8 @@ impl Hash for Scoring {
}
}
impl Scoring {
pub(crate) fn bm25() -> Self {
impl Default for Scoring {
fn default() -> Self {
Self::Bm {
k1: 1.2,
b: 0.75,

View file

@ -107,7 +107,7 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("END") => TokenKind::Keyword(Keyword::End),
UniCase::ascii("EXISTS") => TokenKind::Keyword(Keyword::Exists),
UniCase::ascii("EXPLAIN") => TokenKind::Keyword(Keyword::Explain),
UniCase::ascii("EXTEND_CANDIDATE") => TokenKind::Keyword(Keyword::ExtendCandidates),
UniCase::ascii("EXTEND_CANDIDATES") => TokenKind::Keyword(Keyword::ExtendCandidates),
UniCase::ascii("false") => TokenKind::Keyword(Keyword::False),
UniCase::ascii("FETCH") => TokenKind::Keyword(Keyword::Fetch),
UniCase::ascii("FIELD") => TokenKind::Keyword(Keyword::Field),
@ -139,6 +139,7 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("LIMIT") => TokenKind::Keyword(Keyword::Limit),
UniCase::ascii("LIVE") => TokenKind::Keyword(Keyword::Live),
UniCase::ascii("LOWERCASE") => TokenKind::Keyword(Keyword::Lowercase),
UniCase::ascii("LM") => TokenKind::Keyword(Keyword::Lm),
UniCase::ascii("M") => TokenKind::Keyword(Keyword::M),
UniCase::ascii("M0") => TokenKind::Keyword(Keyword::M0),
UniCase::ascii("ML") => TokenKind::Keyword(Keyword::ML),

View file

@ -556,75 +556,87 @@ impl Parser<'_> {
}
t!("SEARCH") => {
self.pop_peek();
let analyzer =
self.eat(t!("ANALYZER")).then(|| self.next_token_value()).transpose()?;
let scoring = match self.next().kind {
t!("VS") => Scoring::Vs,
t!("BM25") => {
if self.eat(t!("(")) {
let open = self.last_span();
let k1 = self.next_token_value()?;
expected!(self, t!(","));
let b = self.next_token_value()?;
self.expect_closing_delimiter(t!(")"), open)?;
Scoring::Bm {
k1,
b,
}
} else {
Scoring::bm25()
let mut analyzer: Option<Ident> = None;
let mut scoring = None;
let mut doc_ids_order = 100;
let mut doc_lengths_order = 100;
let mut postings_order = 100;
let mut terms_order = 100;
let mut doc_ids_cache = 100;
let mut doc_lengths_cache = 100;
let mut postings_cache = 100;
let mut terms_cache = 100;
let mut hl = false;
loop {
match self.peek_kind() {
t!("ANALYZER") => {
self.pop_peek();
analyzer = Some(self.next_token_value()).transpose()?;
}
t!("VS") => {
self.pop_peek();
scoring = Some(Scoring::Vs);
}
t!("BM25") => {
self.pop_peek();
if self.eat(t!("(")) {
let open = self.last_span();
let k1 = self.next_token_value()?;
expected!(self, t!(","));
let b = self.next_token_value()?;
self.expect_closing_delimiter(t!(")"), open)?;
scoring = Some(Scoring::Bm {
k1,
b,
})
} else {
scoring = Some(Default::default());
};
}
t!("DOC_IDS_ORDER") => {
self.pop_peek();
doc_ids_order = self.next_token_value()?;
}
t!("DOC_LENGTHS_ORDER") => {
self.pop_peek();
doc_lengths_order = self.next_token_value()?;
}
t!("POSTINGS_ORDER") => {
self.pop_peek();
postings_order = self.next_token_value()?;
}
t!("TERMS_ORDER") => {
self.pop_peek();
terms_order = self.next_token_value()?;
}
t!("DOC_IDS_CACHE") => {
self.pop_peek();
doc_ids_cache = self.next_token_value()?;
}
t!("DOC_LENGTHS_CACHE") => {
self.pop_peek();
doc_lengths_cache = self.next_token_value()?;
}
t!("POSTINGS_CACHE") => {
self.pop_peek();
postings_cache = self.next_token_value()?;
}
t!("TERMS_CACHE") => {
self.pop_peek();
terms_cache = self.next_token_value()?;
}
t!("HIGHLIGHTS") => {
self.pop_peek();
hl = true;
}
_ => break,
}
x => unexpected!(self, x, "`VS` or `BM25`"),
};
// TODO: Propose change in how order syntax works.
let doc_ids_order = self
.eat(t!("DOC_IDS_ORDER"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let doc_lengths_order = self
.eat(t!("DOC_LENGTHS_ORDER"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let postings_order = self
.eat(t!("POSTINGS_ORDER"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let terms_order = self
.eat(t!("TERMS_ORDER"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let doc_ids_cache = self
.eat(t!("DOC_IDS_CACHE"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let doc_lengths_cache = self
.eat(t!("DOC_LENGTHS_CACHE"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let postings_cache = self
.eat(t!("POSTINGS_CACHE"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let terms_cache = self
.eat(t!("TERMS_CACHE"))
.then(|| self.next_token_value())
.transpose()?
.unwrap_or(100);
let hl = self.eat(t!("HIGHLIGHTS"));
}
res.index = Index::Search(crate::sql::index::SearchParams {
az: analyzer.unwrap_or_else(|| Ident::from("like")),
sc: scoring,
sc: scoring.unwrap_or_else(Default::default),
hl,
doc_ids_order,
doc_lengths_order,
@ -708,17 +720,17 @@ impl Parser<'_> {
self.pop_peek();
vector_type = self.parse_vector_type()?;
}
t!("M") => {
t!("LM") => {
self.pop_peek();
m = Some(self.next_token_value()?);
ml = Some(self.next_token_value()?);
}
t!("M0") => {
self.pop_peek();
m0 = Some(self.next_token_value()?);
}
t!("ML") => {
t!("M") => {
self.pop_peek();
ml = Some(self.next_token_value()?);
m = Some(self.next_token_value()?);
}
t!("EFC") => {
self.pop_peek();
@ -732,7 +744,10 @@ impl Parser<'_> {
self.pop_peek();
keep_pruned_connections = true;
}
_ => break,
t => {
println!("TOKEN: {t:?}");
break;
}
}
}

View file

@ -3,7 +3,7 @@ use crate::{
block::Entry,
changefeed::ChangeFeed,
filter::Filter,
index::{Distance, MTreeParams, SearchParams, VectorType},
index::{Distance, HnswParams, MTreeParams, SearchParams, VectorType},
language::Language,
statements::{
analyze::AnalyzeStatement, show::ShowSince, show::ShowStatement, sleep::SleepStatement,
@ -452,7 +452,7 @@ fn parse_define_index() {
);
let res =
test_parse!(parse_stmt, r#"DEFINE INDEX index ON TABLE table FIELDS a MTREE DIMENSION 4 DISTANCE MINKOWSKI 5 CAPACITY 6 DOC_IDS_ORDER 7 DOC_IDS_CACHE 8 MTREE_CACHE 9"#).unwrap();
test_parse!(parse_stmt, r#"DEFINE INDEX index ON TABLE table FIELDS a MTREE DIMENSION 4 DISTANCE MINKOWSKI 5 CAPACITY 6 TYPE I16 DOC_IDS_ORDER 7 DOC_IDS_CACHE 8 MTREE_CACHE 9"#).unwrap();
assert_eq!(
res,
@ -468,7 +468,32 @@ fn parse_define_index() {
doc_ids_order: 7,
doc_ids_cache: 8,
mtree_cache: 9,
vector_type: VectorType::F64,
vector_type: VectorType::I16,
}),
comment: None,
if_not_exists: false,
}))
);
let res =
test_parse!(parse_stmt, r#"DEFINE INDEX index ON TABLE table FIELDS a HNSW DIMENSION 128 EFC 250 TYPE F32 DISTANCE MANHATTAN M 6 M0 12 LM 0.5 EXTEND_CANDIDATES KEEP_PRUNED_CONNECTIONS"#).unwrap();
assert_eq!(
res,
Statement::Define(DefineStatement::Index(DefineIndexStatement {
name: Ident("index".to_owned()),
what: Ident("table".to_owned()),
cols: Idioms(vec![Idiom(vec![Part::Field(Ident("a".to_owned()))]),]),
index: Index::Hnsw(HnswParams {
dimension: 128,
distance: Distance::Manhattan,
vector_type: VectorType::F32,
m: 6,
m0: 12,
ef_construction: 250,
extend_candidates: true,
keep_pruned_connections: true,
ml: 0.5.into(),
}),
comment: None,
if_not_exists: false,

View file

@ -100,11 +100,11 @@ keyword! {
Limit => "LIMIT",
Live => "LIVE",
Lowercase => "LOWERCASE",
Lm => "LM",
M => "M",
M0 => "M0",
Merge => "MERGE",
Model => "MODEL",
Ml => "ML",
MTree => "MTREE",
MTreeCache => "MTREE_CACHE",
Namespace => "NAMESPACE",

View file

@ -103,7 +103,7 @@ mod database_upgrade {
upgrade_test_1_1("v1.1.1").await;
}
async fn upgrade_test_1_2(version: &str) {
async fn upgrade_test_1_2_to_1_4(version: &str) {
// Start the docker instance
let (path, mut docker, client) = start_docker(version).await;
@ -136,14 +136,49 @@ mod database_upgrade {
#[cfg(feature = "storage-rocksdb")]
#[serial]
async fn upgrade_test_1_2_0() {
upgrade_test_1_2("v1.2.0").await;
upgrade_test_1_2_to_1_4("v1.2.0").await;
}
#[test(tokio::test(flavor = "multi_thread"))]
#[cfg(feature = "storage-rocksdb")]
#[serial]
async fn upgrade_test_1_2_1() {
upgrade_test_1_2("v1.2.1").await;
upgrade_test_1_2_to_1_4("v1.2.1").await;
}
#[test(tokio::test(flavor = "multi_thread"))]
#[cfg(feature = "storage-rocksdb")]
#[serial]
async fn upgrade_test_1_2_2() {
upgrade_test_1_2_to_1_4("v1.2.2").await;
}
#[test(tokio::test(flavor = "multi_thread"))]
#[cfg(feature = "storage-rocksdb")]
#[serial]
async fn upgrade_test_1_3_0() {
upgrade_test_1_2_to_1_4("v1.3.0").await;
}
#[test(tokio::test(flavor = "multi_thread"))]
#[cfg(feature = "storage-rocksdb")]
#[serial]
async fn upgrade_test_1_3_1() {
upgrade_test_1_2_to_1_4("v1.3.1").await;
}
#[test(tokio::test(flavor = "multi_thread"))]
#[cfg(feature = "storage-rocksdb")]
#[serial]
async fn upgrade_test_1_4_0() {
upgrade_test_1_2_to_1_4("v1.4.0").await;
}
#[test(tokio::test(flavor = "multi_thread"))]
#[cfg(feature = "storage-rocksdb")]
#[serial]
async fn upgrade_test_1_4_2() {
upgrade_test_1_2_to_1_4("v1.4.2").await;
}
// *******
@ -178,11 +213,10 @@ mod database_upgrade {
Expected::Two("{\"dist\": 2.0, \"id\": \"pts:1\"}", "{ \"dist\": 4.0, \"id\": \"pts:2\"}"))];
const CHECK_MTREE_DB: [Check; 1] = [
("SELECT id, vector::distance::euclidean(point, [2,3,4,5]) AS dist FROM pts WHERE point <2> [2,3,4,5]",
("SELECT id, vector::distance::euclidean(point, [2,3,4,5]) AS dist FROM pts WHERE point <|2|> [2,3,4,5]",
Expected::Two("{\"dist\": 2.0, \"id\": {\"tb\": \"pts\", \"id\": {\"Number\": 1}}}", "{ \"dist\": 4.0, \"id\": {\"tb\": \"pts\", \"id\": {\"Number\": 2}}}"))];
const CHECK_KNN_BRUTEFORCE: [Check; 1] = [
("SELECT id, vector::distance::euclidean(point, [2,3,4,5]) AS dist FROM pts WHERE point <2,EUCLIDEAN> [2,3,4,5]",
("SELECT id, vector::distance::euclidean(point, [2,3,4,5]) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> [2,3,4,5]",
Expected::Two("{\"dist\": 2.0, \"id\": {\"tb\": \"pts\", \"id\": {\"Number\": 1}}}", "{ \"dist\": 4.0, \"id\": {\"tb\": \"pts\", \"id\": {\"Number\": 2}}}"))];
type Check = (&'static str, Expected);