From 0772a8c5925e10fc67f041ff947e55e262e4b48c Mon Sep 17 00:00:00 2001 From: Emmanuel Keller Date: Tue, 12 Sep 2023 21:26:03 +0100 Subject: [PATCH] Feature: Vector Search: mtree index + knn operator (#2546) Co-authored-by: Tobie Morgan Hitchcock --- lib/src/dbs/iterator.rs | 2 +- lib/src/dbs/processor.rs | 2 +- lib/src/doc/document.rs | 2 +- lib/src/doc/index.rs | 31 +- lib/src/err/mod.rs | 20 + lib/src/fnc/operate.rs | 66 +- .../modules/surrealdb/functions/array.rs | 1 + lib/src/fnc/util/math/vector.rs | 4 +- lib/src/fnc/vector.rs | 2 +- lib/src/idx/{ft => }/docids.rs | 23 +- lib/src/idx/ft/analyzer/mod.rs | 14 +- lib/src/idx/ft/doclength.rs | 7 +- lib/src/idx/ft/mod.rs | 47 +- lib/src/idx/ft/offsets.rs | 2 +- lib/src/idx/ft/postings.rs | 8 +- lib/src/idx/ft/scorer.rs | 6 +- lib/src/idx/ft/termdocs.rs | 2 +- lib/src/idx/mod.rs | 15 +- lib/src/idx/planner/executor.rs | 125 +- lib/src/idx/planner/iterators.rs | 57 +- lib/src/idx/planner/plan.rs | 5 + lib/src/idx/planner/tree.rs | 83 +- lib/src/idx/trees/btree.rs | 153 +- lib/src/idx/trees/mod.rs | 1 + lib/src/idx/trees/mtree.rs | 1792 +++++++++++++++++ lib/src/idx/trees/store.rs | 23 +- lib/src/key/index/bf.rs | 2 +- lib/src/key/index/bk.rs | 2 +- lib/src/key/index/bo.rs | 2 +- lib/src/key/index/mod.rs | 1 + lib/src/key/index/vm.rs | 68 + lib/src/kvs/ds.rs | 3 + lib/src/sql/expression.rs | 1 + lib/src/sql/index.rs | 5 +- lib/src/sql/operator.rs | 23 +- lib/src/sql/statements/analyze.rs | 6 + lib/src/sql/statements/define/index.rs | 27 +- lib/src/sql/value/serde/ser/distance/mod.rs | 63 +- lib/tests/changefeeds.rs | 16 +- lib/tests/define.rs | 4 +- lib/tests/vector.rs | 60 + 41 files changed, 2541 insertions(+), 235 deletions(-) rename lib/src/idx/{ft => }/docids.rs (94%) create mode 100644 lib/src/idx/trees/mtree.rs create mode 100644 lib/src/key/index/vm.rs create mode 100644 lib/tests/vector.rs diff --git a/lib/src/dbs/iterator.rs b/lib/src/dbs/iterator.rs index 48c75a5c..3bda9deb 100644 --- a/lib/src/dbs/iterator.rs +++ b/lib/src/dbs/iterator.rs @@ -8,7 +8,7 @@ use crate::dbs::Statement; use crate::dbs::{Options, Transaction}; use crate::doc::Document; use crate::err::Error; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::planner::executor::IteratorRef; use crate::sql::array::Array; use crate::sql::edges::Edges; diff --git a/lib/src/dbs/processor.rs b/lib/src/dbs/processor.rs index ee02daa3..1029fc5d 100644 --- a/lib/src/dbs/processor.rs +++ b/lib/src/dbs/processor.rs @@ -594,7 +594,7 @@ impl<'a> Processor<'a> { } } Err(Error::QueryNotExecutedDetail { - message: "No QueryExecutor has not been found.".to_string(), + message: "No QueryExecutor has been found.".to_string(), }) } } diff --git a/lib/src/doc/document.rs b/lib/src/doc/document.rs index 19602e91..8ca4acc8 100644 --- a/lib/src/doc/document.rs +++ b/lib/src/doc/document.rs @@ -4,7 +4,7 @@ use crate::dbs::Workable; use crate::err::Error; use crate::iam::Action; use crate::iam::ResourceKind; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::planner::executor::IteratorRef; use crate::sql::statements::define::DefineEventStatement; use crate::sql::statements::define::DefineFieldStatement; diff --git a/lib/src/doc/index.rs b/lib/src/doc/index.rs index a774cfba..64b4240a 100644 --- a/lib/src/doc/index.rs +++ b/lib/src/doc/index.rs @@ -4,10 +4,11 @@ use crate::dbs::{Options, Transaction}; use crate::doc::{CursorDoc, Document}; use crate::err::Error; use crate::idx::ft::FtIndex; +use crate::idx::trees::mtree::MTreeIndex; use crate::idx::trees::store::TreeStoreType; use crate::idx::IndexKeyBase; use crate::sql::array::Array; -use crate::sql::index::{Index, SearchParams}; +use crate::sql::index::{Index, MTreeParams, SearchParams}; use crate::sql::statements::DefineIndexStatement; use crate::sql::{Part, Thing, Value}; use crate::{key, kvs}; @@ -55,11 +56,7 @@ impl<'a> Document<'a> { Index::Uniq => ic.index_unique(&mut run).await?, Index::Idx => ic.index_non_unique(&mut run).await?, Index::Search(p) => ic.index_full_text(&mut run, p).await?, - Index::MTree(_) => { - return Err(Error::FeatureNotYetImplemented { - feature: "MTree indexing".to_string(), - }) - } + Index::MTree(p) => ic.index_mtree(&mut run, p).await?, }; } } @@ -332,18 +329,36 @@ impl<'a> IndexOperation<'a> { } async fn index_full_text( - &self, + &mut self, run: &mut kvs::Transaction, p: &SearchParams, ) -> Result<(), Error> { let ikb = IndexKeyBase::new(self.opt, self.ix); let az = run.get_db_analyzer(self.opt.ns(), self.opt.db(), p.az.as_str()).await?; let mut ft = FtIndex::new(run, az, ikb, p, TreeStoreType::Write).await?; - if let Some(n) = &self.n { + if let Some(n) = self.n.take() { ft.index_document(run, self.rid, n).await?; } else { ft.remove_document(run, self.rid).await?; } ft.finish(run).await } + + async fn index_mtree( + &mut self, + run: &mut kvs::Transaction, + p: &MTreeParams, + ) -> Result<(), Error> { + let ikb = IndexKeyBase::new(self.opt, self.ix); + let mut mt = MTreeIndex::new(run, ikb, p, TreeStoreType::Write).await?; + // Delete the old index data + if let Some(o) = self.o.take() { + mt.remove_document(run, self.rid, o).await?; + } + // Create the new index data + if let Some(n) = self.n.take() { + mt.index_document(run, self.rid, n).await?; + } + mt.finish(run).await + } } diff --git a/lib/src/err/mod.rs b/lib/src/err/mod.rs index ae022e05..f4f1a154 100644 --- a/lib/src/err/mod.rs +++ b/lib/src/err/mod.rs @@ -208,6 +208,26 @@ pub enum Error { #[error("The URL `{0}` is invalid")] InvalidUrl(String), + /// The size of the vector is incorrect + #[error("Incorrect vector dimension ({current}). Expected a vector of {expected} dimension.")] + InvalidVectorDimension { + current: usize, + expected: usize, + }, + + /// The size of the vector is incorrect + #[error("The vector element ({current}) is not a number.")] + InvalidVectorType { + current: String, + expected: &'static str, + }, + + /// The size of the vector is incorrect + #[error("The value '{current}' is not a vector.")] + InvalidVectorValue { + current: String, + }, + /// The query timedout #[error("The query was not executed because it exceeded the timeout")] QueryTimedout, diff --git a/lib/src/fnc/operate.rs b/lib/src/fnc/operate.rs index 4cfbd16f..9d4cfc73 100644 --- a/lib/src/fnc/operate.rs +++ b/lib/src/fnc/operate.rs @@ -2,6 +2,7 @@ use crate::ctx::Context; use crate::dbs::Transaction; use crate::doc::CursorDoc; use crate::err::Error; +use crate::idx::planner::executor::QueryExecutor; use crate::sql::value::TryAdd; use crate::sql::value::TryDiv; use crate::sql::value::TryMul; @@ -9,7 +10,7 @@ use crate::sql::value::TryNeg; use crate::sql::value::TryPow; use crate::sql::value::TrySub; use crate::sql::value::Value; -use crate::sql::Expression; +use crate::sql::{Expression, Thing}; pub fn neg(a: Value) -> Result { a.try_neg() @@ -167,31 +168,58 @@ pub fn intersects(a: &Value, b: &Value) -> Result { Ok(a.intersects(b).into()) } +enum IndexOption<'a> { + PreMatch, + None, + Execute(&'a QueryExecutor, &'a Thing), +} + +fn get_index_option<'a>( + ctx: &'a Context<'_>, + doc: Option<&'a CursorDoc<'_>>, + exp: &'a Expression, +) -> IndexOption<'a> { + if let Some(doc) = doc { + if let Some(thg) = doc.rid { + if let Some(pla) = ctx.get_query_planner() { + if let Some(exe) = pla.get_query_executor(&thg.tb) { + if let Some(ir) = doc.ir { + if exe.is_iterator_expression(ir, exp) { + return IndexOption::PreMatch; + } + } + return IndexOption::Execute(exe, thg); + } + } + } + } + IndexOption::None +} + pub(crate) async fn matches( ctx: &Context<'_>, txn: &Transaction, doc: Option<&CursorDoc<'_>>, exp: &Expression, ) -> Result { - if let Some(doc) = doc { - if let Some(thg) = doc.rid { - if let Some(pla) = ctx.get_query_planner() { - if let Some(exe) = pla.get_query_executor(&thg.tb) { - // If we find the expression in `pre_match`, - // it means that we are using an Iterator::Index - // and we are iterating over documents that already matches the expression. - if let Some(ir) = doc.ir { - if exe.is_iterator_expression(ir, exp) { - return Ok(Value::Bool(true)); - } - } - // Evaluate the matches - return exe.matches(txn, thg, exp).await; - } - } - } + match get_index_option(ctx, doc, exp) { + IndexOption::PreMatch => Ok(Value::Bool(true)), + IndexOption::None => Ok(Value::Bool(false)), + IndexOption::Execute(exe, thg) => exe.matches(txn, thg, exp).await, + } +} + +pub(crate) async fn knn( + ctx: &Context<'_>, + txn: &Transaction, + doc: Option<&CursorDoc<'_>>, + exp: &Expression, +) -> Result { + match get_index_option(ctx, doc, exp) { + IndexOption::PreMatch => Ok(Value::Bool(true)), + IndexOption::None => Ok(Value::Bool(false)), + IndexOption::Execute(exe, thg) => exe.knn(txn, thg, exp).await, } - Ok(Value::Bool(false)) } #[cfg(test)] diff --git a/lib/src/fnc/script/modules/surrealdb/functions/array.rs b/lib/src/fnc/script/modules/surrealdb/functions/array.rs index e2aaef18..aabef852 100644 --- a/lib/src/fnc/script/modules/surrealdb/functions/array.rs +++ b/lib/src/fnc/script/modules/surrealdb/functions/array.rs @@ -30,6 +30,7 @@ impl_module_def!( "insert" => run, "intersect" => run, "join" => run, + "knn" => run, "last" => run, "len" => run, "logical_and" => run, diff --git a/lib/src/fnc/util/math/vector.rs b/lib/src/fnc/util/math/vector.rs index eb8e3a65..db15e1f7 100644 --- a/lib/src/fnc/util/math/vector.rs +++ b/lib/src/fnc/util/math/vector.rs @@ -132,11 +132,11 @@ impl ManhattanDistance for Vec { } pub trait MinkowskiDistance { - fn minkowski_distance(&self, other: &Self, order: Number) -> Result; + fn minkowski_distance(&self, other: &Self, order: &Number) -> Result; } impl MinkowskiDistance for Vec { - fn minkowski_distance(&self, other: &Self, order: Number) -> Result { + fn minkowski_distance(&self, other: &Self, order: &Number) -> Result { check_same_dimension("vector::distance::minkowski", self, other)?; let p = order.to_float(); let dist: f64 = self diff --git a/lib/src/fnc/vector.rs b/lib/src/fnc/vector.rs index 9c8e9c3e..6fdc9793 100644 --- a/lib/src/fnc/vector.rs +++ b/lib/src/fnc/vector.rs @@ -75,7 +75,7 @@ pub mod distance { } pub fn minkowski((a, b, o): (Vec, Vec, Number)) -> Result { - Ok(a.minkowski_distance(&b, o)?.into()) + Ok(a.minkowski_distance(&b, &o)?.into()) } } diff --git a/lib/src/idx/ft/docids.rs b/lib/src/idx/docids.rs similarity index 94% rename from lib/src/idx/ft/docids.rs rename to lib/src/idx/docids.rs index 35ae715d..7e8dd4e5 100644 --- a/lib/src/idx/ft/docids.rs +++ b/lib/src/idx/docids.rs @@ -25,7 +25,7 @@ pub(crate) struct DocIds { } impl DocIds { - pub(super) async fn new( + pub(in crate::idx) async fn new( tx: &mut Transaction, index_key_base: IndexKeyBase, default_btree_order: u32, @@ -78,7 +78,7 @@ impl DocIds { /// Returns the doc_id for the given doc_key. /// If the doc_id does not exists, a new one is created, and associated to the given key. - pub(super) async fn resolve_doc_id( + pub(in crate::idx) async fn resolve_doc_id( &mut self, tx: &mut Transaction, doc_key: Key, @@ -97,7 +97,7 @@ impl DocIds { Ok(Resolved::New(doc_id)) } - pub(super) async fn remove_doc( + pub(in crate::idx) async fn remove_doc( &mut self, tx: &mut Transaction, doc_key: Key, @@ -119,7 +119,7 @@ impl DocIds { } } - pub(super) async fn get_doc_key( + pub(in crate::idx) async fn get_doc_key( &self, tx: &mut Transaction, doc_id: DocId, @@ -132,12 +132,15 @@ impl DocIds { } } - pub(super) async fn statistics(&self, tx: &mut Transaction) -> Result { + pub(in crate::idx) async fn statistics( + &self, + tx: &mut Transaction, + ) -> Result { let mut store = self.store.lock().await; self.btree.statistics(tx, &mut store).await } - pub(super) async fn finish(&mut self, tx: &mut Transaction) -> Result<(), Error> { + pub(in crate::idx) async fn finish(&mut self, tx: &mut Transaction) -> Result<(), Error> { let updated = self.store.lock().await.finish(tx).await?; if self.updated || updated { let state = State { @@ -172,20 +175,20 @@ impl State { } #[derive(Debug, PartialEq)] -pub(super) enum Resolved { +pub(in crate::idx) enum Resolved { New(DocId), Existing(DocId), } impl Resolved { - pub(super) fn doc_id(&self) -> &DocId { + pub(in crate::idx) fn doc_id(&self) -> &DocId { match self { Resolved::New(doc_id) => doc_id, Resolved::Existing(doc_id) => doc_id, } } - pub(super) fn was_existing(&self) -> bool { + pub(in crate::idx) fn was_existing(&self) -> bool { match self { Resolved::New(_) => false, Resolved::Existing(_) => true, @@ -195,7 +198,7 @@ impl Resolved { #[cfg(test)] mod tests { - use crate::idx::ft::docids::{DocIds, Resolved}; + use crate::idx::docids::{DocIds, Resolved}; use crate::idx::trees::store::TreeStoreType; use crate::idx::IndexKeyBase; use crate::kvs::{Datastore, Transaction}; diff --git a/lib/src/idx/ft/analyzer/mod.rs b/lib/src/idx/ft/analyzer/mod.rs index f03c04c7..b705903d 100644 --- a/lib/src/idx/ft/analyzer/mod.rs +++ b/lib/src/idx/ft/analyzer/mod.rs @@ -64,7 +64,7 @@ impl Analyzer { &self, terms: &mut Terms, tx: &mut Transaction, - field_content: &[Value], + field_content: Vec, ) -> Result<(DocLength, Vec<(TermId, TermFrequency)>), Error> { let mut dl = 0; // Let's first collect all the inputs, and collect the tokens. @@ -101,7 +101,7 @@ impl Analyzer { &self, terms: &mut Terms, tx: &mut Transaction, - content: &[Value], + content: Vec, ) -> Result<(DocLength, Vec<(TermId, TermFrequency)>, Vec<(TermId, OffsetRecords)>), Error> { let mut dl = 0; // Let's first collect all the inputs, and collect the tokens. @@ -135,25 +135,25 @@ impl Analyzer { Ok((dl, tfid, osid)) } - fn analyze_content(&self, content: &[Value], tks: &mut Vec) -> Result<(), Error> { + fn analyze_content(&self, content: Vec, tks: &mut Vec) -> Result<(), Error> { for v in content { self.analyze_value(v, tks)?; } Ok(()) } - fn analyze_value(&self, val: &Value, tks: &mut Vec) -> Result<(), Error> { + fn analyze_value(&self, val: Value, tks: &mut Vec) -> Result<(), Error> { match val { - Value::Strand(s) => tks.push(self.analyze(s.0.clone())?), + Value::Strand(s) => tks.push(self.analyze(s.0)?), Value::Number(n) => tks.push(self.analyze(n.to_string())?), Value::Bool(b) => tks.push(self.analyze(b.to_string())?), Value::Array(a) => { - for v in &a.0 { + for v in a.0 { self.analyze_value(v, tks)?; } } Value::Object(o) => { - for v in o.0.values() { + for (_, v) in o.0 { self.analyze_value(v, tks)?; } } diff --git a/lib/src/idx/ft/doclength.rs b/lib/src/idx/ft/doclength.rs index 712851b3..deb2075e 100644 --- a/lib/src/idx/ft/doclength.rs +++ b/lib/src/idx/ft/doclength.rs @@ -1,5 +1,5 @@ use crate::err::Error; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::trees::bkeys::TrieKeys; use crate::idx::trees::btree::{BState, BStatistics, BTree, BTreeNodeStore, Payload}; use crate::idx::trees::store::{TreeNodeProvider, TreeNodeStore, TreeStoreType}; @@ -72,9 +72,8 @@ impl DocLengths { } pub(super) async fn finish(&self, tx: &mut Transaction) -> Result<(), Error> { - if self.store.lock().await.finish(tx).await? { - tx.set(self.state_key.clone(), self.btree.get_state().try_to_val()?).await?; - } + self.store.lock().await.finish(tx).await?; + self.btree.get_state().finish(tx, &self.state_key).await?; Ok(()) } } diff --git a/lib/src/idx/ft/mod.rs b/lib/src/idx/ft/mod.rs index b01c8a8c..27f942c2 100644 --- a/lib/src/idx/ft/mod.rs +++ b/lib/src/idx/ft/mod.rs @@ -1,5 +1,4 @@ pub(crate) mod analyzer; -pub(crate) mod docids; mod doclength; mod highlighter; mod offsets; @@ -9,8 +8,8 @@ pub(super) mod termdocs; pub(crate) mod terms; use crate::err::Error; +use crate::idx::docids::{DocId, DocIds}; use crate::idx::ft::analyzer::Analyzer; -use crate::idx::ft::docids::{DocId, DocIds}; use crate::idx::ft::doclength::DocLengths; use crate::idx::ft::highlighter::{Highlighter, Offseter}; use crate::idx::ft::offsets::Offsets; @@ -198,7 +197,7 @@ impl FtIndex { &mut self, tx: &mut Transaction, rid: &Thing, - content: &[Value], + content: Vec, ) -> Result<(), Error> { // Resolve the doc_id let resolved = self.doc_ids.write().await.resolve_doc_id(tx, rid.into()).await?; @@ -481,7 +480,7 @@ mod tests { } assert_eq!(map.len(), e.len()); for (k, p) in e { - assert_eq!(map.get(k), Some(&p)); + assert_eq!(map.get(k), Some(&p), "{}", k); } } else { panic!("hits is none"); @@ -549,9 +548,7 @@ mod tests { // Add one document let (mut tx, mut fti) = tx_fti(&ds, TreeStoreType::Write, &az, btree_order, false).await; - fti.index_document(&mut tx, &doc1, &vec![Value::from("hello the world")]) - .await - .unwrap(); + fti.index_document(&mut tx, &doc1, vec![Value::from("hello the world")]).await.unwrap(); finish(tx, fti).await; } @@ -559,8 +556,8 @@ mod tests { // Add two documents let (mut tx, mut fti) = tx_fti(&ds, TreeStoreType::Write, &az, btree_order, false).await; - fti.index_document(&mut tx, &doc2, &vec![Value::from("a yellow hello")]).await.unwrap(); - fti.index_document(&mut tx, &doc3, &vec![Value::from("foo bar")]).await.unwrap(); + fti.index_document(&mut tx, &doc2, vec![Value::from("a yellow hello")]).await.unwrap(); + fti.index_document(&mut tx, &doc3, vec![Value::from("foo bar")]).await.unwrap(); finish(tx, fti).await; } @@ -575,7 +572,13 @@ mod tests { // Search & score let (hits, scr) = search(&mut tx, &fti, "hello").await; - check_hits(&mut tx, hits, scr, vec![(&doc1, Some(0.0)), (&doc2, Some(0.0))]).await; + check_hits( + &mut tx, + hits, + scr, + vec![(&doc1, Some(-0.4859746)), (&doc2, Some(-0.4859746))], + ) + .await; let (hits, scr) = search(&mut tx, &fti, "world").await; check_hits(&mut tx, hits, scr, vec![(&doc1, Some(0.4859746))]).await; @@ -597,7 +600,7 @@ mod tests { // Reindex one document let (mut tx, mut fti) = tx_fti(&ds, TreeStoreType::Write, &az, btree_order, false).await; - fti.index_document(&mut tx, &doc3, &vec![Value::from("nobar foo")]).await.unwrap(); + fti.index_document(&mut tx, &doc3, vec![Value::from("nobar foo")]).await.unwrap(); finish(tx, fti).await; let (mut tx, fti) = tx_fti(&ds, TreeStoreType::Read, &az, btree_order, false).await; @@ -655,28 +658,28 @@ mod tests { fti.index_document( &mut tx, &doc1, - &vec![Value::from("the quick brown fox jumped over the lazy dog")], + vec![Value::from("the quick brown fox jumped over the lazy dog")], ) .await .unwrap(); fti.index_document( &mut tx, &doc2, - &vec![Value::from("the fast fox jumped over the lazy dog")], + vec![Value::from("the fast fox jumped over the lazy dog")], ) .await .unwrap(); fti.index_document( &mut tx, &doc3, - &vec![Value::from("the dog sat there and did nothing")], + vec![Value::from("the dog sat there and did nothing")], ) .await .unwrap(); fti.index_document( &mut tx, &doc4, - &vec![Value::from("the other animals sat there watching")], + vec![Value::from("the other animals sat there watching")], ) .await .unwrap(); @@ -698,10 +701,10 @@ mod tests { hits, scr, vec![ - (&doc1, Some(0.0)), - (&doc2, Some(0.0)), - (&doc3, Some(0.0)), - (&doc4, Some(0.0)), + (&doc1, Some(-3.4388628)), + (&doc2, Some(-3.621457)), + (&doc3, Some(-2.258829)), + (&doc4, Some(-2.393017)), ], ) .await; @@ -711,7 +714,11 @@ mod tests { &mut tx, hits, scr, - vec![(&doc1, Some(0.0)), (&doc2, Some(0.0)), (&doc3, Some(0.0))], + vec![ + (&doc1, Some(-0.7832165)), + (&doc2, Some(-0.8248031)), + (&doc3, Some(-0.87105393)), + ], ) .await; diff --git a/lib/src/idx/ft/offsets.rs b/lib/src/idx/ft/offsets.rs index 21adfd0f..9c7732b6 100644 --- a/lib/src/idx/ft/offsets.rs +++ b/lib/src/idx/ft/offsets.rs @@ -1,5 +1,5 @@ use crate::err::Error; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::ft::terms::TermId; use crate::idx::IndexKeyBase; use crate::kvs::{Transaction, Val}; diff --git a/lib/src/idx/ft/postings.rs b/lib/src/idx/ft/postings.rs index f2a23dd2..0681c89e 100644 --- a/lib/src/idx/ft/postings.rs +++ b/lib/src/idx/ft/postings.rs @@ -1,5 +1,5 @@ use crate::err::Error; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::ft::terms::TermId; use crate::idx::trees::bkeys::TrieKeys; use crate::idx::trees::btree::{BState, BStatistics, BTree, BTreeNodeStore}; @@ -81,10 +81,8 @@ impl Postings { } pub(super) async fn finish(&self, tx: &mut Transaction) -> Result<(), Error> { - let updated = self.store.lock().await.finish(tx).await?; - if self.btree.is_updated() || updated { - tx.set(self.state_key.clone(), self.btree.get_state().try_to_val()?).await?; - } + self.store.lock().await.finish(tx).await?; + self.btree.get_state().finish(tx, &self.state_key).await?; Ok(()) } } diff --git a/lib/src/idx/ft/scorer.rs b/lib/src/idx/ft/scorer.rs index eea8457d..e20fe9e4 100644 --- a/lib/src/idx/ft/scorer.rs +++ b/lib/src/idx/ft/scorer.rs @@ -1,5 +1,5 @@ use crate::err::Error; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::ft::doclength::{DocLength, DocLengths}; use crate::idx::ft::postings::{Postings, TermFrequency}; use crate::idx::ft::termdocs::TermsDocs; @@ -76,8 +76,8 @@ impl BM25Scorer { // (N - n(qi) + 0.5) let numerator = self.doc_count - term_doc_count + 0.5; let idf = (numerator / denominator).ln(); - if idf.is_nan() || idf <= 0.0 { - return 0.0; + if idf.is_nan() { + return f32::NAN; } let tf_prim = 1.0 + term_freq.ln(); // idf * (k1 + 1) diff --git a/lib/src/idx/ft/termdocs.rs b/lib/src/idx/ft/termdocs.rs index 113e7fa5..6dd02ec2 100644 --- a/lib/src/idx/ft/termdocs.rs +++ b/lib/src/idx/ft/termdocs.rs @@ -1,5 +1,5 @@ use crate::err::Error; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::ft::doclength::DocLength; use crate::idx::ft::terms::TermId; use crate::idx::IndexKeyBase; diff --git a/lib/src/idx/mod.rs b/lib/src/idx/mod.rs index 7a073556..9b888207 100644 --- a/lib/src/idx/mod.rs +++ b/lib/src/idx/mod.rs @@ -1,10 +1,11 @@ +pub(crate) mod docids; pub(crate) mod ft; pub(crate) mod planner; pub mod trees; use crate::dbs::Options; use crate::err::Error; -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::ft::terms::TermId; use crate::idx::trees::store::NodeId; use crate::key::index::bc::Bc; @@ -18,6 +19,7 @@ use crate::key::index::bp::Bp; use crate::key::index::bs::Bs; use crate::key::index::bt::Bt; use crate::key::index::bu::Bu; +use crate::key::index::vm::Vm; use crate::kvs::{Key, Val}; use crate::sql::statements::DefineIndexStatement; use revision::Revisioned; @@ -171,6 +173,17 @@ impl IndexKeyBase { ) .into() } + + fn new_vm_key(&self, node_id: Option) -> Key { + Vm::new( + self.inner.ns.as_str(), + self.inner.db.as_str(), + self.inner.tb.as_str(), + self.inner.ix.as_str(), + node_id, + ) + .into() + } } /// This trait provides `Revision` based default implementations for serialization/deserialization diff --git a/lib/src/idx/planner/executor.rs b/lib/src/idx/planner/executor.rs index b81f16cf..b9e5018a 100644 --- a/lib/src/idx/planner/executor.rs +++ b/lib/src/idx/planner/executor.rs @@ -1,25 +1,27 @@ use crate::dbs::{Options, Transaction}; use crate::err::Error; -use crate::idx::ft::docids::{DocId, DocIds}; +use crate::idx::docids::{DocId, DocIds}; use crate::idx::ft::scorer::BM25Scorer; use crate::idx::ft::termdocs::TermsDocs; use crate::idx::ft::terms::TermId; use crate::idx::ft::{FtIndex, MatchRef}; use crate::idx::planner::iterators::{ - IndexEqualThingIterator, IndexRangeThingIterator, MatchesThingIterator, ThingIterator, - UniqueEqualThingIterator, UniqueRangeThingIterator, + IndexEqualThingIterator, IndexRangeThingIterator, KnnThingIterator, MatchesThingIterator, + ThingIterator, UniqueEqualThingIterator, UniqueRangeThingIterator, }; use crate::idx::planner::plan::IndexOperator::Matches; use crate::idx::planner::plan::{IndexOperator, IndexOption, RangeValue}; use crate::idx::planner::tree::{IndexMap, IndexRef}; +use crate::idx::trees::mtree::MTreeIndex; use crate::idx::trees::store::TreeStoreType; use crate::idx::IndexKeyBase; use crate::kvs; use crate::kvs::Key; use crate::sql::index::Index; use crate::sql::statements::DefineIndexStatement; -use crate::sql::{Expression, Object, Table, Thing, Value}; -use std::collections::{HashMap, HashSet}; +use crate::sql::{Array, Expression, Object, Table, Thing, Value}; +use roaring::RoaringTreemap; +use std::collections::{HashMap, HashSet, VecDeque}; use std::sync::Arc; use tokio::sync::RwLock; @@ -30,6 +32,7 @@ pub(crate) struct QueryExecutor { exp_entries: HashMap, FtEntry>, it_entries: Vec, index_definitions: HashMap, + mt_exp: HashMap, MtEntry>, } pub(crate) type IteratorRef = u16; @@ -66,39 +69,59 @@ impl QueryExecutor { let mut mr_entries = HashMap::default(); let mut exp_entries = HashMap::default(); let mut ft_map = HashMap::default(); + let mut mt_map: HashMap = HashMap::default(); + let mut mt_exp = HashMap::default(); // Create all the instances of FtIndex // Build the FtEntries and map them to Expressions and MatchRef for (exp, io) in im.options { - let mut entry = None; let ir = io.ir(); if let Some(idx_def) = im.definitions.get(&ir) { - if let Index::Search(p) = &idx_def.index { - if let Some(ft) = ft_map.get(&ir) { - if entry.is_none() { - entry = FtEntry::new(&mut run, ft, io).await?; + match &idx_def.index { + Index::Search(p) => { + let mut ft_entry = None; + if let Some(ft) = ft_map.get(&ir) { + if ft_entry.is_none() { + ft_entry = FtEntry::new(&mut run, ft, io).await?; + } + } else { + let ikb = IndexKeyBase::new(opt, idx_def); + let az = run.get_db_analyzer(opt.ns(), opt.db(), p.az.as_str()).await?; + let ft = + FtIndex::new(&mut run, az, ikb, p, TreeStoreType::Read).await?; + if ft_entry.is_none() { + ft_entry = FtEntry::new(&mut run, &ft, io).await?; + } + ft_map.insert(ir, ft); } - } else { - let ikb = IndexKeyBase::new(opt, idx_def); - let az = run.get_db_analyzer(opt.ns(), opt.db(), p.az.as_str()).await?; - let ft = FtIndex::new(&mut run, az, ikb, p, TreeStoreType::Read).await?; - if entry.is_none() { - entry = FtEntry::new(&mut run, &ft, io).await?; + if let Some(e) = ft_entry { + if let Matches(_, Some(mr)) = e.0.index_option.op() { + if mr_entries.insert(*mr, e.clone()).is_some() { + return Err(Error::DuplicatedMatchRef { + mr: *mr, + }); + } + } + exp_entries.insert(exp, e); } - ft_map.insert(ir, ft); } - } - } - - if let Some(e) = entry { - if let Matches(_, Some(mr)) = e.0.index_option.op() { - if mr_entries.insert(*mr, e.clone()).is_some() { - return Err(Error::DuplicatedMatchRef { - mr: *mr, - }); + Index::MTree(p) => { + if let IndexOperator::Knn(a, k) = io.op() { + let entry = if let Some(mt) = mt_map.get(&ir) { + MtEntry::new(&mut run, mt, a.clone(), *k).await? + } else { + let ikb = IndexKeyBase::new(opt, idx_def); + let mt = + MTreeIndex::new(&mut run, ikb, p, TreeStoreType::Read).await?; + let entry = MtEntry::new(&mut run, &mt, a.clone(), *k).await?; + mt_map.insert(ir, mt); + entry + }; + mt_exp.insert(exp, entry); + } } + _ => {} } - exp_entries.insert(exp, e); } } @@ -109,6 +132,19 @@ impl QueryExecutor { exp_entries, it_entries: Vec::new(), index_definitions: im.definitions, + mt_exp, + }) + } + + pub(crate) async fn knn( + &self, + _txn: &Transaction, + _thg: &Thing, + exp: &Expression, + ) -> Result { + // If no previous case were successful, we end up with a user error + Err(Error::NoIndexFoundForMatch { + value: exp.to_string(), }) } @@ -168,9 +204,7 @@ impl QueryExecutor { Index::Search { .. } => self.new_search_index_iterator(ir, io.clone()).await, - Index::MTree(_) => Err(Error::FeatureNotYetImplemented { - feature: "VectorSearch iterator".to_string(), - }), + Index::MTree(_) => Ok(self.new_mtree_index_knn_iterator(ir)), } } else { Ok(None) @@ -258,6 +292,16 @@ impl QueryExecutor { Ok(None) } + fn new_mtree_index_knn_iterator(&self, ir: IteratorRef) -> Option { + if let Some(IteratorEntry::Single(exp, ..)) = self.it_entries.get(ir as usize) { + if let Some(mte) = self.mt_exp.get(exp.as_ref()) { + let it = KnnThingIterator::new(mte.doc_ids.clone(), mte.res.clone()); + return Some(ThingIterator::Knn(it)); + } + } + None + } + pub(crate) async fn matches( &self, txn: &Transaction, @@ -406,3 +450,24 @@ impl FtEntry { } } } + +#[derive(Clone)] +pub(super) struct MtEntry { + doc_ids: Arc>, + res: VecDeque, +} + +impl MtEntry { + async fn new( + tx: &mut kvs::Transaction, + mt: &MTreeIndex, + a: Array, + k: u32, + ) -> Result { + let res = mt.knn_search(tx, a, k as usize).await?; + Ok(Self { + res, + doc_ids: mt.doc_ids(), + }) + } +} diff --git a/lib/src/idx/planner/iterators.rs b/lib/src/idx/planner/iterators.rs index 57d2cb8c..57aa7dd2 100644 --- a/lib/src/idx/planner/iterators.rs +++ b/lib/src/idx/planner/iterators.rs @@ -1,6 +1,6 @@ use crate::dbs::{Options, Transaction}; use crate::err::Error; -use crate::idx::ft::docids::{DocId, NO_DOC_ID}; +use crate::idx::docids::{DocId, DocIds, NO_DOC_ID}; use crate::idx::ft::termdocs::TermsDocs; use crate::idx::ft::{FtIndex, HitsIterator}; use crate::idx::planner::plan::RangeValue; @@ -8,6 +8,10 @@ use crate::key::index::Index; use crate::kvs::Key; use crate::sql::statements::DefineIndexStatement; use crate::sql::{Array, Thing, Value}; +use roaring::RoaringTreemap; +use std::collections::VecDeque; +use std::sync::Arc; +use tokio::sync::RwLock; pub(crate) enum ThingIterator { IndexEqual(IndexEqualThingIterator), @@ -15,6 +19,7 @@ pub(crate) enum ThingIterator { UniqueEqual(UniqueEqualThingIterator), UniqueRange(UniqueRangeThingIterator), Matches(MatchesThingIterator), + Knn(KnnThingIterator), } impl ThingIterator { @@ -29,6 +34,7 @@ impl ThingIterator { ThingIterator::IndexRange(i) => i.next_batch(tx, size).await, ThingIterator::UniqueRange(i) => i.next_batch(tx, size).await, ThingIterator::Matches(i) => i.next_batch(tx, size).await, + ThingIterator::Knn(i) => i.next_batch(tx, size).await, } } } @@ -307,3 +313,52 @@ impl MatchesThingIterator { Ok(res) } } + +pub(crate) struct KnnThingIterator { + doc_ids: Arc>, + res: VecDeque, + current: Option, + skip: RoaringTreemap, +} + +impl KnnThingIterator { + pub(super) fn new(doc_ids: Arc>, mut res: VecDeque) -> Self { + let current = res.pop_front(); + Self { + doc_ids, + res, + current, + skip: RoaringTreemap::new(), + } + } + async fn next_batch( + &mut self, + txn: &Transaction, + mut limit: u32, + ) -> Result, Error> { + let mut res = vec![]; + let mut tx = txn.lock().await; + while self.current.is_some() && limit > 0 { + if let Some(docs) = &mut self.current { + if let Some(doc_id) = docs.iter().next() { + docs.remove(doc_id); + if self.skip.insert(doc_id) { + if let Some(doc_key) = + self.doc_ids.read().await.get_doc_key(&mut tx, doc_id).await? + { + res.push((doc_key.into(), doc_id)); + limit -= 1; + } + } + if docs.is_empty() { + self.current = None; + } + } + } + if self.current.is_none() { + self.current = self.res.pop_front(); + } + } + Ok(res) + } +} diff --git a/lib/src/idx/planner/plan.rs b/lib/src/idx/planner/plan.rs index 6f66b7e8..1d1fac76 100644 --- a/lib/src/idx/planner/plan.rs +++ b/lib/src/idx/planner/plan.rs @@ -149,6 +149,7 @@ pub(super) enum IndexOperator { Equality(Array), RangePart(Operator, Value), Matches(String, Option), + Knn(Array, u32), } impl IndexOption { @@ -191,6 +192,10 @@ impl IndexOption { e.insert("operator", Value::from(op.to_string())); e.insert("value", v.to_owned()); } + IndexOperator::Knn(a, k) => { + e.insert("operator", Value::from(format!("<{}>", k))); + e.insert("value", Value::Array(a.clone())); + } }; } } diff --git a/lib/src/idx/planner/tree.rs b/lib/src/idx/planner/tree.rs index 86ac4569..fe6f36a8 100644 --- a/lib/src/idx/planner/tree.rs +++ b/lib/src/idx/planner/tree.rs @@ -102,10 +102,10 @@ impl<'a> TreeBuilder<'a> { match v { Value::Expression(e) => self.eval_expression(e).await, Value::Idiom(i) => self.eval_idiom(i).await, - Value::Strand(_) => Ok(Node::Scalar(v.to_owned())), - Value::Number(_) => Ok(Node::Scalar(v.to_owned())), - Value::Bool(_) => Ok(Node::Scalar(v.to_owned())), - Value::Thing(_) => Ok(Node::Scalar(v.to_owned())), + Value::Strand(_) | Value::Number(_) | Value::Bool(_) | Value::Thing(_) => { + Ok(Node::Scalar(v.to_owned())) + } + Value::Array(a) => Ok(self.eval_array(a)), Value::Subquery(s) => self.eval_subquery(s).await, Value::Param(p) => { let v = p.compute(self.ctx, self.opt, self.txn, None).await?; @@ -115,6 +115,16 @@ impl<'a> TreeBuilder<'a> { } } + fn eval_array(&mut self, a: &Array) -> Node { + // Check if it is a numeric vector + for v in &a.0 { + if !v.is_number() { + return Node::Unsupported(format!("Unsupported array: {}", a)); + } + } + Node::Vector(a.to_owned()) + } + async fn eval_idiom(&mut self, i: &Idiom) -> Result { if let Some(irs) = self.find_indexes(i).await? { if !irs.is_empty() { @@ -165,45 +175,61 @@ impl<'a> TreeBuilder<'a> { irs: &[IndexRef], op: &Operator, id: &Idiom, - v: &Node, + n: &Node, e: &Expression, ) -> Option { - if let Some(v) = v.is_scalar() { - for ir in irs { - if let Some(ix) = self.index_map.definitions.get(ir) { - let op = match &ix.index { - Index::Idx => Self::eval_index_operator(op, v), - Index::Uniq => Self::eval_index_operator(op, v), - Index::Search { - .. - } => { + for ir in irs { + if let Some(ix) = self.index_map.definitions.get(ir) { + let op = match &ix.index { + Index::Idx => Self::eval_index_operator(op, n), + Index::Uniq => Self::eval_index_operator(op, n), + Index::Search { + .. + } => { + if let Some(v) = n.is_scalar() { if let Operator::Matches(mr) = op { Some(IndexOperator::Matches(v.clone().to_raw_string(), *mr)) } else { None } + } else { + None } - Index::MTree(_) => None, - }; - if let Some(op) = op { - let io = IndexOption::new(*ir, id.clone(), op); - self.index_map.options.insert(Arc::new(e.clone()), io.clone()); - return Some(io); } + Index::MTree(_) => { + if let Operator::Knn(k) = op { + if let Node::Vector(a) = n { + Some(IndexOperator::Knn(a.clone(), *k)) + } else { + None + } + } else { + None + } + } + }; + if let Some(op) = op { + let io = IndexOption::new(*ir, id.clone(), op); + self.index_map.options.insert(Arc::new(e.clone()), io.clone()); + return Some(io); } } } None } - fn eval_index_operator(op: &Operator, v: &Value) -> Option { - match op { - Operator::Equal => Some(IndexOperator::Equality(Array::from(v.clone()))), - Operator::LessThan - | Operator::LessThanOrEqual - | Operator::MoreThan - | Operator::MoreThanOrEqual => Some(IndexOperator::RangePart(op.clone(), v.clone())), - _ => None, + fn eval_index_operator(op: &Operator, n: &Node) -> Option { + if let Some(v) = n.is_scalar() { + match op { + Operator::Equal => Some(IndexOperator::Equality(Array::from(v.clone()))), + Operator::LessThan + | Operator::LessThanOrEqual + | Operator::MoreThan + | Operator::MoreThanOrEqual => Some(IndexOperator::RangePart(op.clone(), v.clone())), + _ => None, + } + } else { + None } } @@ -235,6 +261,7 @@ pub(super) enum Node { IndexedField(Idiom, Arc>), NonIndexedField, Scalar(Value), + Vector(Array), Unsupported(String), } diff --git a/lib/src/idx/trees/btree.rs b/lib/src/idx/trees/btree.rs index 6f5a33a7..abf0bbb5 100644 --- a/lib/src/idx/trees/btree.rs +++ b/lib/src/idx/trees/btree.rs @@ -21,7 +21,6 @@ where { state: BState, full_size: u32, - updated: bool, bk: PhantomData, } @@ -31,6 +30,8 @@ pub struct BState { minimum_degree: u32, root: Option, next_node_id: NodeId, + #[serde(skip)] + updated: bool, } impl VersionedSerdeState for BState {} @@ -42,8 +43,34 @@ impl BState { minimum_degree, root: None, next_node_id: 0, + updated: false, } } + + fn set_root(&mut self, node_id: Option) { + if node_id.ne(&self.root) { + self.root = node_id; + self.updated = true; + } + } + + fn new_node_id(&mut self) -> NodeId { + let new_node_id = self.next_node_id; + self.next_node_id += 1; + self.updated = true; + new_node_id + } + + pub(in crate::idx) async fn finish( + &self, + tx: &mut Transaction, + key: &Key, + ) -> Result<(), Error> { + if self.updated { + tx.set(key.clone(), self.try_to_val()?).await?; + } + Ok(()) + } } #[derive(Debug, Default, PartialEq)] @@ -166,7 +193,6 @@ where Self { full_size: state.minimum_degree * 2 - 1, state, - updated: false, bk: PhantomData, } } @@ -180,11 +206,11 @@ where let mut next_node = self.state.root; while let Some(node_id) = next_node.take() { let current = store.get_node(tx, node_id).await?; - if let Some(payload) = current.node.keys().get(searched_key) { + if let Some(payload) = current.n.keys().get(searched_key) { store.set_node(current, false)?; return Ok(Some(payload)); } - if let BTreeNode::Internal(keys, children) = ¤t.node { + if let BTreeNode::Internal(keys, children) = ¤t.n { let child_idx = keys.get_child_idx(searched_key); next_node.replace(children[child_idx]); } @@ -201,27 +227,30 @@ where payload: Payload, ) -> Result<(), Error> { if let Some(root_id) = self.state.root { + // We already have a root node let root = store.get_node(tx, root_id).await?; - if root.node.keys().len() == self.full_size { - let new_root_id = self.new_node_id(); + if root.n.keys().len() == self.full_size { + // The root node is full, let's split it + let new_root_id = self.state.new_node_id(); let new_root = store .new_node(new_root_id, BTreeNode::Internal(BK::default(), vec![root_id]))?; - self.state.root = Some(new_root.id); + self.state.set_root(Some(new_root.id)); self.split_child(store, new_root, 0, root).await?; self.insert_non_full(tx, store, new_root_id, key, payload).await?; } else { + // The root node has place, let's insert the value let root_id = root.id; store.set_node(root, false)?; self.insert_non_full(tx, store, root_id, key, payload).await?; } } else { - let new_root_id = self.new_node_id(); + // We don't have a root node, let's create id + let new_root_id = self.state.new_node_id(); let new_root_node = store.new_node(new_root_id, BTreeNode::Leaf(BK::with_key_val(key, payload)?))?; store.set_node(new_root_node, true)?; - self.state.root = Some(new_root_id); + self.state.set_root(Some(new_root_id)); } - self.updated = true; Ok(()) } @@ -237,7 +266,7 @@ where while let Some(node_id) = next_node_id.take() { let mut node = store.get_node(tx, node_id).await?; let key: Key = key.clone(); - match &mut node.node { + match &mut node.n { BTreeNode::Leaf(keys) => { keys.insert(key, payload); store.set_node(node, true)?; @@ -250,7 +279,7 @@ where } let child_idx = keys.get_child_idx(&key); let child = store.get_node(tx, children[child_idx]).await?; - let next_id = if child.node.keys().len() == self.full_size { + let next_id = if child.n.keys().len() == self.full_size { let split_result = self.split_child(store, node, child_idx, child).await?; if key.gt(&split_result.median_key) { split_result.right_node_id @@ -277,12 +306,12 @@ where idx: usize, child_node: BStoredNode, ) -> Result { - let (left_node, right_node, median_key, median_payload) = match child_node.node { + let (left_node, right_node, median_key, median_payload) = match child_node.n { BTreeNode::Internal(keys, children) => self.split_internal_node(keys, children)?, BTreeNode::Leaf(keys) => self.split_leaf_node(keys)?, }; - let right_node_id = self.new_node_id(); - match parent_node.node { + let right_node_id = self.state.new_node_id(); + match parent_node.n { BTreeNode::Internal(ref mut keys, ref mut children) => { keys.insert(median_key.clone(), median_payload); children.insert(idx + 1, right_node_id); @@ -329,12 +358,6 @@ where Ok((left_node, right_node, r.median_key, r.median_payload)) } - fn new_node_id(&mut self) -> NodeId { - let new_node_id = self.state.next_node_id; - self.state.next_node_id += 1; - new_node_id - } - pub(in crate::idx) async fn delete( &mut self, tx: &mut Transaction, @@ -348,7 +371,7 @@ where while let Some((is_main_key, key_to_delete, node_id)) = next_node.take() { let mut node = store.get_node(tx, node_id).await?; - match &mut node.node { + match &mut node.n { BTreeNode::Leaf(keys) => { // CLRS: 1 if let Some(payload) = keys.get(&key_to_delete) { @@ -361,12 +384,11 @@ where store.remove_node(node.id, node.key)?; // Check if this was the root node if Some(node_id) == self.state.root { - self.state.root = None; + self.state.set_root(None); } } else { store.set_node(node, true)?; } - self.updated = true; } else { store.set_node(node, false)?; } @@ -388,7 +410,6 @@ where .await?, ); store.set_node(node, true)?; - self.updated = true; } else { // CLRS: 3 let (node_update, is_main_key, key_to_delete, next_stored_node) = self @@ -409,11 +430,9 @@ where } } store.remove_node(node_id, node.key)?; - self.state.root = Some(next_stored_node); - self.updated = true; + self.state.set_root(Some(next_stored_node)); } else if node_update { store.set_node(node, true)?; - self.updated = true; } else { store.set_node(node, false)?; } @@ -437,9 +456,9 @@ where let left_idx = keys.get_child_idx(&key_to_delete); let left_id = children[left_idx]; let mut left_node = store.get_node(tx, left_id).await?; - if left_node.node.keys().len() >= self.state.minimum_degree { + if left_node.n.keys().len() >= self.state.minimum_degree { // CLRS: 2a -> left_node is named `y` in the book - if let Some((key_prim, payload_prim)) = left_node.node.keys().get_last_key() { + if let Some((key_prim, payload_prim)) = left_node.n.keys().get_last_key() { keys.remove(&key_to_delete); keys.insert(key_prim.clone(), payload_prim); store.set_node(left_node, true)?; @@ -450,9 +469,9 @@ where let right_idx = left_idx + 1; let right_id = children[right_idx]; let right_node = store.get_node(tx, right_id).await?; - if right_node.node.keys().len() >= self.state.minimum_degree { + if right_node.n.keys().len() >= self.state.minimum_degree { // CLRS: 2b -> right_node is name `z` in the book - if let Some((key_prim, payload_prim)) = right_node.node.keys().get_first_key() { + if let Some((key_prim, payload_prim)) = right_node.n.keys().get_first_key() { keys.remove(&key_to_delete); keys.insert(key_prim.clone(), payload_prim); store.set_node(left_node, false)?; @@ -464,7 +483,7 @@ where // CLRS: 2c // Merge children // The payload is set to 0. The value does not matter, as the key will be deleted after anyway. - left_node.node.append(key_to_delete.clone(), 0, right_node.node)?; + left_node.n.append(key_to_delete.clone(), 0, right_node.n)?; store.set_node(left_node, true)?; store.remove_node(right_id, right_node.key)?; keys.remove(&key_to_delete); @@ -485,11 +504,11 @@ where let child_idx = keys.get_child_idx(&key_to_delete); let child_id = children[child_idx]; let child_stored_node = store.get_node(tx, child_id).await?; - if child_stored_node.node.keys().len() < self.state.minimum_degree { + if child_stored_node.n.keys().len() < self.state.minimum_degree { // right child (successor) if child_idx < children.len() - 1 { let right_child_stored_node = store.get_node(tx, children[child_idx + 1]).await?; - return if right_child_stored_node.node.keys().len() >= self.state.minimum_degree { + return if right_child_stored_node.n.keys().len() >= self.state.minimum_degree { Self::delete_adjust_successor( store, keys, @@ -520,7 +539,7 @@ where if child_idx > 0 { let child_idx = child_idx - 1; let left_child_stored_node = store.get_node(tx, children[child_idx]).await?; - return if left_child_stored_node.node.keys().len() >= self.state.minimum_degree { + return if left_child_stored_node.n.keys().len() >= self.state.minimum_degree { Self::delete_adjust_predecessor( store, keys, @@ -562,12 +581,12 @@ where mut right_child_stored_node: BStoredNode, ) -> Result<(bool, bool, Key, NodeId), Error> { if let Some((ascending_key, ascending_payload)) = - right_child_stored_node.node.keys().get_first_key() + right_child_stored_node.n.keys().get_first_key() { - right_child_stored_node.node.keys_mut().remove(&ascending_key); + right_child_stored_node.n.keys_mut().remove(&ascending_key); if let Some(descending_key) = keys.get_key(child_idx) { if let Some(descending_payload) = keys.remove(&descending_key) { - child_stored_node.node.keys_mut().insert(descending_key, descending_payload); + child_stored_node.n.keys_mut().insert(descending_key, descending_payload); keys.insert(ascending_key, ascending_payload); let child_id = child_stored_node.id; store.set_node(child_stored_node, true)?; @@ -590,12 +609,12 @@ where mut left_child_stored_node: BStoredNode, ) -> Result<(bool, bool, Key, NodeId), Error> { if let Some((ascending_key, ascending_payload)) = - left_child_stored_node.node.keys().get_last_key() + left_child_stored_node.n.keys().get_last_key() { - left_child_stored_node.node.keys_mut().remove(&ascending_key); + left_child_stored_node.n.keys_mut().remove(&ascending_key); if let Some(descending_key) = keys.get_key(child_idx) { if let Some(descending_payload) = keys.remove(&descending_key) { - child_stored_node.node.keys_mut().insert(descending_key, descending_payload); + child_stored_node.n.keys_mut().insert(descending_key, descending_payload); keys.insert(ascending_key, ascending_payload); let child_id = child_stored_node.id; store.set_node(child_stored_node, true)?; @@ -623,7 +642,7 @@ where if let Some(descending_payload) = keys.remove(&descending_key) { children.remove(child_idx + 1); let left_id = left_child.id; - left_child.node.append(descending_key, descending_payload, right_child.node)?; + left_child.n.append(descending_key, descending_payload, right_child.n)?; store.set_node(left_child, true)?; store.remove_node(right_child.id, right_child.key)?; return Ok((true, is_main_key, key_to_delete, left_id)); @@ -645,13 +664,13 @@ where } while let Some((node_id, depth)) = node_queue.pop_front() { let stored = store.get_node(tx, node_id).await?; - stats.keys_count += stored.node.keys().len() as u64; + stats.keys_count += stored.n.keys().len() as u64; if depth > stats.max_depth { stats.max_depth = depth; } stats.nodes_count += 1; stats.total_size += stored.size as u64; - if let BTreeNode::Internal(_, children) = &stored.node { + if let BTreeNode::Internal(_, children) = &stored.n { let depth = depth + 1; for child_id in children.iter() { node_queue.push_front((*child_id, depth)); @@ -665,10 +684,6 @@ where pub(in crate::idx) fn get_state(&self) -> &BState { &self.state } - - pub(in crate::idx) fn is_updated(&self) -> bool { - self.updated - } } #[cfg(test)] @@ -1032,13 +1047,13 @@ mod tests { 0 => { assert_eq!(depth, 1); assert_eq!(node_id, 7); - check_is_internal_node(node.node, vec![("p", 16)], vec![1, 8]); + check_is_internal_node(node.n, vec![("p", 16)], vec![1, 8]); } 1 => { assert_eq!(depth, 2); assert_eq!(node_id, 1); check_is_internal_node( - node.node, + node.n, vec![("c", 3), ("g", 7), ("m", 13)], vec![0, 9, 2, 3], ); @@ -1046,42 +1061,42 @@ mod tests { 2 => { assert_eq!(depth, 2); assert_eq!(node_id, 8); - check_is_internal_node(node.node, vec![("t", 20), ("x", 24)], vec![4, 6, 5]); + check_is_internal_node(node.n, vec![("t", 20), ("x", 24)], vec![4, 6, 5]); } 3 => { assert_eq!(depth, 3); assert_eq!(node_id, 0); - check_is_leaf_node(node.node, vec![("a", 1), ("b", 2)]); + check_is_leaf_node(node.n, vec![("a", 1), ("b", 2)]); } 4 => { assert_eq!(depth, 3); assert_eq!(node_id, 9); - check_is_leaf_node(node.node, vec![("d", 4), ("e", 5), ("f", 6)]); + check_is_leaf_node(node.n, vec![("d", 4), ("e", 5), ("f", 6)]); } 5 => { assert_eq!(depth, 3); assert_eq!(node_id, 2); - check_is_leaf_node(node.node, vec![("j", 10), ("k", 11), ("l", 12)]); + check_is_leaf_node(node.n, vec![("j", 10), ("k", 11), ("l", 12)]); } 6 => { assert_eq!(depth, 3); assert_eq!(node_id, 3); - check_is_leaf_node(node.node, vec![("n", 14), ("o", 15)]); + check_is_leaf_node(node.n, vec![("n", 14), ("o", 15)]); } 7 => { assert_eq!(depth, 3); assert_eq!(node_id, 4); - check_is_leaf_node(node.node, vec![("q", 17), ("r", 18), ("s", 19)]); + check_is_leaf_node(node.n, vec![("q", 17), ("r", 18), ("s", 19)]); } 8 => { assert_eq!(depth, 3); assert_eq!(node_id, 6); - check_is_leaf_node(node.node, vec![("u", 21), ("v", 22)]); + check_is_leaf_node(node.n, vec![("u", 21), ("v", 22)]); } 9 => { assert_eq!(depth, 3); assert_eq!(node_id, 5); - check_is_leaf_node(node.node, vec![("y", 25), ("z", 26)]); + check_is_leaf_node(node.n, vec![("y", 25), ("z", 26)]); } _ => panic!("This node should not exist {}", count), }) @@ -1135,13 +1150,13 @@ mod tests { let nodes_count = t .inspect_nodes(&mut tx, |count, depth, node_id, node| { debug!("{} -> {}", depth, node_id); - node.node.debug(|k| Ok(String::from_utf8(k)?)).unwrap(); + node.n.debug(|k| Ok(String::from_utf8(k)?)).unwrap(); match count { 0 => { assert_eq!(depth, 1); assert_eq!(node_id, 1); check_is_internal_node( - node.node, + node.n, vec![("e", 5), ("l", 12), ("p", 16), ("t", 20), ("x", 24)], vec![0, 9, 3, 4, 6, 5], ); @@ -1149,32 +1164,32 @@ mod tests { 1 => { assert_eq!(depth, 2); assert_eq!(node_id, 0); - check_is_leaf_node(node.node, vec![("a", 1), ("c", 3)]); + check_is_leaf_node(node.n, vec![("a", 1), ("c", 3)]); } 2 => { assert_eq!(depth, 2); assert_eq!(node_id, 9); - check_is_leaf_node(node.node, vec![("j", 10), ("k", 11)]); + check_is_leaf_node(node.n, vec![("j", 10), ("k", 11)]); } 3 => { assert_eq!(depth, 2); assert_eq!(node_id, 3); - check_is_leaf_node(node.node, vec![("n", 14), ("o", 15)]); + check_is_leaf_node(node.n, vec![("n", 14), ("o", 15)]); } 4 => { assert_eq!(depth, 2); assert_eq!(node_id, 4); - check_is_leaf_node(node.node, vec![("q", 17), ("r", 18), ("s", 19)]); + check_is_leaf_node(node.n, vec![("q", 17), ("r", 18), ("s", 19)]); } 5 => { assert_eq!(depth, 2); assert_eq!(node_id, 6); - check_is_leaf_node(node.node, vec![("u", 21), ("v", 22)]); + check_is_leaf_node(node.n, vec![("u", 21), ("v", 22)]); } 6 => { assert_eq!(depth, 2); assert_eq!(node_id, 5); - check_is_leaf_node(node.node, vec![("y", 25), ("z", 26)]); + check_is_leaf_node(node.n, vec![("y", 25), ("z", 26)]); } _ => panic!("This node should not exist {}", count), } @@ -1316,7 +1331,7 @@ mod tests { debug!("----------------------------------"); t.inspect_nodes(tx, |_count, depth, node_id, node| { debug!("{} -> {}", depth, node_id); - node.node.debug(|k| Ok(String::from_utf8(k)?)).unwrap(); + node.n.debug(|k| Ok(String::from_utf8(k)?)).unwrap(); }) .await .unwrap(); @@ -1359,7 +1374,7 @@ mod tests { let mut s = TreeNodeStore::Traversal(TreeNodeProvider::Debug); while let Some((node_id, depth)) = node_queue.pop_front() { let stored_node = s.get_node(tx, node_id).await?; - if let BTreeNode::Internal(_, children) = &stored_node.node { + if let BTreeNode::Internal(_, children) = &stored_node.n { let depth = depth + 1; for child_id in children { node_queue.push_back((*child_id, depth)); diff --git a/lib/src/idx/trees/mod.rs b/lib/src/idx/trees/mod.rs index 6bc0fae7..d0247335 100644 --- a/lib/src/idx/trees/mod.rs +++ b/lib/src/idx/trees/mod.rs @@ -1,3 +1,4 @@ pub mod bkeys; pub mod btree; +pub mod mtree; pub mod store; diff --git a/lib/src/idx/trees/mtree.rs b/lib/src/idx/trees/mtree.rs new file mode 100644 index 00000000..9e2f83b7 --- /dev/null +++ b/lib/src/idx/trees/mtree.rs @@ -0,0 +1,1792 @@ +use crate::err::Error; +use crate::fnc::util::math::vector::{ + CosineSimilarity, EuclideanDistance, HammingDistance, ManhattanDistance, MinkowskiDistance, +}; +use crate::idx::docids::{DocId, DocIds}; +use crate::idx::trees::btree::BStatistics; +use crate::idx::trees::store::{ + NodeId, StoredNode, TreeNode, TreeNodeProvider, TreeNodeStore, TreeStoreType, +}; +use crate::idx::{IndexKeyBase, VersionedSerdeState}; +use crate::kvs::{Key, Transaction, Val}; +use crate::sql::index::{Distance, MTreeParams}; +use crate::sql::{Array, Number, Object, Thing, Value}; +use async_recursion::async_recursion; +use indexmap::map::Entry; +use indexmap::IndexMap; +use revision::revisioned; +use roaring::RoaringTreemap; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::collections::{BTreeMap, BinaryHeap, VecDeque}; +use std::io::Cursor; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; + +pub(crate) type Vector = Vec; + +type MTreeNodeStore = TreeNodeStore; + +type LeafIndexMap = IndexMap, ObjectProperties>; + +pub(crate) struct MTreeIndex { + state_key: Key, + dim: usize, + doc_ids: Arc>, + mtree: Arc>, + store: Arc>, +} + +impl MTreeIndex { + pub(crate) async fn new( + tx: &mut Transaction, + ikb: IndexKeyBase, + p: &MTreeParams, + st: TreeStoreType, + ) -> Result { + let doc_ids = + Arc::new(RwLock::new(DocIds::new(tx, ikb.clone(), p.doc_ids_order, st).await?)); + let state_key = ikb.new_vm_key(None); + let state: MState = if let Some(val) = tx.get(state_key.clone()).await? { + MState::try_from_val(val)? + } else { + MState::new(p.capacity) + }; + + let store = TreeNodeStore::new(TreeNodeProvider::Vector(ikb), st, 20); + let mtree = Arc::new(RwLock::new(MTree::new(state, p.distance.clone()))); + Ok(Self { + state_key, + dim: p.dimension as usize, + doc_ids, + mtree, + store, + }) + } + + pub(crate) async fn index_document( + &mut self, + tx: &mut Transaction, + rid: &Thing, + content: Vec, + ) -> Result<(), Error> { + // Resolve the doc_id + let resolved = self.doc_ids.write().await.resolve_doc_id(tx, rid.into()).await?; + let doc_id = *resolved.doc_id(); + // Index the values + let mut store = self.store.lock().await; + let mut mtree = self.mtree.write().await; + for v in content { + // Extract the vector + let vector = self.check_vector_value(v)?; + mtree.insert(tx, &mut store, vector, doc_id).await?; + } + Ok(()) + } + + pub(crate) async fn knn_search( + &self, + tx: &mut Transaction, + a: Array, + k: usize, + ) -> Result, Error> { + // Extract the vector + let vector = self.check_vector_array(a)?; + // Lock the store + let mut store = self.store.lock().await; + let res = self.mtree.read().await.knn_search(tx, &mut store, &vector, k).await?; + Ok(res.objects) + } + + fn check_vector_array(&self, a: Array) -> Result { + if a.0.len() != self.dim { + return Err(Error::InvalidVectorDimension { + current: a.0.len(), + expected: self.dim, + }); + } + let mut vec = Vec::with_capacity(a.len()); + for v in a.0 { + if let Value::Number(n) = v { + vec.push(n); + } else { + return Err(Error::InvalidVectorType { + current: v.clone().to_string(), + expected: "Number", + }); + } + } + Ok(vec) + } + + fn check_vector_value(&self, v: Value) -> Result { + if let Value::Array(a) = v { + self.check_vector_array(a) + } else { + Err(Error::InvalidVectorValue { + current: v.clone().to_raw_string(), + }) + } + } + + pub(crate) async fn remove_document( + &mut self, + tx: &mut Transaction, + rid: &Thing, + content: Vec, + ) -> Result<(), Error> { + if let Some(doc_id) = self.doc_ids.write().await.remove_doc(tx, rid.into()).await? { + // Index the values + let mut store = self.store.lock().await; + let mut mtree = self.mtree.write().await; + for v in content { + // Extract the vector + let vector = self.check_vector_value(v)?; + mtree.delete(tx, &mut store, vector, doc_id).await?; + } + } + Ok(()) + } + + pub(in crate::idx) fn doc_ids(&self) -> Arc> { + self.doc_ids.clone() + } + + pub(crate) async fn statistics(&self, tx: &mut Transaction) -> Result { + Ok(MtStatistics { + doc_ids: self.doc_ids.read().await.statistics(tx).await?, + }) + } + + pub(crate) async fn finish(self, tx: &mut Transaction) -> Result<(), Error> { + self.doc_ids.write().await.finish(tx).await?; + self.store.lock().await.finish(tx).await?; + self.mtree.write().await.finish(tx, self.state_key).await?; + Ok(()) + } +} + +struct KnnResult { + objects: VecDeque, + #[cfg(debug_assertions)] + #[allow(dead_code)] + visited_nodes: usize, +} + +// https://en.wikipedia.org/wiki/M-tree +// https://arxiv.org/pdf/1004.4216.pdf +struct MTree { + state: MState, + distance: Distance, + minimum: usize, + updated: bool, +} + +impl MTree { + fn new(state: MState, distance: Distance) -> Self { + let minimum = state.capacity as usize / 2; + Self { + state, + distance, + minimum, + updated: false, + } + } + + async fn knn_search( + &self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + v: &Vector, + k: usize, + ) -> Result { + let mut queue = BinaryHeap::new(); + let mut res = BTreeMap::new(); + if let Some(root_id) = self.state.root { + queue.push(PriorityNode(0.0, root_id)); + } + #[cfg(debug_assertions)] + let mut visited_nodes = 0; + while let Some(current) = queue.pop() { + #[cfg(debug_assertions)] + { + visited_nodes += 1; + } + let node = store.get_node(tx, current.1).await?; + match node.n { + MTreeNode::Leaf(ref n) => { + for (o, p) in n { + let d = self.calculate_distance(o.as_ref(), v); + if Self::check_add(k, d, &res) { + res.insert(PriorityResult(d, o.clone()), p.docs.clone()); + if res.len() > k { + res.pop_last(); + } + } + } + } + MTreeNode::Internal(ref n) => { + for entry in n { + let d = self.calculate_distance(entry.center.as_ref(), v); + let min_dist = (d - entry.radius).max(0.0); + if Self::check_add(k, min_dist, &res) { + queue.push(PriorityNode(min_dist, entry.node)); + } + } + } + } + store.set_node(node, false)?; + } + let mut objects = VecDeque::with_capacity(res.len()); + for (_, d) in res { + objects.push_back(d); + } + Ok(KnnResult { + objects, + #[cfg(debug_assertions)] + visited_nodes, + }) + } + + fn check_add(k: usize, dist: f64, res: &BTreeMap) -> bool { + if res.len() < k { + true + } else if let Some(l) = res.keys().last() { + dist < l.0 + } else { + true + } + } +} + +enum InsertionResult { + DocAdded, + CoveringRadius(f64), + PromotedEntries(RoutingEntry, RoutingEntry), +} + +enum DeletionResult { + NotFound, + DocRemoved, + CoveringRadius(f64), + Underflown(NodeId, Key, MTreeNode), +} + +// Insertion +impl MTree { + fn new_node_id(&mut self) -> NodeId { + let new_node_id = self.state.next_node_id; + self.state.next_node_id += 1; + new_node_id + } + + async fn insert( + &mut self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + v: Vec, + id: DocId, + ) -> Result<(), Error> { + if let Some(root_id) = self.state.root { + let node = store.get_node(tx, root_id).await?; + if let InsertionResult::PromotedEntries(r1, r2) = + self.insert_at_node(tx, store, node, &None, Arc::new(v), id).await? + { + self.create_new_internal_root(store, r1, r2)?; + } + } else { + self.create_new_leaf_root(store, v, id)?; + } + Ok(()) + } + + fn create_new_leaf_root( + &mut self, + store: &mut MTreeNodeStore, + v: Vec, + id: DocId, + ) -> Result<(), Error> { + let new_root_id = self.new_node_id(); + let p = ObjectProperties::new_root(id); + let mut objects = LeafIndexMap::with_capacity(1); + objects.insert(Arc::new(v), p); + let new_root_node = store.new_node(new_root_id, MTreeNode::Leaf(objects))?; + store.set_node(new_root_node, true)?; + self.state.root = Some(new_root_id); + self.updated = true; + Ok(()) + } + + fn create_new_internal_root( + &mut self, + store: &mut MTreeNodeStore, + r1: RoutingEntry, + r2: RoutingEntry, + ) -> Result<(), Error> { + let new_root_id = self.new_node_id(); + let new_root_node = store.new_node(new_root_id, MTreeNode::Internal(vec![r1, r2]))?; + store.set_node(new_root_node, true)?; + self.state.root = Some(new_root_id); + self.updated = true; + Ok(()) + } + + #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] + #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))] + async fn insert_at_node( + &mut self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + node: StoredNode, + parent_center: &Option>, + object: Arc, + id: DocId, + ) -> Result { + match node.n { + MTreeNode::Internal(n) => { + self.insert_node_internal( + tx, + store, + node.id, + node.key, + n, + parent_center, + object, + id, + ) + .await + } + MTreeNode::Leaf(n) => { + self.insert_node_leaf(store, node.id, node.key, n, parent_center, object, id) + } + } + } + + #[allow(clippy::too_many_arguments)] + async fn insert_node_internal( + &mut self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + node_id: NodeId, + node_key: Key, + mut node: InternalNode, + parent_center: &Option>, + object: Arc, + id: DocId, + ) -> Result { + let best_entry_idx = self.find_closest(&node, &object)?; + let best_entry = &mut node[best_entry_idx]; + let best_node = store.get_node(tx, best_entry.node).await?; + match self + .insert_at_node(tx, store, best_node, &Some(best_entry.center.clone()), object, id) + .await? + { + InsertionResult::PromotedEntries(p1, p2) => { + node.remove(best_entry_idx); + node.push(p1); + node.push(p2); + if node.len() <= self.state.capacity as usize { + let max_dist = self.compute_internal_max_distance(&node, parent_center); + store.set_node( + StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), + true, + )?; + return Ok(InsertionResult::CoveringRadius(max_dist)); + } + self.split_node(store, node_id, node_key, node) + } + InsertionResult::DocAdded => { + store.set_node( + StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), + false, + )?; + Ok(InsertionResult::DocAdded) + } + InsertionResult::CoveringRadius(covering_radius) => { + let mut updated = false; + if covering_radius > best_entry.radius { + best_entry.radius = covering_radius; + updated = true; + } + let max_dist = self.compute_internal_max_distance(&node, parent_center); + store.set_node( + StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), + updated, + )?; + Ok(InsertionResult::CoveringRadius(max_dist)) + } + } + } + + fn find_closest(&self, node: &InternalNode, object: &Vector) -> Result { + let mut idx = 0; + let dist = f64::MAX; + for (i, e) in node.iter().enumerate() { + let d = self.calculate_distance(e.center.as_ref(), object); + if d < dist { + idx = i; + } + } + Ok(idx) + } + + #[allow(clippy::too_many_arguments)] + fn insert_node_leaf( + &mut self, + store: &mut MTreeNodeStore, + node_id: NodeId, + node_key: Key, + mut node: LeafNode, + parent_center: &Option>, + object: Arc, + id: DocId, + ) -> Result { + match node.entry(object) { + Entry::Occupied(mut e) => { + e.get_mut().docs.insert(id); + store.set_node( + StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), + true, + )?; + return Ok(InsertionResult::DocAdded); + } + Entry::Vacant(e) => { + let d = parent_center + .as_ref() + .map_or(0f64, |v| self.calculate_distance(v.as_ref(), e.key())); + e.insert(ObjectProperties::new(d, id)); + } + }; + if node.len() <= self.state.capacity as usize { + let max_dist = self.compute_leaf_max_distance(&node, parent_center); + store.set_node(StoredNode::new(node.into_mtree_node(), node_id, node_key, 0), true)?; + Ok(InsertionResult::CoveringRadius(max_dist)) + } else { + self.split_node(store, node_id, node_key, node) + } + } + + fn split_node( + &mut self, + store: &mut MTreeNodeStore, + node_id: NodeId, + node_key: Key, + node: N, + ) -> Result + where + N: NodeVectors, + { + let distances = self.compute_distance_matrix(&node)?; + let (p1_idx, p2_idx) = Self::select_promotion_objects(&distances); + let p1_obj = node.get_vector(p1_idx)?; + let p2_obj = node.get_vector(p2_idx)?; + + // Distribute entries, update parent_dist and calculate radius + let (node1, r1, node2, r2) = node.distribute_entries(&distances, p1_idx, p2_idx)?; + + // Create a new node + let new_node = self.new_node_id(); + + // Update the store/cache + let n = StoredNode::new(node1.into_mtree_node(), node_id, node_key, 0); + store.set_node(n, true)?; + let n = store.new_node(new_node, node2.into_mtree_node())?; + store.set_node(n, true)?; + + // Update the split node + let r1 = RoutingEntry { + node: node_id, + center: p1_obj, + radius: r1, + }; + let r2 = RoutingEntry { + node: new_node, + center: p2_obj, + radius: r2, + }; + Ok(InsertionResult::PromotedEntries(r1, r2)) + } + + fn select_promotion_objects(distances: &[Vec]) -> (usize, usize) { + let mut promo = (0, 1); + let mut max_distance = distances[0][1]; + // Compare each pair of objects + let n = distances.len(); + #[allow(clippy::needless_range_loop)] + for i in 0..n { + for j in i + 1..n { + let distance = distances[i][j]; + // If this pair is further apart than the current maximum, update the promotion objects + if distance > max_distance { + promo = (i, j); + max_distance = distance; + } + } + } + promo + } + + fn compute_internal_max_distance( + &self, + node: &InternalNode, + parent: &Option>, + ) -> f64 { + parent.as_ref().map_or(0.0, |p| { + let mut max_dist = 0f64; + for e in node { + max_dist = max_dist.max(self.calculate_distance(p.as_ref(), e.center.as_ref())); + } + max_dist + }) + } + + fn compute_leaf_max_distance(&self, node: &LeafNode, parent: &Option>) -> f64 { + parent.as_ref().map_or(0.0, |p| { + let mut max_dist = 0f64; + for o in node.keys() { + max_dist = max_dist.max(self.calculate_distance(p.as_ref(), o.as_ref())); + } + max_dist + }) + } + + fn compute_distance_matrix(&self, vectors: &N) -> Result>, Error> + where + N: NodeVectors, + { + let n = vectors.len(); + let mut distances = vec![vec![0.0; n]; n]; + for i in 0..n { + let v1 = vectors.get_vector(i)?; + for j in i + 1..n { + let v2 = vectors.get_vector(j)?; + let distance = self.calculate_distance(v1.as_ref(), v2.as_ref()); + distances[i][j] = distance; + distances[j][i] = distance; // Because the distance function is symmetric + } + } + Ok(distances) + } + + fn calculate_distance(&self, v1: &Vector, v2: &Vector) -> f64 { + match &self.distance { + Distance::Euclidean => v1.euclidean_distance(v2).unwrap().as_float(), + Distance::Manhattan => v1.manhattan_distance(v2).unwrap().as_float(), + Distance::Cosine => v1.cosine_similarity(v2).unwrap().as_float(), + Distance::Hamming => v1.hamming_distance(v2).unwrap().as_float(), + Distance::Mahalanobis => v1.manhattan_distance(v2).unwrap().as_float(), + Distance::Minkowski(order) => v1.minkowski_distance(v2, order).unwrap().as_float(), + } + } + + async fn delete( + &mut self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + object: Vector, + doc_id: DocId, + ) -> Result { + if let Some(root_id) = self.state.root { + let node = store.get_node(tx, root_id).await?; + match self.delete_at_node(tx, store, node, &None, Arc::new(object), doc_id).await? { + DeletionResult::DocRemoved => Ok(true), + DeletionResult::CoveringRadius(_) | DeletionResult::NotFound => Ok(false), + DeletionResult::Underflown(id, key, n) => { + let sn = StoredNode::new(n, id, key, 0); + store.set_node(sn, true)?; + Ok(true) + } + } + } else { + Ok(false) + } + } + + #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] + #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))] + async fn delete_at_node( + &mut self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + node: StoredNode, + parent_center: &Option>, + object: Arc, + id: DocId, + ) -> Result { + match node.n { + MTreeNode::Internal(n) => { + self.delete_node_internal( + tx, + store, + node.id, + node.key, + n, + parent_center, + object, + id, + ) + .await + } + MTreeNode::Leaf(n) => { + self.delete_node_leaf(store, node.id, node.key, n, parent_center, object, id).await + } + } + } + + #[allow(clippy::too_many_arguments)] + async fn delete_node_internal( + &mut self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + node_id: NodeId, + node_key: Key, + mut internal_node: InternalNode, + _parent_center: &Option>, + object: Arc, + id: DocId, + ) -> Result { + let mut node_update = false; + let mut child_idx = None; + let mut child_radius = 0.0; + for (i, e) in internal_node.iter().enumerate() { + child_radius = self.calculate_distance(e.center.as_ref(), &object); + if child_radius <= e.radius { + child_idx = Some(i); + break; + } + } + let mut res = DeletionResult::NotFound; + if let Some(child_idx) = child_idx { + let child_entry = &mut internal_node[child_idx]; + let child_id = child_entry.node; + let child_center = child_entry.center.clone(); + let child_node = store.get_node(tx, child_id).await?; + match self + .delete_at_node(tx, store, child_node, &Some(child_center.clone()), object, id) + .await? + { + DeletionResult::NotFound => res = DeletionResult::NotFound, + DeletionResult::DocRemoved => res = DeletionResult::DocRemoved, + DeletionResult::CoveringRadius(r) => { + if r > child_radius { + internal_node[child_idx].radius = r; + node_update = true; + } + } + DeletionResult::Underflown(child_node_id, child_node_key, child_node) => { + if self + .deletion_underflown( + tx, + store, + &mut internal_node, + child_center.as_ref(), + child_node_id, + child_node_key, + child_node, + ) + .await? + { + node_update = true; + } + } + } + if internal_node.len() < self.minimum { + return Ok(DeletionResult::Underflown( + node_id, + node_key, + MTreeNode::Internal(internal_node), + )); + } + } + let sn = StoredNode::new(MTreeNode::Internal(internal_node), node_id, node_key, 0); + store.set_node(sn, node_update)?; + Ok(res) + } + + #[allow(unused_variables, unused_assignments, clippy::too_many_arguments)] + async fn deletion_underflown( + &mut self, + tx: &mut Transaction, + store: &mut MTreeNodeStore, + node: &mut InternalNode, + other_center: &Vector, + underflow_child_id: NodeId, + underflow_child_key: Key, + underflow_child_node: MTreeNode, + ) -> Result { + let min = f64::NAN; + let mut s_child_idx = None; + // Find node entre Onn € N, e <> 0, for which d(On, Onn) is a minimum + for (i, e) in node.iter().enumerate() { + if e.node != underflow_child_id { + let d = self.calculate_distance(other_center, e.center.as_ref()); + if min.is_nan() || d < min { + s_child_idx = Some(i); + } + } + } + if let Some(s_child_idx) = s_child_idx { + let mut node_updated = false; + let s_child_entry = &node[s_child_idx]; + let s_child_center = s_child_entry.center.clone(); + let s_child_node = store.get_node(tx, s_child_entry.node).await?; + if s_child_node.n.len() + underflow_child_node.len() <= self.state.capacity as usize { + node.remove(s_child_idx); + node_updated = true; + // + match underflow_child_node { + MTreeNode::Internal(n) => { + let mut s_child_node = s_child_node.n.internal()?; + for e in n { + let parent_dist = + self.calculate_distance(e.center.as_ref(), s_child_center.as_ref()); + s_child_node.push(e); + } + //TODO + return Err(Error::FeatureNotYetImplemented { + feature: "MTREE deletions".to_string(), + }); + } + MTreeNode::Leaf(n) => { + let mut s_child_node = s_child_node.n.leaf()?; + for (o, mut p) in n { + p.parent_dist = + self.calculate_distance(o.as_ref(), s_child_center.as_ref()); + s_child_node.insert(o, p); + } + //TODO + return Err(Error::FeatureNotYetImplemented { + feature: "MTREE deletions".to_string(), + }); + } + } + } else { + todo!() + } + #[allow(unreachable_code)] + Ok(node_updated) + } else { + Ok(false) + } + } + + #[allow(clippy::too_many_arguments)] + async fn delete_node_leaf( + &mut self, + store: &mut MTreeNodeStore, + node_id: NodeId, + node_key: Key, + mut leaf_node: LeafNode, + parent_center: &Option>, + object: Arc, + id: DocId, + ) -> Result { + let mut doc_removed = false; + let mut entry_removed = false; + if let Entry::Occupied(mut e) = leaf_node.entry(object) { + let p = e.get_mut(); + if p.docs.remove(id) { + doc_removed = true; + if p.docs.is_empty() { + e.remove(); + entry_removed = true; + } + } + } + if entry_removed && leaf_node.len() < self.minimum && self.state.root != Some(node_id) { + return Ok(DeletionResult::Underflown(node_id, node_key, MTreeNode::Leaf(leaf_node))); + } + if doc_removed { + let sn = StoredNode::new(MTreeNode::Leaf(leaf_node), node_id, node_key, 0); + store.set_node(sn, true)?; + return Ok(DeletionResult::DocRemoved); + } + let max_dist = self.compute_leaf_max_distance(&leaf_node, parent_center); + let sn = StoredNode::new(MTreeNode::Leaf(leaf_node), node_id, node_key, 0); + store.set_node(sn, false)?; + Ok(DeletionResult::CoveringRadius(max_dist)) + } + + async fn finish(&self, tx: &mut Transaction, key: Key) -> Result<(), Error> { + if self.updated { + tx.set(key, self.state.try_to_val()?).await?; + } + Ok(()) + } +} + +#[derive(PartialEq)] +struct PriorityNode(f64, NodeId); + +impl Eq for PriorityNode {} + +fn partial_cmp_f64(a: f64, b: f64) -> Option { + let a = if a.is_nan() { + f64::NEG_INFINITY + } else { + a + }; + let b = if b.is_nan() { + f64::NEG_INFINITY + } else { + b + }; + a.partial_cmp(&b) +} + +impl PartialOrd for PriorityNode { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PriorityNode { + fn cmp(&self, other: &Self) -> Ordering { + match partial_cmp_f64(self.0, other.0).unwrap_or(Ordering::Equal) { + Ordering::Equal => self.1.cmp(&other.1), + other => other, + } + } +} + +#[derive(PartialEq)] +struct PriorityResult(f64, Arc); + +impl Eq for PriorityResult {} + +impl PartialOrd for PriorityResult { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PriorityResult { + fn cmp(&self, other: &Self) -> Ordering { + match partial_cmp_f64(self.0, other.0).unwrap_or(Ordering::Equal) { + Ordering::Equal => self.1.cmp(&other.1), + other => other, + } + } +} + +#[derive(Debug)] +enum MTreeNode { + Internal(InternalNode), + Leaf(LeafNode), +} + +impl MTreeNode { + fn len(&self) -> usize { + match self { + MTreeNode::Internal(e) => e.len(), + MTreeNode::Leaf(m) => m.len(), + } + } + + fn internal(self) -> Result { + match self { + MTreeNode::Internal(n) => Ok(n), + MTreeNode::Leaf(_) => Err(Error::Unreachable), + } + } + + fn leaf(self) -> Result { + match self { + MTreeNode::Internal(_) => Err(Error::Unreachable), + MTreeNode::Leaf(n) => Ok(n), + } + } +} +trait NodeVectors: Sized { + fn len(&self) -> usize; + fn get_vector(&self, i: usize) -> Result, Error>; + + fn distribute_entries( + self, + distances: &[Vec], + p1: usize, + p2: usize, + ) -> Result<(Self, f64, Self, f64), Error>; + + fn into_mtree_node(self) -> MTreeNode; +} + +impl NodeVectors for LeafNode { + fn len(&self) -> usize { + self.len() + } + + fn get_vector(&self, i: usize) -> Result, Error> { + self.get_index(i).ok_or(Error::Unreachable).map(|(v, _)| v.clone()) + } + + fn distribute_entries( + mut self, + distances: &[Vec], + p1: usize, + p2: usize, + ) -> Result<(Self, f64, Self, f64), Error> { + let mut leaf1 = LeafNode::new(); + let mut leaf2 = LeafNode::new(); + let (mut r1, mut r2) = (0f64, 0f64); + for (i, (v, mut p)) in self.drain(..).enumerate() { + let dist_p1 = distances[i][p1]; + let dist_p2 = distances[i][p2]; + if dist_p1 <= dist_p2 { + p.parent_dist = dist_p1; + leaf1.insert(v, p); + if dist_p1 > r1 { + r1 = dist_p1; + } + } else { + p.parent_dist = dist_p2; + leaf2.insert(v, p); + if dist_p2 > r2 { + r2 = dist_p2; + } + } + } + Ok((leaf1, r1, leaf2, r2)) + } + + fn into_mtree_node(self) -> MTreeNode { + MTreeNode::Leaf(self) + } +} + +impl NodeVectors for InternalNode { + fn len(&self) -> usize { + self.len() + } + + fn get_vector(&self, i: usize) -> Result, Error> { + self.get(i).ok_or(Error::Unreachable).map(|e| e.center.clone()) + } + + fn distribute_entries( + self, + distances: &[Vec], + p1: usize, + p2: usize, + ) -> Result<(Self, f64, Self, f64), Error> { + let mut internal1 = InternalNode::new(); + let mut internal2 = InternalNode::new(); + let (mut r1, mut r2) = (0f64, 0f64); + for (i, r) in self.into_iter().enumerate() { + let dist_p1 = distances[i][p1]; + let dist_p2 = distances[i][p2]; + if dist_p1 <= dist_p2 { + internal1.push(r); + if dist_p1 > r1 { + r1 = dist_p1; + } + } else { + internal2.push(r); + if dist_p2 > r2 { + r2 = dist_p2; + } + } + } + Ok((internal1, r1, internal2, r2)) + } + + fn into_mtree_node(self) -> MTreeNode { + MTreeNode::Internal(self) + } +} + +type InternalNode = Vec; +type LeafNode = LeafIndexMap; + +impl TreeNode for MTreeNode { + fn try_from_val(val: Val) -> Result { + let mut c: Cursor> = Cursor::new(val); + let node_type: u8 = bincode::deserialize_from(&mut c)?; + match node_type { + 1u8 => { + let objects: IndexMap, ObjectProperties> = + bincode::deserialize_from(c)?; + Ok(MTreeNode::Leaf(objects)) + } + 2u8 => { + let entries: Vec = bincode::deserialize_from(c)?; + Ok(MTreeNode::Internal(entries)) + } + _ => Err(Error::CorruptedIndex), + } + } + + fn try_into_val(&mut self) -> Result { + let mut c: Cursor> = Cursor::new(Vec::new()); + match self { + MTreeNode::Leaf(objects) => { + bincode::serialize_into(&mut c, &1u8)?; + bincode::serialize_into(&mut c, objects)?; + } + MTreeNode::Internal(entries) => { + bincode::serialize_into(&mut c, &2u8)?; + bincode::serialize_into(&mut c, entries)?; + } + }; + Ok(c.into_inner()) + } +} + +pub(crate) struct MtStatistics { + doc_ids: BStatistics, +} + +impl From for Value { + fn from(stats: MtStatistics) -> Self { + let mut res = Object::default(); + res.insert("doc_ids".to_owned(), Value::from(stats.doc_ids)); + Value::from(res) + } +} + +#[derive(Clone, Serialize, Deserialize)] +#[revisioned(revision = 1)] +struct MState { + capacity: u16, + root: Option, + next_node_id: NodeId, +} + +impl MState { + pub fn new(capacity: u16) -> Self { + assert!(capacity >= 2, "Capacity should be >= 2"); + Self { + capacity, + root: None, + next_node_id: 0, + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub(in crate::idx) struct RoutingEntry { + // Reference to the node + node: NodeId, + // Center of the node + center: Arc, + // Covering radius + radius: f64, +} + +#[derive(Serialize, Deserialize, Debug)] +pub(in crate::idx) struct ObjectProperties { + // Distance to its parent object + parent_dist: f64, + // The documents pointing to this vector + docs: RoaringTreemap, +} + +impl ObjectProperties { + fn new(parent_dist: f64, id: DocId) -> Self { + let mut docs = RoaringTreemap::new(); + docs.insert(id); + Self { + parent_dist, + docs, + } + } + + fn new_root(id: DocId) -> Self { + Self::new(0.0, id) + } +} + +impl VersionedSerdeState for MState {} + +#[cfg(test)] +mod tests { + use crate::idx::docids::DocId; + use crate::idx::trees::mtree::{ + MState, MTree, MTreeNode, MTreeNodeStore, ObjectProperties, RoutingEntry, Vector, + }; + use crate::idx::trees::store::{NodeId, TreeNodeProvider, TreeNodeStore, TreeStoreType}; + use crate::kvs::Datastore; + use crate::kvs::Transaction; + use crate::sql::index::Distance; + use indexmap::IndexMap; + use roaring::RoaringTreemap; + use std::collections::VecDeque; + use std::sync::Arc; + use test_log::test; + use tokio::sync::{Mutex, MutexGuard}; + + async fn new_operation( + ds: &Datastore, + t: TreeStoreType, + ) -> (Arc>>, Transaction) { + let s = TreeNodeStore::new(TreeNodeProvider::Debug, t, 20); + let tx = ds.transaction(t == TreeStoreType::Write, false).await.unwrap(); + (s, tx) + } + + async fn finish_operation( + mut tx: Transaction, + mut s: MutexGuard<'_, TreeNodeStore>, + commit: bool, + ) { + s.finish(&mut tx).await.unwrap(); + if commit { + tx.commit().await.unwrap(); + } else { + tx.cancel().await.unwrap(); + } + } + + #[test(tokio::test)] + async fn test_mtree_insertions() { + let mut t = MTree::new(MState::new(3), Distance::Euclidean); + let ds = Datastore::new("memory").await.unwrap(); + + let vec1 = vec![1.into()]; + // First the index is empty + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 0); + } + // Insert single element + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec1.clone(), 1).await.unwrap(); + assert_eq!(t.state.root, Some(0)); + check_leaf(&mut tx, &mut s, 0, |m| { + assert_eq!(m.len(), 1); + check_leaf_vec(m, 0, &vec1, 0.0, &[1]); + }) + .await; + finish_operation(tx, s, true).await; + } + // Check KNN + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![1]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 1); + check_tree_properties(&mut tx, &mut s, &t, 1, 1, Some(1), Some(1)).await; + } + + // insert second element + let vec2 = vec![2.into()]; + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec2.clone(), 2).await.unwrap(); + finish_operation(tx, s, true).await; + } + // vec1 knn + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![1], vec![2]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 1); + assert_eq!(t.state.root, Some(0)); + check_leaf(&mut tx, &mut s, 0, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec1, 0.0, &[1]); + check_leaf_vec(m, 1, &vec2, 0.0, &[2]); + }) + .await; + check_tree_properties(&mut tx, &mut s, &t, 1, 1, Some(2), Some(2)).await; + } + // vec2 knn + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec2, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![2], vec![1]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 1); + } + + // insert new doc to existing vector + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec2.clone(), 3).await.unwrap(); + finish_operation(tx, s, true).await; + } + // vec2 knn + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec2, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![2, 3], vec![1]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 1); + assert_eq!(t.state.root, Some(0)); + check_leaf(&mut tx, &mut s, 0, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec1, 0.0, &[1]); + check_leaf_vec(m, 1, &vec2, 0.0, &[2, 3]); + }) + .await; + check_tree_properties(&mut tx, &mut s, &t, 1, 1, Some(2), Some(2)).await; + } + + // insert third vector + let vec3 = vec![3.into()]; + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec3.clone(), 3).await.unwrap(); + finish_operation(tx, s, true).await; + } + // vec3 knn + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec3, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![3], vec![2, 3], vec![1]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 1); + assert_eq!(t.state.root, Some(0)); + check_leaf(&mut tx, &mut s, 0, |m| { + assert_eq!(m.len(), 3); + check_leaf_vec(m, 0, &vec1, 0.0, &[1]); + check_leaf_vec(m, 1, &vec2, 0.0, &[2, 3]); + check_leaf_vec(m, 2, &vec3, 0.0, &[3]); + }) + .await; + check_tree_properties(&mut tx, &mut s, &t, 1, 1, Some(3), Some(3)).await; + } + + // Check split leaf node + let vec4 = vec![4.into()]; + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec4.clone(), 4).await.unwrap(); + finish_operation(tx, s, true).await; + } + // vec4 knn + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec4, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![4], vec![3], vec![2, 3], vec![1]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 3); + assert_eq!(t.state.root, Some(2)); + check_internal(&mut tx, &mut s, 2, |m| { + assert_eq!(m.len(), 2); + check_routing_vec(m, 0, &vec1, 0, 1.0); + check_routing_vec(m, 1, &vec4, 1, 1.0); + }) + .await; + check_leaf(&mut tx, &mut s, 0, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec1, 0.0, &[1]); + check_leaf_vec(m, 1, &vec2, 1.0, &[2, 3]); + }) + .await; + check_leaf(&mut tx, &mut s, 1, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec3, 1.0, &[3]); + check_leaf_vec(m, 1, &vec4, 0.0, &[4]); + }) + .await; + check_tree_properties(&mut tx, &mut s, &t, 3, 2, Some(2), Some(2)).await; + } + + // Insert vec extending the radius of the last node, calling compute_leaf_radius + let vec6 = vec![6.into()]; + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec6.clone(), 6).await.unwrap(); + finish_operation(tx, s, true).await; + } + // vec6 knn + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec6, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![6], vec![4], vec![3], vec![2, 3], vec![1]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 3); + assert_eq!(t.state.root, Some(2)); + check_internal(&mut tx, &mut s, 2, |m| { + assert_eq!(m.len(), 2); + check_routing_vec(m, 0, &vec1, 0, 1.0); + check_routing_vec(m, 1, &vec4, 1, 2.0); + }) + .await; + check_leaf(&mut tx, &mut s, 0, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec1, 0.0, &[1]); + check_leaf_vec(m, 1, &vec2, 1.0, &[2, 3]); + }) + .await; + check_leaf(&mut tx, &mut s, 1, |m| { + assert_eq!(m.len(), 3); + check_leaf_vec(m, 0, &vec3, 1.0, &[3]); + check_leaf_vec(m, 1, &vec4, 0.0, &[4]); + check_leaf_vec(m, 2, &vec6, 2.0, &[6]); + }) + .await; + check_tree_properties(&mut tx, &mut s, &t, 3, 2, Some(2), Some(3)).await; + } + + // Insert check split internal node + let vec8 = vec![8.into()]; + let vec9 = vec![9.into()]; + let vec10 = vec![10.into()]; + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec8.clone(), 8).await.unwrap(); + t.insert(&mut tx, &mut s, vec9.clone(), 9).await.unwrap(); + t.insert(&mut tx, &mut s, vec10.clone(), 10).await.unwrap(); + finish_operation(tx, s, true).await; + } + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + check_tree_properties(&mut tx, &mut s, &t, 7, 3, Some(2), Some(2)).await; + assert_eq!(t.state.root, Some(6)); + // Check Root node (level 1) + check_internal(&mut tx, &mut s, 6, |m| { + assert_eq!(m.len(), 2); + check_routing_vec(m, 0, &vec1, 2, 2.0); + check_routing_vec(m, 1, &vec10, 5, 4.0); + }) + .await; + // Check level 2 + check_internal(&mut tx, &mut s, 2, |m| { + assert_eq!(m.len(), 2); + check_routing_vec(m, 0, &vec1, 0, 1.0); + check_routing_vec(m, 1, &vec3, 1, 1.0); + }) + .await; + check_internal(&mut tx, &mut s, 5, |m| { + assert_eq!(m.len(), 2); + check_routing_vec(m, 0, &vec6, 3, 2.0); + check_routing_vec(m, 1, &vec10, 4, 1.0); + }) + .await; + // Check level 3 + check_leaf(&mut tx, &mut s, 0, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec1, 0.0, &[1]); + check_leaf_vec(m, 1, &vec2, 1.0, &[2, 3]); + }) + .await; + check_leaf(&mut tx, &mut s, 1, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec3, 0.0, &[3]); + check_leaf_vec(m, 1, &vec4, 1.0, &[4]); + }) + .await; + check_leaf(&mut tx, &mut s, 3, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec6, 0.0, &[6]); + check_leaf_vec(m, 1, &vec8, 2.0, &[8]); + }) + .await; + check_leaf(&mut tx, &mut s, 4, |m| { + assert_eq!(m.len(), 2); + check_leaf_vec(m, 0, &vec9, 1.0, &[9]); + check_leaf_vec(m, 1, &vec10, 0.0, &[10]); + }) + .await; + } + // vec8 knn + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec8, 20).await.unwrap(); + check_knn( + &res.objects, + vec![vec![8], vec![9], vec![6], vec![10], vec![4], vec![3], vec![2, 3], vec![1]], + ); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 7); + } + // vec4 knn(2) + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec4, 2).await.unwrap(); + check_knn(&res.objects, vec![vec![4], vec![3]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 7); + } + + // vec10 knn(2) + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Read).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec10, 2).await.unwrap(); + check_knn(&res.objects, vec![vec![10], vec![9]]); + #[cfg(debug_assertions)] + assert_eq!(res.visited_nodes, 7); + } + } + + #[test(tokio::test)] + #[ignore] + async fn test_mtree_deletion_doc_removed_and_none() { + let ds = Datastore::new("memory").await.unwrap(); + + let mut t = MTree::new(MState::new(4), Distance::Euclidean); + + let vec1 = vec![1.into()]; + let vec2 = vec![2.into()]; + + // Create the tree with vec1 and vec2 having two documents + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec1.clone(), 10).await.unwrap(); + t.insert(&mut tx, &mut s, vec2.clone(), 20).await.unwrap(); + t.insert(&mut tx, &mut s, vec2.clone(), 21).await.unwrap(); + finish_operation(tx, s, true).await; + } + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![10], vec![20, 21]]); + check_tree_properties(&mut tx, &mut s, &t, 1, 1, Some(2), Some(2)).await; + } + + // Remove the doc 21 + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + assert!(t.delete(&mut tx, &mut s, vec2.clone(), 21).await.unwrap()); + finish_operation(tx, s, true).await; + } + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![10], vec![20]]); + check_tree_properties(&mut tx, &mut s, &t, 1, 1, Some(2), Some(2)).await; + } + + // Remove again vec2 / 21 => Deletion::None + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + assert!(!t.delete(&mut tx, &mut s, vec2.clone(), 21).await.unwrap()); + assert!(!t.delete(&mut tx, &mut s, vec2.clone(), 21).await.unwrap()); + finish_operation(tx, s, true).await; + } + + let vec3 = vec![3.into()]; + let vec4 = vec![4.into()]; + let vec5 = vec![5.into()]; + + // Add vec3, vec4 and vec5 having two documents + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, vec3.clone(), 30).await.unwrap(); + t.insert(&mut tx, &mut s, vec4.clone(), 40).await.unwrap(); + t.insert(&mut tx, &mut s, vec5.clone(), 50).await.unwrap(); + t.insert(&mut tx, &mut s, vec5.clone(), 51).await.unwrap(); + finish_operation(tx, s, true).await; + } + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![10], vec![20], vec![30], vec![40], vec![51]]); + check_tree_properties(&mut tx, &mut s, &t, 3, 2, Some(2), Some(3)).await; + } + + // Remove the doc 51 + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + assert!(t.delete(&mut tx, &mut s, vec5.clone(), 51).await.unwrap()); + finish_operation(tx, s, true).await; + } + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![10], vec![20], vec![30], vec![40], vec![50]]); + check_tree_properties(&mut tx, &mut s, &t, 3, 2, Some(2), Some(3)).await; + } + + // Remove again vec5 / 51 => Deletion::None + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + assert!(!t.delete(&mut tx, &mut s, vec5.clone(), 51).await.unwrap()); + assert!(!t.delete(&mut tx, &mut s, vec5.clone(), 51).await.unwrap()); + finish_operation(tx, s, true).await; + } + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![10], vec![20], vec![30], vec![40], vec![50]]); + check_tree_properties(&mut tx, &mut s, &t, 3, 2, Some(2), Some(3)).await; + } + + // Remove vec5 / 50 => DeleteResult::UnderflownLeafIndexMap + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + assert!(!t.delete(&mut tx, &mut s, vec5.clone(), 50).await.unwrap()); + finish_operation(tx, s, true).await; + } + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &vec1, 10).await.unwrap(); + check_knn(&res.objects, vec![vec![10], vec![20], vec![30], vec![40]]); + check_tree_properties(&mut tx, &mut s, &t, 3, 2, Some(2), Some(3)).await; + } + } + + #[test(tokio::test)] + async fn test_mtree_deletions_merge_routing_node() { + let ds = Datastore::new("memory").await.unwrap(); + + let mut t = MTree::new(MState::new(4), Distance::Euclidean); + + let v0 = vec![0.into()]; + let v1 = vec![1.into()]; + let v2 = vec![2.into()]; + let v3 = vec![3.into()]; + let v4 = vec![4.into()]; + let v5 = vec![5.into()]; + let v6 = vec![6.into()]; + let v7 = vec![7.into()]; + let v8 = vec![8.into()]; + let v9 = vec![9.into()]; + let v10 = vec![10.into()]; + let v11 = vec![11.into()]; + let v12 = vec![12.into()]; + let v13 = vec![13.into()]; + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + t.insert(&mut tx, &mut s, v9.clone(), 90).await.unwrap(); + t.insert(&mut tx, &mut s, v10.clone(), 100).await.unwrap(); + t.insert(&mut tx, &mut s, v11.clone(), 110).await.unwrap(); + t.insert(&mut tx, &mut s, v12.clone(), 120).await.unwrap(); + t.insert(&mut tx, &mut s, v13.clone(), 130).await.unwrap(); + t.insert(&mut tx, &mut s, v1.clone(), 10).await.unwrap(); + t.insert(&mut tx, &mut s, v2.clone(), 20).await.unwrap(); + t.insert(&mut tx, &mut s, v3.clone(), 30).await.unwrap(); + t.insert(&mut tx, &mut s, v4.clone(), 40).await.unwrap(); + t.insert(&mut tx, &mut s, v5.clone(), 50).await.unwrap(); + t.insert(&mut tx, &mut s, v6.clone(), 60).await.unwrap(); + t.insert(&mut tx, &mut s, v7.clone(), 70).await.unwrap(); + t.insert(&mut tx, &mut s, v8.clone(), 80).await.unwrap(); + finish_operation(tx, s, true).await; + } + + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Traversal).await; + let mut s = s.lock().await; + let res = t.knn_search(&mut tx, &mut s, &v0, 20).await.unwrap(); + check_knn( + &res.objects, + vec![ + vec![10], + vec![20], + vec![30], + vec![40], + vec![50], + vec![60], + vec![70], + vec![80], + vec![90], + vec![100], + vec![110], + vec![120], + vec![130], + ], + ); + check_tree_properties(&mut tx, &mut s, &t, 8, 3, Some(2), Some(3)).await; + } + + // Remove -> + { + let (s, mut tx) = new_operation(&ds, TreeStoreType::Write).await; + let mut s = s.lock().await; + assert!(t.delete(&mut tx, &mut s, v9.clone(), 90).await.unwrap()); + finish_operation(tx, s, true).await; + } + } + + fn check_leaf_vec( + m: &IndexMap, ObjectProperties>, + idx: usize, + vec: &Vector, + parent_dist: f64, + docs: &[DocId], + ) { + let (v, p) = m.get_index(idx).unwrap(); + assert_eq!(v.as_ref(), vec); + assert_eq!(p.docs.len(), docs.len() as u64); + for doc in docs { + assert!(p.docs.contains(*doc)); + } + assert_eq!(p.parent_dist, parent_dist); + } + + fn check_routing_vec( + m: &Vec, + idx: usize, + center: &Vector, + node_id: NodeId, + radius: f64, + ) { + let p = &m[idx]; + assert_eq!(center, p.center.as_ref()); + assert_eq!(node_id, p.node); + assert_eq!(radius, p.radius); + } + + async fn check_node( + tx: &mut Transaction, + s: &mut MTreeNodeStore, + node_id: NodeId, + check_func: F, + ) where + F: FnOnce(&MTreeNode), + { + let n = s.get_node(tx, node_id).await.unwrap(); + check_func(&n.n); + s.set_node(n, false).unwrap(); + } + + async fn check_leaf( + tx: &mut Transaction, + s: &mut MTreeNodeStore, + node_id: NodeId, + check_func: F, + ) where + F: FnOnce(&IndexMap, ObjectProperties>), + { + check_node(tx, s, node_id, |n| { + if let MTreeNode::Leaf(m) = n { + check_func(m); + } else { + panic!("The node is not a leaf node: {node_id}") + } + }) + .await + } + + async fn check_internal( + tx: &mut Transaction, + s: &mut MTreeNodeStore, + node_id: NodeId, + check_func: F, + ) where + F: FnOnce(&Vec), + { + check_node(tx, s, node_id, |n| { + if let MTreeNode::Internal(m) = n { + check_func(m); + } else { + panic!("The node is not a routing node: {node_id}") + } + }) + .await + } + + fn check_knn(res: &VecDeque, expected: Vec>) { + assert_eq!(res.len(), expected.len(), "{:?}", res); + for (i, (a, b)) in res.iter().zip(expected.iter()).enumerate() { + for id in b { + assert!(a.contains(*id), "{}: {}", i, id); + } + } + } + + async fn check_tree_properties( + tx: &mut Transaction, + s: &mut MTreeNodeStore, + t: &MTree, + expected_node_count: usize, + expected_depth: usize, + expected_min_objects: Option, + expected_max_objects: Option, + ) { + println!("CheckTreeProperties"); + let mut node_count = 0; + let mut max_depth = 0; + let mut min_leaf_depth = None; + let mut max_leaf_depth = None; + let mut min_objects = None; + let mut max_objects = None; + let mut nodes = VecDeque::new(); + if let Some(root_id) = t.state.root { + nodes.push_back((root_id, 1)); + } + while let Some((node_id, depth)) = nodes.pop_front() { + node_count += 1; + if depth > max_depth { + max_depth = depth; + } + let node = s.get_node(tx, node_id).await.unwrap(); + println!( + "Node id: {} - depth: {} - len: {} - {:?}", + node.id, + depth, + node.n.len(), + node.n + ); + match node.n { + MTreeNode::Internal(entries) => { + let next_depth = depth + 1; + entries.iter().for_each(|p| nodes.push_back((p.node, next_depth))); + } + MTreeNode::Leaf(m) => { + update_min(&mut min_objects, m.len()); + update_max(&mut max_objects, m.len()); + update_min(&mut min_leaf_depth, depth); + update_max(&mut max_leaf_depth, depth); + } + } + } + assert_eq!(node_count, expected_node_count, "Node count"); + assert_eq!(max_depth, expected_depth, "Max depth"); + assert_eq!(min_leaf_depth, Some(expected_depth), "Min leaf depth"); + assert_eq!(max_leaf_depth, Some(expected_depth), "Max leaf depth"); + assert_eq!(min_objects, expected_min_objects, "Min objects"); + assert_eq!(max_objects, expected_max_objects, "Max objects"); + } + + fn update_min(min: &mut Option, val: usize) { + if let Some(m) = *min { + if val < m { + *min = Some(val); + } + } else { + *min = Some(val); + } + } + + fn update_max(max: &mut Option, val: usize) { + if let Some(m) = *max { + if val > m { + *max = Some(val); + } + } else { + *max = Some(val); + } + } +} diff --git a/lib/src/idx/trees/store.rs b/lib/src/idx/trees/store.rs index 911bef7c..ff2be872 100644 --- a/lib/src/idx/trees/store.rs +++ b/lib/src/idx/trees/store.rs @@ -9,7 +9,7 @@ use tokio::sync::Mutex; pub type NodeId = u64; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq)] pub enum TreeStoreType { Write, Read, @@ -151,7 +151,7 @@ where #[cfg(debug_assertions)] self.out.insert(id); StoredNode { - node, + n: node, id, key: self.np.get_key(id), size: 0, @@ -238,6 +238,7 @@ pub enum TreeNodeProvider { DocLengths(IndexKeyBase), Postings(IndexKeyBase), Terms(IndexKeyBase), + Vector(IndexKeyBase), Debug, } @@ -248,6 +249,7 @@ impl TreeNodeProvider { TreeNodeProvider::DocLengths(ikb) => ikb.new_bl_key(Some(node_id)), TreeNodeProvider::Postings(ikb) => ikb.new_bp_key(Some(node_id)), TreeNodeProvider::Terms(ikb) => ikb.new_bt_key(Some(node_id)), + TreeNodeProvider::Vector(ikb) => ikb.new_vm_key(Some(node_id)), TreeNodeProvider::Debug => node_id.to_be_bytes().to_vec(), } } @@ -261,7 +263,7 @@ impl TreeNodeProvider { let size = val.len() as u32; let node = N::try_from_val(val)?; Ok(StoredNode { - node, + n: node, id, key, size, @@ -275,19 +277,30 @@ impl TreeNodeProvider { where N: TreeNode, { - let val = node.node.try_into_val()?; + let val = node.n.try_into_val()?; tx.set(node.key, val).await?; Ok(()) } } pub(super) struct StoredNode { - pub(super) node: N, + pub(super) n: N, pub(super) id: NodeId, pub(super) key: Key, pub(super) size: u32, } +impl StoredNode { + pub(super) fn new(n: N, id: NodeId, key: Key, size: u32) -> Self { + Self { + n, + id, + key, + size, + } + } +} + pub trait TreeNode where Self: Sized, diff --git a/lib/src/key/index/bf.rs b/lib/src/key/index/bf.rs index d9a66858..1fa841d9 100644 --- a/lib/src/key/index/bf.rs +++ b/lib/src/key/index/bf.rs @@ -1,5 +1,5 @@ //! Stores Term/Doc frequency -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::ft::terms::TermId; use derive::Key; use serde::{Deserialize, Serialize}; diff --git a/lib/src/key/index/bk.rs b/lib/src/key/index/bk.rs index bb3540d1..bcfe7bd9 100644 --- a/lib/src/key/index/bk.rs +++ b/lib/src/key/index/bk.rs @@ -1,5 +1,5 @@ //! Stores the term list for doc_ids -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use derive::Key; use serde::{Deserialize, Serialize}; diff --git a/lib/src/key/index/bo.rs b/lib/src/key/index/bo.rs index 68d1376e..739d3257 100644 --- a/lib/src/key/index/bo.rs +++ b/lib/src/key/index/bo.rs @@ -1,5 +1,5 @@ //! Stores the offsets -use crate::idx::ft::docids::DocId; +use crate::idx::docids::DocId; use crate::idx::ft::terms::TermId; use derive::Key; use serde::{Deserialize, Serialize}; diff --git a/lib/src/key/index/mod.rs b/lib/src/key/index/mod.rs index ab8da7d8..b3f7901c 100644 --- a/lib/src/key/index/mod.rs +++ b/lib/src/key/index/mod.rs @@ -11,6 +11,7 @@ pub mod bp; pub mod bs; pub mod bt; pub mod bu; +pub mod vm; use crate::sql::array::Array; use crate::sql::id::Id; diff --git a/lib/src/key/index/vm.rs b/lib/src/key/index/vm.rs new file mode 100644 index 00000000..c9d3c04d --- /dev/null +++ b/lib/src/key/index/vm.rs @@ -0,0 +1,68 @@ +//! Stores MTree state and nodes +use crate::idx::trees::store::NodeId; +use derive::Key; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Key)] +pub struct Vm<'a> { + __: u8, + _a: u8, + pub ns: &'a str, + _b: u8, + pub db: &'a str, + _c: u8, + pub tb: &'a str, + _d: u8, + pub ix: &'a str, + _e: u8, + _f: u8, + _g: u8, + pub node_id: Option, +} + +impl<'a> Vm<'a> { + pub fn new( + ns: &'a str, + db: &'a str, + tb: &'a str, + ix: &'a str, + node_id: Option, + ) -> Self { + Self { + __: b'/', + _a: b'*', + ns, + _b: b'*', + db, + _c: b'*', + tb, + _d: b'+', + ix, + _e: b'!', + _f: b'v', + _g: b'm', + node_id, + } + } +} + +#[cfg(test)] +mod tests { + #[test] + fn key() { + use super::*; + #[rustfmt::skip] + let val = Vm::new( + "testns", + "testdb", + "testtb", + "testix", + Some(8) + ); + let enc = Vm::encode(&val).unwrap(); + assert_eq!(enc, b"/*testns\0*testdb\0*testtb\0+testix\0!vm\x01\0\0\0\0\0\0\0\x08"); + + let dec = Vm::decode(&enc).unwrap(); + assert_eq!(val, dec); + } +} diff --git a/lib/src/kvs/ds.rs b/lib/src/kvs/ds.rs index 991e4472..dd27a787 100644 --- a/lib/src/kvs/ds.rs +++ b/lib/src/kvs/ds.rs @@ -316,6 +316,9 @@ impl Datastore { } /// Setup the initial credentials + /// Trigger the `unreachable definition` compilation error, probably due to this issue: + /// https://github.com/rust-lang/rust/issues/111370 + #[allow(unreachable_code, unused_variables)] pub async fn setup_initial_creds(&self, creds: Root<'_>) -> Result<(), Error> { // Start a new writeable transaction let txn = self.transaction(true, false).await?.rollback_with_panic().enclose(); diff --git a/lib/src/sql/expression.rs b/lib/src/sql/expression.rs index 4d3c60ac..f302adfd 100644 --- a/lib/src/sql/expression.rs +++ b/lib/src/sql/expression.rs @@ -191,6 +191,7 @@ impl Expression { Operator::Outside => fnc::operate::outside(&l, &r), Operator::Intersects => fnc::operate::intersects(&l, &r), Operator::Matches(_) => fnc::operate::matches(ctx, txn, doc, self).await, + Operator::Knn(_) => fnc::operate::knn(ctx, txn, doc, self).await, _ => unreachable!(), } } diff --git a/lib/src/sql/index.rs b/lib/src/sql/index.rs index 32232471..f6bf420e 100644 --- a/lib/src/sql/index.rs +++ b/lib/src/sql/index.rs @@ -49,7 +49,7 @@ pub struct MTreeParams { pub doc_ids_order: u32, } -#[derive(Default, Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] +#[derive(Clone, Default, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] #[revisioned(revision = 1)] pub enum Distance { #[default] @@ -182,6 +182,7 @@ pub fn search(i: &str) -> IResult<&str, Index> { pub fn distance(i: &str) -> IResult<&str, Distance> { let (i, _) = mightbespace(i)?; let (i, _) = tag_no_case("DIST")(i)?; + let (i, _) = shouldbespace(i)?; alt(( map(tag_no_case("EUCLIDEAN"), |_| Distance::Euclidean), map(tag_no_case("MANHATTAN"), |_| Distance::Manhattan), @@ -200,7 +201,7 @@ pub fn minkowski(i: &str) -> IResult<&str, Distance> { } pub fn dimension(i: &str) -> IResult<&str, u16> { - let (i, _) = shouldbespace(i)?; + let (i, _) = mightbespace(i)?; let (i, _) = tag_no_case("DIMENSION")(i)?; let (i, _) = shouldbespace(i)?; let (i, dim) = uint16(i)?; diff --git a/lib/src/sql/operator.rs b/lib/src/sql/operator.rs index f878b6b1..f7a43d48 100644 --- a/lib/src/sql/operator.rs +++ b/lib/src/sql/operator.rs @@ -6,6 +6,7 @@ use nom::branch::alt; use nom::bytes::complete::tag; use nom::bytes::complete::tag_no_case; use nom::character::complete::char; +use nom::character::complete::u32 as uint32; use nom::character::complete::u8 as uint8; use nom::combinator::cut; use nom::combinator::opt; @@ -67,6 +68,8 @@ pub enum Operator { // Outside, Intersects, + // + Knn(u32), // <{k}> } impl Default for Operator { @@ -141,6 +144,7 @@ impl fmt::Display for Operator { f.write_str("@@") } } + Self::Knn(k) => write!(f, "<{}>", k), } } } @@ -191,12 +195,14 @@ pub fn binary_symbols(i: &str) -> IResult<&str, Operator> { value(Operator::AnyLike, tag("?~")), value(Operator::Like, char('~')), matches, + knn, )), alt(( value(Operator::LessThanOrEqual, tag("<=")), value(Operator::LessThan, char('<')), value(Operator::MoreThanOrEqual, tag(">=")), value(Operator::MoreThan, char('>')), + knn, )), alt(( value(Operator::Pow, tag("**")), @@ -257,7 +263,6 @@ pub fn binary_phrases(i: &str) -> IResult<&str, Operator> { pub fn matches(i: &str) -> IResult<&str, Operator> { let (i, _) = char('@')(i)?; - // let (i, reference) = opt(|i| uint8(i))(i)?; cut(|i| { let (i, reference) = opt(uint8)(i)?; let (i, _) = char('@')(i)?; @@ -265,6 +270,13 @@ pub fn matches(i: &str) -> IResult<&str, Operator> { })(i) } +pub fn knn(i: &str) -> IResult<&str, Operator> { + let (i, _) = char('<')(i)?; + let (i, k) = uint32(i)?; + let (i, _) = char('>')(i)?; + Ok((i, Operator::Knn(k))) +} + #[cfg(test)] mod tests { use super::*; @@ -290,4 +302,13 @@ mod tests { let res = matches("@256@"); res.unwrap_err(); } + + #[test] + fn test_knn() { + let res = knn("<5>"); + assert!(res.is_ok()); + let out = res.unwrap().1; + assert_eq!("<5>", format!("{}", out)); + assert_eq!(out, Operator::Knn(5)); + } } diff --git a/lib/src/sql/statements/analyze.rs b/lib/src/sql/statements/analyze.rs index 9ec66ec1..0003bf62 100644 --- a/lib/src/sql/statements/analyze.rs +++ b/lib/src/sql/statements/analyze.rs @@ -5,6 +5,7 @@ use crate::doc::CursorDoc; use crate::err::Error; use crate::iam::{Action, ResourceKind}; use crate::idx::ft::FtIndex; +use crate::idx::trees::mtree::MTreeIndex; use crate::idx::trees::store::TreeStoreType; use crate::idx::IndexKeyBase; use crate::sql::comment::shouldbespace; @@ -56,6 +57,11 @@ impl AnalyzeStatement { FtIndex::new(&mut run, az, ikb, p, TreeStoreType::Traversal).await?; ft.statistics(&mut run).await?.into() } + Index::MTree(p) => { + let mt = + MTreeIndex::new(&mut run, ikb, p, TreeStoreType::Traversal).await?; + mt.statistics(&mut run).await?.into() + } _ => { return Err(Error::FeatureNotYetImplemented { feature: "Statistics on unique and non-unique indexes.".to_string(), diff --git a/lib/src/sql/statements/define/index.rs b/lib/src/sql/statements/define/index.rs index 50885198..2a0358a7 100644 --- a/lib/src/sql/statements/define/index.rs +++ b/lib/src/sql/statements/define/index.rs @@ -178,7 +178,7 @@ fn index_comment(i: &str) -> IResult<&str, DefineIndexOption> { mod tests { use super::*; - use crate::sql::index::SearchParams; + use crate::sql::index::{Distance, MTreeParams, SearchParams}; use crate::sql::Ident; use crate::sql::Idiom; use crate::sql::Idioms; @@ -275,4 +275,29 @@ mod tests { "DEFINE INDEX my_index ON my_table FIELDS my_col SEARCH ANALYZER my_analyzer VS DOC_IDS_ORDER 100 DOC_LENGTHS_ORDER 100 POSTINGS_ORDER 100 TERMS_ORDER 100" ); } + + #[test] + fn check_create_mtree_index() { + let sql = "INDEX my_index ON TABLE my_table COLUMNS my_col MTREE DIMENSION 4"; + let (_, idx) = index(sql).unwrap(); + assert_eq!( + idx, + DefineIndexStatement { + name: Ident("my_index".to_string()), + what: Ident("my_table".to_string()), + cols: Idioms(vec![Idiom(vec![Part::Field(Ident("my_col".to_string()))])]), + index: Index::MTree(MTreeParams { + dimension: 4, + distance: Distance::Euclidean, + capacity: 40, + doc_ids_order: 100, + }), + comment: None, + } + ); + assert_eq!( + idx.to_string(), + "DEFINE INDEX my_index ON my_table FIELDS my_col MTREE DIMENSION 4 DIST EUCLIDEAN CAPACITY 40 DOC_IDS_ORDER 100" + ); + } } diff --git a/lib/src/sql/value/serde/ser/distance/mod.rs b/lib/src/sql/value/serde/ser/distance/mod.rs index 3751119d..6bdf8b95 100644 --- a/lib/src/sql/value/serde/ser/distance/mod.rs +++ b/lib/src/sql/value/serde/ser/distance/mod.rs @@ -3,6 +3,7 @@ use crate::sql::index::Distance; use crate::sql::value::serde::ser; use serde::ser::Error as _; use serde::ser::Impossible; +use serde::Serialize; pub(super) struct Serializer; @@ -29,9 +30,34 @@ impl ser::Serializer for Serializer { ) -> Result { match variant { "Euclidean" => Ok(Distance::Euclidean), + "Manhattan" => Ok(Distance::Manhattan), + "Cosine" => Ok(Distance::Cosine), + "Hamming" => Ok(Distance::Hamming), + "Mahalanobis" => Ok(Distance::Mahalanobis), variant => Err(Error::custom(format!("unexpected unit variant `{name}::{variant}`"))), } } + + #[inline] + fn serialize_newtype_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + match variant { + "Minkowski" => { + Ok(Distance::Minkowski(value.serialize(ser::number::Serializer.wrap())?)) + } + variant => { + Err(Error::custom(format!("unexpected newtype variant `{name}::{variant}`"))) + } + } + } } #[cfg(test)] @@ -41,9 +67,44 @@ mod tests { use serde::Serialize; #[test] - fn euclidean() { + fn distance_euclidean() { let dist = Distance::Euclidean; let serialized = dist.serialize(Serializer.wrap()).unwrap(); assert_eq!(dist, serialized); } + + #[test] + fn distance_manhattan() { + let dist = Distance::Manhattan; + let serialized = dist.serialize(Serializer.wrap()).unwrap(); + assert_eq!(dist, serialized); + } + + #[test] + fn distance_mahalanobis() { + let dist = Distance::Mahalanobis; + let serialized = dist.serialize(Serializer.wrap()).unwrap(); + assert_eq!(dist, serialized); + } + + #[test] + fn distance_hamming() { + let dist = Distance::Hamming; + let serialized = dist.serialize(Serializer.wrap()).unwrap(); + assert_eq!(dist, serialized); + } + + #[test] + fn distance_cosine() { + let dist = Distance::Cosine; + let serialized = dist.serialize(Serializer.wrap()).unwrap(); + assert_eq!(dist, serialized); + } + + #[test] + fn distance_minkowski() { + let dist = Distance::Minkowski(7.into()); + let serialized = dist.serialize(Serializer.wrap()).unwrap(); + assert_eq!(dist, serialized); + } } diff --git a/lib/tests/changefeeds.rs b/lib/tests/changefeeds.rs index d9260c28..da3fab74 100644 --- a/lib/tests/changefeeds.rs +++ b/lib/tests/changefeeds.rs @@ -366,7 +366,7 @@ async fn changefeed_with_ts() -> Result<(), Error> { let Value::Object(a) = a else { unreachable!() }; - let Value::Number(versionstamp1) = a.get("versionstamp").unwrap() else { + let Value::Number(versionstamp2) = a.get("versionstamp").unwrap() else { unreachable!() }; let changes = a.get("changes").unwrap().to_owned(); @@ -389,10 +389,10 @@ async fn changefeed_with_ts() -> Result<(), Error> { let Value::Object(a) = a else { unreachable!() }; - let Value::Number(versionstamp2) = a.get("versionstamp").unwrap() else { + let Value::Number(versionstamp3) = a.get("versionstamp").unwrap() else { unreachable!() }; - assert!(versionstamp1 < versionstamp2); + assert!(versionstamp2 < versionstamp3); let changes = a.get("changes").unwrap().to_owned(); assert_eq!( changes, @@ -413,10 +413,10 @@ async fn changefeed_with_ts() -> Result<(), Error> { let Value::Object(a) = a else { unreachable!() }; - let Value::Number(versionstamp3) = a.get("versionstamp").unwrap() else { + let Value::Number(versionstamp4) = a.get("versionstamp").unwrap() else { unreachable!() }; - assert!(versionstamp2 < versionstamp3); + assert!(versionstamp3 < versionstamp4); let changes = a.get("changes").unwrap().to_owned(); assert_eq!( changes, @@ -437,10 +437,10 @@ async fn changefeed_with_ts() -> Result<(), Error> { let Value::Object(a) = a else { unreachable!() }; - let Value::Number(versionstamp4) = a.get("versionstamp").unwrap() else { + let Value::Number(versionstamp5) = a.get("versionstamp").unwrap() else { unreachable!() }; - assert!(versionstamp3 < versionstamp4); + assert!(versionstamp4 < versionstamp5); let changes = a.get("changes").unwrap().to_owned(); assert_eq!( changes, @@ -487,7 +487,7 @@ async fn changefeed_with_ts() -> Result<(), Error> { let Value::Number(versionstamp1b) = a.get("versionstamp").unwrap() else { unreachable!() }; - assert!(versionstamp1 == versionstamp1b); + assert!(versionstamp2 == versionstamp1b); let changes = a.get("changes").unwrap().to_owned(); assert_eq!( changes, diff --git a/lib/tests/define.rs b/lib/tests/define.rs index f085c256..6f090907 100644 --- a/lib/tests/define.rs +++ b/lib/tests/define.rs @@ -1211,7 +1211,9 @@ async fn define_statement_search_index() -> Result<(), Error> { events: {}, fields: {}, tables: {}, - indexes: { blog_title: 'DEFINE INDEX blog_title ON blog FIELDS title SEARCH ANALYZER simple BM25(1.2,0.75) DOC_IDS_ORDER 100 DOC_LENGTHS_ORDER 100 POSTINGS_ORDER 100 TERMS_ORDER 100 HIGHLIGHTS' }, + indexes: { blog_title: 'DEFINE INDEX blog_title ON blog FIELDS title \ + SEARCH ANALYZER simple BM25(1.2,0.75) \ + DOC_IDS_ORDER 100 DOC_LENGTHS_ORDER 100 POSTINGS_ORDER 100 TERMS_ORDER 100 HIGHLIGHTS' }, lives: {}, }", ); diff --git a/lib/tests/vector.rs b/lib/tests/vector.rs new file mode 100644 index 00000000..bd340cbc --- /dev/null +++ b/lib/tests/vector.rs @@ -0,0 +1,60 @@ +mod helpers; +mod parse; +use crate::helpers::new_ds; +use parse::Parse; +use surrealdb::dbs::Session; +use surrealdb::err::Error; +use surrealdb::sql::Value; + +#[tokio::test] +async fn select_where_mtree_knn() -> Result<(), Error> { + let sql = r" + CREATE pts:1 SET point = [1,2,3,4]; + CREATE pts:2 SET point = [4,5,6,7]; + CREATE pts:3 SET point = [8,9,10,11]; + DEFINE INDEX mt_pts ON pts FIELDS point MTREE DIMENSION 4; + LET $pt = [2,3,4,5]; + SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <2> $pt; + SELECT id FROM pts WHERE point <2> $pt EXPLAIN; + "; + let dbs = new_ds().await?; + let ses = Session::owner().with_ns("test").with_db("test"); + let res = &mut dbs.execute(sql, &ses, None).await?; + assert_eq!(res.len(), 7); + // + for _ in 0..5 { + let _ = res.remove(0).result?; + } + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + id: pts:1, + dist: 2f + }, + { + id: pts:2, + dist: 4f + } + ]", + ); + assert_eq!(format!("{:#}", tmp), format!("{:#}", val)); + let tmp = res.remove(0).result?; + let val = Value::parse( + "[ + { + detail: { + plan: { + index: 'mt_pts', + operator: '<2>', + value: [2,3,4,5] + }, + table: 'pts', + }, + operation: 'Iterate Index' + } + ]", + ); + assert_eq!(format!("{:#}", tmp), format!("{:#}", val)); + Ok(()) +}