[Feat] HNSW persistence ()

Co-authored-by: David Bottiau <david.bottiau@outlook.com>
Co-authored-by: Micha de Vries <micha@devrie.sh>
Co-authored-by: Micha de Vries <mt.dev@hotmail.com>
Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
Co-authored-by: ekgns33 <76658405+ekgns33@users.noreply.github.com>
Co-authored-by: Sergii Glushchenko <sergii.glushchenko@surrealdb.com>
Co-authored-by: Yusuke Kuoka <ykuoka@gmail.com>
This commit is contained in:
Emmanuel Keller 2024-08-20 11:42:58 +01:00 committed by GitHub
parent d86a734d04
commit 0a4801dcf8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 2071 additions and 720 deletions

1
.gitignore vendored
View file

@ -56,4 +56,5 @@ Temporary Items
/store/
surreal
history.txt
/doc/tla/states/

View file

@ -428,15 +428,15 @@ impl<'a> IndexOperation<'a> {
}
async fn index_hnsw(&mut self, ctx: &Context, p: &HnswParams) -> Result<(), Error> {
let hnsw = ctx.get_index_stores().get_index_hnsw(self.opt, self.ix, p).await?;
let hnsw = ctx.get_index_stores().get_index_hnsw(ctx, self.opt, self.ix, p).await?;
let mut hnsw = hnsw.write().await;
// Delete the old index data
if let Some(o) = self.o.take() {
hnsw.remove_document(self.rid, &o)?;
hnsw.remove_document(&ctx.tx(), self.rid.id.clone(), &o).await?;
}
// Create the new index data
if let Some(n) = self.n.take() {
hnsw.index_document(self.rid, &n)?;
hnsw.index_document(&ctx.tx(), self.rid.id.clone(), &n).await?;
}
Ok(())
}

View file

@ -2,7 +2,7 @@ use crate::err::Error;
use crate::idx::trees::bkeys::TrieKeys;
use crate::idx::trees::btree::{BState, BState1, BState1skip, BStatistics, BTree, BTreeStore};
use crate::idx::trees::store::{IndexStores, TreeNodeProvider};
use crate::idx::{IndexKeyBase, VersionedSerdeState};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::{Key, Transaction, TransactionType, Val};
use revision::{revisioned, Revisioned};
use roaring::RoaringTreemap;
@ -31,7 +31,7 @@ impl DocIds {
) -> Result<Self, Error> {
let state_key: Key = ikb.new_bd_key(None);
let state: State = if let Some(val) = tx.get(state_key.clone(), None).await? {
State::try_from_val(val)?
VersionedStore::try_from(val)?
} else {
State::new(default_btree_order)
};
@ -142,7 +142,7 @@ impl DocIds {
available_ids: self.available_ids.take(),
next_doc_id: self.next_doc_id,
};
tx.set(self.state_key.clone(), state.try_to_val()?).await?;
tx.set(self.state_key.clone(), VersionedStore::try_into(&state)?).await?;
self.ixs.advance_cache_btree_trie(new_cache);
}
Ok(())
@ -157,8 +157,8 @@ struct State {
next_doc_id: DocId,
}
impl VersionedSerdeState for State {
fn try_from_val(val: Val) -> Result<Self, Error> {
impl VersionedStore for State {
fn try_from(val: Val) -> Result<Self, Error> {
match Self::deserialize_revisioned(&mut val.as_slice()) {
Ok(r) => Ok(r),
// If it fails here, there is the chance it was an old version of BState
@ -193,7 +193,7 @@ impl From<State1> for State {
}
}
impl VersionedSerdeState for State1 {}
impl VersionedStore for State1 {}
#[revisioned(revision = 1)]
#[derive(Serialize, Deserialize)]
@ -213,7 +213,7 @@ impl From<State1skip> for State {
}
}
impl VersionedSerdeState for State1skip {}
impl VersionedStore for State1skip {}
impl State {
fn new(default_btree_order: u32) -> Self {

View file

@ -3,7 +3,7 @@ use crate::idx::docids::DocId;
use crate::idx::trees::bkeys::TrieKeys;
use crate::idx::trees::btree::{BState, BStatistics, BTree, BTreeStore, Payload};
use crate::idx::trees::store::{IndexStores, TreeNodeProvider};
use crate::idx::{IndexKeyBase, VersionedSerdeState};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::{Key, Transaction, TransactionType};
pub(super) type DocLength = u64;
@ -26,7 +26,7 @@ impl DocLengths {
) -> Result<Self, Error> {
let state_key: Key = ikb.new_bl_key(None);
let state: BState = if let Some(val) = tx.get(state_key.clone(), None).await? {
BState::try_from_val(val)?
VersionedStore::try_from(val)?
} else {
BState::new(default_btree_order)
};
@ -87,7 +87,7 @@ impl DocLengths {
pub(super) async fn finish(&mut self, tx: &Transaction) -> Result<(), Error> {
if let Some(new_cache) = self.store.finish(tx).await? {
let state = self.btree.inc_generation();
tx.set(self.state_key.clone(), state.try_to_val()?).await?;
tx.set(self.state_key.clone(), VersionedStore::try_into(state)?).await?;
self.ixs.advance_cache_btree_trie(new_cache);
}
Ok(())

View file

@ -21,7 +21,7 @@ use crate::idx::ft::termdocs::{TermDocs, TermsDocs};
use crate::idx::ft::terms::{TermId, TermLen, Terms};
use crate::idx::trees::btree::BStatistics;
use crate::idx::trees::store::IndexStores;
use crate::idx::{IndexKeyBase, VersionedSerdeState};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::Transaction;
use crate::kvs::{Key, TransactionType};
use crate::sql::index::SearchParams;
@ -94,7 +94,7 @@ struct State {
doc_count: u64,
}
impl VersionedSerdeState for State {}
impl VersionedStore for State {}
impl FtIndex {
pub(crate) async fn new(
@ -119,7 +119,7 @@ impl FtIndex {
) -> Result<Self, Error> {
let state_key: Key = index_key_base.new_bs_key();
let state: State = if let Some(val) = txn.get(state_key.clone(), None).await? {
State::try_from_val(val)?
VersionedStore::try_from(val)?
} else {
State::default()
};
@ -343,7 +343,7 @@ impl FtIndex {
}
// Update the states
tx.set(self.state_key.clone(), self.state.try_to_val()?).await?;
tx.set(self.state_key.clone(), VersionedStore::try_into(&self.state)?).await?;
drop(tx);
Ok(())
}

View file

@ -4,7 +4,7 @@ use crate::idx::ft::terms::TermId;
use crate::idx::trees::bkeys::TrieKeys;
use crate::idx::trees::btree::{BState, BStatistics, BTree, BTreeStore};
use crate::idx::trees::store::{IndexStores, TreeNodeProvider};
use crate::idx::{IndexKeyBase, VersionedSerdeState};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::{Key, Transaction, TransactionType};
pub(super) type TermFrequency = u64;
@ -28,7 +28,7 @@ impl Postings {
) -> Result<Self, Error> {
let state_key: Key = index_key_base.new_bp_key(None);
let state: BState = if let Some(val) = tx.get(state_key.clone(), None).await? {
BState::try_from_val(val)?
VersionedStore::try_from(val)?
} else {
BState::new(order)
};
@ -87,7 +87,7 @@ impl Postings {
pub(super) async fn finish(&mut self, tx: &Transaction) -> Result<(), Error> {
if let Some(new_cache) = self.store.finish(tx).await? {
let state = self.btree.inc_generation();
tx.set(self.state_key.clone(), state.try_to_val()?).await?;
tx.set(self.state_key.clone(), VersionedStore::try_into(state)?).await?;
self.ixs.advance_cache_btree_trie(new_cache);
}
Ok(())

View file

@ -2,7 +2,7 @@ use crate::err::Error;
use crate::idx::trees::bkeys::FstKeys;
use crate::idx::trees::btree::{BState, BState1, BState1skip, BStatistics, BTree, BTreeStore};
use crate::idx::trees::store::{IndexStores, TreeNodeProvider};
use crate::idx::{IndexKeyBase, VersionedSerdeState};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::{Key, Transaction, TransactionType, Val};
use revision::{revisioned, Revisioned};
use roaring::RoaringTreemap;
@ -32,7 +32,7 @@ impl Terms {
) -> Result<Self, Error> {
let state_key: Key = index_key_base.new_bt_key(None);
let state: State = if let Some(val) = tx.get(state_key.clone(), None).await? {
State::try_from_val(val)?
VersionedStore::try_from(val)?
} else {
State::new(default_btree_order)
};
@ -129,7 +129,7 @@ impl Terms {
available_ids: self.available_ids.take(),
next_term_id: self.next_term_id,
};
tx.set(self.state_key.clone(), state.try_to_val()?).await?;
tx.set(self.state_key.clone(), VersionedStore::try_into(&state)?).await?;
self.ixs.advance_store_btree_fst(new_cache);
}
Ok(())
@ -152,7 +152,7 @@ struct State1 {
next_term_id: TermId,
}
impl VersionedSerdeState for State1 {}
impl VersionedStore for State1 {}
#[revisioned(revision = 1)]
#[derive(Serialize, Deserialize)]
@ -162,7 +162,7 @@ struct State1skip {
next_term_id: TermId,
}
impl VersionedSerdeState for State1skip {}
impl VersionedStore for State1skip {}
impl From<State1> for State {
fn from(state: State1) -> Self {
@ -194,8 +194,8 @@ impl State {
}
}
impl VersionedSerdeState for State {
fn try_from_val(val: Val) -> Result<Self, Error> {
impl VersionedStore for State {
fn try_from(val: Val) -> Result<Self, Error> {
match Self::deserialize_revisioned(&mut val.as_slice()) {
Ok(r) => Ok(r),
// If it fails here, there is the chance it was an old version of BState
@ -216,7 +216,7 @@ impl VersionedSerdeState for State {
mod tests {
use crate::idx::ft::postings::TermFrequency;
use crate::idx::ft::terms::{State, Terms};
use crate::idx::{IndexKeyBase, VersionedSerdeState};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::TransactionType::{Read, Write};
use crate::kvs::{Datastore, LockType::*, Transaction, TransactionType};
use rand::{thread_rng, Rng};
@ -226,8 +226,8 @@ mod tests {
#[test]
fn test_state_serde() {
let s = State::new(3);
let val = s.try_to_val().unwrap();
let s = State::try_from_val(val).unwrap();
let val = VersionedStore::try_into(&s).unwrap();
let s: State = VersionedStore::try_from(val).unwrap();
assert_eq!(s.btree.generation(), 0);
assert_eq!(s.next_term_id, 0);
}

View file

@ -183,15 +183,17 @@ impl<'a> IndexOperation<'a> {
}
async fn index_hnsw(&mut self, p: &HnswParams) -> Result<(), Error> {
let hnsw = self.ctx.get_index_stores().get_index_hnsw(self.opt, self.ix, p).await?;
let txn = self.ctx.tx();
let hnsw =
self.ctx.get_index_stores().get_index_hnsw(self.ctx, self.opt, self.ix, p).await?;
let mut hnsw = hnsw.write().await;
// Delete the old index data
if let Some(o) = self.o.take() {
hnsw.remove_document(self.rid, &o)?;
hnsw.remove_document(&txn, self.rid.id.clone(), &o).await?;
}
// Create the new index data
if let Some(n) = self.n.take() {
hnsw.index_document(self.rid, &n)?;
hnsw.index_document(&txn, self.rid.id.clone(), &n).await?;
}
Ok(())
}

View file

@ -7,7 +7,9 @@ pub mod trees;
use crate::err::Error;
use crate::idx::docids::DocId;
use crate::idx::ft::terms::TermId;
use crate::idx::trees::hnsw::ElementId;
use crate::idx::trees::store::NodeId;
use crate::idx::trees::vector::SerializedVector;
use crate::key::index::bc::Bc;
use crate::key::index::bd::Bd;
use crate::key::index::bf::Bf;
@ -19,9 +21,16 @@ use crate::key::index::bp::Bp;
use crate::key::index::bs::Bs;
use crate::key::index::bt::Bt;
use crate::key::index::bu::Bu;
use crate::key::index::hd::Hd;
use crate::key::index::he::He;
use crate::key::index::hi::Hi;
use crate::key::index::hl::Hl;
use crate::key::index::hs::Hs;
use crate::key::index::hv::Hv;
use crate::key::index::vm::Vm;
use crate::kvs::{Key, Val};
use crate::sql::statements::DefineIndexStatement;
use crate::sql::{Id, Thing};
use revision::Revisioned;
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -175,6 +184,72 @@ impl IndexKeyBase {
.into()
}
fn new_hd_key(&self, doc_id: Option<DocId>) -> Key {
Hd::new(
self.inner.ns.as_str(),
self.inner.db.as_str(),
self.inner.tb.as_str(),
self.inner.ix.as_str(),
doc_id,
)
.into()
}
fn new_he_key(&self, element_id: ElementId) -> Key {
He::new(
self.inner.ns.as_str(),
self.inner.db.as_str(),
self.inner.tb.as_str(),
self.inner.ix.as_str(),
element_id,
)
.into()
}
fn new_hi_key(&self, id: Id) -> Key {
Hi::new(
self.inner.ns.as_str(),
self.inner.db.as_str(),
self.inner.tb.as_str(),
self.inner.ix.as_str(),
id,
)
.into()
}
fn new_hl_key(&self, layer: u16, chunk: u32) -> Key {
Hl::new(
self.inner.ns.as_str(),
self.inner.db.as_str(),
self.inner.tb.as_str(),
self.inner.ix.as_str(),
layer,
chunk,
)
.into()
}
fn new_hv_key(&self, vec: Arc<SerializedVector>) -> Key {
Hv::new(
self.inner.ns.as_str(),
self.inner.db.as_str(),
self.inner.tb.as_str(),
self.inner.ix.as_str(),
vec,
)
.into()
}
fn new_hs_key(&self) -> Key {
Hs::new(
self.inner.ns.as_str(),
self.inner.db.as_str(),
self.inner.tb.as_str(),
self.inner.ix.as_str(),
)
.into()
}
fn new_vm_key(&self, node_id: Option<NodeId>) -> Key {
Vm::new(
self.inner.ns.as_str(),
@ -188,17 +263,19 @@ impl IndexKeyBase {
}
/// This trait provides `Revision` based default implementations for serialization/deserialization
trait VersionedSerdeState
trait VersionedStore
where
Self: Sized + Serialize + DeserializeOwned + Revisioned,
{
fn try_to_val(&self) -> Result<Val, Error> {
fn try_into(&self) -> Result<Val, Error> {
let mut val = Vec::new();
self.serialize_revisioned(&mut val)?;
Ok(val)
}
fn try_from_val(val: Val) -> Result<Self, Error> {
fn try_from(val: Val) -> Result<Self, Error> {
Ok(Self::deserialize_revisioned(&mut val.as_slice())?)
}
}
impl VersionedStore for Thing {}

View file

@ -6,6 +6,7 @@ use crate::idx::docids::{DocId, DocIds};
use crate::idx::planner::iterators::KnnIteratorResult;
use crate::idx::trees::hnsw::docs::HnswDocs;
use crate::idx::trees::knn::Ids64;
use crate::kvs::Transaction;
use crate::sql::{Cond, Thing, Value};
use ahash::HashMap;
use reblessive::tree::Stk;
@ -23,13 +24,11 @@ pub enum MTreeConditionChecker<'a> {
MTreeCondition(MTreeCondChecker<'a>),
}
impl<'a> Default for HnswConditionChecker<'a> {
fn default() -> Self {
impl<'a> HnswConditionChecker<'a> {
pub(in crate::idx) fn new() -> Self {
Self::Hnsw(HnswChecker {})
}
}
impl<'a> HnswConditionChecker<'a> {
pub(in crate::idx) fn new_cond(ctx: &'a Context, opt: &'a Options, cond: Arc<Cond>) -> Self {
Self::HnswCondition(HnswCondChecker {
ctx,
@ -41,12 +40,13 @@ impl<'a> HnswConditionChecker<'a> {
pub(in crate::idx) async fn check_truthy(
&mut self,
tx: &Transaction,
stk: &mut Stk,
docs: &HnswDocs,
doc_ids: &Ids64,
doc_ids: Ids64,
) -> Result<bool, Error> {
match self {
Self::HnswCondition(c) => c.check_any_truthy(stk, docs, doc_ids).await,
Self::HnswCondition(c) => c.check_any_truthy(tx, stk, docs, doc_ids).await,
Self::Hnsw(_) => Ok(true),
}
}
@ -65,11 +65,12 @@ impl<'a> HnswConditionChecker<'a> {
pub(in crate::idx) async fn convert_result(
&mut self,
tx: &Transaction,
docs: &HnswDocs,
res: VecDeque<(DocId, f64)>,
) -> Result<VecDeque<KnnIteratorResult>, Error> {
match self {
Self::Hnsw(c) => c.convert_result(docs, res).await,
Self::Hnsw(c) => c.convert_result(tx, docs, res).await,
Self::HnswCondition(c) => Ok(c.convert_result(res)),
}
}
@ -78,12 +79,12 @@ impl<'a> HnswConditionChecker<'a> {
impl<'a> MTreeConditionChecker<'a> {
pub fn new_cond(ctx: &'a Context, opt: &'a Options, cond: Arc<Cond>) -> Self {
if Cond(Value::Bool(true)).ne(cond.as_ref()) {
return Self::MTreeCondition(MTreeCondChecker {
Self::MTreeCondition(MTreeCondChecker {
ctx,
opt,
cond,
cache: Default::default(),
});
})
} else {
Self::new(ctx)
}
@ -253,9 +254,10 @@ impl<'a> MTreeCondChecker<'a> {
pub struct HnswChecker {}
impl<'a> HnswChecker {
impl HnswChecker {
async fn convert_result(
&self,
tx: &Transaction,
docs: &HnswDocs,
res: VecDeque<(DocId, f64)>,
) -> Result<VecDeque<KnnIteratorResult>, Error> {
@ -264,7 +266,7 @@ impl<'a> HnswChecker {
}
let mut result = VecDeque::with_capacity(res.len());
for (doc_id, dist) in res {
if let Some(rid) = docs.get_thing(doc_id) {
if let Some(rid) = docs.get_thing(tx, doc_id).await? {
result.push_back((rid.clone().into(), dist, None));
}
}
@ -286,16 +288,17 @@ impl<'a> HnswCondChecker<'a> {
async fn check_any_truthy(
&mut self,
tx: &Transaction,
stk: &mut Stk,
docs: &HnswDocs,
doc_ids: &Ids64,
doc_ids: Ids64,
) -> Result<bool, Error> {
let mut res = false;
for doc_id in doc_ids.iter() {
if match self.cache.entry(doc_id) {
Entry::Occupied(e) => e.get().truthy,
Entry::Vacant(e) => {
let rid: Option<Thing> = docs.get_thing(doc_id).cloned();
let rid = docs.get_thing(tx, doc_id).await?;
let ent =
CheckerCacheEntry::build(stk, self.ctx, self.opt, rid, self.cond.as_ref())
.await?;

View file

@ -222,8 +222,11 @@ impl InnerQueryExecutor {
Entry::Vacant(e) => {
let hnsw = ctx
.get_index_stores()
.get_index_hnsw(opt, idx_def, p)
.get_index_hnsw(ctx, opt, idx_def, p)
.await?;
// Ensure the local HNSW index is up to date with the KVS
hnsw.write().await.check_state(&ctx.tx()).await?;
// Now we can execute the request
let entry = HnswEntry::new(
stk,
ctx,
@ -788,11 +791,13 @@ impl HnswEntry {
let cond_checker = if let Some(cond) = cond {
HnswConditionChecker::new_cond(ctx, opt, cond)
} else {
HnswConditionChecker::default()
HnswConditionChecker::new()
};
let h = h.read().await;
let res = h.knn_search(v, n as usize, ef as usize, stk, cond_checker).await?;
drop(h);
let res = h
.read()
.await
.knn_search(&ctx.tx(), stk, v, n as usize, ef as usize, cond_checker)
.await?;
Ok(Self {
res,
})

View file

@ -1,7 +1,7 @@
use crate::err::Error;
use crate::idx::trees::bkeys::BKeys;
use crate::idx::trees::store::{NodeId, StoreGeneration, StoredNode, TreeNode, TreeStore};
use crate::idx::VersionedSerdeState;
use crate::idx::VersionedStore;
use crate::kvs::{Key, Transaction, Val};
use crate::sql::{Object, Value};
#[cfg(debug_assertions)]
@ -39,8 +39,8 @@ pub struct BState {
generation: StoreGeneration,
}
impl VersionedSerdeState for BState {
fn try_from_val(val: Val) -> Result<Self, Error> {
impl VersionedStore for BState {
fn try_from(val: Val) -> Result<Self, Error> {
match Self::deserialize_revisioned(&mut val.as_slice()) {
Ok(r) => Ok(r),
// If it fails here, there is the chance it was an old version of BState
@ -997,7 +997,7 @@ mod tests {
BState, BStatistics, BStoredNode, BTree, BTreeNode, BTreeStore, Payload,
};
use crate::idx::trees::store::{NodeId, TreeNode, TreeNodeProvider};
use crate::idx::VersionedSerdeState;
use crate::idx::VersionedStore;
use crate::kvs::{Datastore, Key, LockType::*, Transaction, TransactionType};
use rand::prelude::SliceRandom;
use rand::thread_rng;
@ -1010,8 +1010,8 @@ mod tests {
#[test]
fn test_btree_state_serde() {
let s = BState::new(3);
let val = s.try_to_val().unwrap();
let s: BState = BState::try_from_val(val).unwrap();
let val = VersionedStore::try_into(&s).unwrap();
let s: BState = VersionedStore::try_from(val).unwrap();
assert_eq!(s.minimum_degree, 3);
assert_eq!(s.root, None);
assert_eq!(s.next_node_id, 0);

View file

@ -1,44 +1,38 @@
use crate::idx::trees::hnsw::ElementId;
use ahash::{HashSet, HashSetExt};
use std::fmt::Debug;
use std::hash::Hash;
pub trait DynamicSet<T>: Debug + Send + Sync
where
T: Eq + Hash + Clone + Default + 'static + Send + Sync,
{
pub trait DynamicSet: Debug + Send + Sync {
fn with_capacity(capacity: usize) -> Self;
fn insert(&mut self, v: T) -> bool;
fn contains(&self, v: &T) -> bool;
fn remove(&mut self, v: &T) -> bool;
fn insert(&mut self, v: ElementId) -> bool;
fn contains(&self, v: &ElementId) -> bool;
fn remove(&mut self, v: &ElementId) -> bool;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
fn iter(&self) -> Box<dyn Iterator<Item = &T> + '_>;
fn iter(&self) -> impl Iterator<Item = &ElementId>;
}
#[derive(Debug)]
pub struct HashBrownSet<T>(HashSet<T>);
pub struct AHashSet(HashSet<ElementId>);
impl<T> DynamicSet<T> for HashBrownSet<T>
where
T: Eq + Hash + Clone + Default + Debug + 'static + Send + Sync,
{
impl DynamicSet for AHashSet {
#[inline]
fn with_capacity(capacity: usize) -> Self {
Self(HashSet::with_capacity(capacity))
}
#[inline]
fn insert(&mut self, v: T) -> bool {
fn insert(&mut self, v: ElementId) -> bool {
self.0.insert(v)
}
#[inline]
fn contains(&self, v: &T) -> bool {
fn contains(&self, v: &ElementId) -> bool {
self.0.contains(v)
}
#[inline]
fn remove(&mut self, v: &T) -> bool {
fn remove(&mut self, v: &ElementId) -> bool {
self.0.remove(v)
}
@ -53,35 +47,29 @@ where
}
#[inline]
fn iter(&self) -> Box<dyn Iterator<Item = &T> + '_> {
Box::new(self.0.iter())
fn iter(&self) -> impl Iterator<Item = &ElementId> {
self.0.iter()
}
}
#[derive(Debug)]
pub struct ArraySet<T, const N: usize>
where
T: Eq + Hash + Clone + Default + 'static + Send + Sync,
{
array: [T; N],
pub struct ArraySet<const N: usize> {
array: [ElementId; N],
size: usize,
}
impl<T, const N: usize> DynamicSet<T> for ArraySet<T, N>
where
T: Eq + Hash + Clone + Copy + Default + Debug + 'static + Send + Sync,
{
impl<const N: usize> DynamicSet for ArraySet<N> {
fn with_capacity(_capacity: usize) -> Self {
#[cfg(debug_assertions)]
assert!(_capacity <= N);
Self {
array: [T::default(); N],
array: [0; N],
size: 0,
}
}
#[inline]
fn insert(&mut self, v: T) -> bool {
fn insert(&mut self, v: ElementId) -> bool {
if !self.contains(&v) {
self.array[self.size] = v;
self.size += 1;
@ -92,12 +80,12 @@ where
}
#[inline]
fn contains(&self, v: &T) -> bool {
fn contains(&self, v: &ElementId) -> bool {
self.array[0..self.size].contains(v)
}
#[inline]
fn remove(&mut self, v: &T) -> bool {
fn remove(&mut self, v: &ElementId) -> bool {
if let Some(p) = self.array[0..self.size].iter().position(|e| e.eq(v)) {
self.array[p..].rotate_left(1);
self.size -= 1;
@ -118,23 +106,24 @@ where
}
#[inline]
fn iter(&self) -> Box<dyn Iterator<Item = &T> + '_> {
Box::new(self.array[0..self.size].iter())
fn iter(&self) -> impl Iterator<Item = &ElementId> {
self.array[0..self.size].iter()
}
}
#[cfg(test)]
mod tests {
use crate::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
use crate::idx::trees::dynamicset::{AHashSet, ArraySet, DynamicSet};
use crate::idx::trees::hnsw::ElementId;
use ahash::HashSet;
fn test_dynamic_set<S: DynamicSet<usize>>(capacity: usize) {
let mut dyn_set = S::with_capacity(capacity);
fn test_dynamic_set<S: DynamicSet>(capacity: ElementId) {
let mut dyn_set = S::with_capacity(capacity as usize);
let mut control = HashSet::default();
// Test insertions
for sample in 0..capacity {
assert_eq!(dyn_set.len(), control.len(), "{capacity} - {sample}");
let v: HashSet<usize> = dyn_set.iter().cloned().collect();
let v: HashSet<ElementId> = dyn_set.iter().cloned().collect();
assert_eq!(v, control, "{capacity} - {sample}");
// We should not have the element yet
assert!(!dyn_set.contains(&sample), "{capacity} - {sample}");
@ -159,7 +148,7 @@ mod tests {
control.remove(&sample);
// The control structure and the dyn_set should be identical
assert_eq!(dyn_set.len(), control.len(), "{capacity} - {sample}");
let v: HashSet<usize> = dyn_set.iter().cloned().collect();
let v: HashSet<ElementId> = dyn_set.iter().cloned().collect();
assert_eq!(v, control, "{capacity} - {sample}");
}
}
@ -167,17 +156,17 @@ mod tests {
#[test]
fn test_dynamic_set_hash() {
for capacity in 1..50 {
test_dynamic_set::<HashBrownSet<usize>>(capacity);
test_dynamic_set::<AHashSet>(capacity);
}
}
#[test]
fn test_dynamic_set_array() {
test_dynamic_set::<ArraySet<usize, 1>>(1);
test_dynamic_set::<ArraySet<usize, 2>>(2);
test_dynamic_set::<ArraySet<usize, 4>>(4);
test_dynamic_set::<ArraySet<usize, 10>>(10);
test_dynamic_set::<ArraySet<usize, 20>>(20);
test_dynamic_set::<ArraySet<usize, 30>>(30);
test_dynamic_set::<ArraySet<1>>(1);
test_dynamic_set::<ArraySet<2>>(2);
test_dynamic_set::<ArraySet<4>>(4);
test_dynamic_set::<ArraySet<10>>(10);
test_dynamic_set::<ArraySet<20>>(20);
test_dynamic_set::<ArraySet<30>>(30);
}
}

View file

@ -1,25 +1,25 @@
use crate::err::Error;
use crate::idx::trees::dynamicset::DynamicSet;
use crate::idx::trees::hnsw::ElementId;
use ahash::HashMap;
#[cfg(test)]
use ahash::HashSet;
use bytes::{Buf, BufMut, BytesMut};
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Debug)]
pub(super) struct UndirectedGraph<T, S>
pub(super) struct UndirectedGraph<S>
where
T: Eq + Hash + Clone + Copy + Default + 'static + Send + Sync,
S: DynamicSet<T>,
S: DynamicSet,
{
capacity: usize,
nodes: HashMap<T, S>,
nodes: HashMap<ElementId, S>,
}
impl<T, S> UndirectedGraph<T, S>
impl<S> UndirectedGraph<S>
where
T: Eq + Hash + Clone + Copy + Default + 'static + Send + Sync,
S: DynamicSet<T>,
S: DynamicSet,
{
pub(super) fn new(capacity: usize) -> Self {
Self {
@ -34,11 +34,11 @@ where
}
#[inline]
pub(super) fn get_edges(&self, node: &T) -> Option<&S> {
pub(super) fn get_edges(&self, node: &ElementId) -> Option<&S> {
self.nodes.get(node)
}
pub(super) fn add_empty_node(&mut self, node: T) -> bool {
pub(super) fn add_empty_node(&mut self, node: ElementId) -> bool {
if let Entry::Vacant(e) = self.nodes.entry(node) {
e.insert(S::with_capacity(self.capacity));
true
@ -47,7 +47,11 @@ where
}
}
pub(super) fn add_node_and_bidirectional_edges(&mut self, node: T, edges: S) -> Vec<T> {
pub(super) fn add_node_and_bidirectional_edges(
&mut self,
node: ElementId,
edges: S,
) -> Vec<ElementId> {
let mut r = Vec::with_capacity(edges.len());
for &e in edges.iter() {
self.nodes.entry(e).or_insert_with(|| S::with_capacity(self.capacity)).insert(node);
@ -57,11 +61,11 @@ where
r
}
#[inline]
pub(super) fn set_node(&mut self, node: T, new_edges: S) {
pub(super) fn set_node(&mut self, node: ElementId, new_edges: S) {
self.nodes.insert(node, new_edges);
}
pub(super) fn remove_node_and_bidirectional_edges(&mut self, node: &T) -> Option<S> {
pub(super) fn remove_node_and_bidirectional_edges(&mut self, node: &ElementId) -> Option<S> {
if let Some(edges) = self.nodes.remove(node) {
for edge in edges.iter() {
if let Some(edges_to_node) = self.nodes.get_mut(edge) {
@ -73,25 +77,53 @@ where
None
}
}
pub(super) fn to_val(&self) -> Result<BytesMut, Error> {
let mut buf = BytesMut::new();
buf.put_u32(self.nodes.len() as u32);
for (&e, s) in &self.nodes {
buf.put_u64(e);
buf.put_u16(s.len() as u16);
for &i in s.iter() {
buf.put_u64(i);
}
}
Ok(buf)
}
pub(super) fn reload(&mut self, val: &[u8]) -> Result<(), Error> {
let mut buf = BytesMut::from(val);
self.nodes.clear();
let len = buf.get_u32() as usize;
for _ in 0..len {
let e = buf.get_u64();
let s_len = buf.get_u16() as usize;
let mut s = S::with_capacity(s_len);
for _ in 0..s_len {
s.insert(buf.get_u64() as ElementId);
}
self.nodes.insert(e, s);
}
Ok(())
}
}
#[cfg(test)]
impl<T, S> UndirectedGraph<T, S>
impl<S> UndirectedGraph<S>
where
T: Eq + Hash + Clone + Copy + Default + 'static + Debug + Send + Sync,
S: DynamicSet<T>,
S: DynamicSet,
{
pub(in crate::idx::trees) fn len(&self) -> usize {
self.nodes.len()
}
pub(in crate::idx::trees) fn nodes(&self) -> &HashMap<T, S> {
pub(in crate::idx::trees) fn nodes(&self) -> &HashMap<ElementId, S> {
&self.nodes
}
pub(in crate::idx::trees) fn check(&self, g: Vec<(T, Vec<T>)>) {
pub(in crate::idx::trees) fn check(&self, g: Vec<(ElementId, Vec<ElementId>)>) {
for (n, e) in g {
let edges: HashSet<T> = e.into_iter().collect();
let n_edges: Option<HashSet<T>> =
let edges: HashSet<ElementId> = e.into_iter().collect();
let n_edges: Option<HashSet<ElementId>> =
self.get_edges(&n).map(|e| e.iter().cloned().collect());
assert_eq!(n_edges, Some(edges), "{n:?}");
}
@ -100,12 +132,13 @@ where
#[cfg(test)]
mod tests {
use crate::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
use crate::idx::trees::dynamicset::{AHashSet, ArraySet, DynamicSet};
use crate::idx::trees::graph::UndirectedGraph;
use crate::idx::trees::hnsw::ElementId;
fn test_undirected_graph<S: DynamicSet<i32>>(m_max: usize) {
fn test_undirected_graph<S: DynamicSet>(m_max: usize) {
// Graph creation
let mut g = UndirectedGraph::<i32, S>::new(m_max);
let mut g = UndirectedGraph::<S>::new(m_max);
assert_eq!(g.capacity, 10);
// Adding an empty node
@ -153,7 +186,7 @@ mod tests {
let res = g.remove_node_and_bidirectional_edges(&2);
assert_eq!(
res.map(|v| {
let mut v: Vec<i32> = v.iter().cloned().collect();
let mut v: Vec<ElementId> = v.iter().cloned().collect();
v.sort();
v
}),
@ -174,11 +207,11 @@ mod tests {
#[test]
fn test_undirected_graph_array() {
test_undirected_graph::<ArraySet<i32, 10>>(10);
test_undirected_graph::<ArraySet<10>>(10);
}
#[test]
fn test_undirected_graph_hash() {
test_undirected_graph::<HashBrownSet<i32>>(10);
test_undirected_graph::<AHashSet>(10);
}
}

View file

@ -1,56 +1,214 @@
use crate::err::Error;
use crate::idx::docids::DocId;
use crate::kvs::Key;
use crate::sql::Thing;
use radix_trie::Trie;
use crate::idx::trees::hnsw::flavor::HnswFlavor;
use crate::idx::trees::hnsw::ElementId;
use crate::idx::trees::knn::Ids64;
use crate::idx::trees::vector::{SerializedVector, Vector};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::{Key, Transaction, Val};
use crate::sql::{Id, Thing};
use derive::Store;
use revision::revisioned;
use roaring::RoaringTreemap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Default)]
pub(in crate::idx) struct HnswDocs {
doc_ids: Trie<Key, DocId>,
ids_doc: Vec<Option<Thing>>,
available: RoaringTreemap,
tb: String,
ikb: IndexKeyBase,
state_key: Key,
state_updated: bool,
state: State,
}
#[revisioned(revision = 1)]
#[derive(Default, Clone, Serialize, Deserialize, Store)]
#[non_exhaustive]
struct State {
available: RoaringTreemap,
next_doc_id: DocId,
}
impl VersionedStore for State {}
impl HnswDocs {
pub(super) fn resolve(&mut self, rid: &Thing) -> DocId {
let doc_key: Key = rid.into();
if let Some(doc_id) = self.doc_ids.get(&doc_key) {
*doc_id
pub(in crate::idx) async fn new(
tx: &Transaction,
tb: String,
ikb: IndexKeyBase,
) -> Result<Self, Error> {
let state_key = ikb.new_hd_key(None);
let state = if let Some(k) = tx.get(state_key.clone(), None).await? {
VersionedStore::try_from(k)?
} else {
State::default()
};
Ok(Self {
tb,
ikb,
state_updated: false,
state_key,
state,
})
}
pub(super) async fn resolve(&mut self, tx: &Transaction, id: Id) -> Result<DocId, Error> {
let id_key = self.ikb.new_hi_key(id.clone());
if let Some(v) = tx.get(id_key.clone(), None).await? {
let doc_id = u64::from_be_bytes(v.try_into().unwrap());
Ok(doc_id)
} else {
let doc_id = self.next_doc_id();
self.ids_doc.push(Some(rid.clone()));
self.doc_ids.insert(doc_key, doc_id);
doc_id
tx.set(id_key, doc_id.to_be_bytes()).await?;
let doc_key = self.ikb.new_hd_key(Some(doc_id));
tx.set(doc_key, id).await?;
Ok(doc_id)
}
}
fn next_doc_id(&mut self) -> DocId {
if let Some(doc_id) = self.available.iter().next() {
self.available.remove(doc_id);
self.state_updated = true;
if let Some(doc_id) = self.state.available.iter().next() {
self.state.available.remove(doc_id);
doc_id
} else {
self.ids_doc.len() as DocId
let doc_id = self.state.next_doc_id;
self.state.next_doc_id += 1;
doc_id
}
}
pub(in crate::idx) fn get_thing(&self, doc_id: DocId) -> Option<&Thing> {
if let Some(r) = self.ids_doc.get(doc_id as usize) {
r.as_ref()
pub(in crate::idx) async fn get_thing(
&self,
tx: &Transaction,
doc_id: DocId,
) -> Result<Option<Thing>, Error> {
let doc_key = self.ikb.new_hd_key(Some(doc_id));
if let Some(val) = tx.get(doc_key, None).await? {
let id: Id = val.into();
Ok(Some(Thing::from((self.tb.to_owned(), id))))
} else {
None
Ok(None)
}
}
pub(super) fn remove(&mut self, rid: &Thing) -> Option<DocId> {
let doc_key: Key = rid.into();
if let Some(doc_id) = self.doc_ids.remove(&doc_key) {
let n = doc_id as usize;
if n < self.ids_doc.len() {
self.ids_doc[n] = None;
}
self.available.insert(doc_id);
Some(doc_id)
pub(super) async fn remove(
&mut self,
tx: &Transaction,
id: Id,
) -> Result<Option<DocId>, Error> {
let id_key = self.ikb.new_hi_key(id);
if let Some(v) = tx.get(id_key.clone(), None).await? {
let doc_id = u64::from_be_bytes(v.try_into().unwrap());
let doc_key = self.ikb.new_hd_key(Some(doc_id));
tx.del(doc_key).await?;
tx.del(id_key).await?;
self.state.available.insert(doc_id);
Ok(Some(doc_id))
} else {
None
Ok(None)
}
}
pub(in crate::idx) async fn finish(&mut self, tx: &Transaction) -> Result<(), Error> {
if self.state_updated {
tx.set(self.state_key.clone(), VersionedStore::try_into(&self.state)?).await?;
self.state_updated = true;
}
Ok(())
}
}
#[revisioned(revision = 1)]
#[derive(Serialize, Deserialize)]
#[non_exhaustive]
struct ElementDocs {
e_id: ElementId,
docs: Ids64,
}
impl VersionedStore for ElementDocs {}
pub(in crate::idx) struct VecDocs {
ikb: IndexKeyBase,
}
impl VecDocs {
pub(super) fn new(ikb: IndexKeyBase) -> Self {
Self {
ikb,
}
}
pub(super) async fn get_docs(
&self,
tx: &Transaction,
pt: &Vector,
) -> Result<Option<Ids64>, Error> {
let key = self.ikb.new_hv_key(Arc::new(pt.into()));
if let Some(val) = tx.get(key, None).await? {
let ed: ElementDocs = VersionedStore::try_from(val)?;
Ok(Some(ed.docs))
} else {
Ok(None)
}
}
pub(super) async fn insert(
&self,
tx: &Transaction,
o: Vector,
d: DocId,
h: &mut HnswFlavor,
) -> Result<(), Error> {
let ser_vec = Arc::new(SerializedVector::from(&o));
let key = self.ikb.new_hv_key(ser_vec);
if let Some(ed) = match tx.get(key.clone(), None).await? {
Some(val) => {
// We already have the vector
let mut ed: ElementDocs = VersionedStore::try_from(val)?;
ed.docs.insert(d).map(|new_docs| {
ed.docs = new_docs;
ed
})
}
None => {
// We don't have the vector, we insert it in the graph
let element_id = h.insert(tx, o).await?;
let ed = ElementDocs {
e_id: element_id,
docs: Ids64::One(d),
};
Some(ed)
}
} {
let val: Val = VersionedStore::try_into(&ed)?;
tx.set(key, val).await?;
}
Ok(())
}
pub(super) async fn remove(
&self,
tx: &Transaction,
o: &Vector,
d: DocId,
h: &mut HnswFlavor,
) -> Result<(), Error> {
let key = self.ikb.new_hv_key(Arc::new(o.into()));
if let Some(val) = tx.get(key.clone(), None).await? {
let mut ed: ElementDocs = VersionedStore::try_from(val)?;
if let Some(new_docs) = ed.docs.remove(d) {
if new_docs.is_empty() {
tx.del(key).await?;
h.remove(tx, ed.e_id).await?;
} else {
ed.docs = new_docs;
let val: Val = VersionedStore::try_into(&ed)?;
tx.set(key, val).await?;
}
}
};
Ok(())
}
}

View file

@ -1,27 +1,41 @@
use crate::err::Error;
use crate::idx::trees::hnsw::ElementId;
use crate::idx::trees::vector::SharedVector;
use crate::idx::trees::vector::{SerializedVector, SharedVector, Vector};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::Transaction;
use crate::sql::index::Distance;
use ahash::HashMap;
use dashmap::DashMap;
pub(super) struct HnswElements {
elements: HashMap<ElementId, SharedVector>,
ikb: IndexKeyBase,
elements: DashMap<ElementId, SharedVector>,
next_element_id: ElementId,
dist: Distance,
}
impl HnswElements {
pub(super) fn new(dist: Distance) -> Self {
pub(super) fn new(ikb: IndexKeyBase, dist: Distance) -> Self {
Self {
ikb,
elements: Default::default(),
next_element_id: 0,
dist,
}
}
pub(super) fn set_next_element_id(&mut self, next: ElementId) {
self.next_element_id = next;
}
pub(super) fn next_element_id(&self) -> ElementId {
self.next_element_id
}
pub(super) fn inc_next_element_id(&mut self) -> ElementId {
self.next_element_id += 1;
self.next_element_id
}
#[cfg(test)]
pub(super) fn len(&self) -> usize {
self.elements.len()
@ -32,27 +46,59 @@ impl HnswElements {
self.elements.contains_key(e_id)
}
pub(super) fn inc_next_element_id(&mut self) {
self.next_element_id += 1;
pub(super) async fn insert(
&mut self,
tx: &Transaction,
id: ElementId,
vec: Vector,
ser_vec: &SerializedVector,
) -> Result<SharedVector, Error> {
let key = self.ikb.new_he_key(id);
let val = VersionedStore::try_into(ser_vec)?;
tx.set(key, val).await?;
let pt: SharedVector = vec.into();
self.elements.insert(id, pt.clone());
Ok(pt)
}
pub(super) fn insert(&mut self, id: ElementId, pt: SharedVector) {
self.elements.insert(id, pt);
}
pub(super) fn get_vector(&self, e_id: &ElementId) -> Option<&SharedVector> {
self.elements.get(e_id)
pub(super) async fn get_vector(
&self,
tx: &Transaction,
e_id: &ElementId,
) -> Result<Option<SharedVector>, Error> {
if let Some(r) = self.elements.get(e_id) {
return Ok(Some(r.value().clone()));
}
let key = self.ikb.new_he_key(*e_id);
match tx.get(key, None).await? {
None => Ok(None),
Some(val) => {
let vec: SerializedVector = VersionedStore::try_from(val)?;
let vec = Vector::from(vec);
let vec: SharedVector = vec.into();
self.elements.insert(*e_id, vec.clone());
Ok(Some(vec))
}
}
}
pub(super) fn distance(&self, a: &SharedVector, b: &SharedVector) -> f64 {
self.dist.calculate(a, b)
}
pub(super) fn get_distance(&self, q: &SharedVector, e_id: &ElementId) -> Option<f64> {
self.elements.get(e_id).map(|e_pt| self.dist.calculate(e_pt, q))
pub(super) async fn get_distance(
&self,
tx: &Transaction,
q: &SharedVector,
e_id: &ElementId,
) -> Result<Option<f64>, Error> {
Ok(self.get_vector(tx, e_id).await?.map(|r| self.dist.calculate(&r, q)))
}
pub(super) fn remove(&mut self, e_id: &ElementId) {
self.elements.remove(e_id);
pub(super) async fn remove(&mut self, tx: &Transaction, e_id: ElementId) -> Result<(), Error> {
self.elements.remove(&e_id);
let key = self.ikb.new_he_key(e_id);
tx.del(key).await?;
Ok(())
}
}

View file

@ -1,182 +1,217 @@
use crate::err::Error;
use crate::idx::planner::checker::HnswConditionChecker;
use crate::idx::trees::dynamicset::{ArraySet, HashBrownSet};
use crate::idx::trees::dynamicset::{AHashSet, ArraySet};
use crate::idx::trees::hnsw::docs::HnswDocs;
use crate::idx::trees::hnsw::index::VecDocs;
use crate::idx::trees::hnsw::docs::VecDocs;
use crate::idx::trees::hnsw::{ElementId, Hnsw, HnswSearch};
use crate::idx::trees::vector::SharedVector;
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::idx::IndexKeyBase;
use crate::kvs::Transaction;
use crate::sql::index::HnswParams;
use reblessive::tree::Stk;
pub(super) type ASet<const N: usize> = ArraySet<ElementId, N>;
pub(super) type HSet = HashBrownSet<ElementId>;
pub(super) enum HnswFlavor {
H5_9(Hnsw<ASet<9>, ASet<5>>),
H5_17(Hnsw<ASet<17>, ASet<5>>),
H5_25(Hnsw<ASet<25>, ASet<5>>),
H5set(Hnsw<HSet, ASet<5>>),
H9_17(Hnsw<ASet<17>, ASet<9>>),
H9_25(Hnsw<ASet<25>, ASet<9>>),
H9set(Hnsw<HSet, ASet<9>>),
H13_25(Hnsw<ASet<25>, ASet<13>>),
H13set(Hnsw<HSet, ASet<13>>),
H17set(Hnsw<HSet, ASet<17>>),
H21set(Hnsw<HSet, ASet<21>>),
H25set(Hnsw<HSet, ASet<25>>),
H29set(Hnsw<HSet, ASet<29>>),
Hset(Hnsw<HSet, HSet>),
H5_9(Hnsw<ArraySet<9>, ArraySet<5>>),
H5_17(Hnsw<ArraySet<17>, ArraySet<5>>),
H5_25(Hnsw<ArraySet<25>, ArraySet<5>>),
H5set(Hnsw<AHashSet, ArraySet<5>>),
H9_17(Hnsw<ArraySet<17>, ArraySet<9>>),
H9_25(Hnsw<ArraySet<25>, ArraySet<9>>),
H9set(Hnsw<AHashSet, ArraySet<9>>),
H13_25(Hnsw<ArraySet<25>, ArraySet<13>>),
H13set(Hnsw<AHashSet, ArraySet<13>>),
H17set(Hnsw<AHashSet, ArraySet<17>>),
H21set(Hnsw<AHashSet, ArraySet<21>>),
H25set(Hnsw<AHashSet, ArraySet<25>>),
H29set(Hnsw<AHashSet, ArraySet<29>>),
Hset(Hnsw<AHashSet, AHashSet>),
}
impl HnswFlavor {
pub(super) fn new(p: &HnswParams) -> Self {
pub(super) fn new(ibk: IndexKeyBase, p: &HnswParams) -> Self {
match p.m {
1..=4 => match p.m0 {
1..=8 => Self::H5_9(Hnsw::<ASet<9>, ASet<5>>::new(p)),
9..=16 => Self::H5_17(Hnsw::<ASet<17>, ASet<5>>::new(p)),
17..=24 => Self::H5_25(Hnsw::<ASet<25>, ASet<5>>::new(p)),
_ => Self::H5set(Hnsw::<HSet, ASet<5>>::new(p)),
1..=8 => Self::H5_9(Hnsw::<ArraySet<9>, ArraySet<5>>::new(ibk, p)),
9..=16 => Self::H5_17(Hnsw::<ArraySet<17>, ArraySet<5>>::new(ibk, p)),
17..=24 => Self::H5_25(Hnsw::<ArraySet<25>, ArraySet<5>>::new(ibk, p)),
_ => Self::H5set(Hnsw::<AHashSet, ArraySet<5>>::new(ibk, p)),
},
5..=8 => match p.m0 {
1..=16 => Self::H9_17(Hnsw::<ASet<17>, ASet<9>>::new(p)),
17..=24 => Self::H9_25(Hnsw::<ASet<25>, ASet<9>>::new(p)),
_ => Self::H9set(Hnsw::<HSet, ASet<9>>::new(p)),
1..=16 => Self::H9_17(Hnsw::<ArraySet<17>, ArraySet<9>>::new(ibk, p)),
17..=24 => Self::H9_25(Hnsw::<ArraySet<25>, ArraySet<9>>::new(ibk, p)),
_ => Self::H9set(Hnsw::<AHashSet, ArraySet<9>>::new(ibk, p)),
},
9..=12 => match p.m0 {
17..=24 => Self::H13_25(Hnsw::<ASet<25>, ASet<13>>::new(p)),
_ => Self::H13set(Hnsw::<HSet, ASet<13>>::new(p)),
17..=24 => Self::H13_25(Hnsw::<ArraySet<25>, ArraySet<13>>::new(ibk, p)),
_ => Self::H13set(Hnsw::<AHashSet, ArraySet<13>>::new(ibk, p)),
},
13..=16 => Self::H17set(Hnsw::<HSet, ASet<17>>::new(p)),
17..=20 => Self::H21set(Hnsw::<HSet, ASet<21>>::new(p)),
21..=24 => Self::H25set(Hnsw::<HSet, ASet<25>>::new(p)),
25..=28 => Self::H29set(Hnsw::<HSet, ASet<29>>::new(p)),
_ => Self::Hset(Hnsw::<HSet, HSet>::new(p)),
13..=16 => Self::H17set(Hnsw::<AHashSet, ArraySet<17>>::new(ibk, p)),
17..=20 => Self::H21set(Hnsw::<AHashSet, ArraySet<21>>::new(ibk, p)),
21..=24 => Self::H25set(Hnsw::<AHashSet, ArraySet<25>>::new(ibk, p)),
25..=28 => Self::H29set(Hnsw::<AHashSet, ArraySet<29>>::new(ibk, p)),
_ => Self::Hset(Hnsw::<AHashSet, AHashSet>::new(ibk, p)),
}
}
pub(super) fn insert(&mut self, q_pt: SharedVector) -> ElementId {
pub(super) async fn check_state(&mut self, tx: &Transaction) -> Result<(), Error> {
match self {
HnswFlavor::H5_9(h) => h.insert(q_pt),
HnswFlavor::H5_17(h) => h.insert(q_pt),
HnswFlavor::H5_25(h) => h.insert(q_pt),
HnswFlavor::H5set(h) => h.insert(q_pt),
HnswFlavor::H9_17(h) => h.insert(q_pt),
HnswFlavor::H9_25(h) => h.insert(q_pt),
HnswFlavor::H9set(h) => h.insert(q_pt),
HnswFlavor::H13_25(h) => h.insert(q_pt),
HnswFlavor::H13set(h) => h.insert(q_pt),
HnswFlavor::H17set(h) => h.insert(q_pt),
HnswFlavor::H21set(h) => h.insert(q_pt),
HnswFlavor::H25set(h) => h.insert(q_pt),
HnswFlavor::H29set(h) => h.insert(q_pt),
HnswFlavor::Hset(h) => h.insert(q_pt),
HnswFlavor::H5_9(h) => h.check_state(tx).await,
HnswFlavor::H5_17(h) => h.check_state(tx).await,
HnswFlavor::H5_25(h) => h.check_state(tx).await,
HnswFlavor::H5set(h) => h.check_state(tx).await,
HnswFlavor::H9_17(h) => h.check_state(tx).await,
HnswFlavor::H9_25(h) => h.check_state(tx).await,
HnswFlavor::H9set(h) => h.check_state(tx).await,
HnswFlavor::H13_25(h) => h.check_state(tx).await,
HnswFlavor::H13set(h) => h.check_state(tx).await,
HnswFlavor::H17set(h) => h.check_state(tx).await,
HnswFlavor::H21set(h) => h.check_state(tx).await,
HnswFlavor::H25set(h) => h.check_state(tx).await,
HnswFlavor::H29set(h) => h.check_state(tx).await,
HnswFlavor::Hset(h) => h.check_state(tx).await,
}
}
pub(super) fn remove(&mut self, e_id: ElementId) -> bool {
pub(super) async fn insert(
&mut self,
tx: &Transaction,
q_pt: Vector,
) -> Result<ElementId, Error> {
match self {
HnswFlavor::H5_9(h) => h.remove(e_id),
HnswFlavor::H5_17(h) => h.remove(e_id),
HnswFlavor::H5_25(h) => h.remove(e_id),
HnswFlavor::H5set(h) => h.remove(e_id),
HnswFlavor::H9_17(h) => h.remove(e_id),
HnswFlavor::H9_25(h) => h.remove(e_id),
HnswFlavor::H9set(h) => h.remove(e_id),
HnswFlavor::H13_25(h) => h.remove(e_id),
HnswFlavor::H13set(h) => h.remove(e_id),
HnswFlavor::H17set(h) => h.remove(e_id),
HnswFlavor::H21set(h) => h.remove(e_id),
HnswFlavor::H25set(h) => h.remove(e_id),
HnswFlavor::H29set(h) => h.remove(e_id),
HnswFlavor::Hset(h) => h.remove(e_id),
HnswFlavor::H5_9(h) => h.insert(tx, q_pt).await,
HnswFlavor::H5_17(h) => h.insert(tx, q_pt).await,
HnswFlavor::H5_25(h) => h.insert(tx, q_pt).await,
HnswFlavor::H5set(h) => h.insert(tx, q_pt).await,
HnswFlavor::H9_17(h) => h.insert(tx, q_pt).await,
HnswFlavor::H9_25(h) => h.insert(tx, q_pt).await,
HnswFlavor::H9set(h) => h.insert(tx, q_pt).await,
HnswFlavor::H13_25(h) => h.insert(tx, q_pt).await,
HnswFlavor::H13set(h) => h.insert(tx, q_pt).await,
HnswFlavor::H17set(h) => h.insert(tx, q_pt).await,
HnswFlavor::H21set(h) => h.insert(tx, q_pt).await,
HnswFlavor::H25set(h) => h.insert(tx, q_pt).await,
HnswFlavor::H29set(h) => h.insert(tx, q_pt).await,
HnswFlavor::Hset(h) => h.insert(tx, q_pt).await,
}
}
pub(super) fn knn_search(&self, search: &HnswSearch) -> Vec<(f64, ElementId)> {
pub(super) async fn remove(
&mut self,
tx: &Transaction,
e_id: ElementId,
) -> Result<bool, Error> {
match self {
HnswFlavor::H5_9(h) => h.knn_search(search),
HnswFlavor::H5_17(h) => h.knn_search(search),
HnswFlavor::H5_25(h) => h.knn_search(search),
HnswFlavor::H5set(h) => h.knn_search(search),
HnswFlavor::H9_17(h) => h.knn_search(search),
HnswFlavor::H9_25(h) => h.knn_search(search),
HnswFlavor::H9set(h) => h.knn_search(search),
HnswFlavor::H13_25(h) => h.knn_search(search),
HnswFlavor::H13set(h) => h.knn_search(search),
HnswFlavor::H17set(h) => h.knn_search(search),
HnswFlavor::H21set(h) => h.knn_search(search),
HnswFlavor::H25set(h) => h.knn_search(search),
HnswFlavor::H29set(h) => h.knn_search(search),
HnswFlavor::Hset(h) => h.knn_search(search),
HnswFlavor::H5_9(h) => h.remove(tx, e_id).await,
HnswFlavor::H5_17(h) => h.remove(tx, e_id).await,
HnswFlavor::H5_25(h) => h.remove(tx, e_id).await,
HnswFlavor::H5set(h) => h.remove(tx, e_id).await,
HnswFlavor::H9_17(h) => h.remove(tx, e_id).await,
HnswFlavor::H9_25(h) => h.remove(tx, e_id).await,
HnswFlavor::H9set(h) => h.remove(tx, e_id).await,
HnswFlavor::H13_25(h) => h.remove(tx, e_id).await,
HnswFlavor::H13set(h) => h.remove(tx, e_id).await,
HnswFlavor::H17set(h) => h.remove(tx, e_id).await,
HnswFlavor::H21set(h) => h.remove(tx, e_id).await,
HnswFlavor::H25set(h) => h.remove(tx, e_id).await,
HnswFlavor::H29set(h) => h.remove(tx, e_id).await,
HnswFlavor::Hset(h) => h.remove(tx, e_id).await,
}
}
pub(super) async fn knn_search(
&self,
tx: &Transaction,
search: &HnswSearch,
) -> Result<Vec<(f64, ElementId)>, Error> {
match self {
HnswFlavor::H5_9(h) => h.knn_search(tx, search).await,
HnswFlavor::H5_17(h) => h.knn_search(tx, search).await,
HnswFlavor::H5_25(h) => h.knn_search(tx, search).await,
HnswFlavor::H5set(h) => h.knn_search(tx, search).await,
HnswFlavor::H9_17(h) => h.knn_search(tx, search).await,
HnswFlavor::H9_25(h) => h.knn_search(tx, search).await,
HnswFlavor::H9set(h) => h.knn_search(tx, search).await,
HnswFlavor::H13_25(h) => h.knn_search(tx, search).await,
HnswFlavor::H13set(h) => h.knn_search(tx, search).await,
HnswFlavor::H17set(h) => h.knn_search(tx, search).await,
HnswFlavor::H21set(h) => h.knn_search(tx, search).await,
HnswFlavor::H25set(h) => h.knn_search(tx, search).await,
HnswFlavor::H29set(h) => h.knn_search(tx, search).await,
HnswFlavor::Hset(h) => h.knn_search(tx, search).await,
}
}
pub(super) async fn knn_search_checked(
&self,
tx: &Transaction,
stk: &mut Stk,
search: &HnswSearch,
hnsw_docs: &HnswDocs,
vec_docs: &VecDocs,
stk: &mut Stk,
chk: &mut HnswConditionChecker<'_>,
) -> Result<Vec<(f64, ElementId)>, Error> {
match self {
HnswFlavor::H5_9(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H5_17(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H5_25(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H5set(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H9_17(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H9_25(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H9set(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H13_25(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H13set(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H17set(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H21set(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H25set(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::H29set(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
HnswFlavor::Hset(h) => {
h.knn_search_checked(search, hnsw_docs, vec_docs, stk, chk).await
h.knn_search_checked(tx, stk, search, hnsw_docs, vec_docs, chk).await
}
}
}
pub(super) fn get_vector(&self, e_id: &ElementId) -> Option<&SharedVector> {
pub(super) async fn get_vector(
&self,
tx: &Transaction,
e_id: &ElementId,
) -> Result<Option<SharedVector>, Error> {
match self {
HnswFlavor::H5_9(h) => h.get_vector(e_id),
HnswFlavor::H5_17(h) => h.get_vector(e_id),
HnswFlavor::H5_25(h) => h.get_vector(e_id),
HnswFlavor::H5set(h) => h.get_vector(e_id),
HnswFlavor::H9_17(h) => h.get_vector(e_id),
HnswFlavor::H9_25(h) => h.get_vector(e_id),
HnswFlavor::H9set(h) => h.get_vector(e_id),
HnswFlavor::H13_25(h) => h.get_vector(e_id),
HnswFlavor::H13set(h) => h.get_vector(e_id),
HnswFlavor::H17set(h) => h.get_vector(e_id),
HnswFlavor::H21set(h) => h.get_vector(e_id),
HnswFlavor::H25set(h) => h.get_vector(e_id),
HnswFlavor::H29set(h) => h.get_vector(e_id),
HnswFlavor::Hset(h) => h.get_vector(e_id),
HnswFlavor::H5_9(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H5_17(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H5_25(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H5set(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H9_17(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H9_25(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H9set(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H13_25(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H13set(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H17set(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H21set(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H25set(h) => h.get_vector(tx, e_id).await,
HnswFlavor::H29set(h) => h.get_vector(tx, e_id).await,
HnswFlavor::Hset(h) => h.get_vector(tx, e_id).await,
}
}
#[cfg(test)]

View file

@ -1,8 +1,10 @@
use crate::err::Error;
use crate::idx::trees::dynamicset::DynamicSet;
use crate::idx::trees::hnsw::layer::HnswLayer;
use crate::idx::trees::hnsw::{ElementId, HnswElements};
use crate::idx::trees::knn::DoublePriorityQueue;
use crate::idx::trees::vector::SharedVector;
use crate::kvs::Transaction;
use crate::sql::index::HnswParams;
#[derive(Debug)]
@ -30,61 +32,72 @@ impl From<&HnswParams> for Heuristic {
}
impl Heuristic {
pub(super) fn select<S>(
#[allow(clippy::too_many_arguments)]
pub(super) async fn select<S>(
&self,
tx: &Transaction,
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
c: DoublePriorityQueue,
ignore: Option<ElementId>,
res: &mut S,
) where
S: DynamicSet<ElementId>,
) -> Result<(), Error>
where
S: DynamicSet,
{
match self {
Self::Standard => Self::heuristic(elements, layer, c, res),
Self::Ext => Self::heuristic_ext(elements, layer, q_id, q_pt, c, res),
Self::Keep => Self::heuristic_keep(elements, layer, c, res),
Self::ExtAndKeep => Self::heuristic_ext_keep(elements, layer, q_id, q_pt, c, res),
Self::Standard => Self::heuristic(tx, elements, layer, c, res).await,
Self::Ext => Self::heuristic_ext(tx, elements, layer, q_id, q_pt, c, ignore, res).await,
Self::Keep => Self::heuristic_keep(tx, elements, layer, c, res).await,
Self::ExtAndKeep => {
Self::heuristic_ext_keep(tx, elements, layer, q_id, q_pt, c, ignore, res).await
}
}
}
fn heuristic<S>(
async fn heuristic<S>(
tx: &Transaction,
elements: &HnswElements,
layer: &HnswLayer<S>,
mut c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
) -> Result<(), Error>
where
S: DynamicSet,
{
let m_max = layer.m_max();
if c.len() <= m_max {
c.to_dynamic_set(res);
} else {
while let Some((e_dist, e_id)) = c.pop_first() {
if Self::is_closer(elements, e_dist, e_id, res) && res.len() == m_max {
if Self::is_closer(tx, elements, e_dist, e_id, res).await? && res.len() == m_max {
break;
}
}
}
Ok(())
}
fn heuristic_keep<S>(
async fn heuristic_keep<S>(
tx: &Transaction,
elements: &HnswElements,
layer: &HnswLayer<S>,
mut c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
) -> Result<(), Error>
where
S: DynamicSet,
{
let m_max = layer.m_max();
if c.len() <= m_max {
c.to_dynamic_set(res);
return;
return Ok(());
}
let mut pruned = Vec::new();
while let Some((e_dist, e_id)) = c.pop_first() {
if Self::is_closer(elements, e_dist, e_id, res) {
if Self::is_closer(tx, elements, e_dist, e_id, res).await? {
if res.len() == m_max {
break;
}
@ -98,25 +111,32 @@ impl Heuristic {
res.insert(e_id);
}
}
Ok(())
}
fn extend_candidates<S>(
async fn extend_candidates<S>(
tx: &Transaction,
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
c: &mut DoublePriorityQueue,
) where
S: DynamicSet<ElementId>,
ignore: Option<ElementId>,
) -> Result<(), Error>
where
S: DynamicSet,
{
let m_max = layer.m_max();
let mut ex = c.to_set();
if let Some(i) = ignore {
ex.insert(i);
}
let mut ext = Vec::with_capacity(m_max.min(c.len()));
for (_, e_id) in c.to_vec().into_iter() {
if let Some(e_conn) = layer.get_edges(&e_id) {
for &e_adj in e_conn.iter() {
if e_adj != q_id && ex.insert(e_adj) {
if let Some(d) = elements.get_distance(q_pt, &e_adj) {
if let Some(d) = elements.get_distance(tx, q_pt, &e_adj).await? {
ext.push((d, e_adj));
}
}
@ -129,52 +149,67 @@ impl Heuristic {
for (e_dist, e_id) in ext {
c.push(e_dist, e_id);
}
Ok(())
}
fn heuristic_ext<S>(
#[allow(clippy::too_many_arguments)]
async fn heuristic_ext<S>(
tx: &Transaction,
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
mut c: DoublePriorityQueue,
ignore: Option<ElementId>,
res: &mut S,
) where
S: DynamicSet<ElementId>,
{
Self::extend_candidates(elements, layer, q_id, q_pt, &mut c);
Self::heuristic(elements, layer, c, res)
}
fn heuristic_ext_keep<S>(
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
mut c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
{
Self::extend_candidates(elements, layer, q_id, q_pt, &mut c);
Self::heuristic_keep(elements, layer, c, res)
}
fn is_closer<S>(elements: &HnswElements, e_dist: f64, e_id: ElementId, r: &mut S) -> bool
) -> Result<(), Error>
where
S: DynamicSet<ElementId>,
S: DynamicSet,
{
if let Some(current_vec) = elements.get_vector(&e_id) {
Self::extend_candidates(tx, elements, layer, q_id, q_pt, &mut c, ignore).await?;
Self::heuristic(tx, elements, layer, c, res).await
}
#[allow(clippy::too_many_arguments)]
async fn heuristic_ext_keep<S>(
tx: &Transaction,
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
mut c: DoublePriorityQueue,
ignore: Option<ElementId>,
res: &mut S,
) -> Result<(), Error>
where
S: DynamicSet,
{
Self::extend_candidates(tx, elements, layer, q_id, q_pt, &mut c, ignore).await?;
Self::heuristic_keep(tx, elements, layer, c, res).await
}
async fn is_closer<S>(
tx: &Transaction,
elements: &HnswElements,
e_dist: f64,
e_id: ElementId,
r: &mut S,
) -> Result<bool, Error>
where
S: DynamicSet,
{
if let Some(current_vec) = elements.get_vector(tx, &e_id).await? {
for r_id in r.iter() {
if let Some(r_dist) = elements.get_distance(current_vec, r_id) {
if let Some(r_dist) = elements.get_distance(tx, &current_vec, r_id).await? {
if e_dist > r_dist {
return false;
return Ok(false);
}
}
}
r.insert(e_id);
true
Ok(true)
} else {
false
Ok(false)
}
}
}

View file

@ -1,18 +1,19 @@
use crate::err::Error;
use crate::idx::docids::DocId;
use crate::idx::planner::checker::HnswConditionChecker;
use crate::idx::planner::iterators::KnnIteratorResult;
use crate::idx::trees::hnsw::docs::HnswDocs;
use crate::idx::trees::hnsw::docs::{HnswDocs, VecDocs};
use crate::idx::trees::hnsw::elements::HnswElements;
use crate::idx::trees::hnsw::flavor::HnswFlavor;
use crate::idx::trees::hnsw::{ElementId, HnswSearch};
use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder};
use crate::idx::trees::knn::{KnnResult, KnnResultBuilder};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::idx::IndexKeyBase;
use crate::kvs::Transaction;
use crate::sql::index::{HnswParams, VectorType};
use crate::sql::{Number, Thing, Value};
use crate::sql::{Id, Number, Value};
#[cfg(debug_assertions)]
use ahash::HashMap;
use reblessive::tree::Stk;
use std::collections::hash_map::Entry;
use std::collections::VecDeque;
pub struct HnswIndex {
@ -23,8 +24,6 @@ pub struct HnswIndex {
vec_docs: VecDocs,
}
pub(super) type VecDocs = HashMap<SharedVector, (Ids64, ElementId)>;
pub(super) struct HnswCheckedSearchContext<'a> {
elements: &'a HnswElements,
docs: &'a HnswDocs,
@ -62,8 +61,8 @@ impl<'a> HnswCheckedSearchContext<'a> {
self.docs
}
pub(super) fn get_docs(&self, pt: &SharedVector) -> Option<&Ids64> {
self.vec_docs.get(pt).map(|(doc_ids, _)| doc_ids)
pub(super) fn vec_docs(&self) -> &VecDocs {
self.vec_docs
}
pub(super) fn elements(&self) -> &HnswElements {
@ -72,84 +71,76 @@ impl<'a> HnswCheckedSearchContext<'a> {
}
impl HnswIndex {
pub fn new(p: &HnswParams) -> Self {
Self {
pub async fn new(
tx: &Transaction,
ikb: IndexKeyBase,
tb: String,
p: &HnswParams,
) -> Result<Self, Error> {
Ok(Self {
dim: p.dimension as usize,
vector_type: p.vector_type,
hnsw: HnswFlavor::new(p),
docs: HnswDocs::default(),
vec_docs: HashMap::default(),
}
hnsw: HnswFlavor::new(ikb.clone(), p),
docs: HnswDocs::new(tx, tb, ikb.clone()).await?,
vec_docs: VecDocs::new(ikb),
})
}
pub fn index_document(&mut self, rid: &Thing, content: &Vec<Value>) -> Result<(), Error> {
pub async fn index_document(
&mut self,
tx: &Transaction,
id: Id,
content: &Vec<Value>,
) -> Result<(), Error> {
// Ensure the layers are up-to-date
self.hnsw.check_state(tx).await?;
// Resolve the doc_id
let doc_id = self.docs.resolve(rid);
let doc_id = self.docs.resolve(tx, id).await?;
// Index the values
for value in content {
// Extract the vector
let vector = Vector::try_from_value(self.vector_type, self.dim, value)?;
vector.check_dimension(self.dim)?;
self.insert(vector.into(), doc_id);
// Insert the vector
self.vec_docs.insert(tx, vector, doc_id, &mut self.hnsw).await?;
}
self.docs.finish(tx).await?;
Ok(())
}
pub(super) fn insert(&mut self, o: SharedVector, d: DocId) {
match self.vec_docs.entry(o) {
Entry::Occupied(mut e) => {
let (docs, element_id) = e.get_mut();
if let Some(new_docs) = docs.insert(d) {
let element_id = *element_id;
e.insert((new_docs, element_id));
}
}
Entry::Vacant(e) => {
let o = e.key().clone();
let element_id = self.hnsw.insert(o);
e.insert((Ids64::One(d), element_id));
}
}
}
pub(super) fn remove(&mut self, o: SharedVector, d: DocId) {
if let Entry::Occupied(mut e) = self.vec_docs.entry(o) {
let (docs, e_id) = e.get_mut();
if let Some(new_docs) = docs.remove(d) {
let e_id = *e_id;
if new_docs.is_empty() {
e.remove();
self.hnsw.remove(e_id);
} else {
e.insert((new_docs, e_id));
}
}
}
}
pub(crate) fn remove_document(
pub(crate) async fn remove_document(
&mut self,
rid: &Thing,
tx: &Transaction,
id: Id,
content: &Vec<Value>,
) -> Result<(), Error> {
if let Some(doc_id) = self.docs.remove(rid) {
if let Some(doc_id) = self.docs.remove(tx, id).await? {
// Ensure the layers are up-to-date
self.hnsw.check_state(tx).await?;
for v in content {
// Extract the vector
let vector = Vector::try_from_value(self.vector_type, self.dim, v)?;
vector.check_dimension(self.dim)?;
// Remove the vector
self.remove(vector.into(), doc_id);
self.vec_docs.remove(tx, &vector, doc_id, &mut self.hnsw).await?;
}
self.docs.finish(tx).await?;
}
Ok(())
}
// Ensure the layers are up-to-date
pub async fn check_state(&mut self, tx: &Transaction) -> Result<(), Error> {
self.hnsw.check_state(tx).await
}
pub async fn knn_search(
&self,
tx: &Transaction,
stk: &mut Stk,
pt: &[Number],
k: usize,
ef: usize,
stk: &mut Stk,
mut chk: HnswConditionChecker<'_>,
) -> Result<VecDeque<KnnIteratorResult>, Error> {
// Extract the vector
@ -157,48 +148,52 @@ impl HnswIndex {
vector.check_dimension(self.dim)?;
let search = HnswSearch::new(vector, k, ef);
// Do the search
let result = self.search(&search, stk, &mut chk).await?;
let res = chk.convert_result(&self.docs, result.docs).await?;
let result = self.search(tx, stk, &search, &mut chk).await?;
let res = chk.convert_result(tx, &self.docs, result.docs).await?;
Ok(res)
}
pub(super) async fn search(
&self,
search: &HnswSearch,
tx: &Transaction,
stk: &mut Stk,
search: &HnswSearch,
chk: &mut HnswConditionChecker<'_>,
) -> Result<KnnResult, Error> {
// Do the search
let neighbors = match chk {
HnswConditionChecker::Hnsw(_) => self.hnsw.knn_search(search),
HnswConditionChecker::Hnsw(_) => self.hnsw.knn_search(tx, search).await?,
HnswConditionChecker::HnswCondition(_) => {
self.hnsw.knn_search_checked(search, &self.docs, &self.vec_docs, stk, chk).await?
self.hnsw
.knn_search_checked(tx, stk, search, &self.docs, &self.vec_docs, chk)
.await?
}
};
Ok(self.build_result(neighbors, search.k, chk))
self.build_result(tx, neighbors, search.k, chk).await
}
fn build_result(
async fn build_result(
&self,
tx: &Transaction,
neighbors: Vec<(f64, ElementId)>,
n: usize,
chk: &mut HnswConditionChecker<'_>,
) -> KnnResult {
) -> Result<KnnResult, Error> {
let mut builder = KnnResultBuilder::new(n);
for (e_dist, e_id) in neighbors {
if builder.check_add(e_dist) {
if let Some(v) = self.hnsw.get_vector(&e_id) {
if let Some((docs, _)) = self.vec_docs.get(v) {
if let Some(v) = self.hnsw.get_vector(tx, &e_id).await? {
if let Some(docs) = self.vec_docs.get_docs(tx, &v).await? {
let evicted_docs = builder.add(e_dist, docs);
chk.expires(evicted_docs);
}
}
}
}
builder.build(
Ok(builder.build(
#[cfg(debug_assertions)]
HashMap::default(),
)
))
}
#[cfg(test)]

View file

@ -7,24 +7,40 @@ use crate::idx::trees::hnsw::index::HnswCheckedSearchContext;
use crate::idx::trees::hnsw::{ElementId, HnswElements};
use crate::idx::trees::knn::DoublePriorityQueue;
use crate::idx::trees::vector::SharedVector;
use crate::idx::IndexKeyBase;
use crate::kvs::Transaction;
use ahash::HashSet;
use reblessive::tree::Stk;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::mem;
#[revisioned(revision = 1)]
#[derive(Default, Debug, Serialize, Deserialize)]
pub(super) struct LayerState {
pub(super) version: u64,
pub(super) chunks: u32,
}
#[derive(Debug)]
pub(super) struct HnswLayer<S>
where
S: DynamicSet<ElementId>,
S: DynamicSet,
{
graph: UndirectedGraph<ElementId, S>,
ikb: IndexKeyBase,
level: u16,
graph: UndirectedGraph<S>,
m_max: usize,
}
impl<S> HnswLayer<S>
where
S: DynamicSet<ElementId>,
S: DynamicSet,
{
pub(super) fn new(m_max: usize) -> Self {
pub(super) fn new(ikb: IndexKeyBase, level: usize, m_max: usize) -> Self {
Self {
ikb,
level: level as u16,
graph: UndirectedGraph::new(m_max + 1),
m_max,
}
@ -38,93 +54,115 @@ where
self.graph.get_edges(e_id)
}
pub(super) fn add_empty_node(&mut self, node: ElementId) -> bool {
self.graph.add_empty_node(node)
pub(super) async fn add_empty_node(
&mut self,
tx: &Transaction,
node: ElementId,
st: &mut LayerState,
) -> Result<bool, Error> {
if !self.graph.add_empty_node(node) {
return Ok(false);
}
self.save(tx, st).await?;
Ok(true)
}
pub(super) fn search_single(
pub(super) async fn search_single(
&self,
tx: &Transaction,
elements: &HnswElements,
pt: &SharedVector,
ep_dist: f64,
ep_id: ElementId,
ef: usize,
) -> DoublePriorityQueue {
) -> Result<DoublePriorityQueue, Error> {
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(ep_dist, ep_id);
let w = candidates.clone();
self.search(elements, pt, candidates, visited, w, ef)
self.search(tx, elements, pt, candidates, visited, w, ef).await
}
pub(super) async fn search_single_with_ignore(
&self,
tx: &Transaction,
elements: &HnswElements,
pt: &SharedVector,
ignore_id: ElementId,
ef: usize,
) -> Result<Option<ElementId>, Error> {
let visited = HashSet::from_iter([ignore_id]);
let mut candidates = DoublePriorityQueue::default();
if let Some(dist) = elements.get_distance(tx, pt, &ignore_id).await? {
candidates.push(dist, ignore_id);
}
let w = DoublePriorityQueue::default();
let q = self.search(tx, elements, pt, candidates, visited, w, ef).await?;
Ok(q.peek_first().map(|(_, e_id)| e_id))
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn search_single_checked(
&self,
tx: &Transaction,
stk: &mut Stk,
search: &HnswCheckedSearchContext<'_>,
ep_pt: &SharedVector,
ep_dist: f64,
ep_id: ElementId,
stk: &mut Stk,
chk: &mut HnswConditionChecker<'_>,
) -> Result<DoublePriorityQueue, Error> {
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(ep_dist, ep_id);
let mut w = DoublePriorityQueue::default();
Self::add_if_truthy(search, &mut w, ep_pt, ep_dist, ep_id, stk, chk).await?;
self.search_checked(search, candidates, visited, w, stk, chk).await
Self::add_if_truthy(tx, stk, search, &mut w, ep_pt, ep_dist, ep_id, chk).await?;
self.search_checked(tx, stk, search, candidates, visited, w, chk).await
}
pub(super) fn search_multi(
pub(super) async fn search_multi(
&self,
tx: &Transaction,
elements: &HnswElements,
pt: &SharedVector,
candidates: DoublePriorityQueue,
ef: usize,
) -> DoublePriorityQueue {
) -> Result<DoublePriorityQueue, Error> {
let w = candidates.clone();
let visited = w.to_set();
self.search(elements, pt, candidates, visited, w, ef)
self.search(tx, elements, pt, candidates, visited, w, ef).await
}
pub(super) fn search_single_ignore_ep(
pub(super) async fn search_multi_with_ignore(
&self,
tx: &Transaction,
elements: &HnswElements,
pt: &SharedVector,
ep_id: ElementId,
) -> Option<(f64, ElementId)> {
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(0.0, ep_id);
let w = candidates.clone();
let q = self.search(elements, pt, candidates, visited, w, 1);
q.peek_first()
}
pub(super) fn search_multi_ignore_ep(
&self,
elements: &HnswElements,
pt: &SharedVector,
ep_id: ElementId,
ignore_ids: Vec<ElementId>,
efc: usize,
) -> DoublePriorityQueue {
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(0.0, ep_id);
) -> Result<DoublePriorityQueue, Error> {
let mut candidates = DoublePriorityQueue::default();
for id in &ignore_ids {
if let Some(dist) = elements.get_distance(tx, pt, id).await? {
candidates.push(dist, *id);
}
}
let visited = HashSet::from_iter(ignore_ids);
let w = DoublePriorityQueue::default();
self.search(elements, pt, candidates, visited, w, efc)
self.search(tx, elements, pt, candidates, visited, w, efc).await
}
pub(super) fn search(
#[allow(clippy::too_many_arguments)]
pub(super) async fn search(
&self,
tx: &Transaction,
elements: &HnswElements,
pt: &SharedVector,
mut candidates: DoublePriorityQueue,
mut visited: HashSet<ElementId>,
mut w: DoublePriorityQueue,
q: &SharedVector,
mut candidates: DoublePriorityQueue, // set of candidates
mut visited: HashSet<ElementId>, // set of visited elements
mut w: DoublePriorityQueue, // dynamic list of found nearest neighbors
ef: usize,
) -> DoublePriorityQueue {
let mut f_dist = if let Some(d) = w.peek_last_dist() {
d
} else {
return w;
};
while let Some((dist, doc)) = candidates.pop_first() {
if dist > f_dist {
) -> Result<DoublePriorityQueue, Error> {
let mut fq_dist = w.peek_last_dist().unwrap_or(f64::MAX);
while let Some((cq_dist, doc)) = candidates.pop_first() {
if cq_dist > fq_dist {
break;
}
if let Some(neighbourhood) = self.graph.get_edges(&doc) {
@ -133,30 +171,32 @@ where
if !visited.insert(e_id) {
continue;
}
if let Some(e_pt) = elements.get_vector(&e_id) {
let e_dist = elements.distance(e_pt, pt);
if e_dist < f_dist || w.len() < ef {
if let Some(e_pt) = elements.get_vector(tx, &e_id).await? {
let e_dist = elements.distance(&e_pt, q);
if e_dist < fq_dist || w.len() < ef {
candidates.push(e_dist, e_id);
w.push(e_dist, e_id);
if w.len() > ef {
w.pop_last();
}
f_dist = w.peek_last_dist().unwrap(); // w can't be empty
fq_dist = w.peek_last_dist().unwrap_or(f64::MAX);
}
}
}
}
}
w
Ok(w)
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn search_checked(
&self,
tx: &Transaction,
stk: &mut Stk,
search: &HnswCheckedSearchContext<'_>,
mut candidates: DoublePriorityQueue,
mut visited: HashSet<ElementId>,
mut w: DoublePriorityQueue,
stk: &mut Stk,
chk: &mut HnswConditionChecker<'_>,
) -> Result<DoublePriorityQueue, Error> {
let mut f_dist = w.peek_last_dist().unwrap_or(f64::MAX);
@ -175,12 +215,14 @@ where
if !visited.insert(e_id) {
continue;
}
if let Some(e_pt) = elements.get_vector(&e_id) {
let e_dist = elements.distance(e_pt, pt);
if let Some(e_pt) = elements.get_vector(tx, &e_id).await? {
let e_dist = elements.distance(&e_pt, pt);
if e_dist < f_dist || w.len() < ef {
candidates.push(e_dist, e_id);
if Self::add_if_truthy(search, &mut w, e_pt, e_dist, e_id, stk, chk)
.await?
if Self::add_if_truthy(
tx, stk, search, &mut w, &e_pt, e_dist, e_id, chk,
)
.await?
{
f_dist = w.peek_last_dist().unwrap(); // w can't be empty
}
@ -192,17 +234,19 @@ where
Ok(w)
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn add_if_truthy(
tx: &Transaction,
stk: &mut Stk,
search: &HnswCheckedSearchContext<'_>,
w: &mut DoublePriorityQueue,
e_pt: &SharedVector,
e_dist: f64,
e_id: ElementId,
stk: &mut Stk,
chk: &mut HnswConditionChecker<'_>,
) -> Result<bool, Error> {
if let Some(docs) = search.get_docs(e_pt) {
if chk.check_truthy(stk, search.docs(), docs).await? {
if let Some(docs) = search.vec_docs().get_docs(tx, e_pt).await? {
if chk.check_truthy(tx, stk, search.docs(), docs).await? {
w.push(e_dist, e_id);
if w.len() > search.ef() {
if let Some((_, id)) = w.pop_last() {
@ -215,21 +259,21 @@ where
Ok(false)
}
pub(super) fn insert(
pub(super) async fn insert(
&mut self,
(tx, st): (&Transaction, &mut LayerState),
elements: &HnswElements,
heuristic: &Heuristic,
efc: usize,
q_id: ElementId,
q_pt: &SharedVector,
(q_id, q_pt): (ElementId, &SharedVector),
mut eps: DoublePriorityQueue,
) -> DoublePriorityQueue {
) -> Result<DoublePriorityQueue, Error> {
let w;
let mut neighbors = self.graph.new_edges();
{
w = self.search_multi(elements, q_pt, eps, efc);
w = self.search_multi(tx, elements, q_pt, eps, efc).await?;
eps = w.clone();
heuristic.select(elements, self, q_id, q_pt, w, &mut neighbors);
heuristic.select(tx, elements, self, q_id, q_pt, w, None, &mut neighbors).await?;
};
let neighbors = self.graph.add_node_and_bidirectional_edges(q_id, neighbors);
@ -237,10 +281,12 @@ where
for e_id in neighbors {
if let Some(e_conn) = self.graph.get_edges(&e_id) {
if e_conn.len() > self.m_max {
if let Some(e_pt) = elements.get_vector(&e_id) {
let e_c = self.build_priority_list(elements, e_id, e_conn);
if let Some(e_pt) = elements.get_vector(tx, &e_id).await? {
let e_c = self.build_priority_list(tx, elements, e_id, e_conn).await?;
let mut e_new_conn = self.graph.new_edges();
heuristic.select(elements, self, e_id, e_pt, e_c, &mut e_new_conn);
heuristic
.select(tx, elements, self, e_id, &e_pt, e_c, None, &mut e_new_conn)
.await?;
#[cfg(debug_assertions)]
assert!(!e_new_conn.contains(&e_id));
self.graph.set_node(e_id, e_new_conn);
@ -251,66 +297,111 @@ where
unreachable!("Element: {}", e_id);
}
}
eps
self.save(tx, st).await?;
Ok(eps)
}
fn build_priority_list(
async fn build_priority_list(
&self,
tx: &Transaction,
elements: &HnswElements,
e_id: ElementId,
neighbors: &S,
) -> DoublePriorityQueue {
) -> Result<DoublePriorityQueue, Error> {
let mut w = DoublePriorityQueue::default();
if let Some(e_pt) = elements.get_vector(&e_id) {
if let Some(e_pt) = elements.get_vector(tx, &e_id).await? {
for n_id in neighbors.iter() {
if let Some(n_pt) = elements.get_vector(n_id) {
let dist = elements.distance(e_pt, n_pt);
if let Some(n_pt) = elements.get_vector(tx, n_id).await? {
let dist = elements.distance(&e_pt, &n_pt);
w.push(dist, *n_id);
}
}
}
w
Ok(w)
}
pub(super) fn remove(
pub(super) async fn remove(
&mut self,
tx: &Transaction,
st: &mut LayerState,
elements: &HnswElements,
heuristic: &Heuristic,
e_id: ElementId,
efc: usize,
) -> bool {
) -> Result<bool, Error> {
if let Some(f_ids) = self.graph.remove_node_and_bidirectional_edges(&e_id) {
for &q_id in f_ids.iter() {
if let Some(q_pt) = elements.get_vector(&q_id) {
let c = self.search_multi_ignore_ep(elements, q_pt, q_id, efc);
let mut neighbors = self.graph.new_edges();
heuristic.select(elements, self, q_id, q_pt, c, &mut neighbors);
if let Some(q_pt) = elements.get_vector(tx, &q_id).await? {
let c = self
.search_multi_with_ignore(tx, elements, &q_pt, vec![q_id, e_id], efc)
.await?;
let mut q_new_conn = self.graph.new_edges();
heuristic
.select(tx, elements, self, q_id, &q_pt, c, Some(e_id), &mut q_new_conn)
.await?;
#[cfg(debug_assertions)]
{
assert!(
!neighbors.contains(&q_id),
"!neighbors.contains(&q_id) - q_id: {q_id} - f_ids: {neighbors:?}"
!q_new_conn.contains(&q_id),
"!q_new_conn.contains(&q_id) - q_id: {q_id} - f_ids: {q_new_conn:?}"
);
assert!(
!neighbors.contains(&e_id),
"!neighbors.contains(&e_id) - e_id: {e_id} - f_ids: {neighbors:?}"
!q_new_conn.contains(&e_id),
"!q_new_conn.contains(&e_id) - e_id: {e_id} - f_ids: {q_new_conn:?}"
);
assert!(neighbors.len() < self.m_max);
assert!(q_new_conn.len() <= self.m_max);
}
self.graph.set_node(q_id, neighbors);
self.graph.set_node(q_id, q_new_conn);
}
}
true
self.save(tx, st).await?;
Ok(true)
} else {
false
Ok(false)
}
}
// Base on FoundationDB max value size (100K)
// https://apple.github.io/foundationdb/known-limitations.html#large-keys-and-values
const CHUNK_SIZE: usize = 100_000;
async fn save(&mut self, tx: &Transaction, st: &mut LayerState) -> Result<(), Error> {
// Serialise the graph
let val = self.graph.to_val()?;
// Split it into chunks
let chunks = val.chunks(Self::CHUNK_SIZE);
let old_chunks_len = mem::replace(&mut st.chunks, chunks.len() as u32);
for (i, chunk) in chunks.enumerate() {
let key = self.ikb.new_hl_key(self.level, i as u32);
tx.set(key, chunk).await?;
}
// Delete larger chunks if they exists
for i in st.chunks..old_chunks_len {
let key = self.ikb.new_hl_key(self.level, i);
tx.del(key).await?;
}
// Increase the version
st.version += 1;
Ok(())
}
pub(super) async fn load(&mut self, tx: &Transaction, st: &LayerState) -> Result<(), Error> {
let mut val = Vec::new();
// Load the chunks
for i in 0..st.chunks {
let key = self.ikb.new_hl_key(self.level, i);
let chunk =
tx.get(key, None).await?.ok_or_else(|| Error::Unreachable("Missing chunk"))?;
val.extend(chunk);
}
// Rebuild the graph
self.graph.reload(&val)
}
}
#[cfg(test)]
impl<S> HnswLayer<S>
where
S: DynamicSet<ElementId>,
S: DynamicSet,
{
pub(in crate::idx::trees::hnsw) fn check_props(&self, elements: &HnswElements) {
assert!(self.graph.len() <= elements.len(), "{} - {}", self.graph.len(), elements.len());

View file

@ -9,16 +9,22 @@ use crate::err::Error;
use crate::idx::planner::checker::HnswConditionChecker;
use crate::idx::trees::dynamicset::DynamicSet;
use crate::idx::trees::hnsw::docs::HnswDocs;
use crate::idx::trees::hnsw::docs::VecDocs;
use crate::idx::trees::hnsw::elements::HnswElements;
use crate::idx::trees::hnsw::heuristic::Heuristic;
use crate::idx::trees::hnsw::index::{HnswCheckedSearchContext, VecDocs};
use crate::idx::trees::hnsw::layer::HnswLayer;
use crate::idx::trees::hnsw::index::HnswCheckedSearchContext;
use crate::idx::trees::hnsw::layer::{HnswLayer, LayerState};
use crate::idx::trees::knn::DoublePriorityQueue;
use crate::idx::trees::vector::SharedVector;
use crate::idx::trees::vector::{SerializedVector, SharedVector, Vector};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::{Key, Transaction, Val};
use crate::sql::index::HnswParams;
use rand::prelude::SmallRng;
use rand::{Rng, SeedableRng};
use reblessive::tree::Stk;
use revision::revisioned;
use serde::{Deserialize, Serialize};
struct HnswSearch {
pt: SharedVector,
@ -35,67 +41,125 @@ impl HnswSearch {
}
}
}
#[revisioned(revision = 1)]
#[derive(Default, Serialize, Deserialize)]
pub(super) struct HnswState {
enter_point: Option<ElementId>,
next_element_id: ElementId,
layer0: LayerState,
layers: Vec<LayerState>,
}
impl VersionedStore for HnswState {}
struct Hnsw<L0, L>
where
L0: DynamicSet<ElementId>,
L: DynamicSet<ElementId>,
L0: DynamicSet,
L: DynamicSet,
{
ikb: IndexKeyBase,
state_key: Key,
state: HnswState,
m: usize,
efc: usize,
ml: f64,
layer0: HnswLayer<L0>,
layers: Vec<HnswLayer<L>>,
enter_point: Option<ElementId>,
elements: HnswElements,
rng: SmallRng,
heuristic: Heuristic,
}
pub(super) type ElementId = u64;
pub(crate) type ElementId = u64;
impl<L0, L> Hnsw<L0, L>
where
L0: DynamicSet<ElementId>,
L: DynamicSet<ElementId>,
L0: DynamicSet,
L: DynamicSet,
{
fn new(p: &HnswParams) -> Self {
fn new(ikb: IndexKeyBase, p: &HnswParams) -> Self {
let m0 = p.m0 as usize;
let state_key = ikb.new_hs_key();
Self {
state_key,
state: Default::default(),
m: p.m as usize,
efc: p.ef_construction as usize,
ml: p.ml.to_float(),
enter_point: None,
layer0: HnswLayer::new(m0),
layer0: HnswLayer::new(ikb.clone(), 0, m0),
layers: Vec::default(),
elements: HnswElements::new(p.distance.clone()),
elements: HnswElements::new(ikb.clone(), p.distance.clone()),
rng: SmallRng::from_entropy(),
heuristic: p.into(),
ikb,
}
}
fn insert_level(&mut self, q_pt: SharedVector, q_level: usize) -> ElementId {
// Attribute an ID to the vector
async fn check_state(&mut self, tx: &Transaction) -> Result<(), Error> {
// Read the state
let st: HnswState = if let Some(val) = tx.get(self.state_key.clone(), None).await? {
VersionedStore::try_from(val)?
} else {
Default::default()
};
// Compare versions
if st.layer0.version != self.state.layer0.version {
self.layer0.load(tx, &st.layer0).await?;
}
for ((new_stl, stl), layer) in
st.layers.iter().zip(self.state.layers.iter_mut()).zip(self.layers.iter_mut())
{
if new_stl.version != stl.version {
layer.load(tx, new_stl).await?;
}
}
// Retrieve missing layers
for i in self.layers.len()..st.layers.len() {
let mut l = HnswLayer::new(self.ikb.clone(), i + 1, self.m);
l.load(tx, &st.layers[i]).await?;
self.layers.push(l);
}
// Remove non-existing layers
for _ in self.layers.len()..st.layers.len() {
self.layers.pop();
}
// Set the enter_point
self.elements.set_next_element_id(st.next_element_id);
self.state = st;
Ok(())
}
async fn insert_level(
&mut self,
tx: &Transaction,
q_pt: Vector,
q_level: usize,
) -> Result<ElementId, Error> {
// Attributes an ID to the vector
let q_id = self.elements.next_element_id();
let top_up_layers = self.layers.len();
// Be sure we have existing (up) layers if required
for _ in top_up_layers..q_level {
self.layers.push(HnswLayer::new(self.m));
for i in top_up_layers..q_level {
self.layers.push(HnswLayer::new(self.ikb.clone(), i + 1, self.m));
self.state.layers.push(LayerState::default());
}
// Store the vector
self.elements.insert(q_id, q_pt.clone());
let pt_ser = SerializedVector::from(&q_pt);
let q_pt = self.elements.insert(tx, q_id, q_pt, &pt_ser).await?;
if let Some(ep_id) = self.enter_point {
if let Some(ep_id) = self.state.enter_point {
// We already have an enter_point, let's insert the element in the layers
self.insert_element(q_id, &q_pt, q_level, ep_id, top_up_layers);
self.insert_element(tx, q_id, &q_pt, q_level, ep_id, top_up_layers).await?;
} else {
// Otherwise is the first element
self.insert_first_element(q_id, q_level);
self.insert_first_element(tx, q_id, q_level).await?;
}
self.elements.inc_next_element_id();
q_id
self.state.next_element_id = self.elements.inc_next_element_id();
Ok(q_id)
}
fn get_random_level(&mut self) -> usize {
@ -103,29 +167,44 @@ where
(-unif.ln() * self.ml).floor() as usize // calculate the layer
}
fn insert_first_element(&mut self, id: ElementId, level: usize) {
async fn insert_first_element(
&mut self,
tx: &Transaction,
id: ElementId,
level: usize,
) -> Result<(), Error> {
if level > 0 {
for layer in self.layers.iter_mut().take(level) {
layer.add_empty_node(id);
// Insert in up levels
for (layer, state) in
self.layers.iter_mut().zip(self.state.layers.iter_mut()).take(level)
{
layer.add_empty_node(tx, id, state).await?;
}
}
self.layer0.add_empty_node(id);
self.enter_point = Some(id);
// Insert in layer 0
self.layer0.add_empty_node(tx, id, &mut self.state.layer0).await?;
// Update the enter point
self.state.enter_point = Some(id);
//
Ok(())
}
fn insert_element(
async fn insert_element(
&mut self,
tx: &Transaction,
q_id: ElementId,
q_pt: &SharedVector,
q_level: usize,
mut ep_id: ElementId,
top_up_layers: usize,
) {
if let Some(mut ep_dist) = self.elements.get_distance(q_pt, &ep_id) {
) -> Result<(), Error> {
if let Some(mut ep_dist) = self.elements.get_distance(tx, q_pt, &ep_id).await? {
if q_level < top_up_layers {
for layer in self.layers[q_level..top_up_layers].iter_mut().rev() {
if let Some(ep_dist_id) =
layer.search_single(&self.elements, q_pt, ep_dist, ep_id, 1).peek_first()
if let Some(ep_dist_id) = layer
.search_single(tx, &self.elements, q_pt, ep_dist, ep_id, 1)
.await?
.peek_first()
{
(ep_dist, ep_id) = ep_dist_id;
} else {
@ -139,16 +218,43 @@ where
let insert_to_up_layers = q_level.min(top_up_layers);
if insert_to_up_layers > 0 {
for layer in self.layers.iter_mut().take(insert_to_up_layers).rev() {
eps = layer.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);
for (layer, st) in self
.layers
.iter_mut()
.zip(self.state.layers.iter_mut())
.take(insert_to_up_layers)
.rev()
{
eps = layer
.insert(
(tx, st),
&self.elements,
&self.heuristic,
self.efc,
(q_id, q_pt),
eps,
)
.await?;
}
}
self.layer0.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);
self.layer0
.insert(
(tx, &mut self.state.layer0),
&self.elements,
&self.heuristic,
self.efc,
(q_id, q_pt),
eps,
)
.await?;
if top_up_layers < q_level {
for layer in self.layers[top_up_layers..q_level].iter_mut() {
if !layer.add_empty_node(q_id) {
for (layer, st) in self.layers[top_up_layers..q_level]
.iter_mut()
.zip(self.state.layers[top_up_layers..q_level].iter_mut())
{
if !layer.add_empty_node(tx, q_id, st).await? {
#[cfg(debug_assertions)]
unreachable!("Already there {}", q_id);
}
@ -156,80 +262,105 @@ where
}
if q_level > top_up_layers {
self.enter_point = Some(q_id);
self.state.enter_point = Some(q_id);
}
} else {
#[cfg(debug_assertions)]
unreachable!()
}
Ok(())
}
fn insert(&mut self, q_pt: SharedVector) -> ElementId {
async fn save_state(&self, tx: &Transaction) -> Result<(), Error> {
let val: Val = VersionedStore::try_into(&self.state)?;
tx.set(self.state_key.clone(), val).await?;
Ok(())
}
async fn insert(&mut self, tx: &Transaction, q_pt: Vector) -> Result<ElementId, Error> {
let q_level = self.get_random_level();
self.insert_level(q_pt, q_level)
let res = self.insert_level(tx, q_pt, q_level).await?;
self.save_state(tx).await?;
Ok(res)
}
fn remove(&mut self, e_id: ElementId) -> bool {
async fn remove(&mut self, tx: &Transaction, e_id: ElementId) -> Result<bool, Error> {
let mut removed = false;
let e_pt = self.elements.get_vector(&e_id).cloned();
// Do we have the vector?
if let Some(e_pt) = e_pt {
let layers = self.layers.len();
let mut new_enter_point = None;
// Are we deleting the current enter point?
if Some(e_id) == self.enter_point {
// Let's find a new enter point
new_enter_point = if layers == 0 {
self.layer0.search_single_ignore_ep(&self.elements, &e_pt, e_id)
} else {
self.layers[layers - 1].search_single_ignore_ep(&self.elements, &e_pt, e_id)
};
}
self.elements.remove(&e_id);
if let Some(e_pt) = self.elements.get_vector(tx, &e_id).await? {
// Check if we are deleted the current enter_point
let mut new_enter_point = if Some(e_id) == self.state.enter_point {
None
} else {
self.state.enter_point
};
// Remove from the up layers
for layer in self.layers.iter_mut().rev() {
if layer.remove(&self.elements, &self.heuristic, e_id, self.efc) {
for (layer, st) in self.layers.iter_mut().zip(self.state.layers.iter_mut()).rev() {
if new_enter_point.is_none() {
new_enter_point = layer
.search_single_with_ignore(tx, &self.elements, &e_pt, e_id, self.efc)
.await?;
}
if layer.remove(tx, st, &self.elements, &self.heuristic, e_id, self.efc).await? {
removed = true;
}
}
// Check possible new enter_point at layer0
if new_enter_point.is_none() {
new_enter_point = self
.layer0
.search_single_with_ignore(tx, &self.elements, &e_pt, e_id, self.efc)
.await?;
}
// Remove from layer 0
if self.layer0.remove(&self.elements, &self.heuristic, e_id, self.efc) {
if self
.layer0
.remove(tx, &mut self.state.layer0, &self.elements, &self.heuristic, e_id, self.efc)
.await?
{
removed = true;
}
if removed && new_enter_point.is_some() {
// Update the enter point
self.enter_point = new_enter_point.map(|(_, e_id)| e_id);
}
self.elements.remove(tx, e_id).await?;
self.state.enter_point = new_enter_point;
}
removed
self.save_state(tx).await?;
Ok(removed)
}
fn knn_search(&self, search: &HnswSearch) -> Vec<(f64, ElementId)> {
if let Some((ep_dist, ep_id)) = self.search_ep(&search.pt) {
let w =
self.layer0.search_single(&self.elements, &search.pt, ep_dist, ep_id, search.ef);
w.to_vec_limit(search.k)
async fn knn_search(
&self,
tx: &Transaction,
search: &HnswSearch,
) -> Result<Vec<(f64, ElementId)>, Error> {
if let Some((ep_dist, ep_id)) = self.search_ep(tx, &search.pt).await? {
let w = self
.layer0
.search_single(tx, &self.elements, &search.pt, ep_dist, ep_id, search.ef)
.await?;
Ok(w.to_vec_limit(search.k))
} else {
vec![]
Ok(vec![])
}
}
async fn knn_search_checked(
&self,
tx: &Transaction,
stk: &mut Stk,
search: &HnswSearch,
hnsw_docs: &HnswDocs,
vec_docs: &VecDocs,
stk: &mut Stk,
chk: &mut HnswConditionChecker<'_>,
) -> Result<Vec<(f64, ElementId)>, Error> {
if let Some((ep_dist, ep_id)) = self.search_ep(&search.pt) {
if let Some(ep_pt) = self.elements.get_vector(&ep_id) {
if let Some((ep_dist, ep_id)) = self.search_ep(tx, &search.pt).await? {
if let Some(ep_pt) = self.elements.get_vector(tx, &ep_id).await? {
let search_ctx = HnswCheckedSearchContext::new(
&self.elements,
hnsw_docs,
@ -239,7 +370,7 @@ where
);
let w = self
.layer0
.search_single_checked(&search_ctx, ep_pt, ep_dist, ep_id, stk, chk)
.search_single_checked(tx, stk, &search_ctx, &ep_pt, ep_dist, ep_id, chk)
.await?;
return Ok(w.to_vec_limit(search.k));
}
@ -247,12 +378,18 @@ where
Ok(vec![])
}
fn search_ep(&self, pt: &SharedVector) -> Option<(f64, ElementId)> {
if let Some(mut ep_id) = self.enter_point {
if let Some(mut ep_dist) = self.elements.get_distance(pt, &ep_id) {
async fn search_ep(
&self,
tx: &Transaction,
pt: &SharedVector,
) -> Result<Option<(f64, ElementId)>, Error> {
if let Some(mut ep_id) = self.state.enter_point {
if let Some(mut ep_dist) = self.elements.get_distance(tx, pt, &ep_id).await? {
for layer in self.layers.iter().rev() {
if let Some(ep_dist_id) =
layer.search_single(&self.elements, pt, ep_dist, ep_id, 1).peek_first()
if let Some(ep_dist_id) = layer
.search_single(tx, &self.elements, pt, ep_dist, ep_id, 1)
.await?
.peek_first()
{
(ep_dist, ep_id) = ep_dist_id;
} else {
@ -260,17 +397,21 @@ where
unreachable!()
}
}
return Some((ep_dist, ep_id));
return Ok(Some((ep_dist, ep_id)));
} else {
#[cfg(debug_assertions)]
unreachable!()
}
}
None
Ok(None)
}
fn get_vector(&self, e_id: &ElementId) -> Option<&SharedVector> {
self.elements.get_vector(e_id)
async fn get_vector(
&self,
tx: &Transaction,
e_id: &ElementId,
) -> Result<Option<SharedVector>, Error> {
self.elements.get_vector(tx, e_id).await
}
#[cfg(test)]
fn check_hnsw_properties(&self, expected_count: usize) {
@ -281,8 +422,8 @@ where
#[cfg(test)]
fn check_hnsw_props<L0, L>(h: &Hnsw<L0, L>, expected_count: usize)
where
L0: DynamicSet<ElementId>,
L: DynamicSet<ElementId>,
L0: DynamicSet,
L: DynamicSet,
{
assert_eq!(h.elements.len(), expected_count);
for layer in h.layers.iter() {
@ -292,48 +433,56 @@ where
#[cfg(test)]
mod tests {
use crate::ctx::{Context, MutableContext};
use crate::err::Error;
use crate::idx::docids::DocId;
use crate::idx::planner::checker::HnswConditionChecker;
use crate::idx::trees::hnsw::flavor::HnswFlavor;
use crate::idx::trees::hnsw::index::HnswIndex;
use crate::idx::trees::hnsw::HnswSearch;
use crate::idx::trees::hnsw::{ElementId, HnswSearch};
use crate::idx::trees::knn::tests::{new_vectors_from_file, TestCollection};
use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::idx::IndexKeyBase;
use crate::kvs::LockType::Optimistic;
use crate::kvs::{Datastore, Transaction, TransactionType};
use crate::sql::index::{Distance, HnswParams, VectorType};
use crate::sql::{Id, Value};
use ahash::{HashMap, HashSet};
use ndarray::Array1;
use reblessive::tree::Stk;
use roaring::RoaringTreemap;
use std::collections::hash_map::Entry;
use std::ops::Deref;
use std::sync::Arc;
use test_log::test;
fn insert_collection_hnsw(
async fn insert_collection_hnsw(
tx: &Transaction,
h: &mut HnswFlavor,
collection: &TestCollection,
) -> HashSet<SharedVector> {
let mut set = HashSet::default();
) -> HashMap<ElementId, SharedVector> {
let mut map = HashMap::default();
for (_, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone();
h.insert(obj.clone());
set.insert(obj);
h.check_hnsw_properties(set.len());
let e_id = h.insert(tx, obj.clone_vector()).await.unwrap();
map.insert(e_id, obj);
h.check_hnsw_properties(map.len());
}
set
map
}
fn find_collection_hnsw(h: &HnswFlavor, collection: &TestCollection) {
async fn find_collection_hnsw(tx: &Transaction, h: &HnswFlavor, collection: &TestCollection) {
let max_knn = 20.min(collection.len());
for (_, obj) in collection.to_vec_ref() {
for knn in 1..max_knn {
let search = HnswSearch::new(obj.clone(), knn, 80);
let res = h.knn_search(&search);
let res = h.knn_search(tx, &search).await.unwrap();
if collection.is_unique() {
let mut found = false;
for (_, e_id) in &res {
if let Some(v) = h.get_vector(e_id) {
if obj.eq(v) {
if let Some(v) = h.get_vector(tx, e_id).await.unwrap() {
if v.eq(obj) {
found = true;
break;
}
@ -365,10 +514,38 @@ mod tests {
}
}
fn test_hnsw_collection(p: &HnswParams, collection: &TestCollection) {
let mut h = HnswFlavor::new(p);
insert_collection_hnsw(&mut h, collection);
find_collection_hnsw(&h, collection);
async fn delete_collection_hnsw(
tx: &Transaction,
h: &mut HnswFlavor,
mut map: HashMap<ElementId, SharedVector>,
) {
let element_ids: Vec<ElementId> = map.keys().copied().collect();
for e_id in element_ids {
assert!(h.remove(tx, e_id).await.unwrap());
map.remove(&e_id);
h.check_hnsw_properties(map.len());
}
}
async fn test_hnsw_collection(p: &HnswParams, collection: &TestCollection) {
let ds = Datastore::new("memory").await.unwrap();
let mut h = HnswFlavor::new(IndexKeyBase::default(), p);
let map = {
let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
let map = insert_collection_hnsw(&tx, &mut h, collection).await;
tx.commit().await.unwrap();
map
};
{
let tx = ds.transaction(TransactionType::Read, Optimistic).await.unwrap();
find_collection_hnsw(&tx, &h, collection).await;
tx.cancel().await.unwrap();
}
{
let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
delete_collection_hnsw(&tx, &mut h, map).await;
tx.commit().await.unwrap();
}
}
fn new_params(
@ -395,7 +572,7 @@ mod tests {
)
}
fn test_hnsw(collection_size: usize, p: HnswParams) {
async fn test_hnsw(collection_size: usize, p: HnswParams) {
info!("Collection size: {collection_size} - Params: {p:?}");
let collection = TestCollection::new(
true,
@ -404,7 +581,7 @@ mod tests {
p.dimension as usize,
&p.distance,
);
test_hnsw_collection(&p, &collection);
test_hnsw_collection(&p, &collection).await;
}
#[test(tokio::test(flavor = "multi_thread"))]
@ -418,7 +595,7 @@ mod tests {
// (Distance::Jaccard, 100),
(Distance::Manhattan, 5),
(Distance::Minkowski(2.into()), 5),
//(Distance::Pearson, 5),
// (Distance::Pearson, 5),
] {
for vt in [
VectorType::F64,
@ -430,7 +607,7 @@ mod tests {
for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
let p = new_params(dim, vt, dist.clone(), 24, 500, extend, keep);
let f = tokio::spawn(async move {
test_hnsw(30, p);
test_hnsw(30, p).await;
});
futures.push(f);
}
@ -442,15 +619,16 @@ mod tests {
Ok(())
}
fn insert_collection_hnsw_index(
async fn insert_collection_hnsw_index(
tx: &Transaction,
h: &mut HnswIndex,
collection: &TestCollection,
) -> HashMap<SharedVector, HashSet<DocId>> {
) -> Result<HashMap<SharedVector, HashSet<DocId>>, Error> {
let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::default();
for (doc_id, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone();
h.insert(obj.clone(), *doc_id);
match map.entry(obj) {
let content = vec![Value::from(obj.deref())];
h.index_document(tx, Id::Number(*doc_id as i64), &content).await.unwrap();
match map.entry(obj.clone()) {
Entry::Occupied(mut e) => {
e.get_mut().insert(*doc_id);
}
@ -460,10 +638,11 @@ mod tests {
}
h.check_hnsw_properties(map.len());
}
map
Ok(map)
}
async fn find_collection_hnsw_index(
tx: &Transaction,
stk: &mut Stk,
h: &mut HnswIndex,
collection: &TestCollection,
@ -471,9 +650,9 @@ mod tests {
let max_knn = 20.min(collection.len());
for (doc_id, obj) in collection.to_vec_ref() {
for knn in 1..max_knn {
let mut chk = HnswConditionChecker::default();
let mut chk = HnswConditionChecker::new();
let search = HnswSearch::new(obj.clone(), knn, 500);
let res = h.search(&search, stk, &mut chk).await.unwrap();
let res = h.search(tx, stk, &search, &mut chk).await.unwrap();
if knn == 1 && res.docs.len() == 1 && res.docs[0].1 > 0.0 {
let docs: Vec<DocId> = res.docs.iter().map(|(d, _)| *d).collect();
if collection.is_unique() {
@ -501,14 +680,15 @@ mod tests {
}
}
fn delete_hnsw_index_collection(
async fn delete_hnsw_index_collection(
tx: &Transaction,
h: &mut HnswIndex,
collection: &TestCollection,
mut map: HashMap<SharedVector, HashSet<DocId>>,
) {
) -> Result<(), Error> {
for (doc_id, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone();
h.remove(obj.clone(), *doc_id);
let content = vec![Value::from(obj.deref())];
h.remove_document(tx, Id::Number(*doc_id as i64), &content).await?;
if let Entry::Occupied(mut e) = map.entry(obj.clone()) {
let set = e.get_mut();
set.remove(doc_id);
@ -516,12 +696,24 @@ mod tests {
e.remove();
}
}
// Check properties
h.check_hnsw_properties(map.len());
}
Ok(())
}
async fn new_ctx(ds: &Datastore, tt: TransactionType) -> Context {
let tx = Arc::new(ds.transaction(tt, Optimistic).await.unwrap());
let mut ctx = MutableContext::default();
ctx.set_transaction(tx);
ctx.freeze()
}
async fn test_hnsw_index(collection_size: usize, unique: bool, p: HnswParams) {
info!("test_hnsw_index - coll size: {collection_size} - params: {p:?}");
let ds = Datastore::new("memory").await.unwrap();
let collection = TestCollection::new(
unique,
collection_size,
@ -529,16 +721,39 @@ mod tests {
p.dimension as usize,
&p.distance,
);
let mut h = HnswIndex::new(&p);
let map = insert_collection_hnsw_index(&mut h, &collection);
let mut stack = reblessive::tree::TreeStack::new();
stack
.enter(|stk| async {
find_collection_hnsw_index(stk, &mut h, &collection).await;
})
.finish()
.await;
delete_hnsw_index_collection(&mut h, &collection, map);
// Create index
let (mut h, map) = {
let ctx = new_ctx(&ds, TransactionType::Write).await;
let tx = ctx.tx();
let mut h =
HnswIndex::new(&tx, IndexKeyBase::default(), "test".to_string(), &p).await.unwrap();
// Fill index
let map = insert_collection_hnsw_index(&tx, &mut h, &collection).await.unwrap();
tx.commit().await.unwrap();
(h, map)
};
// Search index
{
let mut stack = reblessive::tree::TreeStack::new();
let ctx = new_ctx(&ds, TransactionType::Read).await;
let tx = ctx.tx();
stack
.enter(|stk| async {
find_collection_hnsw_index(&tx, stk, &mut h, &collection).await;
})
.finish()
.await;
}
// Delete collection
{
let ctx = new_ctx(&ds, TransactionType::Write).await;
let tx = ctx.tx();
delete_hnsw_index_collection(&tx, &mut h, &collection, map).await.unwrap();
tx.commit().await.unwrap();
}
}
#[test(tokio::test(flavor = "multi_thread"))]
@ -552,7 +767,7 @@ mod tests {
// (Distance::Jaccard, 100),
(Distance::Manhattan, 5),
(Distance::Minkowski(2.into()), 5),
(Distance::Pearson, 5),
// (Distance::Pearson, 5),
] {
for vt in [
VectorType::F64,
@ -562,7 +777,7 @@ mod tests {
VectorType::I16,
] {
for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
for unique in [false, true] {
for unique in [true, false] {
let p = new_params(dim, vt, dist.clone(), 8, 150, extend, keep);
let f = tokio::spawn(async move {
test_hnsw_index(30, unique, p).await;
@ -578,8 +793,8 @@ mod tests {
Ok(())
}
#[test]
fn test_simple_hnsw() {
#[test(tokio::test(flavor = "multi_thread"))]
async fn test_simple_hnsw() {
let collection = TestCollection::Unique(vec![
(0, new_i16_vec(-2, -3)),
(1, new_i16_vec(-2, 1)),
@ -593,12 +808,21 @@ mod tests {
(9, new_i16_vec(-4, -2)),
(10, new_i16_vec(0, 3)),
]);
let ikb = IndexKeyBase::default();
let p = new_params(2, VectorType::I16, Distance::Euclidean, 3, 500, true, true);
let mut h = HnswFlavor::new(&p);
insert_collection_hnsw(&mut h, &collection);
let search = HnswSearch::new(new_i16_vec(-2, -3), 10, 501);
let res = h.knn_search(&search);
assert_eq!(res.len(), 10);
let mut h = HnswFlavor::new(ikb, &p);
let ds = Arc::new(Datastore::new("memory").await.unwrap());
{
let tx = ds.transaction(TransactionType::Write, Optimistic).await.unwrap();
insert_collection_hnsw(&tx, &mut h, &collection).await;
tx.commit().await.unwrap();
}
{
let tx = ds.transaction(TransactionType::Read, Optimistic).await.unwrap();
let search = HnswSearch::new(new_i16_vec(-2, -3), 10, 501);
let res = h.knn_search(&tx, &search).await.unwrap();
assert_eq!(res.len(), 10);
}
}
async fn test_recall(
@ -610,6 +834,9 @@ mod tests {
tests_ef_recall: &[(usize, f64)],
) -> Result<(), Error> {
info!("Build data collection");
let ds = Arc::new(Datastore::new("memory").await?);
let collection: Arc<TestCollection> =
Arc::new(TestCollection::NonUnique(new_vectors_from_file(
p.vector_type,
@ -617,11 +844,15 @@ mod tests {
Some(ingest_limit),
)?));
let mut h = HnswIndex::new(&p);
let ctx = new_ctx(&ds, TransactionType::Write).await;
let tx = ctx.tx();
let mut h = HnswIndex::new(&tx, IndexKeyBase::default(), "Index".to_string(), &p).await?;
info!("Insert collection");
for (doc_id, obj) in collection.to_vec_ref() {
h.insert(obj.clone(), *doc_id);
let content = vec![Value::from(obj.deref())];
h.index_document(&tx, Id::Number(*doc_id as i64), &content).await?;
}
tx.commit().await?;
let h = Arc::new(h);
@ -638,6 +869,7 @@ mod tests {
let queries = queries.clone();
let collection = collection.clone();
let h = h.clone();
let ds = ds.clone();
let f = tokio::spawn(async move {
let mut stack = reblessive::tree::TreeStack::new();
stack
@ -645,9 +877,11 @@ mod tests {
let mut total_recall = 0.0;
for (_, pt) in queries.to_vec_ref() {
let knn = 10;
let mut chk = HnswConditionChecker::default();
let mut chk = HnswConditionChecker::new();
let search = HnswSearch::new(pt.clone(), knn, efs);
let hnsw_res = h.search(&search, stk, &mut chk).await.unwrap();
let ctx = new_ctx(&ds, TransactionType::Read).await;
let tx = ctx.tx();
let hnsw_res = h.search(&tx, stk, &search, &mut chk).await.unwrap();
assert_eq!(hnsw_res.docs.len(), knn, "Different size - knn: {knn}",);
let brute_force_res = collection.knn(pt, Distance::Euclidean, knn);
let rec = brute_force_res.recall(&hnsw_res);
@ -722,7 +956,7 @@ mod tests {
for (doc_id, doc_pt) in self.to_vec_ref() {
let d = dist.calculate(doc_pt, pt);
if b.check_add(d) {
b.add(d, &Ids64::One(*doc_id));
b.add(d, Ids64::One(*doc_id));
}
}
b.build(

View file

@ -5,7 +5,9 @@ use crate::idx::trees::store::NodeId;
#[cfg(debug_assertions)]
use ahash::HashMap;
use ahash::{HashSet, HashSetExt};
use revision::revisioned;
use roaring::RoaringTreemap;
use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, Reverse};
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, VecDeque};
@ -23,7 +25,7 @@ impl PriorityNode {
}
}
#[derive(Default, Clone)]
#[derive(Default, Debug, Clone)]
pub(super) struct DoublePriorityQueue(BTreeMap<FloatKey, VecDeque<ElementId>>, usize);
impl DoublePriorityQueue {
@ -86,6 +88,7 @@ impl DoublePriorityQueue {
(k, v)
})
}
pub(super) fn peek_last_dist(&self) -> Option<f64> {
self.0.last_key_value().map(|(k, _)| k.0)
}
@ -124,7 +127,7 @@ impl DoublePriorityQueue {
s
}
pub(super) fn to_dynamic_set<S: DynamicSet<ElementId>>(&self, set: &mut S) {
pub(super) fn to_dynamic_set<S: DynamicSet>(&self, set: &mut S) {
for q in self.0.values() {
for v in q {
set.insert(*v);
@ -175,8 +178,10 @@ impl Ord for FloatKey {
/// When identifiers are added or removed, the method returned the most appropriate
/// variant (if required).
#[derive(Debug, Clone, PartialEq)]
#[revisioned(revision = 1)]
#[derive(Serialize, Deserialize)]
#[non_exhaustive]
pub(in crate::idx) enum Ids64 {
#[allow(dead_code)] // Will be used with HNSW
Empty,
One(u64),
Vec2([u64; 2]),
@ -408,7 +413,6 @@ impl Ids64 {
}
}
#[allow(dead_code)] // Will be used with HNSW
pub(super) fn remove(&mut self, d: DocId) -> Option<Self> {
match self {
Self::Empty => None,
@ -541,7 +545,7 @@ impl KnnResultBuilder {
true
}
pub(super) fn add(&mut self, dist: f64, docs: &Ids64) -> Ids64 {
pub(super) fn add(&mut self, dist: f64, docs: Ids64) -> Ids64 {
let pr = FloatKey(dist);
docs.append_to(&mut self.docs);
match self.priority_list.entry(pr) {
@ -550,7 +554,7 @@ impl KnnResultBuilder {
}
Entry::Occupied(mut e) => {
let d = e.get_mut();
if let Some(n) = d.append_from(docs) {
if let Some(n) = d.append_from(&docs) {
e.insert(n);
}
}
@ -815,10 +819,10 @@ pub(super) mod tests {
#[test]
fn knn_result_builder_test() {
let mut b = KnnResultBuilder::new(7);
b.add(0.0, &Ids64::One(5));
b.add(0.2, &Ids64::Vec3([0, 1, 2]));
b.add(0.2, &Ids64::One(3));
b.add(0.2, &Ids64::Vec2([6, 8]));
b.add(0.0, Ids64::One(5));
b.add(0.2, Ids64::Vec3([0, 1, 2]));
b.add(0.2, Ids64::One(3));
b.add(0.2, Ids64::Vec2([6, 8]));
let res = b.build(
#[cfg(debug_assertions)]
HashMap::default(),

View file

@ -22,7 +22,7 @@ use crate::idx::trees::store::{
IndexStores, NodeId, StoredNode, TreeNode, TreeNodeProvider, TreeStore,
};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::idx::{IndexKeyBase, VersionedSerdeState};
use crate::idx::{IndexKeyBase, VersionedStore};
use crate::kvs::{Key, Transaction, TransactionType, Val};
use crate::sql::index::{Distance, MTreeParams, VectorType};
use crate::sql::{Number, Object, Thing, Value};
@ -58,7 +58,7 @@ impl MTreeIndex {
));
let state_key = ikb.new_vm_key(None);
let state: MState = if let Some(val) = txn.get(state_key.clone(), None).await? {
MState::try_from_val(val)?
VersionedStore::try_from(val)?
} else {
MState::new(p.capacity)
};
@ -175,7 +175,7 @@ impl MTreeIndex {
let mut mtree = self.mtree.write().await;
if let Some(new_cache) = self.store.finish(tx).await? {
mtree.state.generation += 1;
tx.set(self.state_key.clone(), mtree.state.try_to_val()?).await?;
tx.set(self.state_key.clone(), VersionedStore::try_into(&mtree.state)?).await?;
self.ixs.advance_store_mtree(new_cache);
}
drop(mtree);
@ -246,7 +246,7 @@ impl MTree {
}
}
if !docs.is_empty() {
let evicted_docs = res.add(d, &docs);
let evicted_docs = res.add(d, docs);
chk.expires(evicted_docs);
}
}
@ -1468,7 +1468,7 @@ impl ObjectProperties {
}
}
impl VersionedSerdeState for MState {}
impl VersionedStore for MState {}
#[cfg(test)]
mod tests {

View file

@ -1,3 +1,5 @@
use crate::ctx::Context;
use crate::err::Error;
use crate::idx::trees::hnsw::index::HnswIndex;
use crate::idx::IndexKeyBase;
use crate::kvs::Key;
@ -18,38 +20,38 @@ impl Default for HnswIndexes {
}
impl HnswIndexes {
pub(super) async fn get(&self, ikb: &IndexKeyBase, p: &HnswParams) -> SharedHnswIndex {
pub(super) async fn get(
&self,
ctx: &Context,
tb: &str,
ikb: &IndexKeyBase,
p: &HnswParams,
) -> Result<SharedHnswIndex, Error> {
let key = ikb.new_vm_key(None);
let r = self.0.read().await;
let h = r.get(&key).cloned();
drop(r);
let h = self.0.read().await.get(&key).cloned();
if let Some(h) = h {
return h;
return Ok(h);
}
let mut w = self.0.write().await;
let ix = match w.entry(key) {
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(e) => {
let h = Arc::new(RwLock::new(HnswIndex::new(p)));
let h = Arc::new(RwLock::new(
HnswIndex::new(&ctx.tx(), ikb.clone(), tb.to_string(), p).await?,
));
e.insert(h.clone());
h
}
};
drop(w);
ix
Ok(ix)
}
pub(super) async fn remove(&self, ikb: &IndexKeyBase) {
let key = ikb.new_vm_key(None);
let mut w = self.0.write().await;
w.remove(&key);
drop(w);
self.0.write().await.remove(&key);
}
pub(super) async fn is_empty(&self) -> bool {
let h = self.0.read().await;
let r = h.is_empty();
drop(h);
r
self.0.read().await.is_empty()
}
}

View file

@ -278,12 +278,13 @@ impl IndexStores {
pub(crate) async fn get_index_hnsw(
&self,
ctx: &Context,
opt: &Options,
ix: &DefineIndexStatement,
p: &HnswParams,
) -> Result<SharedHnswIndex, Error> {
let ikb = IndexKeyBase::new(opt.ns()?, opt.db()?, ix)?;
Ok(self.0.hnsw_indexes.get(&ikb, p).await)
self.0.hnsw_indexes.get(ctx, &ix.what, &ikb, p).await
}
pub(crate) async fn index_removed(

View file

@ -1,5 +1,6 @@
use crate::err::Error;
use crate::fnc::util::math::ToFloat;
use crate::idx::VersionedStore;
use crate::sql::index::{Distance, VectorType};
use crate::sql::{Number, Value};
use ahash::AHasher;
@ -17,7 +18,7 @@ use std::ops::{Add, Deref, Div, Sub};
use std::sync::Arc;
/// In the context of a Symmetric MTree index, the term object refers to a vector, representing the indexed item.
#[derive(Debug, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum Vector {
F64(Array1<f64>),
@ -28,9 +29,9 @@ pub enum Vector {
}
#[revisioned(revision = 1)]
#[derive(Serialize, Deserialize)]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
enum SerializedVector {
pub enum SerializedVector {
F64(Vec<f64>),
F32(Vec<f32>),
I64(Vec<i64>),
@ -38,6 +39,8 @@ enum SerializedVector {
I16(Vec<i16>),
}
impl VersionedStore for SerializedVector {}
impl From<&Vector> for SerializedVector {
fn from(value: &Vector) -> Self {
match value {
@ -395,6 +398,27 @@ impl Hash for Vector {
}
}
#[cfg(test)]
impl SharedVector {
pub(crate) fn clone_vector(&self) -> Vector {
self.0.as_ref().clone()
}
}
#[cfg(test)]
impl From<&Vector> for Value {
fn from(v: &Vector) -> Self {
let vec: Vec<Number> = match v {
Vector::F64(a) => a.iter().map(|i| Number::Float(*i)).collect(),
Vector::F32(a) => a.iter().map(|i| Number::Float(*i as f64)).collect(),
Vector::I64(a) => a.iter().map(|i| Number::Int(*i)).collect(),
Vector::I32(a) => a.iter().map(|i| Number::Int(*i as i64)).collect(),
Vector::I16(a) => a.iter().map(|i| Number::Int(*i as i64)).collect(),
};
Value::from(vec)
}
}
impl Vector {
pub(super) fn try_from_value(t: VectorType, d: usize, v: &Value) -> Result<Self, Error> {
let res = match t {

View file

@ -124,6 +124,14 @@ pub enum Category {
IndexBTreeNodeTerms,
/// crate::key::index::bu /*{ns}*{db}*{tb}+{ix}!bu{id}
IndexTerms,
/// crate::key::index::he /*{ns}*{db}*{tb}+{ix}!he{id}
IndexHnswElements,
/// crate::key::index::hd /*{ns}*{db}*{tb}+{ix}!hd{id}
IndexHnswDocIds,
/// crate::key::index::hi /*{ns}*{db}*{tb}+{ix}!hi{id}
IndexHnswThings,
/// crate::key::index::hv /*{ns}*{db}*{tb}+{ix}!hv{vec}
IndexHnswVec,
/// crate::key::index /*{ns}*{db}*{tb}+{ix}*{fd}{id}
Index,
///
@ -194,6 +202,10 @@ impl Display for Category {
Self::IndexFullTextState => "IndexFullTextState",
Self::IndexBTreeNodeTerms => "IndexBTreeNodeTerms",
Self::IndexTerms => "IndexTerms",
Self::IndexHnswElements => "IndexHnswElements",
Self::IndexHnswDocIds => "IndexHnswDocIds",
Self::IndexHnswThings => "IndexHnswThings",
Self::IndexHnswVec => "IndexHnswVec",
Self::Index => "Index",
Self::ChangeFeed => "ChangeFeed",
Self::Thing => "Thing",

64
core/src/key/index/hd.rs Normal file
View file

@ -0,0 +1,64 @@
//! Stores the DocIds -> Thing of an HNSW index
use crate::idx::docids::DocId;
use derive::Key;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Key)]
#[non_exhaustive]
pub struct Hd<'a> {
__: u8,
_a: u8,
pub ns: &'a str,
_b: u8,
pub db: &'a str,
_c: u8,
pub tb: &'a str,
_d: u8,
pub ix: &'a str,
_e: u8,
_f: u8,
_g: u8,
pub doc_id: Option<DocId>,
}
impl<'a> Hd<'a> {
pub fn new(ns: &'a str, db: &'a str, tb: &'a str, ix: &'a str, doc_id: Option<DocId>) -> Self {
Self {
__: b'/',
_a: b'*',
ns,
_b: b'*',
db,
_c: b'*',
tb,
_d: b'+',
ix,
_e: b'!',
_f: b'h',
_g: b'd',
doc_id,
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn key() {
use super::*;
#[rustfmt::skip]
let val = Hd::new(
"testns",
"testdb",
"testtb",
"testix",
Some(7)
);
let enc = Hd::encode(&val).unwrap();
assert_eq!(enc, b"/*testns\0*testdb\0*testtb\0+testix\0!hd\x01\0\0\0\0\0\0\0\x07");
let dec = Hd::decode(&enc).unwrap();
assert_eq!(val, dec);
}
}

64
core/src/key/index/he.rs Normal file
View file

@ -0,0 +1,64 @@
//! Stores Vector of an HNSW index
use crate::idx::trees::hnsw::ElementId;
use derive::Key;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Key)]
#[non_exhaustive]
pub struct He<'a> {
__: u8,
_a: u8,
pub ns: &'a str,
_b: u8,
pub db: &'a str,
_c: u8,
pub tb: &'a str,
_d: u8,
pub ix: &'a str,
_e: u8,
_f: u8,
_g: u8,
pub element_id: ElementId,
}
impl<'a> He<'a> {
pub fn new(ns: &'a str, db: &'a str, tb: &'a str, ix: &'a str, element_id: ElementId) -> Self {
Self {
__: b'/',
_a: b'*',
ns,
_b: b'*',
db,
_c: b'*',
tb,
_d: b'+',
ix,
_e: b'!',
_f: b'h',
_g: b'e',
element_id,
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn key() {
use super::*;
#[rustfmt::skip]
let val = He::new(
"testns",
"testdb",
"testtb",
"testix",
7
);
let enc = He::encode(&val).unwrap();
assert_eq!(enc, b"/*testns\0*testdb\0*testtb\0+testix\0!he\0\0\0\0\0\0\0\x07");
let dec = He::decode(&enc).unwrap();
assert_eq!(val, dec);
}
}

62
core/src/key/index/hi.rs Normal file
View file

@ -0,0 +1,62 @@
//! Stores Things of an HNSW index
use crate::sql::Id;
use derive::Key;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Key)]
#[non_exhaustive]
pub struct Hi<'a> {
__: u8,
_a: u8,
pub ns: &'a str,
_b: u8,
pub db: &'a str,
_c: u8,
pub tb: &'a str,
_d: u8,
pub ix: &'a str,
_e: u8,
_f: u8,
_g: u8,
pub id: Id,
}
impl<'a> Hi<'a> {
pub fn new(ns: &'a str, db: &'a str, tb: &'a str, ix: &'a str, id: Id) -> Self {
Self {
__: b'/',
_a: b'*',
ns,
_b: b'*',
db,
_c: b'*',
tb,
_d: b'+',
ix,
_e: b'!',
_f: b'h',
_g: b'i',
id,
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn key() {
use super::*;
let val = Hi::new("testns", "testdb", "testtb", "testix", Id::String("testid".to_string()));
let enc = Hi::encode(&val).unwrap();
assert_eq!(
enc,
b"/*testns\0*testdb\0*testtb\0+testix\0!hi\0\0\0\x01testid\0",
"{}",
String::from_utf8_lossy(&enc)
);
let dec = Hi::decode(&enc).unwrap();
assert_eq!(val, dec);
}
}

64
core/src/key/index/hl.rs Normal file
View file

@ -0,0 +1,64 @@
//! Store and chunked layers of an HNSW index
use derive::Key;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Key)]
#[non_exhaustive]
pub struct Hl<'a> {
__: u8,
_a: u8,
pub ns: &'a str,
_b: u8,
pub db: &'a str,
_c: u8,
pub tb: &'a str,
_d: u8,
pub ix: &'a str,
_e: u8,
_f: u8,
_g: u8,
pub layer: u16,
pub chunk: u32,
}
impl<'a> Hl<'a> {
pub fn new(ns: &'a str, db: &'a str, tb: &'a str, ix: &'a str, layer: u16, chunk: u32) -> Self {
Self {
__: b'/',
_a: b'*',
ns,
_b: b'*',
db,
_c: b'*',
tb,
_d: b'+',
ix,
_e: b'!',
_f: b'h',
_g: b'l',
layer,
chunk,
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn key() {
use super::*;
let val = Hl::new("testns", "testdb", "testtb", "testix", 7, 8);
let enc = Hl::encode(&val).unwrap();
assert_eq!(
enc,
b"/*testns\0*testdb\0*testtb\0+testix\0!hl\0\x07\0\0\0\x08",
"{}",
String::from_utf8_lossy(&enc)
);
let dec = Hl::decode(&enc).unwrap();
assert_eq!(val, dec);
}
}

60
core/src/key/index/hs.rs Normal file
View file

@ -0,0 +1,60 @@
//! Store state of an HNSW index
use derive::Key;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Key)]
#[non_exhaustive]
pub struct Hs<'a> {
__: u8,
_a: u8,
pub ns: &'a str,
_b: u8,
pub db: &'a str,
_c: u8,
pub tb: &'a str,
_d: u8,
pub ix: &'a str,
_e: u8,
_f: u8,
_g: u8,
}
impl<'a> Hs<'a> {
pub fn new(ns: &'a str, db: &'a str, tb: &'a str, ix: &'a str) -> Self {
Self {
__: b'/',
_a: b'*',
ns,
_b: b'*',
db,
_c: b'*',
tb,
_d: b'+',
ix,
_e: b'!',
_f: b'h',
_g: b's',
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn key() {
use super::*;
let val = Hs::new("testns", "testdb", "testtb", "testix");
let enc = Hs::encode(&val).unwrap();
assert_eq!(
enc,
b"/*testns\0*testdb\0*testtb\0+testix\0!hs",
"{}",
String::from_utf8_lossy(&enc)
);
let dec = Hs::decode(&enc).unwrap();
assert_eq!(val, dec);
}
}

76
core/src/key/index/hv.rs Normal file
View file

@ -0,0 +1,76 @@
//! Stores Things of an HNSW index
use crate::idx::trees::vector::SerializedVector;
use derive::Key;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Key)]
#[non_exhaustive]
pub struct Hv<'a> {
__: u8,
_a: u8,
pub ns: &'a str,
_b: u8,
pub db: &'a str,
_c: u8,
pub tb: &'a str,
_d: u8,
pub ix: &'a str,
_e: u8,
_f: u8,
_g: u8,
pub vec: Arc<SerializedVector>,
}
impl<'a> Hv<'a> {
pub fn new(
ns: &'a str,
db: &'a str,
tb: &'a str,
ix: &'a str,
vec: Arc<SerializedVector>,
) -> Self {
Self {
__: b'/',
_a: b'*',
ns,
_b: b'*',
db,
_c: b'*',
tb,
_d: b'+',
ix,
_e: b'!',
_f: b'h',
_g: b'v',
vec,
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn key() {
use super::*;
let val = Hv::new(
"testns",
"testdb",
"testtb",
"testix",
Arc::new(SerializedVector::I16(vec![2])),
);
let enc = Hv::encode(&val).unwrap();
assert_eq!(
enc,
b"/*testns\0*testdb\0*testtb\0+testix\0!hv\0\0\0\x04\x80\x02\x01",
"{}",
String::from_utf8_lossy(&enc)
);
let dec = Hv::decode(&enc).unwrap();
assert_eq!(val, dec);
}
}

View file

@ -11,6 +11,12 @@ pub mod bp;
pub mod bs;
pub mod bt;
pub mod bu;
pub mod hd;
pub mod he;
pub mod hi;
pub mod hl;
pub mod hs;
pub mod hv;
pub mod vm;
use crate::key::category::Categorise;

View file

@ -240,10 +240,39 @@ impl Datastore {
/// # Ok(())
/// # }
/// ```
pub async fn new(path: &str) -> Result<Datastore, Error> {
pub async fn new(path: &str) -> Result<Self, Error> {
Self::new_with_clock(path, None).await
}
#[cfg(debug_assertions)]
/// Create a new datastore with the same persistent data (inner), with flushed cache.
/// Simulating a server restart
pub fn restart(self) -> Self {
Self {
id: self.id,
strict: self.strict,
auth_enabled: self.auth_enabled,
query_timeout: self.query_timeout,
transaction_timeout: self.transaction_timeout,
capabilities: self.capabilities,
notification_channel: self.notification_channel,
index_stores: Default::default(),
#[cfg(not(target_arch = "wasm32"))]
index_builder: IndexBuilder::new(self.transaction_factory.clone()),
#[cfg(feature = "jwks")]
jwks_cache: Arc::new(Default::default()),
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
))]
temporary_directory: self.temporary_directory,
transaction_factory: self.transaction_factory,
}
}
#[allow(unused_variables)]
pub async fn new_with_clock(
path: &str,

View file

@ -4,6 +4,7 @@ use crate::dbs::Options;
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::sql::{escape::escape_rid, Array, Number, Object, Strand, Thing, Uuid, Value};
use derive::Key;
use nanoid::nanoid;
use reblessive::tree::Stk;
use revision::revisioned;
@ -23,7 +24,7 @@ pub enum Gen {
}
#[revisioned(revision = 1)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)]
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Key, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
pub enum Id {

View file

@ -0,0 +1,8 @@
INIT Init
NEXT Next
CONSTANT
Nodes = {"Node1", "Node2"}
Clients = {"Client1", "Client2", "Client3"}
MaxVersion = 5
MaxWrites = 10

View file

@ -0,0 +1,94 @@
---- MODULE versioned_index ----
EXTENDS Naturals, FiniteSets, Sequences
CONSTANTS Nodes, Clients, MaxVersion, MaxWrites
VARIABLES remoteIndex, remoteVersion, localIndex, localVersion, clientVersion, clientOps
(*
The `Init` predicate defines the initial state of the system.
- remoteIndex: Initially, the remote index is set to 0.
- remoteVersion: Initially, the remote version is set to 0.
- localIndex: Each node's local index is initially set to 0.
- localVersion: Each node's local version is initially set to 0.
- clientVersion: Each client's version is initially set to 0.
- clientOps: Each client's operation count is initially set to 0.
*)
Init ==
/\ remoteIndex = 0
/\ remoteVersion = 0
/\ localIndex = [n \in Nodes |-> 0]
/\ localVersion = [n \in Nodes |-> 0]
/\ clientVersion = [c \in Clients |-> 0]
/\ clientOps = [c \in Clients |-> 0]
(*
The `UpdateToLatest` action updates the local index and version to the latest remote version if outdated.
*)
UpdateToLatest(n) ==
/\ localVersion[n] < remoteVersion
/\ localIndex' = [localIndex EXCEPT ![n] = remoteIndex]
/\ localVersion' = [localVersion EXCEPT ![n] = remoteVersion]
/\ UNCHANGED <<remoteIndex, remoteVersion, clientVersion, clientOps>>
(*
The `Read` action represents a node reading the remote index.
- If the local version is outdated, updates the local index and version.
- If the local version is up-to-date, reads the value from the local index.
- Sets the client's version to the local version.
- UNCHANGED <<remoteIndex, remoteVersion>>: These remain unchanged.
*)
Read(n, c) ==
/\ (localVersion[n] < remoteVersion => UpdateToLatest(n))
/\ UNCHANGED <<localIndex, localVersion, remoteIndex, remoteVersion, clientOps>>
/\ clientVersion' = [clientVersion EXCEPT ![c] = localVersion[n]]
(*
The `Write` action represents a node writing a new index to the remote index.
- Ensures the local index and version are up-to-date.
- If the local version is up-to-date, writes the local index, increments the version, and updates the remote index and version.
- Sets the client's version to the new local version.
- Increments the operation count for the client.
*)
Write(n, c, newIndex) ==
/\ clientOps[c] < MaxWrites
/\ remoteVersion < MaxVersion (* Ensure the remote version does not exceed the maximum allowed version *)
/\ (localVersion[n] < remoteVersion => UpdateToLatest(n)) (* Update if the local version is outdated *)
/\ localIndex' = [localIndex EXCEPT ![n] = newIndex] (* Update the local index with the new index *)
/\ localVersion' = [localVersion EXCEPT ![n] = localVersion[n] + 1] (* Increment the local version *)
/\ remoteIndex' = newIndex (* Update the remote index with the new index *)
/\ remoteVersion' = localVersion[n] + 1 (* Increment the remote version *)
/\ clientVersion' = [clientVersion EXCEPT ![c] = localVersion[n] + 1] (* Update the client version *)
/\ clientOps' = [clientOps EXCEPT ![c] = clientOps[c] + 1]
(*
The `Client` action simulates multiple clients calling Read and Write and collecting the returned version.
- Ensures subsequent calls get identical or larger versions.
*)
Client ==
\E c \in Clients:
clientOps[c] < MaxWrites /\
\E n \in Nodes:
(clientOps[c] < MaxWrites => (Read(n, c) \/ \E newIndex \in 0..MaxVersion: Write(n, c, newIndex)))
(*
The `Next` relation defines the possible state transitions in the system.
- Includes the `Client` action.
*)
Next ==
Client
(*
The `Invariant` defines a property that must always hold.
- The local version of any node must be at least as recent as the remote version.
- The client version must be non-decreasing.
- Each client's operation count must not exceed the maximum allowed operations.
*)
Invariant ==
/\ \A n \in Nodes: localVersion[n] >= remoteVersion
/\ \A c \in Clients: clientVersion[c] <= remoteVersion
/\ \A n \in Nodes: localIndex[n] <= remoteIndex
====

View file

@ -2,7 +2,7 @@ use criterion::measurement::WallTime;
use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion, Throughput};
use std::collections::HashSet;
use std::time::{Duration, SystemTime};
use surrealdb_core::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
use surrealdb_core::idx::trees::dynamicset::{AHashSet, ArraySet, DynamicSet};
fn bench_hashset(samples_vec: &Vec<Vec<u64>>) {
for samples in samples_vec {
@ -19,7 +19,7 @@ fn bench_hashset(samples_vec: &Vec<Vec<u64>>) {
fn bench_hashbrown(samples_vec: &Vec<Vec<u64>>) {
for samples in samples_vec {
let mut h = HashBrownSet::<u64>::with_capacity(samples.len());
let mut h = AHashSet::<u64>::with_capacity(samples.len());
for &s in samples {
h.insert(s);
}

View file

@ -1,15 +1,19 @@
use criterion::measurement::WallTime;
use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion, Throughput};
use flate2::read::GzDecoder;
use futures::executor::block_on;
use reblessive::TreeStack;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::time::Duration;
use surrealdb::sql::index::Distance;
use surrealdb_core::dbs::Session;
use surrealdb_core::idx::planner::checker::HnswConditionChecker;
use surrealdb_core::idx::planner::checker::{HnswChecker, HnswConditionChecker};
use surrealdb_core::idx::trees::hnsw::index::HnswIndex;
use surrealdb_core::kvs::Datastore;
use surrealdb_core::idx::IndexKeyBase;
use surrealdb_core::kvs::LockType::Optimistic;
use surrealdb_core::kvs::TransactionType::{Read, Write};
use surrealdb_core::kvs::{Datastore, Transaction};
use surrealdb_core::sql::index::{HnswParams, VectorType};
use surrealdb_core::sql::{value, Array, Id, Number, Thing, Value};
use tokio::runtime::{Builder, Runtime};
@ -38,13 +42,13 @@ fn bench_hnsw_no_db(c: &mut Criterion) {
let mut group = get_group(c, GROUP_NAME, samples.len(), 10);
let id = format!("insert len: {}", samples.len());
group.bench_function(id, |b| {
b.iter(|| insert_objects(&samples));
b.to_async(Runtime::new().unwrap()).iter(|| insert_objects(&samples));
});
group.finish();
}
// Create an HNSW instance with data
let hnsw = insert_objects(&samples);
let (ds, hnsw) = block_on(insert_objects(&samples));
let samples = new_vectors_from_file(QUERYING_SOURCE);
let samples: Vec<Vec<Number>> =
@ -55,7 +59,7 @@ fn bench_hnsw_no_db(c: &mut Criterion) {
let mut group = get_group(c, GROUP_NAME, samples.len(), 10);
let id = format!("lookup len: {}", samples.len());
group.bench_function(id, |b| {
b.to_async(Runtime::new().unwrap()).iter(|| knn_lookup_objects(&hnsw, &samples));
b.to_async(Runtime::new().unwrap()).iter(|| knn_lookup_objects(&ds, &hnsw, &samples));
});
group.finish();
}
@ -74,7 +78,6 @@ fn bench_hnsw_with_db(c: &mut Criterion) {
{
let mut group = get_group(c, GROUP_NAME, samples.len(), 10);
let id = format!("insert len: {}", samples.len());
group.bench_function(id, |b| {
b.to_async(Runtime::new().unwrap()).iter(|| insert_objects_db(session, true, &samples));
});
@ -200,7 +203,7 @@ async fn init_datastore(session: &Session, with_index: bool) -> Datastore {
ds
}
fn hnsw() -> HnswIndex {
async fn hnsw(tx: &Transaction) -> HnswIndex {
let p = HnswParams::new(
DIMENSION,
Distance::Euclidean,
@ -212,15 +215,18 @@ fn hnsw() -> HnswIndex {
false,
false,
);
HnswIndex::new(&p)
HnswIndex::new(tx, IndexKeyBase::default(), "test".to_string(), &p).await.unwrap()
}
fn insert_objects(samples: &[(Thing, Vec<Value>)]) -> HnswIndex {
let mut h = hnsw();
for (id, content) in samples {
h.index_document(id, content).unwrap();
async fn insert_objects(samples: &[(Thing, Vec<Value>)]) -> (Datastore, HnswIndex) {
let ds = Datastore::new("memory").await.unwrap();
let tx = ds.transaction(Write, Optimistic).await.unwrap();
let mut h = hnsw(&tx).await;
for (thg, content) in samples {
h.index_document(&tx, thg.id.clone(), content).await.unwrap();
}
h
tx.commit().await.unwrap();
(ds, h)
}
async fn insert_objects_db(session: &Session, create_index: bool, inserts: &[String]) -> Datastore {
@ -231,13 +237,21 @@ async fn insert_objects_db(session: &Session, create_index: bool, inserts: &[Str
ds
}
async fn knn_lookup_objects(h: &HnswIndex, samples: &[Vec<Number>]) {
async fn knn_lookup_objects(ds: &Datastore, h: &HnswIndex, samples: &[Vec<Number>]) {
let mut stack = TreeStack::new();
stack
.enter(|stk| async {
let tx = ds.transaction(Read, Optimistic).await.unwrap();
for v in samples {
let r = h
.knn_search(v, NN, EF_SEARCH, stk, HnswConditionChecker::default())
.knn_search(
&tx,
stk,
v,
NN,
EF_SEARCH,
HnswConditionChecker::Hnsw(HnswChecker {}),
)
.await
.unwrap();
assert_eq!(r.len(), NN);

View file

@ -22,21 +22,21 @@ use tokio::task;
fn bench_index_mtree_combinations(c: &mut Criterion) {
for (samples, dimension, cache) in [
(2500, 3, 100),
(2500, 3, 2500),
(2500, 3, 0),
(1000, 50, 100),
(1000, 50, 1000),
(1000, 50, 0),
(500, 300, 100),
(500, 300, 500),
(500, 300, 0),
(250, 1024, 75),
(250, 1024, 250),
(250, 1024, 0),
(100, 2048, 50),
(100, 2048, 100),
(100, 2048, 0),
(1000, 3, 100),
(1000, 3, 1000),
(1000, 3, 0),
(300, 50, 100),
(300, 50, 300),
(300, 50, 0),
(150, 300, 50),
(150, 300, 150),
(150, 300, 0),
(75, 1024, 25),
(75, 1024, 75),
(75, 1024, 0),
(50, 2048, 20),
(50, 2048, 50),
(50, 2048, 0),
] {
bench_index_mtree(c, samples, dimension, cache);
}

View file

@ -243,8 +243,7 @@ impl Test {
/// Arguments `sql` - A string slice representing the SQL query.
/// Panics if an error occurs.
#[allow(dead_code)]
pub async fn new(sql: &str) -> Result<Self, Error> {
let ds = new_ds().await?;
pub async fn with_ds(ds: Datastore, sql: &str) -> Result<Self, Error> {
let session = Session::owner().with_ns("test").with_db("test");
let responses = ds.execute(sql, &session, None).await?;
Ok(Self {
@ -255,6 +254,17 @@ impl Test {
})
}
pub async fn new(sql: &str) -> Result<Self, Error> {
Self::with_ds(new_ds().await?, sql).await
}
/// Simulates restarting the Datastore
/// - Data are persistent (including memory store)
/// - Flushing caches (jwks, IndexStore, ...)
pub async fn restart(self, sql: &str) -> Result<Self, Error> {
Self::with_ds(self.ds.restart(), sql).await
}
/// Checks if the number of responses matches the expected size.
/// Panics if the number of responses does not match the expected size
#[allow(dead_code)]
@ -412,14 +422,3 @@ impl Test {
Ok(self)
}
}
impl Drop for Test {
/// Drops the instance of the struct
/// This method will panic if there are remaining responses that have not been checked.
fn drop(&mut self) {
// Check for a panic to make sure test doesnt cause a double panic.
if !std::thread::panicking() && !self.responses.is_empty() {
panic!("Not every response has been checked");
}
}
}

View file

@ -520,3 +520,66 @@ async fn select_bruteforce_knn_with_condition() -> Result<(), Error> {
//
Ok(())
}
#[tokio::test]
async fn check_hnsw_persistence() -> Result<(), Error> {
let sql = r"
CREATE pts:1 SET point = [1,2,3,4];
CREATE pts:2 SET point = [4,5,6,7];
CREATE pts:4 SET point = [12,13,14,15];
DEFINE INDEX hnsw_pts ON pts FIELDS point HNSW DIMENSION 4 DIST EUCLIDEAN TYPE F32 EFC 500 M 12;
CREATE pts:3 SET point = [8,9,10,11];
SELECT id, vector::distance::knn() AS dist FROM pts WHERE point <|2,100|> [2,3,4,5];
DELETE pts:4;
SELECT id, vector::distance::knn() AS dist FROM pts WHERE point <|2,100|> [2,3,4,5];
";
// Ingest the data in the datastore.
let mut t = Test::new(sql).await?;
t.skip_ok(5)?;
t.expect_val(
"[
{
id: pts:1,
dist: 2f
},
{
id: pts:2,
dist: 4f
}
]",
)?;
t.skip_ok(1)?;
t.expect_val(
"[
{
id: pts:1,
dist: 2f
},
{
id: pts:2,
dist: 4f
}
]",
)?;
// Restart the datastore and execute the SELECT query
let sql =
"SELECT id, vector::distance::knn() AS dist FROM pts WHERE point <|2,100|> [2,3,4,5];";
let mut t = t.restart(sql).await?;
// We should find results
t.expect_val(
"[
{
id: pts:1,
dist: 2f
},
{
id: pts:2,
dist: 4f
}
]",
)?;
Ok(())
}