Feature: Initial Hnsw implementation ()

This commit is contained in:
Emmanuel Keller 2024-05-08 15:26:41 +01:00 committed by GitHub
parent 061ad8c712
commit 009486b2bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 3478 additions and 493 deletions

77
Cargo.lock generated
View file

@ -351,6 +351,15 @@ version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
[[package]]
name = "approx"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278"
dependencies = [
"num-traits",
]
[[package]]
name = "approx"
version = "0.5.1"
@ -1559,7 +1568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [
"cfg-if",
"hashbrown 0.14.3",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
@ -2346,7 +2355,7 @@ version = "0.7.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff16065e5720f376fbced200a5ae0f47ace85fd70b7e54269790281353b6d61"
dependencies = [
"approx",
"approx 0.5.1",
"arbitrary",
"num-traits",
"rstar 0.11.0",
@ -2481,12 +2490,13 @@ dependencies = [
[[package]]
name = "hashbrown"
version = "0.14.3"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash 0.8.11",
"allocator-api2",
"serde",
]
[[package]]
@ -2872,7 +2882,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [
"equivalent",
"hashbrown 0.14.3",
"hashbrown 0.14.5",
"serde",
]
@ -3172,6 +3182,18 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "linfa-linalg"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e7562b41c8876d3367897067013bb2884cc78e6893f092ecd26b305176ac82"
dependencies = [
"ndarray",
"num-traits",
"rand 0.8.5",
"thiserror",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.13"
@ -3232,7 +3254,7 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc"
dependencies = [
"hashbrown 0.14.3",
"hashbrown 0.14.5",
]
[[package]]
@ -3438,11 +3460,28 @@ version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"approx 0.4.0",
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
"serde",
]
[[package]]
name = "ndarray-stats"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af5a8477ac96877b5bd1fd67e0c28736c12943aba24eda92b127e036b0c8f400"
dependencies = [
"indexmap 1.9.3",
"itertools 0.10.5",
"ndarray",
"noisy_float",
"num-integer",
"num-traits",
"rand 0.8.5",
]
[[package]]
@ -3482,6 +3521,15 @@ dependencies = [
"libc",
]
[[package]]
name = "noisy_float"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af"
dependencies = [
"num-traits",
]
[[package]]
name = "nom"
version = "7.1.3"
@ -4370,7 +4418,7 @@ checksum = "b1380629287ed1247c1e0fcc6d43efdcec508b65382c9ab775cc8f3df7ca07b0"
dependencies = [
"ahash 0.8.11",
"equivalent",
"hashbrown 0.14.3",
"hashbrown 0.14.5",
"parking_lot",
]
@ -4382,7 +4430,7 @@ checksum = "347e1a588d1de074eeb3c00eadff93db4db65aeb62aee852b1efd0949fe65b6c"
dependencies = [
"ahash 0.8.11",
"equivalent",
"hashbrown 0.14.3",
"hashbrown 0.14.5",
"parking_lot",
]
@ -5631,7 +5679,7 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61addf9117b11d1f5b4bf6fe94242ba25f59d2d4b2080544b771bd647024fd00"
dependencies = [
"hashbrown 0.14.3",
"hashbrown 0.14.5",
"num-traits",
"robust",
"smallvec",
@ -5843,10 +5891,12 @@ dependencies = [
"criterion",
"dmp",
"env_logger 0.10.2",
"flate2",
"flume",
"futures",
"futures-concurrency",
"geo 0.27.0",
"hashbrown 0.14.5",
"indexmap 2.2.6",
"native-tls",
"once_cell",
@ -5911,20 +5961,25 @@ dependencies = [
"echodb",
"env_logger 0.10.2",
"ext-sort",
"flate2",
"foundationdb",
"fst",
"futures",
"fuzzy-matcher",
"geo 0.27.0",
"geo-types",
"hashbrown 0.14.5",
"hex",
"indxdb",
"ipnet",
"lexicmp",
"linfa-linalg",
"md-5",
"nanoid",
"ndarray",
"ndarray-stats",
"nom",
"num-traits",
"num_cpus",
"object_store",
"once_cell",
@ -6052,7 +6107,7 @@ dependencies = [
"crossbeam",
"crossbeam-channel",
"futures",
"hashbrown 0.14.3",
"hashbrown 0.14.5",
"lru",
"parking_lot",
"quick_cache 0.4.2",
@ -6971,7 +7026,7 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c07dbb7140d0a8fa046a25eabaaad09f4082076ccb8f157d9fe2b13dc9ae570"
dependencies = [
"hashbrown 0.14.3",
"hashbrown 0.14.5",
]
[[package]]

View file

@ -1309,3 +1309,27 @@ allow_apis = [
[pkg.phf_macros]
allow_proc_macro = true
[pkg.rawpointer]
allow_unsafe = true
[pkg.matrixmultiply]
allow_unsafe = true
[pkg.approx]
allow_unsafe = true
[pkg.num-complex]
allow_unsafe = true
[pkg.ndarray]
allow_unsafe = true
[pkg.ndarray-stats]
allow_unsafe = true
[pkg.noisy_float]
allow_unsafe = true
[pkg.linfa-linalg]
allow_unsafe = true

View file

@ -40,7 +40,7 @@ kv-fdb-7_1 = ["foundationdb/fdb-7_1", "kv-fdb", "dep:tempfile", "dep:ext-sort"]
kv-surrealkv = ["dep:surrealkv", "tokio/time", "dep:tempfile", "dep:ext-sort"]
scripting = ["dep:js"]
http = ["dep:reqwest"]
ml = ["dep:surrealml", "dep:ndarray"]
ml = ["dep:surrealml"]
jwks = ["dep:reqwest"]
arbitrary = [
"dep:arbitrary",
@ -86,6 +86,7 @@ futures = "0.3.29"
fuzzy-matcher = "0.3.7"
geo = { version = "0.27.0", features = ["use-serde"] }
geo-types = { version = "0.7.12", features = ["arbitrary"] }
hashbrown = { version = "0.14.5", features = ["serde"] }
hex = { version = "0.4.3" }
indxdb = { version = "0.4.0", optional = true }
ipnet = "2.9.0"
@ -102,9 +103,12 @@ js = { version = "0.6.2", package = "rquickjs", features = [
], optional = true }
jsonwebtoken = { version = "8.3.0-surreal.1", package = "surrealdb-jsonwebtoken" }
lexicmp = "0.1.0"
linfa-linalg = "=0.1.0"
md-5 = "0.10.6"
nanoid = "0.4.0"
ndarray = { version = "0.15.6", optional = true }
ndarray = { version = "=0.15.6", features = ["serde"] }
ndarray-stats = "=0.5.1"
num-traits = "0.2.18"
nom = { version = "7.1.3", features = ["alloc"] }
num_cpus = "1.16.0"
object_store = { version = "0.8.0", optional = false }
@ -152,6 +156,7 @@ url = "2.5.0"
[dev-dependencies]
criterion = { version = "0.5.1", features = ["async_tokio"] }
env_logger = "0.10.1"
flate2 = "1.0.28"
pprof = { version = "0.13.0", features = ["flamegraph", "criterion"] }
serial_test = "2.0.0"
temp-dir = "0.1.11"

View file

@ -9,7 +9,7 @@ use crate::idx::IndexKeyBase;
use crate::key;
use crate::kvs::TransactionType;
use crate::sql::array::Array;
use crate::sql::index::{Index, MTreeParams, SearchParams};
use crate::sql::index::{HnswParams, Index, MTreeParams, SearchParams};
use crate::sql::statements::DefineIndexStatement;
use crate::sql::{Part, Thing, Value};
use reblessive::tree::Stk;
@ -65,6 +65,7 @@ impl<'a> Document<'a> {
Index::Idx => ic.index_non_unique(txn).await?,
Index::Search(p) => ic.index_full_text(stk, ctx, txn, p).await?,
Index::MTree(p) => ic.index_mtree(stk, ctx, txn, p).await?,
Index::Hnsw(p) => ic.index_hnsw(ctx, p).await?,
};
}
}
@ -407,4 +408,18 @@ impl<'a> IndexOperation<'a> {
}
mt.finish(&mut tx).await
}
async fn index_hnsw(&mut self, ctx: &Context<'_>, p: &HnswParams) -> Result<(), Error> {
let hnsw = ctx.get_index_stores().get_index_hnsw(self.opt, self.ix, p).await;
let mut hnsw = hnsw.write().await;
// Delete the old index data
if let Some(o) = self.o.take() {
hnsw.remove_document(self.rid, &o)?;
}
// Create the new index data
if let Some(n) = self.n.take() {
hnsw.index_document(self.rid, &n)?;
}
Ok(())
}
}

View file

@ -10,8 +10,9 @@ use crate::idx::ft::terms::Terms;
use crate::idx::ft::{FtIndex, MatchRef};
use crate::idx::planner::iterators::{
DocIdsIterator, IndexEqualThingIterator, IndexJoinThingIterator, IndexRangeThingIterator,
IndexUnionThingIterator, MatchesThingIterator, ThingIterator, UniqueEqualThingIterator,
UniqueJoinThingIterator, UniqueRangeThingIterator, UniqueUnionThingIterator,
IndexUnionThingIterator, MatchesThingIterator, ThingIterator, ThingsIterator,
UniqueEqualThingIterator, UniqueJoinThingIterator, UniqueRangeThingIterator,
UniqueUnionThingIterator,
};
use crate::idx::planner::knn::KnnPriorityList;
use crate::idx::planner::plan::IndexOperator::Matches;
@ -19,6 +20,7 @@ use crate::idx::planner::plan::{IndexOperator, IndexOption, RangeValue};
use crate::idx::planner::tree::{IdiomPosition, IndexRef, IndexesMap};
use crate::idx::planner::{IterationStage, KnnSet};
use crate::idx::trees::mtree::MTreeIndex;
use crate::idx::trees::store::hnsw::SharedHnswIndex;
use crate::idx::IndexKeyBase;
use crate::kvs;
use crate::kvs::{Key, TransactionType};
@ -26,12 +28,14 @@ use crate::sql::index::{Distance, Index};
use crate::sql::statements::DefineIndexStatement;
use crate::sql::{Array, Expression, Idiom, Number, Object, Table, Thing, Value};
use reblessive::tree::Stk;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use tokio::sync::RwLock;
pub(super) type KnnEntry = (KnnPriorityList, Idiom, Arc<Vec<Number>>, Distance);
pub(super) type KnnExpressions = HashMap<Arc<Expression>, (u32, Idiom, Arc<Vec<Number>>, Distance)>;
pub(super) type AnnExpressions = HashMap<Arc<Expression>, (usize, Idiom, Arc<Vec<Number>>, usize)>;
#[derive(Clone)]
pub(crate) struct QueryExecutor(Arc<InnerQueryExecutor>);
@ -44,6 +48,7 @@ pub(super) struct InnerQueryExecutor {
it_entries: Vec<IteratorEntry>,
index_definitions: Vec<DefineIndexStatement>,
mt_entries: HashMap<Arc<Expression>, MtEntry>,
hnsw_entries: HashMap<Arc<Expression>, HnswEntry>,
knn_entries: HashMap<Arc<Expression>, KnnEntry>,
}
@ -91,6 +96,8 @@ impl InnerQueryExecutor {
let mut ft_map = HashMap::default();
let mut mt_map: HashMap<IndexRef, MTreeIndex> = HashMap::default();
let mut mt_entries = HashMap::default();
let mut hnsw_map: HashMap<IndexRef, SharedHnswIndex> = HashMap::default();
let mut hnsw_entries = HashMap::default();
let mut knn_entries = HashMap::with_capacity(knns.len());
// Create all the instances of FtIndex
@ -100,28 +107,27 @@ impl InnerQueryExecutor {
if let Some(idx_def) = im.definitions.get(ix_ref as usize) {
match &idx_def.index {
Index::Search(p) => {
let mut ft_entry = None;
if let Some(ft) = ft_map.get(&ix_ref) {
if ft_entry.is_none() {
ft_entry = FtEntry::new(stk, ctx, opt, txn, ft, io).await?;
let ft_entry = match ft_map.entry(ix_ref) {
Entry::Occupied(e) => {
FtEntry::new(stk, ctx, opt, txn, e.get(), io).await?
}
} else {
let ikb = IndexKeyBase::new(opt, idx_def);
let ft = FtIndex::new(
ctx.get_index_stores(),
opt,
txn,
p.az.as_str(),
ikb,
p,
TransactionType::Read,
)
.await?;
if ft_entry.is_none() {
ft_entry = FtEntry::new(stk, ctx, opt, txn, &ft, io).await?;
Entry::Vacant(e) => {
let ikb = IndexKeyBase::new(opt, idx_def);
let ft = FtIndex::new(
ctx.get_index_stores(),
opt,
txn,
p.az.as_str(),
ikb,
p,
TransactionType::Read,
)
.await?;
let fte = FtEntry::new(stk, ctx, opt, txn, &ft, io).await?;
e.insert(ft);
fte
}
ft_map.insert(ix_ref, ft);
}
};
if let Some(e) = ft_entry {
if let Matches(_, Some(mr)) = e.0.index_option.op() {
if mr_entries.insert(*mr, e.clone()).is_some() {
@ -136,25 +142,45 @@ impl InnerQueryExecutor {
Index::MTree(p) => {
if let IndexOperator::Knn(a, k) = io.op() {
let mut tx = txn.lock().await;
let entry = if let Some(mt) = mt_map.get(&ix_ref) {
MtEntry::new(&mut tx, mt, a, *k).await?
} else {
let ikb = IndexKeyBase::new(opt, idx_def);
let mt = MTreeIndex::new(
ctx.get_index_stores(),
&mut tx,
ikb,
p,
TransactionType::Read,
)
.await?;
let entry = MtEntry::new(&mut tx, &mt, a, *k).await?;
mt_map.insert(ix_ref, mt);
entry
let entry = match mt_map.entry(ix_ref) {
Entry::Occupied(e) => MtEntry::new(&mut tx, e.get(), a, *k).await?,
Entry::Vacant(e) => {
let ikb = IndexKeyBase::new(opt, idx_def);
let mt = MTreeIndex::new(
ctx.get_index_stores(),
&mut tx,
ikb,
p,
TransactionType::Read,
)
.await?;
let entry = MtEntry::new(&mut tx, &mt, a, *k).await?;
e.insert(mt);
entry
}
};
mt_entries.insert(exp, entry);
}
}
Index::Hnsw(p) => {
if let IndexOperator::Ann(a, n, ef) = io.op() {
let entry = match hnsw_map.entry(ix_ref) {
Entry::Occupied(e) => {
HnswEntry::new(e.get().clone(), a, *n, *ef).await?
}
Entry::Vacant(e) => {
let hnsw = ctx
.get_index_stores()
.get_index_hnsw(opt, idx_def, p)
.await;
let entry = HnswEntry::new(hnsw.clone(), a, *n, *ef).await?;
e.insert(hnsw);
entry
}
};
hnsw_entries.insert(exp, entry);
}
}
_ => {}
}
}
@ -172,6 +198,7 @@ impl InnerQueryExecutor {
it_entries: Vec::new(),
index_definitions: im.definitions,
mt_entries,
hnsw_entries,
knn_entries,
})
}
@ -290,6 +317,7 @@ impl QueryExecutor {
..
} => self.new_search_index_iterator(it_ref, io.clone()).await,
Index::MTree(_) => Ok(self.new_mtree_index_knn_iterator(it_ref)),
Index::Hnsw(_) => Ok(self.new_hnsw_index_ann_iterator(it_ref)),
}
} else {
Ok(None)
@ -385,7 +413,7 @@ impl QueryExecutor {
if let Some(IteratorEntry::Single(exp, ..)) = self.0.it_entries.get(it_ref as usize) {
if let Matches(_, _) = io.op() {
if let Some(fti) = self.0.ft_map.get(&io.ix_ref()) {
if let Some(fte) = self.0.exp_entries.get(exp.as_ref()) {
if let Some(fte) = self.0.exp_entries.get(exp) {
let it = MatchesThingIterator::new(fti, fte.0.terms_docs.clone()).await?;
return Ok(Some(ThingIterator::Matches(it)));
}
@ -408,6 +436,16 @@ impl QueryExecutor {
None
}
fn new_hnsw_index_ann_iterator(&self, it_ref: IteratorRef) -> Option<ThingIterator> {
if let Some(IteratorEntry::Single(exp, ..)) = self.0.it_entries.get(it_ref as usize) {
if let Some(he) = self.0.hnsw_entries.get(exp) {
let it = ThingsIterator::new(he.res.iter().map(|(thg, _)| thg.clone()).collect());
return Some(ThingIterator::Things(it));
}
}
None
}
async fn build_iterators(
&self,
opt: &Options,
@ -660,3 +698,17 @@ impl MtEntry {
})
}
}
#[derive(Clone)]
pub(super) struct HnswEntry {
res: VecDeque<(Thing, f64)>,
}
impl HnswEntry {
async fn new(h: SharedHnswIndex, a: &Array, n: usize, ef: usize) -> Result<Self, Error> {
let res = h.read().await.knn_search(a, n, ef)?;
Ok(Self {
res,
})
}
}

View file

@ -24,6 +24,7 @@ pub(crate) enum ThingIterator {
UniqueJoin(Box<UniqueJoinThingIterator>),
Matches(MatchesThingIterator),
Knn(DocIdsIterator),
Things(ThingsIterator),
}
impl ThingIterator {
@ -44,6 +45,7 @@ impl ThingIterator {
Self::Knn(i) => i.next_batch(tx, size, collector).await,
Self::IndexJoin(i) => Box::pin(i.next_batch(tx, size, collector)).await,
Self::UniqueJoin(i) => Box::pin(i.next_batch(tx, size, collector)).await,
Self::Things(i) => Ok(i.next_batch(size, collector)),
}
}
}
@ -687,3 +689,27 @@ impl DocIdsIterator {
Ok(count as usize)
}
}
pub(crate) struct ThingsIterator {
res: VecDeque<Thing>,
}
impl ThingsIterator {
pub(super) fn new(res: VecDeque<Thing>) -> Self {
Self {
res,
}
}
fn next_batch<T: ThingCollector>(&mut self, limit: u32, collector: &mut T) -> usize {
let mut count = 0;
while limit > count {
if let Some(thg) = self.res.pop_front() {
collector.add(thg, None);
count += 1;
} else {
break;
}
}
count as usize
}
}

View file

@ -3,8 +3,8 @@ use crate::idx::ft::MatchRef;
use crate::idx::planner::tree::{GroupRef, IdiomPosition, IndexRef, Node};
use crate::sql::statements::DefineIndexStatement;
use crate::sql::with::With;
use crate::sql::{Array, Idiom, Object};
use crate::sql::{Expression, Operator, Value};
use crate::sql::{Array, Expression, Idiom, Object};
use crate::sql::{Operator, Value};
use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::hash::Hash;
@ -180,6 +180,7 @@ pub(super) enum IndexOperator {
RangePart(Operator, Value),
Matches(String, Option<MatchRef>),
Knn(Array, u32),
Ann(Array, usize, usize),
}
impl IndexOption {
@ -256,6 +257,10 @@ impl IndexOption {
e.insert("operator", Value::from(format!("<{}>", k)));
e.insert("value", Value::Array(a.clone()));
}
IndexOperator::Ann(a, n, ef) => {
e.insert("operator", Value::from(format!("<{},{}>", n, ef)));
e.insert("value", Value::Array(a.clone()));
}
};
Value::from(e)
}

View file

@ -1,7 +1,7 @@
use crate::ctx::Context;
use crate::dbs::{Options, Transaction};
use crate::err::Error;
use crate::idx::planner::executor::KnnExpressions;
use crate::idx::planner::executor::{AnnExpressions, KnnExpressions};
use crate::idx::planner::plan::{IndexOperator, IndexOption};
use crate::kvs;
use crate::sql::index::{Distance, Index};
@ -60,6 +60,7 @@ struct TreeBuilder<'a> {
index_map: IndexesMap,
with_indexes: Vec<IndexRef>,
knn_expressions: KnnExpressions,
ann_expressions: AnnExpressions,
idioms_record_options: HashMap<Idiom, RecordOptions>,
group_sequence: GroupRef,
}
@ -98,6 +99,7 @@ impl<'a> TreeBuilder<'a> {
index_map: Default::default(),
with_indexes,
knn_expressions: Default::default(),
ann_expressions: Default::default(),
idioms_record_options: Default::default(),
group_sequence: 0,
}
@ -377,6 +379,7 @@ impl<'a> TreeBuilder<'a> {
..
} => Self::eval_matches_operator(op, n),
Index::MTree(_) => self.eval_indexed_knn(e, op, n, id)?,
Index::Hnsw(_) => self.eval_indexed_ann(e, op, n, id)?,
};
if let Some(op) = op {
let io = IndexOption::new(*ir, id.clone(), p, op);
@ -436,6 +439,27 @@ impl<'a> TreeBuilder<'a> {
Ok(None)
}
fn eval_indexed_ann(
&mut self,
exp: &Arc<Expression>,
op: &Operator,
nd: &Node,
id: &Idiom,
) -> Result<Option<IndexOperator>, Error> {
if let Operator::Ann(n, ef) = op {
if let Node::Computed(v) = nd {
let vec: Vec<Number> = v.as_ref().try_into()?;
let n = *n as usize;
let ef = *ef as usize;
self.ann_expressions.insert(exp.clone(), (n, id.clone(), Arc::new(vec), ef));
if let Value::Array(a) = v.as_ref() {
return Ok(Some(IndexOperator::Ann(a.clone(), n, ef)));
}
}
}
Ok(None)
}
fn eval_knn(&mut self, id: &Idiom, val: &Node, exp: &Arc<Expression>) -> Result<(), Error> {
if let Operator::Knn(k, d) = exp.operator() {
if let Node::Computed(v) = val {
@ -515,7 +539,7 @@ impl SchemaCache {
pub(super) type GroupRef = u16;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
#[derive(Debug, Clone, Eq, PartialEq)]
pub(super) enum Node {
Expression {
group: GroupRef,

View file

@ -4,10 +4,10 @@ use crate::idx::trees::store::{NodeId, StoreGeneration, StoredNode, TreeNode, Tr
use crate::idx::VersionedSerdeState;
use crate::kvs::{Key, Transaction, Val};
use crate::sql::{Object, Value};
#[cfg(debug_assertions)]
use hashbrown::HashSet;
use revision::{revisioned, Revisioned};
use serde::{Deserialize, Serialize};
#[cfg(debug_assertions)]
use std::collections::HashSet;
use std::collections::VecDeque;
use std::fmt::{Debug, Display, Formatter};
use std::io::Cursor;

View file

@ -0,0 +1,183 @@
use hashbrown::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
pub trait DynamicSet<T>: Debug + Send + Sync
where
T: Eq + Hash + Clone + Default + 'static + Send + Sync,
{
fn with_capacity(capacity: usize) -> Self;
fn insert(&mut self, v: T) -> bool;
fn contains(&self, v: &T) -> bool;
fn remove(&mut self, v: &T) -> bool;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
fn iter(&self) -> Box<dyn Iterator<Item = &T> + '_>;
}
#[derive(Debug)]
pub struct HashBrownSet<T>(HashSet<T>);
impl<T> DynamicSet<T> for HashBrownSet<T>
where
T: Eq + Hash + Clone + Default + Debug + 'static + Send + Sync,
{
#[inline]
fn with_capacity(capacity: usize) -> Self {
Self(HashSet::with_capacity(capacity))
}
#[inline]
fn insert(&mut self, v: T) -> bool {
self.0.insert(v)
}
#[inline]
fn contains(&self, v: &T) -> bool {
self.0.contains(v)
}
#[inline]
fn remove(&mut self, v: &T) -> bool {
self.0.remove(v)
}
#[inline]
fn len(&self) -> usize {
self.0.len()
}
#[inline]
fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
fn iter(&self) -> Box<dyn Iterator<Item = &T> + '_> {
Box::new(self.0.iter())
}
}
#[derive(Debug)]
pub struct ArraySet<T, const N: usize>
where
T: Eq + Hash + Clone + Default + 'static + Send + Sync,
{
array: [T; N],
size: usize,
}
impl<T, const N: usize> DynamicSet<T> for ArraySet<T, N>
where
T: Eq + Hash + Clone + Copy + Default + Debug + 'static + Send + Sync,
{
fn with_capacity(_capacity: usize) -> Self {
#[cfg(debug_assertions)]
assert!(_capacity <= N);
Self {
array: [T::default(); N],
size: 0,
}
}
#[inline]
fn insert(&mut self, v: T) -> bool {
if !self.contains(&v) {
self.array[self.size] = v;
self.size += 1;
true
} else {
false
}
}
#[inline]
fn contains(&self, v: &T) -> bool {
self.array[0..self.size].contains(v)
}
#[inline]
fn remove(&mut self, v: &T) -> bool {
if let Some(p) = self.array[0..self.size].iter().position(|e| e.eq(v)) {
self.array[p..].rotate_left(1);
self.size -= 1;
true
} else {
false
}
}
#[inline]
fn len(&self) -> usize {
self.size
}
#[inline]
fn is_empty(&self) -> bool {
self.size == 0
}
#[inline]
fn iter(&self) -> Box<dyn Iterator<Item = &T> + '_> {
Box::new(self.array[0..self.size].iter())
}
}
#[cfg(test)]
mod tests {
use crate::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
use hashbrown::HashSet;
fn test_dynamic_set<S: DynamicSet<usize>>(capacity: usize) {
let mut dyn_set = S::with_capacity(capacity);
let mut control = HashSet::new();
// Test insertions
for sample in 0..capacity {
assert_eq!(dyn_set.len(), control.len(), "{capacity} - {sample}");
let v: HashSet<usize> = dyn_set.iter().cloned().collect();
assert_eq!(v, control, "{capacity} - {sample}");
// We should not have the element yet
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
// The first insertion returns true
assert_eq!(dyn_set.insert(sample), true);
assert_eq!(dyn_set.contains(&sample), true, "{capacity} - {sample}");
// The second insertion returns false
assert_eq!(dyn_set.insert(sample), false);
assert_eq!(dyn_set.contains(&sample), true, "{capacity} - {sample}");
// We update the control structure
control.insert(sample);
}
// Test removals
for sample in 0..capacity {
// The first removal returns true
assert_eq!(dyn_set.remove(&sample), true);
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
// The second removal returns false
assert_eq!(dyn_set.remove(&sample), false);
assert_eq!(dyn_set.contains(&sample), false, "{capacity} - {sample}");
// We update the control structure
control.remove(&sample);
// The control structure and the dyn_set should be identical
assert_eq!(dyn_set.len(), control.len(), "{capacity} - {sample}");
let v: HashSet<usize> = dyn_set.iter().cloned().collect();
assert_eq!(v, control, "{capacity} - {sample}");
}
}
#[test]
fn test_dynamic_set_hash() {
for capacity in 1..50 {
test_dynamic_set::<HashBrownSet<usize>>(capacity);
}
}
#[test]
fn test_dynamic_set_array() {
test_dynamic_set::<ArraySet<usize, 1>>(1);
test_dynamic_set::<ArraySet<usize, 2>>(2);
test_dynamic_set::<ArraySet<usize, 4>>(4);
test_dynamic_set::<ArraySet<usize, 10>>(10);
test_dynamic_set::<ArraySet<usize, 20>>(20);
test_dynamic_set::<ArraySet<usize, 30>>(30);
}
}

184
core/src/idx/trees/graph.rs Normal file
View file

@ -0,0 +1,184 @@
use crate::idx::trees::dynamicset::DynamicSet;
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
#[cfg(test)]
use hashbrown::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
#[derive(Debug)]
pub(super) struct UndirectedGraph<T, S>
where
T: Eq + Hash + Clone + Copy + Default + 'static + Send + Sync,
S: DynamicSet<T>,
{
capacity: usize,
nodes: HashMap<T, S>,
}
impl<T, S> UndirectedGraph<T, S>
where
T: Eq + Hash + Clone + Copy + Default + 'static + Send + Sync,
S: DynamicSet<T>,
{
pub(super) fn new(capacity: usize) -> Self {
Self {
capacity,
nodes: HashMap::new(),
}
}
#[inline]
pub(super) fn new_edges(&self) -> S {
S::with_capacity(self.capacity)
}
#[inline]
pub(super) fn get_edges(&self, node: &T) -> Option<&S> {
self.nodes.get(node)
}
pub(super) fn add_empty_node(&mut self, node: T) -> bool {
if let Entry::Vacant(e) = self.nodes.entry(node) {
e.insert(S::with_capacity(self.capacity));
true
} else {
false
}
}
pub(super) fn add_node_and_bidirectional_edges(&mut self, node: T, edges: S) -> Vec<T> {
let mut r = Vec::with_capacity(edges.len());
for &e in edges.iter() {
self.nodes.entry(e).or_insert_with(|| S::with_capacity(self.capacity)).insert(node);
r.push(e);
}
self.nodes.insert(node, edges);
r
}
#[inline]
pub(super) fn set_node(&mut self, node: T, new_edges: S) {
self.nodes.insert(node, new_edges);
}
pub(super) fn remove_node_and_bidirectional_edges(&mut self, node: &T) -> Option<S> {
if let Some(edges) = self.nodes.remove(node) {
for edge in edges.iter() {
if let Some(edges_to_node) = self.nodes.get_mut(edge) {
edges_to_node.remove(node);
}
}
Some(edges)
} else {
None
}
}
}
#[cfg(test)]
impl<T, S> UndirectedGraph<T, S>
where
T: Eq + Hash + Clone + Copy + Default + 'static + Debug + Send + Sync,
S: DynamicSet<T>,
{
pub(in crate::idx::trees) fn len(&self) -> usize {
self.nodes.len()
}
pub(in crate::idx::trees) fn nodes(&self) -> &HashMap<T, S> {
&self.nodes
}
pub(in crate::idx::trees) fn check(&self, g: Vec<(T, Vec<T>)>) {
for (n, e) in g {
let edges: HashSet<T> = e.into_iter().collect();
let n_edges: Option<HashSet<T>> =
self.get_edges(&n).map(|e| e.iter().cloned().collect());
assert_eq!(n_edges, Some(edges), "{n:?}");
}
}
}
#[cfg(test)]
mod tests {
use crate::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
use crate::idx::trees::graph::UndirectedGraph;
fn test_undirected_graph<S: DynamicSet<i32>>(m_max: usize) {
// Graph creation
let mut g = UndirectedGraph::<i32, S>::new(m_max);
assert_eq!(g.capacity, 10);
// Adding an empty node
let res = g.add_empty_node(0);
assert!(res);
g.check(vec![(0, vec![])]);
// Adding the same node
let res = g.add_empty_node(0);
assert!(!res);
g.check(vec![(0, vec![])]);
// Adding a node with one edge
let mut e = g.new_edges();
e.insert(0);
let res = g.add_node_and_bidirectional_edges(1, e);
assert_eq!(res, vec![0]);
g.check(vec![(0, vec![1]), (1, vec![0])]);
// Adding a node with two edges
let mut e = g.new_edges();
e.insert(0);
e.insert(1);
let mut res = g.add_node_and_bidirectional_edges(2, e);
res.sort();
assert_eq!(res, vec![0, 1]);
g.check(vec![(0, vec![1, 2]), (1, vec![0, 2]), (2, vec![0, 1])]);
// Adding a node with two edges
let mut e = g.new_edges();
e.insert(1);
e.insert(2);
let mut res = g.add_node_and_bidirectional_edges(3, e);
res.sort();
assert_eq!(res, vec![1, 2]);
g.check(vec![(0, vec![1, 2]), (1, vec![0, 2, 3]), (2, vec![0, 1, 3]), (3, vec![1, 2])]);
// Change the edges of a node
let mut e = g.new_edges();
e.insert(0);
g.set_node(3, e);
g.check(vec![(0, vec![1, 2]), (1, vec![0, 2, 3]), (2, vec![0, 1, 3]), (3, vec![0])]);
// Remove a node
let res = g.remove_node_and_bidirectional_edges(&2);
assert_eq!(
res.map(|v| {
let mut v: Vec<i32> = v.iter().cloned().collect();
v.sort();
v
}),
Some(vec![0, 1, 3])
);
g.check(vec![(0, vec![1]), (1, vec![0, 3]), (3, vec![0])]);
// Remove again
let res = g.remove_node_and_bidirectional_edges(&2);
assert!(res.is_none());
// Set a non existing node
let mut e = g.new_edges();
e.insert(1);
g.set_node(2, e);
g.check(vec![(0, vec![1]), (1, vec![0, 3]), (2, vec![1]), (3, vec![0])]);
}
#[test]
fn test_undirected_graph_array() {
test_undirected_graph::<ArraySet<i32, 10>>(10);
}
#[test]
fn test_undirected_graph_hash() {
test_undirected_graph::<HashBrownSet<i32>>(10);
}
}

View file

@ -0,0 +1,175 @@
use crate::idx::trees::dynamicset::DynamicSet;
use crate::idx::trees::hnsw::layer::HnswLayer;
use crate::idx::trees::hnsw::{ElementId, HnswElements};
use crate::idx::trees::knn::DoublePriorityQueue;
use crate::idx::trees::vector::SharedVector;
use crate::sql::index::HnswParams;
#[derive(Debug)]
pub(super) enum Heuristic {
Standard,
Ext,
Keep,
ExtAndKeep,
}
impl From<&HnswParams> for Heuristic {
fn from(p: &HnswParams) -> Self {
if p.keep_pruned_connections {
if p.extend_candidates {
Self::ExtAndKeep
} else {
Self::Keep
}
} else if p.extend_candidates {
Self::Ext
} else {
Self::Standard
}
}
}
impl Heuristic {
pub(super) fn select<S>(
&self,
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
{
match self {
Self::Standard => Self::heuristic(elements, layer, c, res),
Self::Ext => Self::heuristic_ext(elements, layer, q_id, q_pt, c, res),
Self::Keep => Self::heuristic_keep(elements, layer, c, res),
Self::ExtAndKeep => Self::heuristic_ext_keep(elements, layer, q_id, q_pt, c, res),
}
}
fn heuristic<S>(
elements: &HnswElements,
layer: &HnswLayer<S>,
mut c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
{
let m_max = layer.m_max();
if c.len() <= m_max {
c.to_dynamic_set(res);
} else {
while let Some((e_dist, e_id)) = c.pop_first() {
if Self::is_closer(elements, e_dist, e_id, res) && res.len() == m_max {
break;
}
}
}
}
fn heuristic_keep<S>(
elements: &HnswElements,
layer: &HnswLayer<S>,
mut c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
{
let m_max = layer.m_max();
if c.len() <= m_max {
c.to_dynamic_set(res);
return;
}
let mut pruned = Vec::new();
while let Some((e_dist, e_id)) = c.pop_first() {
if Self::is_closer(elements, e_dist, e_id, res) {
if res.len() == m_max {
break;
}
} else {
pruned.push(e_id);
}
}
let n = m_max - res.len();
if n > 0 {
for e_id in pruned.drain(0..n) {
res.insert(e_id);
}
}
}
fn extend_candidates<S>(
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
c: &mut DoublePriorityQueue,
) where
S: DynamicSet<ElementId>,
{
let m_max = layer.m_max();
let mut ex = c.to_set();
let mut ext = Vec::with_capacity(m_max.min(c.len()));
for (_, e_id) in c.to_vec().into_iter() {
for &e_adj in layer.get_edges(&e_id).unwrap_or_else(|| unreachable!()).iter() {
if e_adj != q_id && ex.insert(e_adj) {
if let Some(d) = elements.get_distance(q_pt, &e_adj) {
ext.push((d, e_adj));
}
}
}
}
for (e_dist, e_id) in ext {
c.push(e_dist, e_id);
}
}
fn heuristic_ext<S>(
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
mut c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
{
Self::extend_candidates(elements, layer, q_id, q_pt, &mut c);
Self::heuristic(elements, layer, c, res)
}
fn heuristic_ext_keep<S>(
elements: &HnswElements,
layer: &HnswLayer<S>,
q_id: ElementId,
q_pt: &SharedVector,
mut c: DoublePriorityQueue,
res: &mut S,
) where
S: DynamicSet<ElementId>,
{
Self::extend_candidates(elements, layer, q_id, q_pt, &mut c);
Self::heuristic_keep(elements, layer, c, res)
}
fn is_closer<S>(elements: &HnswElements, e_dist: f64, e_id: ElementId, r: &mut S) -> bool
where
S: DynamicSet<ElementId>,
{
if let Some(current_vec) = elements.get_vector(&e_id) {
for r_id in r.iter() {
if let Some(r_dist) = elements.get_distance(current_vec, r_id) {
if e_dist > r_dist {
return false;
}
}
}
r.insert(e_id);
true
} else {
false
}
}
}

View file

@ -0,0 +1,246 @@
use crate::idx::trees::dynamicset::DynamicSet;
use crate::idx::trees::graph::UndirectedGraph;
use crate::idx::trees::hnsw::heuristic::Heuristic;
use crate::idx::trees::hnsw::{ElementId, HnswElements};
use crate::idx::trees::knn::DoublePriorityQueue;
use crate::idx::trees::vector::SharedVector;
use hashbrown::HashSet;
#[derive(Debug)]
pub(super) struct HnswLayer<S>
where
S: DynamicSet<ElementId>,
{
graph: UndirectedGraph<ElementId, S>,
m_max: usize,
}
impl<S> HnswLayer<S>
where
S: DynamicSet<ElementId>,
{
pub(super) fn new(m_max: usize) -> Self {
Self {
graph: UndirectedGraph::new(m_max + 1),
m_max,
}
}
pub(super) fn m_max(&self) -> usize {
self.m_max
}
pub(super) fn get_edges(&self, e_id: &ElementId) -> Option<&S> {
self.graph.get_edges(e_id)
}
pub(super) fn add_empty_node(&mut self, node: ElementId) -> bool {
self.graph.add_empty_node(node)
}
pub(super) fn search_single(
&self,
elements: &HnswElements,
q: &SharedVector,
ep_dist: f64,
ep_id: ElementId,
ef: usize,
) -> DoublePriorityQueue {
let visited = HashSet::from([ep_id]);
let candidates = DoublePriorityQueue::from(ep_dist, ep_id);
let w = candidates.clone();
self.search(elements, q, candidates, visited, w, ef)
}
pub(super) fn search_multi(
&self,
elements: &HnswElements,
q: &SharedVector,
candidates: DoublePriorityQueue,
ef: usize,
) -> DoublePriorityQueue {
let w = candidates.clone();
let visited = w.to_set();
self.search(elements, q, candidates, visited, w, ef)
}
pub(super) fn search_single_ignore_ep(
&self,
elements: &HnswElements,
q: &SharedVector,
ep_id: ElementId,
) -> Option<(f64, ElementId)> {
let visited = HashSet::from([ep_id]);
let candidates = DoublePriorityQueue::from(0.0, ep_id);
let w = candidates.clone();
let q = self.search(elements, q, candidates, visited, w, 1);
q.peek_first()
}
pub(super) fn search_multi_ignore_ep(
&self,
elements: &HnswElements,
q: &SharedVector,
ep_id: ElementId,
ef: usize,
) -> DoublePriorityQueue {
let visited = HashSet::from([ep_id]);
let candidates = DoublePriorityQueue::from(0.0, ep_id);
let w = DoublePriorityQueue::default();
self.search(elements, q, candidates, visited, w, ef)
}
pub(super) fn search(
&self,
elements: &HnswElements,
q: &SharedVector,
mut candidates: DoublePriorityQueue,
mut visited: HashSet<ElementId>,
mut w: DoublePriorityQueue,
ef: usize,
) -> DoublePriorityQueue {
let mut f_dist = if let Some(d) = w.peek_last_dist() {
d
} else {
return w;
};
while let Some((dist, doc)) = candidates.pop_first() {
if dist > f_dist {
break;
}
if let Some(neighbourhood) = self.graph.get_edges(&doc) {
for &e_id in neighbourhood.iter() {
if visited.insert(e_id) {
if let Some(e_pt) = elements.get_vector(&e_id) {
let e_dist = elements.distance(e_pt, q);
if e_dist < f_dist || w.len() < ef {
candidates.push(e_dist, e_id);
w.push(e_dist, e_id);
if w.len() > ef {
w.pop_last();
}
f_dist = w.peek_last_dist().unwrap(); // w can't be empty
}
}
}
}
}
}
w
}
pub(super) fn insert(
&mut self,
elements: &HnswElements,
heuristic: &Heuristic,
efc: usize,
q_id: ElementId,
q_pt: &SharedVector,
mut eps: DoublePriorityQueue,
) -> DoublePriorityQueue {
let w;
let mut neighbors = self.graph.new_edges();
{
w = self.search_multi(elements, q_pt, eps, efc);
eps = w.clone();
heuristic.select(elements, self, q_id, q_pt, w, &mut neighbors);
};
let neighbors = self.graph.add_node_and_bidirectional_edges(q_id, neighbors);
for e_id in neighbors {
let e_conn =
self.graph.get_edges(&e_id).unwrap_or_else(|| unreachable!("Element: {}", e_id));
if e_conn.len() > self.m_max {
if let Some(e_pt) = elements.get_vector(&e_id) {
let e_c = self.build_priority_list(elements, e_id, e_conn);
let mut e_new_conn = self.graph.new_edges();
heuristic.select(elements, self, e_id, e_pt, e_c, &mut e_new_conn);
#[cfg(debug_assertions)]
assert!(!e_new_conn.contains(&e_id));
self.graph.set_node(e_id, e_new_conn);
}
}
}
eps
}
fn build_priority_list(
&self,
elements: &HnswElements,
e_id: ElementId,
neighbors: &S,
) -> DoublePriorityQueue {
let mut w = DoublePriorityQueue::default();
if let Some(e_pt) = elements.get_vector(&e_id) {
for n_id in neighbors.iter() {
if let Some(n_pt) = elements.get_vector(n_id) {
let dist = elements.distance(e_pt, n_pt);
w.push(dist, *n_id);
}
}
}
w
}
pub(super) fn remove(
&mut self,
elements: &HnswElements,
heuristic: &Heuristic,
e_id: ElementId,
efc: usize,
) -> bool {
if let Some(f_ids) = self.graph.remove_node_and_bidirectional_edges(&e_id) {
for &q_id in f_ids.iter() {
if let Some(q_pt) = elements.get_vector(&q_id) {
let c = self.search_multi_ignore_ep(elements, q_pt, q_id, efc);
let mut neighbors = self.graph.new_edges();
heuristic.select(elements, self, q_id, q_pt, c, &mut neighbors);
#[cfg(debug_assertions)]
{
assert!(
!neighbors.contains(&q_id),
"!neighbors.contains(&q_id) - q_id: {q_id} - f_ids: {neighbors:?}"
);
assert!(
!neighbors.contains(&e_id),
"!neighbors.contains(&e_id) - e_id: {e_id} - f_ids: {neighbors:?}"
);
assert!(neighbors.len() < self.m_max);
}
self.graph.set_node(q_id, neighbors);
}
}
true
} else {
false
}
}
}
#[cfg(test)]
impl<S> HnswLayer<S>
where
S: DynamicSet<ElementId>,
{
pub(in crate::idx::trees::hnsw) fn check_props(&self, elements: &HnswElements) {
assert!(
self.graph.len() <= elements.elements.len(),
"{} - {}",
self.graph.len(),
elements.elements.len()
);
for (e_id, f_ids) in self.graph.nodes() {
assert!(
f_ids.len() <= self.m_max,
"Foreign list e_id: {e_id} - len = len({}) <= m_layer({})",
self.m_max,
f_ids.len(),
);
assert!(!f_ids.contains(e_id), "!f_ids.contains(e_id) - el: {e_id} - f_ids: {f_ids:?}");
assert!(
elements.elements.contains_key(e_id),
"h.elements.contains_key(e_id) - el: {e_id} - f_ids: {f_ids:?}"
);
}
}
}

View file

@ -0,0 +1,928 @@
mod heuristic;
mod layer;
use crate::err::Error;
use crate::idx::docids::DocId;
use crate::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
use crate::idx::trees::hnsw::heuristic::Heuristic;
use crate::idx::trees::hnsw::layer::HnswLayer;
use crate::idx::trees::knn::{DoublePriorityQueue, Ids64, KnnResult, KnnResultBuilder};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::kvs::Key;
use crate::sql::index::{Distance, HnswParams, VectorType};
use crate::sql::{Array, Thing, Value};
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
use radix_trie::Trie;
use rand::prelude::SmallRng;
use rand::{Rng, SeedableRng};
use roaring::RoaringTreemap;
use std::collections::VecDeque;
pub struct HnswIndex {
dim: usize,
vector_type: VectorType,
hnsw: Box<dyn HnswMethods>,
docs: HnswDocs,
vec_docs: HashMap<SharedVector, (Ids64, ElementId)>,
}
type ASet<const N: usize> = ArraySet<ElementId, N>;
type HSet = HashBrownSet<ElementId>;
impl HnswIndex {
pub fn new(p: &HnswParams) -> Self {
Self {
dim: p.dimension as usize,
vector_type: p.vector_type,
hnsw: Self::new_hnsw(p),
docs: HnswDocs::default(),
vec_docs: HashMap::default(),
}
}
fn new_hnsw(p: &HnswParams) -> Box<dyn HnswMethods> {
match p.m {
1..=4 => match p.m0 {
1..=8 => Box::new(Hnsw::<ASet<9>, ASet<5>>::new(p)),
9..=16 => Box::new(Hnsw::<ASet<17>, ASet<5>>::new(p)),
17..=24 => Box::new(Hnsw::<ASet<25>, ASet<5>>::new(p)),
_ => Box::new(Hnsw::<HSet, ASet<5>>::new(p)),
},
5..=8 => match p.m0 {
1..=16 => Box::new(Hnsw::<ASet<17>, ASet<9>>::new(p)),
17..=24 => Box::new(Hnsw::<ASet<25>, ASet<9>>::new(p)),
_ => Box::new(Hnsw::<HSet, ASet<9>>::new(p)),
},
9..=12 => match p.m0 {
17..=24 => Box::new(Hnsw::<ASet<25>, ASet<13>>::new(p)),
_ => Box::new(Hnsw::<HSet, ASet<13>>::new(p)),
},
13..=16 => Box::new(Hnsw::<HSet, ASet<17>>::new(p)),
17..=20 => Box::new(Hnsw::<HSet, ASet<21>>::new(p)),
21..=24 => Box::new(Hnsw::<HSet, ASet<25>>::new(p)),
25..=28 => Box::new(Hnsw::<HSet, ASet<29>>::new(p)),
_ => Box::new(Hnsw::<HSet, HSet>::new(p)),
}
}
pub fn index_document(&mut self, rid: &Thing, content: &Vec<Value>) -> Result<(), Error> {
// Resolve the doc_id
let doc_id = self.docs.resolve(rid);
// Index the values
for value in content {
// Extract the vector
let vector = Vector::try_from_value(self.vector_type, self.dim, value)?;
vector.check_dimension(self.dim)?;
self.insert(vector.into(), doc_id);
}
Ok(())
}
fn insert(&mut self, o: SharedVector, d: DocId) {
match self.vec_docs.entry(o) {
Entry::Occupied(mut e) => {
let (docs, element_id) = e.get_mut();
if let Some(new_docs) = docs.insert(d) {
let element_id = *element_id;
e.insert((new_docs, element_id));
}
}
Entry::Vacant(e) => {
let o = e.key().clone();
let element_id = self.hnsw.insert(o);
e.insert((Ids64::One(d), element_id));
}
}
}
fn remove(&mut self, o: SharedVector, d: DocId) {
if let Entry::Occupied(mut e) = self.vec_docs.entry(o) {
let (docs, e_id) = e.get_mut();
if let Some(new_docs) = docs.remove(d) {
let e_id = *e_id;
if new_docs.is_empty() {
e.remove();
self.hnsw.remove(e_id);
} else {
e.insert((new_docs, e_id));
}
}
}
}
pub(crate) fn remove_document(
&mut self,
rid: &Thing,
content: &Vec<Value>,
) -> Result<(), Error> {
if let Some(doc_id) = self.docs.remove(rid) {
for v in content {
// Extract the vector
let vector = Vector::try_from_value(self.vector_type, self.dim, v)?;
vector.check_dimension(self.dim)?;
// Remove the vector
self.remove(vector.into(), doc_id);
}
}
Ok(())
}
pub fn knn_search(
&self,
a: &Array,
n: usize,
ef: usize,
) -> Result<VecDeque<(Thing, f64)>, Error> {
// Extract the vector
let vector = Vector::try_from_array(self.vector_type, a)?;
vector.check_dimension(self.dim)?;
// Do the search
let res = self.search(&vector.into(), n, ef);
Ok(self.result(res))
}
fn result(&self, res: KnnResult) -> VecDeque<(Thing, f64)> {
res.docs
.into_iter()
.filter_map(|(doc_id, dist)| self.docs.get(doc_id).map(|t| (t.clone(), dist)))
.collect()
}
fn search(&self, o: &SharedVector, n: usize, ef: usize) -> KnnResult {
let neighbors = self.hnsw.knn_search(o, n, ef);
let mut builder = KnnResultBuilder::new(n);
for (e_dist, e_id) in neighbors {
if builder.check_add(e_dist) {
if let Some(v) = self.hnsw.get_vector(&e_id) {
if let Some((docs, _)) = self.vec_docs.get(v) {
builder.add(e_dist, docs);
}
}
}
}
builder.build(
#[cfg(debug_assertions)]
HashMap::new(),
)
}
}
#[derive(Default)]
struct HnswDocs {
doc_ids: Trie<Key, DocId>,
ids_doc: Vec<Option<Thing>>,
available: RoaringTreemap,
}
impl HnswDocs {
fn resolve(&mut self, rid: &Thing) -> DocId {
let doc_key: Key = rid.into();
if let Some(doc_id) = self.doc_ids.get(&doc_key) {
*doc_id
} else {
let doc_id = self.next_doc_id();
self.ids_doc.push(Some(rid.clone()));
self.doc_ids.insert(doc_key, doc_id);
doc_id
}
}
fn next_doc_id(&mut self) -> DocId {
if let Some(doc_id) = self.available.iter().next() {
self.available.remove(doc_id);
doc_id
} else {
self.ids_doc.len() as DocId
}
}
fn get(&self, doc_id: DocId) -> Option<Thing> {
if let Some(t) = self.ids_doc.get(doc_id as usize) {
t.clone()
} else {
None
}
}
fn remove(&mut self, rid: &Thing) -> Option<DocId> {
let doc_key: Key = rid.into();
if let Some(doc_id) = self.doc_ids.remove(&doc_key) {
let n = doc_id as usize;
if n < self.ids_doc.len() {
self.ids_doc[n] = None;
}
self.available.insert(doc_id);
Some(doc_id)
} else {
None
}
}
}
trait HnswMethods: Send + Sync {
fn insert(&mut self, q_pt: SharedVector) -> ElementId;
fn remove(&mut self, e_id: ElementId) -> bool;
fn knn_search(&self, q: &SharedVector, k: usize, efs: usize) -> Vec<(f64, ElementId)>;
fn get_vector(&self, e_id: &ElementId) -> Option<&SharedVector>;
#[cfg(test)]
fn check_hnsw_properties(&self, expected_count: usize);
}
#[cfg(test)]
fn check_hnsw_props<L0, L>(h: &Hnsw<L0, L>, expected_count: usize)
where
L0: DynamicSet<ElementId>,
L: DynamicSet<ElementId>,
{
assert_eq!(h.elements.elements.len(), expected_count);
for layer in h.layers.iter() {
layer.check_props(&h.elements);
}
}
struct HnswElements {
elements: HashMap<ElementId, SharedVector>,
next_element_id: ElementId,
dist: Distance,
}
impl HnswElements {
fn new(dist: Distance) -> Self {
Self {
elements: Default::default(),
next_element_id: 0,
dist,
}
}
fn get_vector(&self, e_id: &ElementId) -> Option<&SharedVector> {
self.elements.get(e_id)
}
fn distance(&self, a: &SharedVector, b: &SharedVector) -> f64 {
self.dist.calculate(a, b)
}
fn get_distance(&self, q: &SharedVector, e_id: &ElementId) -> Option<f64> {
self.elements.get(e_id).map(|e_pt| self.dist.calculate(e_pt, q))
}
fn remove(&mut self, e_id: &ElementId) {
self.elements.remove(e_id);
}
}
struct Hnsw<L0, L>
where
L0: DynamicSet<ElementId>,
L: DynamicSet<ElementId>,
{
m: usize,
efc: usize,
ml: f64,
layer0: HnswLayer<L0>,
layers: Vec<HnswLayer<L>>,
enter_point: Option<ElementId>,
elements: HnswElements,
rng: SmallRng,
heuristic: Heuristic,
}
pub(super) type ElementId = u64;
impl<L0, L> Hnsw<L0, L>
where
L0: DynamicSet<ElementId>,
L: DynamicSet<ElementId>,
{
fn new(p: &HnswParams) -> Self {
let m0 = p.m0 as usize;
Self {
m: p.m as usize,
efc: p.ef_construction as usize,
ml: p.ml.to_float(),
enter_point: None,
layer0: HnswLayer::new(m0),
layers: Vec::default(),
elements: HnswElements::new(p.distance.clone()),
rng: SmallRng::from_entropy(),
heuristic: p.into(),
}
}
fn insert_level(&mut self, q_pt: SharedVector, q_level: usize) -> ElementId {
// Attribute an ID to the vector
let q_id = self.elements.next_element_id;
let top_up_layers = self.layers.len();
// Be sure we have existing (up) layers if required
for _ in top_up_layers..q_level {
self.layers.push(HnswLayer::new(self.m));
}
// Store the vector
self.elements.elements.insert(q_id, q_pt.clone());
if let Some(ep_id) = self.enter_point {
// We already have an enter_point, let's insert the element in the layers
self.insert_element(q_id, &q_pt, q_level, ep_id, top_up_layers);
} else {
// Otherwise is the first element
self.insert_first_element(q_id, q_level);
}
self.elements.next_element_id += 1;
q_id
}
fn get_random_level(&mut self) -> usize {
let unif: f64 = self.rng.gen(); // generate a uniform random number between 0 and 1
(-unif.ln() * self.ml).floor() as usize // calculate the layer
}
fn insert_first_element(&mut self, id: ElementId, level: usize) {
if level > 0 {
for layer in self.layers.iter_mut().take(level) {
layer.add_empty_node(id);
}
}
self.layer0.add_empty_node(id);
self.enter_point = Some(id);
}
fn insert_element(
&mut self,
q_id: ElementId,
q_pt: &SharedVector,
q_level: usize,
mut ep_id: ElementId,
top_up_layers: usize,
) {
let mut ep_dist =
self.elements.get_distance(q_pt, &ep_id).unwrap_or_else(|| unreachable!());
if q_level < top_up_layers {
for layer in self.layers[q_level..top_up_layers].iter_mut().rev() {
(ep_dist, ep_id) = layer
.search_single(&self.elements, q_pt, ep_dist, ep_id, 1)
.peek_first()
.unwrap_or_else(|| unreachable!())
}
}
let mut eps = DoublePriorityQueue::from(ep_dist, ep_id);
let insert_to_up_layers = q_level.min(top_up_layers);
if insert_to_up_layers > 0 {
for layer in self.layers.iter_mut().take(insert_to_up_layers).rev() {
eps = layer.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);
}
}
self.layer0.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);
if top_up_layers < q_level {
for layer in self.layers[top_up_layers..q_level].iter_mut() {
if !layer.add_empty_node(q_id) {
unreachable!("Already there {}", q_id);
}
}
}
if q_level > top_up_layers {
self.enter_point = Some(q_id);
}
}
}
impl<L0, L> HnswMethods for Hnsw<L0, L>
where
L0: DynamicSet<ElementId>,
L: DynamicSet<ElementId>,
{
fn insert(&mut self, q_pt: SharedVector) -> ElementId {
let q_level = self.get_random_level();
self.insert_level(q_pt, q_level)
}
fn remove(&mut self, e_id: ElementId) -> bool {
let mut removed = false;
let e_pt = self.elements.get_vector(&e_id).cloned();
// Do we have the vector?
if let Some(e_pt) = e_pt {
let layers = self.layers.len();
let mut new_enter_point = None;
// Are we deleting the current enter point?
if Some(e_id) == self.enter_point {
// Let's find a new enter point
new_enter_point = if layers == 0 {
self.layer0.search_single_ignore_ep(&self.elements, &e_pt, e_id)
} else {
self.layers[layers - 1].search_single_ignore_ep(&self.elements, &e_pt, e_id)
};
}
self.elements.remove(&e_id);
// Remove from the up layers
for layer in self.layers.iter_mut().rev() {
if layer.remove(&self.elements, &self.heuristic, e_id, self.efc) {
removed = true;
}
}
// Remove from layer 0
if self.layer0.remove(&self.elements, &self.heuristic, e_id, self.efc) {
removed = true;
}
if removed && new_enter_point.is_some() {
// Update the enter point
self.enter_point = new_enter_point.map(|(_, e_id)| e_id);
}
}
removed
}
fn knn_search(&self, q: &SharedVector, k: usize, efs: usize) -> Vec<(f64, ElementId)> {
#[cfg(debug_assertions)]
let expected_w_len = self.elements.elements.len().min(k);
if let Some(mut ep_id) = self.enter_point {
let mut ep_dist =
self.elements.get_distance(q, &ep_id).unwrap_or_else(|| unreachable!());
for layer in self.layers.iter().rev() {
(ep_dist, ep_id) = layer
.search_single(&self.elements, q, ep_dist, ep_id, 1)
.peek_first()
.unwrap_or_else(|| unreachable!());
}
{
let w = self.layer0.search_single(&self.elements, q, ep_dist, ep_id, efs);
#[cfg(debug_assertions)]
if w.len() < expected_w_len {
debug!(
"0 search_layer - ep_id: {ep_id:?} - ef_search: {efs} - k: {k} - w.len: {} < {expected_w_len}",
w.len()
);
}
w.to_vec_limit(k)
}
} else {
vec![]
}
}
fn get_vector(&self, e_id: &ElementId) -> Option<&SharedVector> {
self.elements.get_vector(e_id)
}
#[cfg(test)]
fn check_hnsw_properties(&self, expected_count: usize) {
check_hnsw_props(self, expected_count);
}
}
#[cfg(test)]
mod tests {
use crate::err::Error;
use crate::idx::docids::DocId;
use crate::idx::trees::hnsw::{HnswIndex, HnswMethods};
use crate::idx::trees::knn::tests::{new_vectors_from_file, TestCollection};
use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::sql::index::{Distance, HnswParams, VectorType};
use hashbrown::{hash_map::Entry, HashMap, HashSet};
use ndarray::Array1;
use roaring::RoaringTreemap;
use std::sync::Arc;
use test_log::test;
fn insert_collection_hnsw(
h: &mut Box<dyn HnswMethods>,
collection: &TestCollection,
) -> HashSet<SharedVector> {
let mut set = HashSet::new();
for (_, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone().into();
h.insert(obj.clone());
set.insert(obj);
h.check_hnsw_properties(set.len());
}
set
}
fn find_collection_hnsw(h: &Box<dyn HnswMethods>, collection: &TestCollection) {
let max_knn = 20.min(collection.len());
for (_, obj) in collection.to_vec_ref() {
let obj = obj.clone().into();
for knn in 1..max_knn {
let res = h.knn_search(&obj, knn, 80);
if collection.is_unique() {
let mut found = false;
for (_, e_id) in &res {
if let Some(v) = h.get_vector(e_id) {
if obj.eq(v) {
found = true;
break;
}
}
}
assert!(
found,
"Search: {:?} - Knn: {} - Vector not found - Got: {:?} - Coll: {}",
obj,
knn,
res,
collection.len(),
);
}
let expected_len = collection.len().min(knn);
if expected_len != res.len() {
info!("expected_len != res.len()")
}
assert_eq!(
expected_len,
res.len(),
"Wrong knn count - Expected: {} - Got: {} - Collection: {} - - Res: {:?}",
expected_len,
res.len(),
collection.len(),
res,
)
}
}
}
fn test_hnsw_collection(p: &HnswParams, collection: &TestCollection) {
let mut h = HnswIndex::new_hnsw(p);
insert_collection_hnsw(&mut h, collection);
find_collection_hnsw(&h, &collection);
}
fn new_params(
dimension: usize,
vector_type: VectorType,
distance: Distance,
m: usize,
efc: usize,
extend_candidates: bool,
keep_pruned_connections: bool,
) -> HnswParams {
let m = m as u8;
let m0 = m * 2;
HnswParams::new(
dimension as u16,
distance,
vector_type,
m,
m0,
(1.0 / (m as f64).ln()).into(),
efc as u16,
extend_candidates,
keep_pruned_connections,
)
}
fn test_hnsw(collection_size: usize, p: HnswParams) {
info!("Collection size: {collection_size} - Params: {p:?}");
let collection = TestCollection::new(
true,
collection_size,
p.vector_type,
p.dimension as usize,
&p.distance,
);
test_hnsw_collection(&p, &collection);
}
#[test(tokio::test(flavor = "multi_thread"))]
async fn tests_hnsw() -> Result<(), Error> {
let mut futures = Vec::new();
for (dist, dim) in [
(Distance::Chebyshev, 5),
(Distance::Cosine, 5),
(Distance::Euclidean, 5),
(Distance::Hamming, 20),
// (Distance::Jaccard, 100),
(Distance::Manhattan, 5),
(Distance::Minkowski(2.into()), 5),
//(Distance::Pearson, 5),
] {
for vt in [
VectorType::F64,
VectorType::F32,
VectorType::I64,
VectorType::I32,
VectorType::I16,
] {
for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
let p = new_params(dim, vt, dist.clone(), 24, 500, extend, keep);
let f = tokio::spawn(async move {
test_hnsw(30, p);
});
futures.push(f);
}
}
}
for f in futures {
f.await.expect("Task error");
}
Ok(())
}
fn insert_collection_hnsw_index(
h: &mut HnswIndex,
collection: &TestCollection,
) -> HashMap<SharedVector, HashSet<DocId>> {
let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::new();
for (doc_id, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone().into();
h.insert(obj.clone(), *doc_id);
match map.entry(obj) {
Entry::Occupied(mut e) => {
e.get_mut().insert(*doc_id);
}
Entry::Vacant(e) => {
e.insert(HashSet::from([*doc_id]));
}
}
h.hnsw.check_hnsw_properties(map.len());
}
map
}
fn find_collection_hnsw_index(h: &mut HnswIndex, collection: &TestCollection) {
let max_knn = 20.min(collection.len());
for (doc_id, obj) in collection.to_vec_ref() {
for knn in 1..max_knn {
let obj: SharedVector = obj.clone().into();
let res = h.search(&obj, knn, 500);
if knn == 1 && res.docs.len() == 1 && res.docs[0].1 > 0.0 {
let docs: Vec<DocId> = res.docs.iter().map(|(d, _)| *d).collect();
if collection.is_unique() {
assert!(
docs.contains(doc_id),
"Search: {:?} - Knn: {} - Wrong Doc - Expected: {} - Got: {:?}",
obj,
knn,
doc_id,
res.docs
);
}
}
let expected_len = collection.len().min(knn);
assert_eq!(
expected_len,
res.docs.len(),
"Wrong knn count - Expected: {} - Got: {} - - Docs: {:?} - Collection: {}",
expected_len,
res.docs.len(),
res.docs,
collection.len(),
)
}
}
}
fn delete_hnsw_index_collection(
h: &mut HnswIndex,
collection: &TestCollection,
mut map: HashMap<SharedVector, HashSet<DocId>>,
) {
for (doc_id, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone().into();
h.remove(obj.clone(), *doc_id);
if let Entry::Occupied(mut e) = map.entry(obj.clone()) {
let set = e.get_mut();
set.remove(doc_id);
if set.is_empty() {
e.remove();
}
}
h.hnsw.check_hnsw_properties(map.len());
}
}
fn test_hnsw_index(collection_size: usize, unique: bool, p: HnswParams) {
info!("test_hnsw_index - coll size: {collection_size} - params: {p:?}");
let collection = TestCollection::new(
unique,
collection_size,
p.vector_type,
p.dimension as usize,
&p.distance,
);
let mut h = HnswIndex::new(&p);
let map = insert_collection_hnsw_index(&mut h, &collection);
find_collection_hnsw_index(&mut h, &collection);
delete_hnsw_index_collection(&mut h, &collection, map);
}
#[test(tokio::test(flavor = "multi_thread"))]
async fn tests_hnsw_index() -> Result<(), Error> {
let mut futures = Vec::new();
for (dist, dim) in [
(Distance::Chebyshev, 5),
(Distance::Cosine, 5),
(Distance::Euclidean, 5),
(Distance::Hamming, 20),
// (Distance::Jaccard, 100),
(Distance::Manhattan, 5),
(Distance::Minkowski(2.into()), 5),
(Distance::Pearson, 5),
] {
for vt in [
VectorType::F64,
VectorType::F32,
VectorType::I64,
VectorType::I32,
VectorType::I16,
] {
for (extend, keep) in [(false, false), (true, false), (false, true), (true, true)] {
for unique in [false, true] {
let p = new_params(dim, vt, dist.clone(), 8, 150, extend, keep);
let f = tokio::spawn(async move {
test_hnsw_index(30, unique, p);
});
futures.push(f);
}
}
}
}
for f in futures {
f.await.expect("Task error");
}
Ok(())
}
#[test]
fn test_simple_hnsw() {
let collection = TestCollection::Unique(vec![
(0, new_i16_vec(-2, -3)),
(1, new_i16_vec(-2, 1)),
(2, new_i16_vec(-4, 3)),
(3, new_i16_vec(-3, 1)),
(4, new_i16_vec(-1, 1)),
(5, new_i16_vec(-2, 3)),
(6, new_i16_vec(3, 0)),
(7, new_i16_vec(-1, -2)),
(8, new_i16_vec(-2, 2)),
(9, new_i16_vec(-4, -2)),
(10, new_i16_vec(0, 3)),
]);
let p = new_params(2, VectorType::I16, Distance::Euclidean, 3, 500, true, true);
let mut h = HnswIndex::new_hnsw(&p);
insert_collection_hnsw(&mut h, &collection);
let pt = new_i16_vec(-2, -3);
let knn = 10;
let efs = 501;
let res = h.knn_search(&pt, knn, efs);
assert_eq!(res.len(), knn);
}
async fn test_recall(
embeddings_file: &str,
ingest_limit: usize,
queries_file: &str,
query_limit: usize,
p: HnswParams,
tests_ef_recall: &[(usize, f64)],
) -> Result<(), Error> {
info!("Build data collection");
let collection: Arc<TestCollection> =
Arc::new(TestCollection::NonUnique(new_vectors_from_file(
p.vector_type,
&format!("../tests/data/{embeddings_file}"),
Some(ingest_limit),
)?));
let mut h = HnswIndex::new(&p);
info!("Insert collection");
for (doc_id, obj) in collection.to_vec_ref() {
h.insert(obj.clone(), *doc_id);
}
let h = Arc::new(h);
info!("Build query collection");
let queries = Arc::new(TestCollection::NonUnique(new_vectors_from_file(
p.vector_type,
&format!("../tests/data/{queries_file}"),
Some(query_limit),
)?));
info!("Check recall");
let mut futures = Vec::with_capacity(tests_ef_recall.len());
for &(efs, expected_recall) in tests_ef_recall {
let queries = queries.clone();
let collection = collection.clone();
let h = h.clone();
let f = tokio::spawn(async move {
let mut total_recall = 0.0;
for (_, pt) in queries.to_vec_ref() {
let knn = 10;
let hnsw_res = h.search(pt, knn, efs);
assert_eq!(hnsw_res.docs.len(), knn, "Different size - knn: {knn}",);
let brute_force_res = collection.knn(pt, Distance::Euclidean, knn);
let rec = brute_force_res.recall(&hnsw_res);
if rec == 1.0 {
assert_eq!(brute_force_res.docs, hnsw_res.docs);
}
total_recall += rec;
}
let recall = total_recall / queries.to_vec_ref().len() as f64;
info!("EFS: {efs} - Recall: {recall}");
assert!(
recall >= expected_recall,
"EFS: {efs} - Recall: {recall} - Expected: {expected_recall}"
);
});
futures.push(f);
}
for f in futures {
f.await.expect("Task failure");
}
Ok(())
}
#[test(tokio::test(flavor = "multi_thread"))]
async fn test_recall_euclidean() -> Result<(), Error> {
let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, false, false);
test_recall(
"hnsw-random-9000-20-euclidean.gz",
3000,
"hnsw-random-5000-20-euclidean.gz",
500,
p,
&[(10, 0.98), (40, 1.0)],
)
.await
}
#[test(tokio::test(flavor = "multi_thread"))]
async fn test_recall_euclidean_keep_pruned_connections() -> Result<(), Error> {
let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, false, true);
test_recall(
"hnsw-random-9000-20-euclidean.gz",
3000,
"hnsw-random-5000-20-euclidean.gz",
500,
p,
&[(10, 0.98), (40, 1.0)],
)
.await
}
#[test(tokio::test(flavor = "multi_thread"))]
async fn test_recall_euclidean_full() -> Result<(), Error> {
let p = new_params(20, VectorType::F32, Distance::Euclidean, 8, 100, true, true);
test_recall(
"hnsw-random-9000-20-euclidean.gz",
1000,
"hnsw-random-5000-20-euclidean.gz",
200,
p,
&[(10, 0.98), (40, 1.0)],
)
.await
}
impl TestCollection {
fn knn(&self, pt: &SharedVector, dist: Distance, n: usize) -> KnnResult {
let mut b = KnnResultBuilder::new(n);
for (doc_id, doc_pt) in self.to_vec_ref() {
let d = dist.calculate(doc_pt, pt);
if b.check_add(d) {
b.add(d, &Ids64::One(*doc_id));
}
}
b.build(
#[cfg(debug_assertions)]
HashMap::new(),
)
}
}
impl KnnResult {
fn recall(&self, res: &KnnResult) -> f64 {
let mut bits = RoaringTreemap::new();
for &(doc_id, _) in &self.docs {
bits.insert(doc_id);
}
let mut found = 0;
for &(doc_id, _) in &res.docs {
if bits.contains(doc_id) {
found += 1;
}
}
found as f64 / bits.len() as f64
}
}
fn new_i16_vec(x: isize, y: isize) -> SharedVector {
let vec = Vector::I16(Array1::from_vec(vec![x as i16, y as i16]));
vec.into()
}
}

View file

@ -1,14 +1,18 @@
use crate::idx::docids::DocId;
use crate::idx::trees::dynamicset::DynamicSet;
use crate::idx::trees::hnsw::ElementId;
use crate::idx::trees::store::NodeId;
#[cfg(debug_assertions)]
use hashbrown::HashMap;
use hashbrown::HashSet;
use roaring::RoaringTreemap;
use std::cmp::{Ordering, Reverse};
use std::collections::btree_map::Entry;
#[cfg(debug_assertions)]
use std::collections::HashMap;
use std::collections::{BTreeMap, VecDeque};
#[derive(Debug, Clone, Copy, Ord, Eq, PartialEq, PartialOrd)]
pub(super) struct PriorityNode(Reverse<FloatKey>, NodeId);
impl PriorityNode {
pub(super) fn new(d: f64, id: NodeId) -> Self {
Self(Reverse(FloatKey::new(d)), id)
@ -19,6 +23,116 @@ impl PriorityNode {
}
}
#[derive(Default, Clone)]
pub(super) struct DoublePriorityQueue(BTreeMap<FloatKey, VecDeque<ElementId>>, usize);
impl DoublePriorityQueue {
pub(super) fn from(d: f64, e: ElementId) -> Self {
let mut q = DoublePriorityQueue::default();
q.push(d, e);
q
}
pub(super) fn len(&self) -> usize {
self.1
}
pub(super) fn push(&mut self, dist: f64, id: ElementId) {
match self.0.entry(FloatKey(dist)) {
Entry::Vacant(e) => {
e.insert(VecDeque::from([id]));
}
Entry::Occupied(mut e) => {
e.get_mut().push_back(id);
}
}
self.1 += 1;
}
pub(super) fn pop_first(&mut self) -> Option<(f64, ElementId)> {
if let Some(mut e) = self.0.first_entry() {
let d = e.key().0;
let q = e.get_mut();
if let Some(v) = q.pop_front() {
if q.is_empty() {
e.remove();
}
self.1 -= 1;
return Some((d, v));
}
}
None
}
pub(super) fn pop_last(&mut self) -> Option<(f64, ElementId)> {
if let Some(mut e) = self.0.last_entry() {
let d = e.key().0;
let q = e.get_mut();
if let Some(v) = q.pop_back() {
if q.is_empty() {
e.remove();
}
self.1 -= 1;
return Some((d, v));
}
}
None
}
pub(super) fn peek_first(&self) -> Option<(f64, ElementId)> {
self.0.first_key_value().map(|(k, q)| {
let k = k.0;
let v = *q.iter().next().unwrap(); // By design the contains always contains one element
(k, v)
})
}
pub(super) fn peek_last_dist(&self) -> Option<f64> {
self.0.last_key_value().map(|(k, _)| k.0)
}
pub(super) fn to_vec(&self) -> Vec<(f64, ElementId)> {
let mut v = Vec::with_capacity(self.1);
for (d, q) in &self.0 {
for e in q {
v.push((d.0, *e));
}
}
v
}
pub(super) fn to_vec_limit(&self, mut limit: usize) -> Vec<(f64, ElementId)> {
let mut v = Vec::with_capacity(self.1.min(limit));
for (d, q) in &self.0 {
for e in q {
v.push((d.0, *e));
limit -= 1;
if limit == 0 {
return v;
}
}
}
v
}
pub(super) fn to_set(&self) -> HashSet<ElementId> {
let mut s = HashSet::with_capacity(self.1);
for q in self.0.values() {
for v in q {
s.insert(*v);
}
}
s
}
pub(super) fn to_dynamic_set<S: DynamicSet<ElementId>>(&self, set: &mut S) {
for q in self.0.values() {
for v in q {
set.insert(*v);
}
}
}
}
/// Treats f64 as a sortable data type.
/// It provides an implementation so it can be used as a key in a BTreeMap or BTreeSet.
#[derive(Debug, Clone, Copy)]
@ -91,6 +205,10 @@ impl Ids64 {
}
}
pub(super) fn is_empty(&self) -> bool {
matches!(self, Self::Empty)
}
fn append_to(&self, to: &mut RoaringTreemap) {
match &self {
Self::Empty => {}
@ -491,19 +609,26 @@ pub struct KnnResult {
#[cfg(test)]
pub(super) mod tests {
use crate::err::Error;
use crate::idx::docids::DocId;
use crate::idx::trees::knn::{FloatKey, Ids64, KnnResultBuilder};
use crate::idx::trees::knn::{DoublePriorityQueue, FloatKey, Ids64, KnnResultBuilder};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::sql::index::{Distance, VectorType};
use crate::sql::Number;
use crate::sql::{Array, Number};
use crate::syn::Parse;
use flate2::read::GzDecoder;
#[cfg(debug_assertions)]
use hashbrown::HashMap;
use hashbrown::HashSet;
use rand::prelude::SmallRng;
use rand::{Rng, SeedableRng};
use roaring::RoaringTreemap;
use rust_decimal::prelude::Zero;
use std::cmp::Reverse;
#[cfg(debug_assertions)]
use std::collections::HashMap;
use std::collections::{BTreeSet, BinaryHeap, VecDeque};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::time::SystemTime;
use test_log::test;
pub(crate) fn get_seed_rnd() -> SmallRng {
@ -534,16 +659,47 @@ pub(super) mod tests {
}
}
pub(in crate::idx::trees) fn new_vectors_from_file<V: From<Vector>>(
t: VectorType,
path: &str,
limit: Option<usize>,
) -> Result<Vec<(DocId, V)>, Error> {
// Open the gzip file
let file = File::open(path)?;
// Create a GzDecoder to read the file
let gz = GzDecoder::new(file);
// Wrap the decoder in a BufReader
let reader = BufReader::new(gz);
let mut res = Vec::new();
// Iterate over each line in the file
for (i, line_result) in reader.lines().enumerate() {
if let Some(l) = limit {
if l == i {
break;
}
}
let line = line_result?;
let array = Array::parse(&line);
let vec = Vector::try_from_array(t, &array)?.into();
res.push((i as DocId, vec));
}
Ok(res)
}
pub(in crate::idx::trees) fn new_random_vec(
rng: &mut SmallRng,
t: VectorType,
dim: usize,
gen: &RandomItemGenerator,
) -> SharedVector {
let mut vec = Vector::new(t, dim);
let mut vec: Vec<Number> = Vec::with_capacity(dim);
for _ in 0..dim {
vec.add(&gen.generate(rng));
vec.push(gen.generate(rng));
}
let vec = Vector::try_from_array(t, &Array::from(vec)).unwrap();
if vec.is_null() {
// Some similarities (cosine) is undefined for null vector.
new_random_vec(rng, t, dim, gen)
@ -596,7 +752,7 @@ pub(super) mod tests {
gen: &RandomItemGenerator,
rng: &mut SmallRng,
) -> Self {
let mut vector_set = BTreeSet::new();
let mut vector_set = HashSet::new();
let mut attempts = collection_size * 2;
while vector_set.len() < collection_size {
vector_set.insert(new_random_vec(rng, vector_type, dimension, gen));
@ -735,4 +891,109 @@ pub(super) mod tests {
assert_eq!(q.pop(), Some(n2));
assert_eq!(q.pop(), Some(n3));
}
#[test]
fn test_double_priority_queue() {
let mut q = DoublePriorityQueue::from(2.0, 2);
q.push(3.0, 4);
q.push(3.0, 3);
q.push(1.0, 1);
assert_eq!(q.len(), 4);
assert_eq!(q.peek_first(), Some((1.0, 1)));
assert_eq!(q.peek_last_dist(), Some(3.0));
assert_eq!(q.pop_first(), Some((1.0, 1)));
assert_eq!(q.len(), 3);
assert_eq!(q.peek_first(), Some((2.0, 2)));
assert_eq!(q.peek_last_dist(), Some(3.0));
assert_eq!(q.pop_first(), Some((2.0, 2)));
assert_eq!(q.len(), 2);
assert_eq!(q.peek_first(), Some((3.0, 4)));
assert_eq!(q.peek_last_dist(), Some(3.0));
assert_eq!(q.pop_first(), Some((3.0, 4)));
assert_eq!(q.len(), 1);
assert_eq!(q.peek_first(), Some((3.0, 3)));
assert_eq!(q.peek_last_dist(), Some(3.0));
assert_eq!(q.pop_first(), Some((3.0, 3)));
assert_eq!(q.len(), 0);
assert_eq!(q.peek_first(), None);
assert_eq!(q.peek_last_dist(), None);
let mut q = DoublePriorityQueue::from(2.0, 2).clone();
q.push(3.0, 4);
q.push(3.0, 3);
q.push(1.0, 1);
assert_eq!(q.pop_last(), Some((3.0, 3)));
assert_eq!(q.len(), 3);
assert_eq!(q.peek_first(), Some((1.0, 1)));
assert_eq!(q.peek_last_dist(), Some(3.0));
assert_eq!(q.pop_last(), Some((3.0, 4)));
assert_eq!(q.len(), 2);
assert_eq!(q.peek_first(), Some((1.0, 1)));
assert_eq!(q.peek_last_dist(), Some(2.0));
assert_eq!(q.pop_last(), Some((2.0, 2)));
assert_eq!(q.len(), 1);
assert_eq!(q.peek_first(), Some((1.0, 1)));
assert_eq!(q.peek_last_dist(), Some(1.0));
assert_eq!(q.pop_last(), Some((1.0, 1)));
assert_eq!(q.len(), 0);
assert_eq!(q.peek_first(), None);
assert_eq!(q.peek_last_dist(), None);
}
#[test]
// In HNSW we are maintaining a candidate list that requires both to know the first element
// and the last element of a set.
// There is two possible options.
// 1. Using a BTreeSet that provide first() and last() methods.
// 2. Maintaining two BinaryHeap. One providing the min, and the other the max.
// This test checks that option 2 is faster than option 1.
// Actually option 2 is about 4 times faster than option 1.
fn confirm_binaryheaps_faster_than_btreeset() {
// Build samples
const TOTAL: usize = 500;
let mut pns = Vec::with_capacity(TOTAL);
for i in 0..TOTAL {
pns.push((FloatKey::new(i as f64), i as u64));
}
// Test BTreeSet
let duration_btree_set = {
let first = Some(&pns[0]);
let t = SystemTime::now();
let mut bt = BTreeSet::new();
for pn in &pns {
bt.insert(*pn);
assert_eq!(bt.first(), first);
assert_eq!(bt.last(), Some(pn));
}
t.elapsed().unwrap()
};
// Test double BinaryHeap
let duration_binary_heap = {
let r_first = Reverse(pns[0]);
let first = Some(&r_first);
let t = SystemTime::now();
let mut max = BinaryHeap::with_capacity(TOTAL);
let mut min = BinaryHeap::with_capacity(TOTAL);
for pn in &pns {
max.push(*pn);
min.push(Reverse(*pn));
assert_eq!(min.peek(), first);
assert_eq!(max.peek(), Some(pn));
}
t.elapsed().unwrap()
};
info!("{duration_btree_set:?} {duration_binary_heap:?}");
assert!(duration_btree_set > duration_binary_heap);
}
}

View file

@ -1,6 +1,9 @@
pub mod bkeys;
pub mod btree;
pub mod knn;
pub mod dynamicset;
mod graph;
pub mod hnsw;
mod knn;
pub mod mtree;
pub mod store;
pub mod vector;

View file

@ -1,10 +1,10 @@
use hashbrown::hash_map::Entry;
use hashbrown::{HashMap, HashSet};
use reblessive::tree::Stk;
use revision::revisioned;
use roaring::RoaringTreemap;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::collections::{BinaryHeap, HashSet, VecDeque};
use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{Debug, Display, Formatter};
use std::io::Cursor;
use std::sync::Arc;
@ -1430,8 +1430,9 @@ impl VersionedSerdeState for MState {}
#[cfg(test)]
mod tests {
use hashbrown::{HashMap, HashSet};
use reblessive::tree::Stk;
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::VecDeque;
use crate::err::Error;
use test_log::test;

View file

@ -4,8 +4,8 @@ use crate::idx::trees::store::{NodeId, StoreGeneration, StoredNode, TreeNode, Tr
use crate::kvs::{Key, Transaction};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use hashbrown::{HashMap, HashSet};
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{Debug, Display};
use std::sync::Arc;

View file

@ -0,0 +1,49 @@
use crate::idx::trees::hnsw::HnswIndex;
use crate::idx::IndexKeyBase;
use crate::kvs::Key;
use crate::sql::index::HnswParams;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub(crate) type SharedHnswIndex = Arc<RwLock<HnswIndex>>;
pub(crate) struct HnswIndexes(Arc<RwLock<HashMap<Key, SharedHnswIndex>>>);
impl Default for HnswIndexes {
fn default() -> Self {
Self(Arc::new(RwLock::new(HashMap::new())))
}
}
impl HnswIndexes {
pub(super) async fn get(&self, ikb: &IndexKeyBase, p: &HnswParams) -> SharedHnswIndex {
let key = ikb.new_vm_key(None);
{
let r = self.0.read().await;
if let Some(h) = r.get(&key).cloned() {
return h;
}
}
let mut w = self.0.write().await;
match w.entry(key) {
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(e) => {
let h = Arc::new(RwLock::new(HnswIndex::new(p)));
e.insert(h.clone());
h
}
}
}
pub(super) async fn remove(&self, ikb: &IndexKeyBase) {
let key = ikb.new_vm_key(None);
let mut w = self.0.write().await;
w.remove(&key);
}
pub(super) async fn is_empty(&self) -> bool {
self.0.read().await.is_empty()
}
}

View file

@ -1,5 +1,5 @@
use futures::future::join_all;
use std::collections::HashMap;
use hashbrown::HashMap;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use tokio::sync::Mutex;

View file

@ -1,4 +1,5 @@
pub mod cache;
pub(crate) mod hnsw;
mod lru;
pub(crate) mod tree;
@ -8,9 +9,11 @@ use crate::idx::trees::bkeys::{FstKeys, TrieKeys};
use crate::idx::trees::btree::{BTreeNode, BTreeStore};
use crate::idx::trees::mtree::{MTreeNode, MTreeStore};
use crate::idx::trees::store::cache::{TreeCache, TreeCaches};
use crate::idx::trees::store::hnsw::{HnswIndexes, SharedHnswIndex};
use crate::idx::trees::store::tree::{TreeRead, TreeWrite};
use crate::idx::IndexKeyBase;
use crate::kvs::{Key, Transaction, TransactionType, Val};
use crate::sql::index::HnswParams;
use crate::sql::statements::DefineIndexStatement;
use crate::sql::Index;
use std::fmt::{Debug, Display, Formatter};
@ -199,6 +202,7 @@ struct Inner {
btree_fst_caches: TreeCaches<BTreeNode<FstKeys>>,
btree_trie_caches: TreeCaches<BTreeNode<TrieKeys>>,
mtree_caches: TreeCaches<MTreeNode>,
hnsw_indexes: HnswIndexes,
}
impl Default for IndexStores {
fn default() -> Self {
@ -206,6 +210,7 @@ impl Default for IndexStores {
btree_fst_caches: TreeCaches::default(),
btree_trie_caches: TreeCaches::default(),
mtree_caches: TreeCaches::default(),
hnsw_indexes: HnswIndexes::default(),
}))
}
}
@ -256,6 +261,16 @@ impl IndexStores {
self.0.mtree_caches.new_cache(new_cache);
}
pub(crate) async fn get_index_hnsw(
&self,
opt: &Options,
ix: &DefineIndexStatement,
p: &HnswParams,
) -> SharedHnswIndex {
let ikb = IndexKeyBase::new(opt, ix);
self.0.hnsw_indexes.get(&ikb, p).await
}
pub(crate) async fn index_removed(
&self,
opt: &Options,
@ -267,6 +282,7 @@ impl IndexStores {
opt,
tx.get_and_cache_tb_index(opt.ns(), opt.db(), tb, ix).await?.as_ref(),
)
.await
}
pub(crate) async fn namespace_removed(
@ -287,12 +303,12 @@ impl IndexStores {
tb: &str,
) -> Result<(), Error> {
for ix in tx.all_tb_indexes(opt.ns(), opt.db(), tb).await?.iter() {
self.remove_index(opt, ix)?;
self.remove_index(opt, ix).await?;
}
Ok(())
}
fn remove_index(&self, opt: &Options, ix: &DefineIndexStatement) -> Result<(), Error> {
async fn remove_index(&self, opt: &Options, ix: &DefineIndexStatement) -> Result<(), Error> {
let ikb = IndexKeyBase::new(opt, ix);
match ix.index {
Index::Search(_) => {
@ -301,6 +317,9 @@ impl IndexStores {
Index::MTree(_) => {
self.remove_mtree_caches(ikb);
}
Index::Hnsw(_) => {
self.remove_hnsw_index(ikb).await;
}
_ => {}
}
Ok(())
@ -318,9 +337,14 @@ impl IndexStores {
self.0.mtree_caches.remove_caches(&TreeNodeProvider::Vector(ikb.clone()));
}
pub fn is_empty(&self) -> bool {
async fn remove_hnsw_index(&self, ikb: IndexKeyBase) {
self.0.hnsw_indexes.remove(&ikb).await;
}
pub async fn is_empty(&self) -> bool {
self.0.mtree_caches.is_empty()
&& self.0.btree_fst_caches.is_empty()
&& self.0.btree_trie_caches.is_empty()
&& self.0.hnsw_indexes.is_empty().await
}
}

View file

@ -2,7 +2,7 @@ use crate::err::Error;
use crate::idx::trees::store::cache::TreeCache;
use crate::idx::trees::store::{NodeId, StoredNode, TreeNode, TreeNodeProvider};
use crate::kvs::{Key, Transaction};
use std::collections::{HashMap, HashSet};
use hashbrown::{HashMap, HashSet};
use std::fmt::{Debug, Display};
use std::mem;
use std::sync::Arc;

View file

@ -1,28 +1,276 @@
use crate::err::Error;
use crate::fnc::util::math::deviation::deviation;
use crate::fnc::util::math::mean::Mean;
use crate::fnc::util::math::ToFloat;
use crate::sql::index::{Distance, VectorType};
use crate::sql::{Array, Number, Value};
use revision::revisioned;
use hashbrown::HashSet;
use linfa_linalg::norm::Norm;
use ndarray::{Array1, LinalgScalar, Zip};
use ndarray_stats::DeviationExt;
use num_traits::Zero;
use rust_decimal::prelude::FromPrimitive;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::cmp::PartialEq;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::ops::{Mul, Sub};
use std::ops::{Add, Deref, Div, Sub};
use std::sync::Arc;
/// In the context of a Symmetric MTree index, the term object refers to a vector, representing the indexed item.
#[revisioned(revision = 1)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Vector {
F64(Vec<f64>),
F32(Vec<f32>),
I64(Vec<i64>),
I32(Vec<i32>),
I16(Vec<i16>),
F64(Array1<f64>),
F32(Array1<f32>),
I64(Array1<i64>),
I32(Array1<i32>),
I16(Array1<i16>),
}
impl Vector {
#[inline]
fn chebyshev<T>(a: &Array1<T>, b: &Array1<T>) -> f64
where
T: ToFloat,
{
a.iter()
.zip(b.iter())
.map(|(a, b)| (a.to_float() - b.to_float()).abs())
.fold(0.0_f64, f64::max)
}
fn chebyshev_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => a.linf_dist(b).unwrap_or(f64::INFINITY),
(Self::F32(a), Self::F32(b)) => {
a.linf_dist(b).map(|r| r as f64).unwrap_or(f64::INFINITY)
}
(Self::I64(a), Self::I64(b)) => {
a.linf_dist(b).map(|r| r as f64).unwrap_or(f64::INFINITY)
}
(Self::I32(a), Self::I32(b)) => {
a.linf_dist(b).map(|r| r as f64).unwrap_or(f64::INFINITY)
}
(Self::I16(a), Self::I16(b)) => Self::chebyshev(a, b),
_ => f64::NAN,
}
}
#[inline]
fn cosine_distance_f64(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let dot_product = a.dot(b);
let norm_a = a.norm_l2();
let norm_b = b.norm_l2();
1.0 - dot_product / (norm_a * norm_b)
}
#[inline]
fn cosine_distance_f32(a: &Array1<f32>, b: &Array1<f32>) -> f64 {
let dot_product = a.dot(b) as f64;
let norm_a = a.norm_l2() as f64;
let norm_b = b.norm_l2() as f64;
1.0 - dot_product / (norm_a * norm_b)
}
#[inline]
fn cosine_dist<T>(a: &Array1<T>, b: &Array1<T>) -> f64
where
T: ToFloat + LinalgScalar,
{
let dot_product = a.dot(b).to_float();
let norm_a = a.mapv(|x| x.to_float() * x.to_float()).sum().sqrt();
let norm_b = b.mapv(|x| x.to_float() * x.to_float()).sum().sqrt();
1.0 - dot_product / (norm_a * norm_b)
}
fn cosine_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::cosine_distance_f64(a, b),
(Self::F32(a), Self::F32(b)) => Self::cosine_distance_f32(a, b),
(Self::I64(a), Self::I64(b)) => Self::cosine_dist(a, b),
(Self::I32(a), Self::I32(b)) => Self::cosine_dist(a, b),
(Self::I16(a), Self::I16(b)) => Self::cosine_dist(a, b),
_ => f64::INFINITY,
}
}
#[inline]
fn euclidean<T>(a: &Array1<T>, b: &Array1<T>) -> f64
where
T: ToFloat,
{
Zip::from(a).and(b).map_collect(|x, y| (x.to_float() - y.to_float()).powi(2)).sum().sqrt()
}
fn euclidean_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => a.l2_dist(b).unwrap_or(f64::INFINITY),
(Self::F32(a), Self::F32(b)) => a.l2_dist(b).unwrap_or(f64::INFINITY),
(Self::I64(a), Self::I64(b)) => a.l2_dist(b).unwrap_or(f64::INFINITY),
(Self::I32(a), Self::I32(b)) => a.l2_dist(b).unwrap_or(f64::INFINITY),
(Self::I16(a), Self::I16(b)) => Self::euclidean(a, b),
_ => f64::INFINITY,
}
}
#[inline]
fn hamming<T>(a: &Array1<T>, b: &Array1<T>) -> f64
where
T: PartialEq,
{
Zip::from(a).and(b).fold(0, |acc, a, b| {
if a != b {
acc + 1
} else {
acc
}
}) as f64
}
fn hamming_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::hamming(a, b),
(Self::F32(a), Self::F32(b)) => Self::hamming(a, b),
(Self::I64(a), Self::I64(b)) => Self::hamming(a, b),
(Self::I32(a), Self::I32(b)) => Self::hamming(a, b),
(Self::I16(a), Self::I16(b)) => Self::hamming(a, b),
_ => f64::INFINITY,
}
}
#[inline]
fn jaccard_f64(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut union: HashSet<u64> = a.iter().map(|f| f.to_bits()).collect();
let intersection_size = b.iter().fold(0, |acc, n| {
if !union.insert(n.to_bits()) {
acc + 1
} else {
acc
}
}) as f64;
1.0 - intersection_size / union.len() as f64
}
#[inline]
fn jaccard_f32(a: &Array1<f32>, b: &Array1<f32>) -> f64 {
let mut union: HashSet<u32> = a.iter().map(|f| f.to_bits()).collect();
let intersection_size = b.iter().fold(0, |acc, n| {
if !union.insert(n.to_bits()) {
acc + 1
} else {
acc
}
}) as f64;
intersection_size / union.len() as f64
}
#[inline]
fn jaccard_integers<T>(a: &Array1<T>, b: &Array1<T>) -> f64
where
T: Eq + Hash + Clone,
{
let mut union: HashSet<T> = a.iter().cloned().collect();
let intersection_size = b.iter().cloned().fold(0, |acc, n| {
if !union.insert(n) {
acc + 1
} else {
acc
}
}) as f64;
intersection_size / union.len() as f64
}
pub(super) fn jaccard_similarity(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::jaccard_f64(a, b),
(Self::F32(a), Self::F32(b)) => Self::jaccard_f32(a, b),
(Self::I64(a), Self::I64(b)) => Self::jaccard_integers(a, b),
(Self::I32(a), Self::I32(b)) => Self::jaccard_integers(a, b),
(Self::I16(a), Self::I16(b)) => Self::jaccard_integers(a, b),
_ => f64::NAN,
}
}
#[inline]
fn manhattan<T>(a: &Array1<T>, b: &Array1<T>) -> f64
where
T: Sub<Output = T> + ToFloat + Copy,
{
a.iter().zip(b.iter()).map(|(&a, &b)| (a - b).to_float().abs()).sum()
}
pub(super) fn manhattan_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => a.l1_dist(b).unwrap_or(f64::INFINITY),
(Self::F32(a), Self::F32(b)) => a.l1_dist(b).map(|r| r as f64).unwrap_or(f64::INFINITY),
(Self::I64(a), Self::I64(b)) => a.l1_dist(b).map(|r| r as f64).unwrap_or(f64::INFINITY),
(Self::I32(a), Self::I32(b)) => a.l1_dist(b).map(|r| r as f64).unwrap_or(f64::INFINITY),
(Self::I16(a), Self::I16(b)) => Self::manhattan(a, b),
_ => f64::NAN,
}
}
#[inline]
fn minkowski<T>(a: &Array1<T>, b: &Array1<T>, order: f64) -> f64
where
T: ToFloat,
{
let dist: f64 = a
.iter()
.zip(b.iter())
.map(|(a, b)| (a.to_float() - b.to_float()).abs().powf(order))
.sum();
dist.powf(1.0 / order)
}
pub(super) fn minkowski_distance(&self, other: &Self, order: f64) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::minkowski(a, b, order),
(Self::F32(a), Self::F32(b)) => Self::minkowski(a, b, order),
(Self::I64(a), Self::I64(b)) => Self::minkowski(a, b, order),
(Self::I32(a), Self::I32(b)) => Self::minkowski(a, b, order),
(Self::I16(a), Self::I16(b)) => Self::minkowski(a, b, order),
_ => f64::NAN,
}
}
#[inline]
fn pearson<T>(x: &Array1<T>, y: &Array1<T>) -> f64
where
T: ToFloat + Clone + FromPrimitive + Add<Output = T> + Div<Output = T> + Zero,
{
let mean_x = x.mean().unwrap().to_float();
let mean_y = y.mean().unwrap().to_float();
let mut sum_xy = 0.0;
let mut sum_x2 = 0.0;
let mut sum_y2 = 0.0;
for (xi, yi) in x.iter().zip(y.iter()) {
let diff_x = xi.to_float() - mean_x;
let diff_y = yi.to_float() - mean_y;
sum_xy += diff_x * diff_y;
sum_x2 += diff_x.powi(2);
sum_y2 += diff_y.powi(2);
}
let numerator = sum_xy;
let denominator = (sum_x2 * sum_y2).sqrt();
if denominator == 0.0 {
return 0.0; // Return 0 if the denominator is 0
}
numerator / denominator
}
fn pearson_similarity(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::pearson(a, b),
(Self::F32(a), Self::F32(b)) => Self::pearson(a, b),
(Self::I64(a), Self::I64(b)) => Self::pearson(a, b),
(Self::I32(a), Self::I32(b)) => Self::pearson(a, b),
(Self::I16(a), Self::I16(b)) => Self::pearson(a, b),
_ => f64::NAN,
}
}
}
/// For vectors, as we want to support very large vectors, we want to avoid copy or clone.
@ -36,13 +284,15 @@ impl From<Vector> for SharedVector {
fn from(v: Vector) -> Self {
let mut h = DefaultHasher::new();
v.hash(&mut h);
Self(v.into(), h.finish())
Self(Arc::new(v), h.finish())
}
}
impl Borrow<Vector> for &SharedVector {
fn borrow(&self) -> &Vector {
self.0.as_ref()
impl Deref for SharedVector {
type Target = Vector;
fn deref(&self) -> &Self::Target {
&self.0
}
}
@ -59,18 +309,6 @@ impl PartialEq for SharedVector {
}
impl Eq for SharedVector {}
impl PartialOrd for SharedVector {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SharedVector {
fn cmp(&self, other: &Self) -> Ordering {
self.0.as_ref().cmp(other.0.as_ref())
}
}
impl Serialize for SharedVector {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@ -120,76 +358,51 @@ impl Hash for Vector {
}
}
impl PartialEq for Vector {
fn eq(&self, other: &Self) -> bool {
use Vector::*;
match (self, other) {
(F64(v), F64(v_o)) => v == v_o,
(F32(v), F32(v_o)) => v == v_o,
(I64(v), I64(v_o)) => v == v_o,
(I32(v), I32(v_o)) => v == v_o,
(I16(v), I16(v_o)) => v == v_o,
_ => false,
}
}
}
impl Eq for Vector {}
impl PartialOrd for Vector {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Vector {
fn cmp(&self, other: &Self) -> Ordering {
use Vector::*;
match (self, other) {
(F64(v), F64(v_o)) => v.partial_cmp(v_o).unwrap_or(Ordering::Equal),
(F32(v), F32(v_o)) => v.partial_cmp(v_o).unwrap_or(Ordering::Equal),
(I64(v), I64(v_o)) => v.cmp(v_o),
(I32(v), I32(v_o)) => v.cmp(v_o),
(I16(v), I16(v_o)) => v.cmp(v_o),
(F64(_), _) => Ordering::Less,
(_, F64(_)) => Ordering::Greater,
(F32(_), _) => Ordering::Less,
(_, F32(_)) => Ordering::Greater,
(I64(_), _) => Ordering::Less,
(_, I64(_)) => Ordering::Greater,
(I32(_), _) => Ordering::Less,
(_, I32(_)) => Ordering::Greater,
}
}
}
impl Vector {
pub(super) fn new(t: VectorType, d: usize) -> Self {
match t {
VectorType::F64 => Self::F64(Vec::with_capacity(d)),
VectorType::F32 => Self::F32(Vec::with_capacity(d)),
VectorType::I64 => Self::I64(Vec::with_capacity(d)),
VectorType::I32 => Self::I32(Vec::with_capacity(d)),
VectorType::I16 => Self::I16(Vec::with_capacity(d)),
}
}
pub(super) fn try_from_value(t: VectorType, d: usize, v: &Value) -> Result<Self, Error> {
let mut vec = Vector::new(t, d);
vec.check_vector_value(v)?;
Ok(vec)
let res = match t {
VectorType::F64 => {
let mut vec = Vec::with_capacity(d);
Self::check_vector_value(v, &mut vec)?;
Vector::F64(Array1::from_vec(vec))
}
VectorType::F32 => {
let mut vec = Vec::with_capacity(d);
Self::check_vector_value(v, &mut vec)?;
Vector::F32(Array1::from_vec(vec))
}
VectorType::I64 => {
let mut vec = Vec::with_capacity(d);
Self::check_vector_value(v, &mut vec)?;
Vector::I64(Array1::from_vec(vec))
}
VectorType::I32 => {
let mut vec = Vec::with_capacity(d);
Self::check_vector_value(v, &mut vec)?;
Vector::I32(Array1::from_vec(vec))
}
VectorType::I16 => {
let mut vec = Vec::with_capacity(d);
Self::check_vector_value(v, &mut vec)?;
Vector::I16(Array1::from_vec(vec))
}
};
Ok(res)
}
fn check_vector_value(&mut self, value: &Value) -> Result<(), Error> {
fn check_vector_value<T>(value: &Value, vec: &mut Vec<T>) -> Result<(), Error>
where
T: for<'a> TryFrom<&'a Number, Error = Error>,
{
match value {
Value::Array(a) => {
for v in a.0.iter() {
self.check_vector_value(v)?;
Self::check_vector_value(v, vec)?;
}
Ok(())
}
Value::Number(n) => {
self.add(n);
vec.push(n.try_into()?);
Ok(())
}
_ => Err(Error::InvalidVectorValue(value.clone().to_raw_string())),
@ -197,10 +410,43 @@ impl Vector {
}
pub fn try_from_array(t: VectorType, a: &Array) -> Result<Self, Error> {
let mut vec = Vector::new(t, a.len());
let res = match t {
VectorType::F64 => {
let mut vec = Vec::with_capacity(a.len());
Self::check_vector_array(a, &mut vec)?;
Vector::F64(Array1::from_vec(vec))
}
VectorType::F32 => {
let mut vec = Vec::with_capacity(a.len());
Self::check_vector_array(a, &mut vec)?;
Vector::F32(Array1::from_vec(vec))
}
VectorType::I64 => {
let mut vec = Vec::with_capacity(a.len());
Self::check_vector_array(a, &mut vec)?;
Vector::I64(Array1::from_vec(vec))
}
VectorType::I32 => {
let mut vec = Vec::with_capacity(a.len());
Self::check_vector_array(a, &mut vec)?;
Vector::I32(Array1::from_vec(vec))
}
VectorType::I16 => {
let mut vec = Vec::with_capacity(a.len());
Self::check_vector_array(a, &mut vec)?;
Vector::I16(Array1::from_vec(vec))
}
};
Ok(res)
}
fn check_vector_array<T>(a: &Array, vec: &mut Vec<T>) -> Result<(), Error>
where
T: for<'a> TryFrom<&'a Number, Error = Error>,
{
for v in &a.0 {
if let Value::Number(n) = v {
vec.add(n);
vec.push(n.try_into()?);
} else {
return Err(Error::InvalidVectorType {
current: v.clone().to_string(),
@ -208,17 +454,7 @@ impl Vector {
});
}
}
Ok(vec)
}
pub(super) fn add(&mut self, n: &Number) {
match self {
Self::F64(v) => v.push(n.to_float()),
Self::F32(v) => v.push(n.to_float() as f32),
Self::I64(v) => v.push(n.to_int()),
Self::I32(v) => v.push(n.to_int() as i32),
Self::I16(v) => v.push(n.to_int() as i16),
};
Ok(())
}
pub(super) fn len(&self) -> usize {
@ -242,238 +478,22 @@ impl Vector {
}
}
fn dot<T>(a: &[T], b: &[T]) -> f64
where
T: Mul<Output = T> + Copy + ToFloat,
{
a.iter().zip(b.iter()).map(|(&x, &y)| x.to_float() * y.to_float()).sum::<f64>()
}
fn magnitude<T>(v: &[T]) -> f64
where
T: ToFloat + Copy,
{
v.iter()
.map(|&x| {
let x = x.to_float();
x * x
})
.sum::<f64>()
.sqrt()
}
fn normalize<T>(v: &[T]) -> Vec<f64>
where
T: ToFloat + Copy,
{
let mag = Self::magnitude(v);
if mag == 0.0 || mag.is_nan() {
vec![0.0; v.len()] // Return a zero vector if magnitude is zero
} else {
v.iter().map(|&x| x.to_float() / mag).collect()
}
}
fn cosine<T>(a: &[T], b: &[T]) -> f64
where
T: ToFloat + Mul<Output = T> + Copy,
{
let norm_a = Self::normalize(a);
let norm_b = Self::normalize(b);
let mut s = Self::dot(&norm_a, &norm_b);
s = s.clamp(-1.0, 1.0);
1.0 - s
}
pub(crate) fn cosine_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::cosine(a, b),
(Self::F32(a), Self::F32(b)) => Self::cosine(a, b),
(Self::I64(a), Self::I64(b)) => Self::cosine(a, b),
(Self::I32(a), Self::I32(b)) => Self::cosine(a, b),
(Self::I16(a), Self::I16(b)) => Self::cosine(a, b),
_ => f64::NAN,
}
}
pub(super) fn check_dimension(&self, expected_dim: usize) -> Result<(), Error> {
Self::check_expected_dimension(self.len(), expected_dim)
}
fn chebyshev<T>(a: &[T], b: &[T]) -> f64
where
T: ToFloat,
{
a.iter()
.zip(b.iter())
.map(|(a, b)| (a.to_float() - b.to_float()).abs())
.fold(f64::MIN, f64::max)
}
pub(crate) fn chebyshev_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::chebyshev(a, b),
(Self::F32(a), Self::F32(b)) => Self::chebyshev(a, b),
(Self::I64(a), Self::I64(b)) => Self::chebyshev(a, b),
(Self::I32(a), Self::I32(b)) => Self::chebyshev(a, b),
(Self::I16(a), Self::I16(b)) => Self::chebyshev(a, b),
_ => f64::NAN,
}
}
fn euclidean<T>(a: &[T], b: &[T]) -> f64
where
T: ToFloat,
{
a.iter()
.zip(b.iter())
.map(|(a, b)| (a.to_float() - b.to_float()).powi(2))
.sum::<f64>()
.sqrt()
}
pub(crate) fn euclidean_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::euclidean(a, b),
(Self::F32(a), Self::F32(b)) => Self::euclidean(a, b),
(Self::I64(a), Self::I64(b)) => Self::euclidean(a, b),
(Self::I32(a), Self::I32(b)) => Self::euclidean(a, b),
(Self::I16(a), Self::I16(b)) => Self::euclidean(a, b),
_ => f64::INFINITY,
}
}
fn hamming<T>(a: &[T], b: &[T]) -> f64
where
T: PartialEq,
{
a.iter().zip(b.iter()).filter(|&(a, b)| a != b).count() as f64
}
pub(crate) fn hamming_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::hamming(a, b),
(Self::F32(a), Self::F32(b)) => Self::hamming(a, b),
(Self::I64(a), Self::I64(b)) => Self::hamming(a, b),
(Self::I32(a), Self::I32(b)) => Self::hamming(a, b),
(Self::I16(a), Self::I16(b)) => Self::hamming(a, b),
_ => f64::NAN,
}
}
fn jaccard_f64(a: &[f64], b: &[f64]) -> f64 {
let mut union: HashSet<u64> = HashSet::from_iter(a.iter().map(|f| f.to_bits()));
let intersection_size = b.iter().filter(|n| !union.insert(n.to_bits())).count() as f64;
intersection_size / union.len() as f64
}
fn jaccard_f32(a: &[f32], b: &[f32]) -> f64 {
let mut union: HashSet<u32> = HashSet::from_iter(a.iter().map(|f| f.to_bits()));
let intersection_size = b.iter().filter(|n| !union.insert(n.to_bits())).count() as f64;
intersection_size / union.len() as f64
}
fn jaccard_integers<T>(a: &[T], b: &[T]) -> f64
where
T: Eq + Hash,
{
let mut union: HashSet<&T> = HashSet::from_iter(a.iter());
let intersection_size = b.iter().filter(|n| !union.insert(n)).count() as f64;
intersection_size / union.len() as f64
}
pub(crate) fn jaccard_similarity(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::jaccard_f64(a, b),
(Self::F32(a), Self::F32(b)) => Self::jaccard_f32(a, b),
(Self::I64(a), Self::I64(b)) => Self::jaccard_integers(a, b),
(Self::I32(a), Self::I32(b)) => Self::jaccard_integers(a, b),
(Self::I16(a), Self::I16(b)) => Self::jaccard_integers(a, b),
_ => f64::NAN,
}
}
fn manhattan<T>(a: &[T], b: &[T]) -> f64
where
T: Sub<Output = T> + ToFloat + Copy,
{
a.iter().zip(b.iter()).map(|(&a, &b)| ((a - b).to_float()).abs()).sum()
}
pub(crate) fn manhattan_distance(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::manhattan(a, b),
(Self::F32(a), Self::F32(b)) => Self::manhattan(a, b),
(Self::I64(a), Self::I64(b)) => Self::manhattan(a, b),
(Self::I32(a), Self::I32(b)) => Self::manhattan(a, b),
(Self::I16(a), Self::I16(b)) => Self::manhattan(a, b),
_ => f64::NAN,
}
}
fn minkowski<T>(a: &[T], b: &[T], order: f64) -> f64
where
T: ToFloat,
{
let dist: f64 = a
.iter()
.zip(b.iter())
.map(|(a, b)| (a.to_float() - b.to_float()).abs().powf(order))
.sum();
dist.powf(1.0 / order)
}
pub(crate) fn minkowski_distance(&self, other: &Self, order: f64) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::minkowski(a, b, order),
(Self::F32(a), Self::F32(b)) => Self::minkowski(a, b, order),
(Self::I64(a), Self::I64(b)) => Self::minkowski(a, b, order),
(Self::I32(a), Self::I32(b)) => Self::minkowski(a, b, order),
(Self::I16(a), Self::I16(b)) => Self::minkowski(a, b, order),
_ => f64::NAN,
}
}
fn pearson<T>(a: &[T], b: &[T]) -> f64
where
T: ToFloat,
{
let m1 = a.mean();
let m2 = b.mean();
let covar: f64 =
a.iter().zip(b.iter()).map(|(x, y)| (x.to_float() - m1) * (y.to_float() - m2)).sum();
let covar = covar / a.len() as f64;
let std_dev1 = deviation(a, m1, false);
let std_dev2 = deviation(b, m2, false);
covar / (std_dev1 * std_dev2)
}
fn pearson_similarity(&self, other: &Self) -> f64 {
match (self, other) {
(Self::F64(a), Self::F64(b)) => Self::pearson(a, b),
(Self::F32(a), Self::F32(b)) => Self::pearson(a, b),
(Self::I64(a), Self::I64(b)) => Self::pearson(a, b),
(Self::I32(a), Self::I32(b)) => Self::pearson(a, b),
(Self::I16(a), Self::I16(b)) => Self::pearson(a, b),
_ => f64::NAN,
}
}
}
impl Distance {
pub(super) fn calculate<V>(&self, a: V, b: V) -> f64
where
V: Borrow<Vector>,
{
pub(super) fn calculate(&self, a: &Vector, b: &Vector) -> f64 {
match self {
Distance::Chebyshev => a.borrow().chebyshev_distance(b.borrow()),
Distance::Cosine => a.borrow().cosine_distance(b.borrow()),
Distance::Euclidean => a.borrow().euclidean_distance(b.borrow()),
Distance::Hamming => a.borrow().hamming_distance(b.borrow()),
Distance::Jaccard => a.borrow().jaccard_similarity(b.borrow()),
Distance::Manhattan => a.borrow().manhattan_distance(b.borrow()),
Distance::Minkowski(order) => {
a.borrow().minkowski_distance(b.borrow(), order.to_float())
}
Distance::Pearson => a.borrow().pearson_similarity(b.borrow()),
Distance::Chebyshev => a.chebyshev_distance(b),
Distance::Cosine => a.cosine_distance(b),
Distance::Euclidean => a.euclidean_distance(b),
Distance::Hamming => a.hamming_distance(b),
Distance::Jaccard => a.jaccard_similarity(b),
Distance::Manhattan => a.manhattan_distance(b),
Distance::Minkowski(order) => a.minkowski_distance(b, order.to_float()),
Distance::Pearson => a.pearson_similarity(b),
}
}
}
@ -481,7 +501,7 @@ impl Distance {
#[cfg(test)]
mod tests {
use crate::idx::trees::knn::tests::{get_seed_rnd, new_random_vec, RandomItemGenerator};
use crate::idx::trees::vector::Vector;
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::sql::index::{Distance, VectorType};
use crate::sql::Array;
@ -497,8 +517,10 @@ mod tests {
// Check the "Vector" optimised implementations
for t in [VectorType::F64] {
let v1 = Vector::try_from_array(t, &Array::from(v1.clone())).unwrap();
let v2 = Vector::try_from_array(t, &Array::from(v2.clone())).unwrap();
let v1: SharedVector =
Vector::try_from_array(t, &Array::from(v1.clone())).unwrap().into();
let v2: SharedVector =
Vector::try_from_array(t, &Array::from(v2.clone())).unwrap().into();
assert_eq!(dist.calculate(&v1, &v2), res);
}
}

View file

@ -2039,6 +2039,7 @@ impl Transaction {
) -> Result<LiveStatement, Error> {
let key = crate::key::table::lq::new(ns, db, tb, *lv);
let key_enc = crate::key::table::lq::Lq::encode(&key)?;
#[cfg(debug_assertions)]
trace!("Getting lv ({:?}) {}", lv, sprint_key(&key_enc));
let val = self.get(key_enc).await?.ok_or(Error::LvNotFound {
value: lv.to_string(),
@ -2056,6 +2057,7 @@ impl Transaction {
) -> Result<DefineEventStatement, Error> {
let key = crate::key::table::ev::new(ns, db, tb, ev);
let key_enc = crate::key::table::ev::Ev::encode(&key)?;
#[cfg(debug_assertions)]
trace!("Getting ev ({:?}) {}", ev, sprint_key(&key_enc));
let val = self.get(key_enc).await?.ok_or(Error::EvNotFound {
value: ev.to_string(),
@ -2073,6 +2075,7 @@ impl Transaction {
) -> Result<DefineFieldStatement, Error> {
let key = crate::key::table::fd::new(ns, db, tb, fd);
let key_enc = crate::key::table::fd::Fd::encode(&key)?;
#[cfg(debug_assertions)]
trace!("Getting fd ({:?}) {}", fd, sprint_key(&key_enc));
let val = self.get(key_enc).await?.ok_or(Error::FdNotFound {
value: fd.to_string(),
@ -2090,6 +2093,7 @@ impl Transaction {
) -> Result<DefineIndexStatement, Error> {
let key = crate::key::table::ix::new(ns, db, tb, ix);
let key_enc = crate::key::table::ix::Ix::encode(&key)?;
#[cfg(debug_assertions)]
trace!("Getting ix ({:?}) {}", ix, sprint_key(&key_enc));
let val = self.get(key_enc).await?.ok_or(Error::IxNotFound {
value: ix.to_string(),
@ -2920,6 +2924,7 @@ impl Transaction {
let ts_pairs: Vec<(Vec<u8>, Vec<u8>)> = self.getr(begin..end, u32::MAX).await?;
let latest_ts_pair = ts_pairs.last();
if let Some((k, _)) = latest_ts_pair {
#[cfg(debug_assertions)]
trace!(
"There already was a greater committed timestamp {} in ns: {}, db: {} found: {}",
ts,

View file

@ -190,7 +190,9 @@ impl Expression {
Operator::Matches(_) => {
fnc::operate::matches(stk, ctx, opt, txn, doc, self, l, r).await
}
Operator::Knn(_, _) => fnc::operate::knn(stk, ctx, opt, txn, doc, self).await,
Operator::Knn(_, _) | Operator::Ann(_, _) => {
fnc::operate::knn(stk, ctx, opt, txn, doc, self).await
}
_ => unreachable!(),
}
}

View file

@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize};
use std::fmt;
use std::fmt::{Display, Formatter};
#[revisioned(revision = 1)]
#[revisioned(revision = 2)]
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
@ -26,6 +26,9 @@ pub enum Index {
Search(SearchParams),
/// M-Tree index for distance based metrics
MTree(MTreeParams),
/// HNSW index for distance based metrics
#[revision(start = 2)]
Hnsw(HnswParams),
}
#[revisioned(revision = 2)]
@ -99,6 +102,49 @@ pub enum Distance1 {
Minkowski(Number),
}
#[revisioned(revision = 1)]
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
pub struct HnswParams {
pub dimension: u16,
pub distance: Distance,
pub vector_type: VectorType,
pub m: u8,
pub m0: u8,
pub ef_construction: u16,
pub extend_candidates: bool,
pub keep_pruned_connections: bool,
pub ml: Number,
}
impl HnswParams {
#[allow(clippy::too_many_arguments)]
pub fn new(
dimension: u16,
distance: Distance,
vector_type: VectorType,
m: u8,
m0: u8,
ml: Number,
ef_construction: u16,
extend_candidates: bool,
keep_pruned_connections: bool,
) -> Self {
Self {
dimension,
distance,
vector_type,
m,
m0,
ef_construction,
ml,
extend_candidates,
keep_pruned_connections,
}
}
}
#[revisioned(revision = 1)]
#[derive(Clone, Default, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
@ -202,6 +248,20 @@ impl Display for Index {
p.dimension, p.distance, p.vector_type, p.capacity, p.doc_ids_order, p.doc_ids_cache, p.mtree_cache
)
}
Self::Hnsw(p) => {
write!(
f,
"HNSW DIMENSION {} DIST {} TYPE {} EFC {} M {} M0 {} ML {}",
p.dimension, p.distance, p.vector_type, p.ef_construction, p.m, p.m0, p.ml
)?;
if p.extend_candidates {
f.write_str(" EXTEND_CANDIDATES")?
}
if p.keep_pruned_connections {
f.write_str(" KEEP_PRUNED_CONNECTIONS")?
}
Ok(())
}
}
}
}

View file

@ -2,6 +2,7 @@ use super::value::{TryAdd, TryDiv, TryMul, TryNeg, TryPow, TryRem, TrySub};
use crate::err::Error;
use crate::fnc::util::math::ToFloat;
use crate::sql::strand::Strand;
use crate::sql::Value;
use revision::revisioned;
use rust_decimal::prelude::*;
use serde::{Deserialize, Serialize};
@ -153,6 +154,54 @@ impl TryFrom<Number> for Decimal {
}
}
impl TryFrom<&Number> for f64 {
type Error = Error;
fn try_from(n: &Number) -> Result<Self, Self::Error> {
Ok(n.to_float())
}
}
impl TryFrom<&Number> for f32 {
type Error = Error;
fn try_from(n: &Number) -> Result<Self, Self::Error> {
n.to_float().to_f32().ok_or_else(|| Error::ConvertTo {
from: Value::Number(n.clone()),
into: "f32".to_string(),
})
}
}
impl TryFrom<&Number> for i64 {
type Error = Error;
fn try_from(n: &Number) -> Result<Self, Self::Error> {
Ok(n.to_int())
}
}
impl TryFrom<&Number> for i32 {
type Error = Error;
fn try_from(n: &Number) -> Result<Self, Self::Error> {
n.to_int().to_i32().ok_or_else(|| Error::ConvertTo {
from: Value::Number(n.clone()),
into: "i32".to_string(),
})
}
}
impl TryFrom<&Number> for i16 {
type Error = Error;
fn try_from(n: &Number) -> Result<Self, Self::Error> {
n.to_int().to_i16().ok_or_else(|| Error::ConvertTo {
from: Value::Number(n.clone()),
into: "i16".to_string(),
})
}
}
impl Display for Number {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {

View file

@ -6,7 +6,7 @@ use std::fmt;
use std::fmt::Write;
/// Binary operators.
#[revisioned(revision = 1)]
#[revisioned(revision = 2)]
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
@ -60,7 +60,9 @@ pub enum Operator {
Outside,
Intersects,
//
Knn(u32, Option<Distance>), // <{k}[,{dist}]>
Knn(u32, Option<Distance>), // <|{k}[,{dist}]|>
#[revision(start = 2)]
Ann(u32, u32), // <|{k},{ef}|>
//
Rem, // %
}
@ -146,6 +148,9 @@ impl fmt::Display for Operator {
write!(f, "<|{k}|>")
}
}
Self::Ann(k, ef) => {
write!(f, "<{k},{ef}>")
}
}
}
}

View file

@ -101,11 +101,13 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("DROP") => TokenKind::Keyword(Keyword::Drop),
UniCase::ascii("DUPLICATE") => TokenKind::Keyword(Keyword::Duplicate),
UniCase::ascii("EDGENGRAM") => TokenKind::Keyword(Keyword::Edgengram),
UniCase::ascii("EFC") => TokenKind::Keyword(Keyword::Efc),
UniCase::ascii("EVENT") => TokenKind::Keyword(Keyword::Event),
UniCase::ascii("ELSE") => TokenKind::Keyword(Keyword::Else),
UniCase::ascii("END") => TokenKind::Keyword(Keyword::End),
UniCase::ascii("EXISTS") => TokenKind::Keyword(Keyword::Exists),
UniCase::ascii("EXPLAIN") => TokenKind::Keyword(Keyword::Explain),
UniCase::ascii("EXTEND_CANDIDATE") => TokenKind::Keyword(Keyword::ExtendCandidates),
UniCase::ascii("false") => TokenKind::Keyword(Keyword::False),
UniCase::ascii("FETCH") => TokenKind::Keyword(Keyword::Fetch),
UniCase::ascii("FIELD") => TokenKind::Keyword(Keyword::Field),
@ -121,6 +123,7 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("FUNCTION") => TokenKind::Keyword(Keyword::Function),
UniCase::ascii("GROUP") => TokenKind::Keyword(Keyword::Group),
UniCase::ascii("HIGHLIGHTS") => TokenKind::Keyword(Keyword::Highlights),
UniCase::ascii("HNSW") => TokenKind::Keyword(Keyword::Hnsw),
UniCase::ascii("IGNORE") => TokenKind::Keyword(Keyword::Ignore),
UniCase::ascii("INCLUDE") => TokenKind::Keyword(Keyword::Include),
UniCase::ascii("INDEX") => TokenKind::Keyword(Keyword::Index),
@ -130,11 +133,15 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("IF") => TokenKind::Keyword(Keyword::If),
UniCase::ascii("IS") => TokenKind::Keyword(Keyword::Is),
UniCase::ascii("KEY") => TokenKind::Keyword(Keyword::Key),
UniCase::ascii("KEEP_PRUNED_CONNECTIONS") => TokenKind::Keyword(Keyword::KeepPrunedConnections),
UniCase::ascii("KILL") => TokenKind::Keyword(Keyword::Kill),
UniCase::ascii("LET") => TokenKind::Keyword(Keyword::Let),
UniCase::ascii("LIMIT") => TokenKind::Keyword(Keyword::Limit),
UniCase::ascii("LIVE") => TokenKind::Keyword(Keyword::Live),
UniCase::ascii("LOWERCASE") => TokenKind::Keyword(Keyword::Lowercase),
UniCase::ascii("M") => TokenKind::Keyword(Keyword::M),
UniCase::ascii("M0") => TokenKind::Keyword(Keyword::M0),
UniCase::ascii("ML") => TokenKind::Keyword(Keyword::ML),
UniCase::ascii("MERGE") => TokenKind::Keyword(Keyword::Merge),
UniCase::ascii("MODEL") => TokenKind::Keyword(Keyword::Model),
UniCase::ascii("MTREE") => TokenKind::Keyword(Keyword::MTree),
@ -248,7 +255,6 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("DURATION") => TokenKind::Keyword(Keyword::Duration),
UniCase::ascii("FLOAT") => TokenKind::Keyword(Keyword::Float),
UniCase::ascii("fn") => TokenKind::Keyword(Keyword::Fn),
UniCase::ascii("ml") => TokenKind::Keyword(Keyword::ML),
UniCase::ascii("INT") => TokenKind::Keyword(Keyword::Int),
UniCase::ascii("NUMBER") => TokenKind::Keyword(Keyword::Number),
UniCase::ascii("OBJECT") => TokenKind::Keyword(Keyword::Object),

View file

@ -353,6 +353,7 @@ impl Parser<'_> {
| TokenKind::Language(_)
| TokenKind::Algorithm(_)
| TokenKind::Distance(_)
| TokenKind::VectorType(_)
| TokenKind::Identifier
)
}

View file

@ -5,6 +5,7 @@ use reblessive::Stk;
use super::mac::unexpected;
use super::ParseError;
use crate::sql::{value::TryNeg, Cast, Expression, Number, Operator, Value};
use crate::syn::token::Token;
use crate::syn::{
parser::{mac::expected, ParseErrorKind, ParseResult, Parser},
token::{t, NumberKind, TokenKind},
@ -194,6 +195,36 @@ impl Parser<'_> {
})))
}
}
pub fn parse_knn(&mut self, token: Token) -> ParseResult<Operator> {
let amount = self.next_token_value()?;
let op = if self.eat(t!(",")) {
let token = self.next();
match &token.kind {
TokenKind::Distance(k) => {
let d = self.convert_distance(k).map(Some)?;
Operator::Knn(amount, d)
},
TokenKind::Number(NumberKind::Integer) => {
let ef = self.token_value(token)?;
Operator::Ann(amount, ef)
}
_ => {
return Err(ParseError::new(
ParseErrorKind::UnexpectedExplain {
found: token.kind,
expected: "a distance or an integer",
explain: "The NN operator accepts either a distance for brute force operation, or an EF value for approximate operations",
},
token.span,
))
}
}
} else {
Operator::Knn(amount, None)
};
self.expect_closing_delimiter(t!("|>"), token.span)?;
Ok(op)
}
async fn parse_infix_op(
&mut self,
@ -260,12 +291,7 @@ impl Parser<'_> {
Operator::NotInside
}
t!("IN") => Operator::Inside,
t!("<|") => {
let amount = self.next_token_value()?;
let dist = self.eat(t!(",")).then(|| self.parse_distance()).transpose()?;
self.expect_closing_delimiter(t!("|>"), token.span)?;
Operator::Knn(amount, dist)
}
t!("<|") => self.parse_knn(token)?,
// should be unreachable as we previously check if the token was a prefix op.
x => unreachable!("found non-operator token {x:?}"),

View file

@ -1,5 +1,6 @@
use reblessive::Stk;
use crate::sql::index::HnswParams;
use crate::{
sql::{
filter::Filter,
@ -685,6 +686,71 @@ impl Parser<'_> {
vector_type,
})
}
t!("HNSW") => {
self.pop_peek();
expected!(self, t!("DIMENSION"));
let dimension = self.next_token_value()?;
let mut distance = Distance::Euclidean;
let mut vector_type = VectorType::F64;
let mut m = None;
let mut m0 = None;
let mut ml = None;
let mut ef_construction = 150;
let mut extend_candidates = false;
let mut keep_pruned_connections = false;
loop {
match self.peek_kind() {
t!("DISTANCE") => {
self.pop_peek();
distance = self.parse_distance()?;
}
t!("TYPE") => {
self.pop_peek();
vector_type = self.parse_vector_type()?;
}
t!("M") => {
self.pop_peek();
m = Some(self.next_token_value()?);
}
t!("M0") => {
self.pop_peek();
m0 = Some(self.next_token_value()?);
}
t!("ML") => {
self.pop_peek();
ml = Some(self.next_token_value()?);
}
t!("EFC") => {
self.pop_peek();
ef_construction = self.next_token_value()?;
}
t!("EXTEND_CANDIDATES") => {
self.pop_peek();
extend_candidates = true;
}
t!("KEEP_PRUNED_CONNECTIONS") => {
self.pop_peek();
keep_pruned_connections = true;
}
_ => break,
}
}
let m = m.unwrap_or(12);
let m0 = m0.unwrap_or(m * 2);
let ml = ml.unwrap_or(1.0 / (m as f64).ln()).into();
res.index = Index::Hnsw(HnswParams::new(
dimension,
distance,
vector_type,
m,
m0,
ml,
ef_construction,
extend_candidates,
keep_pruned_connections,
));
}
t!("COMMENT") => {
self.pop_peek();
res.comment = Some(self.next_token_value()?);

View file

@ -2,13 +2,11 @@
use reblessive::Stk;
use crate::sql::index::VectorType;
use crate::syn::token::VectorTypeKind;
use crate::{
sql::{
change_feed_include::ChangeFeedInclude, changefeed::ChangeFeed, index::Distance, Base,
Cond, Data, Duration, Fetch, Fetchs, Field, Fields, Group, Groups, Ident, Idiom, Output,
Permission, Permissions, Tables, Timeout, Value, View,
change_feed_include::ChangeFeedInclude, changefeed::ChangeFeed, index::Distance,
index::VectorType, Base, Cond, Data, Duration, Fetch, Fetchs, Field, Fields, Group, Groups,
Ident, Idiom, Output, Permission, Permissions, Tables, Timeout, Value, View,
},
syn::{
parser::{
@ -16,7 +14,7 @@ use crate::{
mac::{expected, unexpected},
ParseError, ParseErrorKind, ParseResult, Parser,
},
token::{t, DistanceKind, Span, TokenKind},
token::{t, DistanceKind, Span, TokenKind, VectorTypeKind},
},
};
@ -383,32 +381,29 @@ impl Parser<'_> {
})
}
pub fn parse_distance(&mut self) -> ParseResult<Distance> {
let dist = match self.next().kind {
TokenKind::Distance(x) => match x {
DistanceKind::Chebyshev => Distance::Chebyshev,
DistanceKind::Cosine => Distance::Cosine,
DistanceKind::Euclidean => Distance::Euclidean,
DistanceKind::Manhattan => Distance::Manhattan,
DistanceKind::Hamming => Distance::Hamming,
DistanceKind::Jaccard => Distance::Jaccard,
DistanceKind::Minkowski => {
let distance = self.next_token_value()?;
Distance::Minkowski(distance)
}
DistanceKind::Pearson => Distance::Pearson,
},
x => unexpected!(self, x, "a distance measure"),
pub fn convert_distance(&mut self, k: &DistanceKind) -> ParseResult<Distance> {
let dist = match k {
DistanceKind::Chebyshev => Distance::Chebyshev,
DistanceKind::Cosine => Distance::Cosine,
DistanceKind::Euclidean => Distance::Euclidean,
DistanceKind::Manhattan => Distance::Manhattan,
DistanceKind::Hamming => Distance::Hamming,
DistanceKind::Jaccard => Distance::Jaccard,
DistanceKind::Minkowski => {
let distance = self.next_token_value()?;
Distance::Minkowski(distance)
}
DistanceKind::Pearson => Distance::Pearson,
};
Ok(dist)
}
pub fn try_parse_distance(&mut self) -> ParseResult<Option<Distance>> {
if !self.eat(t!("DISTANCE")) {
return Ok(None);
pub fn parse_distance(&mut self) -> ParseResult<Distance> {
match self.next().kind {
TokenKind::Distance(k) => self.convert_distance(&k),
x => unexpected!(self, x, "a distance measure"),
}
self.parse_distance().map(Some)
}
pub fn parse_vector_type(&mut self) -> ParseResult<VectorType> {

View file

@ -64,12 +64,14 @@ keyword! {
DocLengthsOrder => "DOC_LENGTHS_ORDER",
Drop => "DROP",
Duplicate => "DUPLICATE",
Efc => "EFC",
Edgengram => "EDGENGRAM",
Event => "EVENT",
Else => "ELSE",
End => "END",
Exists => "EXISTS",
Explain => "EXPLAIN",
ExtendCandidates => "EXTEND_CANDIDATES",
False => "false",
Fetch => "FETCH",
Field => "FIELD",
@ -82,6 +84,7 @@ keyword! {
Function => "FUNCTION",
Group => "GROUP",
Highlights => "HIGHLIGHTS",
Hnsw => "HNSW",
Ignore => "IGNORE",
Include => "INCLUDE",
Index => "INDEX",
@ -91,13 +94,17 @@ keyword! {
If => "IF",
Is => "IS",
Key => "KEY",
KeepPrunedConnections => "KEEP_PRUNED_CONNECTIONS",
Kill => "KILL",
Let => "LET",
Limit => "LIMIT",
Live => "LIVE",
Lowercase => "LOWERCASE",
M => "M",
M0 => "M0",
Merge => "MERGE",
Model => "MODEL",
Ml => "ML",
MTree => "MTREE",
MTreeCache => "MTREE_CACHE",
Namespace => "NAMESPACE",

View file

@ -305,6 +305,23 @@ macro_rules! t {
$crate::syn::token::TokenKind::Distance($crate::syn::token::DistanceKind::Pearson)
};
// VectorType
("F64") => {
$crate::syn::token::TokenKind::VectorType($crate::syn::token::VectorTypeKind::F64)
};
("F32") => {
$crate::syn::token::TokenKind::VectorType($crate::syn::token::VectorTypeKind::F32)
};
("I64") => {
$crate::syn::token::TokenKind::VectorType($crate::syn::token::VectorTypeKind::I64)
};
("I32") => {
$crate::syn::token::TokenKind::VectorType($crate::syn::token::VectorTypeKind::I32)
};
("I16") => {
$crate::syn::token::TokenKind::VectorType($crate::syn::token::VectorTypeKind::I16)
};
($t:tt) => {
$crate::syn::token::TokenKind::Keyword($crate::syn::token::keyword_t!($t))
};

View file

@ -115,6 +115,8 @@ reblessive = { version = "0.3.5", features = ["tree"] }
[dev-dependencies]
criterion = { version = "0.5.1", features = ["async_tokio"] }
env_logger = "0.10.1"
flate2 = "1.0.28"
hashbrown = "0.14.5"
pprof = { version = "0.13.0", features = ["flamegraph", "criterion"] }
rand = "0.8.5"
radix_trie = "0.2.1"
@ -183,6 +185,10 @@ harness = false
name = "index_mtree"
harness = false
[[bench]]
name = "index_hnsw"
harness = false
[[bench]]
name = "move_vs_clone"
harness = false
@ -190,3 +196,7 @@ harness = false
[[bench]]
name = "sdb"
harness = false
[[bench]]
name = "hashset_vs_vector"
harness = false

View file

@ -0,0 +1,120 @@
use criterion::measurement::WallTime;
use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion, Throughput};
use std::collections::HashSet;
use std::time::{Duration, SystemTime};
use surrealdb_core::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
fn bench_hashset(samples_vec: &Vec<Vec<u64>>) {
for samples in samples_vec {
let mut h = HashSet::with_capacity(samples.len());
for &s in samples {
h.insert(s);
}
for s in samples {
assert!(h.contains(s));
}
assert_eq!(h.len(), samples.len());
}
}
fn bench_hashbrown(samples_vec: &Vec<Vec<u64>>) {
for samples in samples_vec {
let mut h = HashBrownSet::<u64>::with_capacity(samples.len());
for &s in samples {
h.insert(s);
}
for s in samples {
assert!(h.contains(s));
}
assert_eq!(h.len(), samples.len());
}
}
fn bench_vector(samples_vec: &Vec<Vec<u64>>) {
for samples in samples_vec {
let mut v = Vec::with_capacity(samples.len());
for &s in samples {
// Same behaviour than Hash
if !v.contains(&s) {
v.push(s);
}
}
for s in samples {
assert!(v.contains(s));
}
assert_eq!(v.len(), samples.len());
}
}
fn bench_array<const N: usize>(samples_vec: &Vec<Vec<u64>>) {
for samples in samples_vec {
let mut v = ArraySet::<u64, N>::with_capacity(samples.len());
for &s in samples {
v.insert(s);
}
for s in samples {
assert!(v.contains(s));
}
assert_eq!(v.len(), samples.len());
}
}
fn create_samples(capacity: usize, num_samples: usize) -> Vec<Vec<u64>> {
let mut s = SystemTime::now().elapsed().unwrap().as_secs();
let mut res = Vec::with_capacity(num_samples);
for _ in 0..num_samples {
let mut samples = Vec::with_capacity(capacity);
for _ in 0..capacity {
s += 1;
samples.push(s);
}
res.push(samples);
}
res
}
/// This bench compares the performance of insert and search for small size HashSet collections.
/// It compares HashSet, HashBrown, Vector and SmallVec.
/// It is used to help choosing the best options for the UndirectedGraph used for the HNSW index.
/// The ultimate goal is to be sure that the DynamicSet use the best option based on the expected capacity.
fn bench_hashset_vs_vector(c: &mut Criterion) {
const ITERATIONS: usize = 1_000_000;
let mut group = c.benchmark_group("hashset_vs_vector");
group.throughput(Throughput::Elements(ITERATIONS as u64));
group.sample_size(10);
group.measurement_time(Duration::from_secs(10));
group_test::<4>(&mut group, ITERATIONS);
group_test::<8>(&mut group, ITERATIONS);
group_test::<16>(&mut group, ITERATIONS);
group_test::<24>(&mut group, ITERATIONS);
group_test::<28>(&mut group, ITERATIONS);
group_test::<30>(&mut group, ITERATIONS);
group_test::<32>(&mut group, ITERATIONS);
group.finish();
}
fn group_test<const N: usize>(group: &mut BenchmarkGroup<WallTime>, iterations: usize) {
let samples = create_samples(N, iterations);
group.bench_function(format!("hashset_{N}"), |b| {
b.iter(|| bench_hashset(&samples));
});
group.bench_function(format!("hashbrown_{N}"), |b| {
b.iter(|| bench_hashbrown(&samples));
});
group.bench_function(format!("vector_{N}"), |b| {
b.iter(|| bench_vector(&samples));
});
group.bench_function(format!("array_{N}"), |b| {
b.iter(|| bench_array::<N>(&samples));
});
}
criterion_group!(benches, bench_hashset_vs_vector);
criterion_main!(benches);

238
lib/benches/index_hnsw.rs Normal file
View file

@ -0,0 +1,238 @@
use criterion::measurement::WallTime;
use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion, Throughput};
use flate2::read::GzDecoder;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::time::Duration;
use surrealdb::idx::trees::hnsw::HnswIndex;
use surrealdb::sql::index::Distance;
use surrealdb_core::dbs::Session;
use surrealdb_core::kvs::Datastore;
use surrealdb_core::sql::index::{HnswParams, VectorType};
use surrealdb_core::sql::{value, Array, Id, Thing, Value};
use tokio::runtime::{Builder, Runtime};
const EF_CONSTRUCTION: u16 = 150;
const EF_SEARCH: usize = 80;
const NN: usize = 10;
const M: u8 = 24;
const M0: u8 = 48;
const DIMENSION: u16 = 20;
const INGESTING_SOURCE: &str = "../tests/data/hnsw-random-9000-20-euclidean.gz";
const QUERYING_SOURCE: &str = "../tests/data/hnsw-random-5000-20-euclidean.gz";
fn bench_hnsw_no_db(c: &mut Criterion) {
const GROUP_NAME: &str = "hnsw_no_db";
let samples = new_vectors_from_file(INGESTING_SOURCE);
let samples: Vec<(Thing, Vec<Value>)> =
samples.into_iter().map(|(r, a)| (r, vec![Value::Array(a)])).collect();
// Indexing benchmark group
{
let mut group = get_group(c, GROUP_NAME, samples.len(), 10);
let id = format!("insert len: {}", samples.len());
group.bench_function(id, |b| {
b.iter(|| insert_objects(&samples));
});
group.finish();
}
// Create an HNSW instance with data
let hnsw = insert_objects(&samples);
let samples = new_vectors_from_file(QUERYING_SOURCE);
let samples: Vec<Array> = samples.into_iter().map(|(_, a)| a).collect();
// Knn lookup benchmark group
{
let mut group = get_group(c, GROUP_NAME, samples.len(), 10);
let id = format!("lookup len: {}", samples.len());
group.bench_function(id, |b| {
b.iter(|| knn_lookup_objects(&hnsw, &samples));
});
group.finish();
}
}
fn bench_hnsw_with_db(c: &mut Criterion) {
const GROUP_NAME: &str = "hnsw_with_db";
let samples = new_vectors_from_file(INGESTING_SOURCE);
let samples: Vec<String> =
samples.into_iter().map(|(r, a)| format!("CREATE {r} SET r={a} RETURN NONE;")).collect();
let session = &Session::owner().with_ns("ns").with_db("db");
// Indexing benchmark group
{
let mut group = get_group(c, GROUP_NAME, samples.len(), 10);
let id = format!("insert len: {}", samples.len());
group.bench_function(id, |b| {
b.to_async(Runtime::new().unwrap()).iter(|| insert_objects_db(session, true, &samples));
});
group.finish();
}
let b = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap();
let ds = b.block_on(insert_objects_db(session, true, &samples));
// Knn lookup benchmark group
let samples = new_vectors_from_file(QUERYING_SOURCE);
let selects: Vec<String> = samples
.into_iter()
.map(|(_, a)| format!("SELECT id FROM e WHERE r <|{NN},{EF_SEARCH}|> {a};"))
.collect();
{
let mut group = get_group(c, GROUP_NAME, selects.len(), 10);
let id = format!("lookup len: {}", selects.len());
group.bench_function(id, |b| {
b.to_async(Runtime::new().unwrap())
.iter(|| knn_lookup_objects_db(&ds, session, &selects));
});
group.finish();
}
}
fn bench_db_without_index(c: &mut Criterion) {
const GROUP_NAME: &str = "hnsw_without_index";
let samples = new_vectors_from_file(INGESTING_SOURCE);
let samples: Vec<String> =
samples.into_iter().map(|(r, a)| format!("CREATE {r} SET r={a} RETURN NONE;")).collect();
let session = &Session::owner().with_ns("ns").with_db("db");
// Ingesting benchmark group
{
let mut group = get_group(c, GROUP_NAME, samples.len(), 10);
let id = format!("insert len: {}", samples.len());
group.bench_function(id, |b| {
b.to_async(Runtime::new().unwrap())
.iter(|| insert_objects_db(session, false, &samples));
});
group.finish();
}
let b = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap();
let ds = b.block_on(insert_objects_db(session, false, &samples));
// Knn lookup benchmark group
let samples = new_vectors_from_file(QUERYING_SOURCE);
let selects: Vec<String> = samples
.into_iter()
.map(|(id, _)| format!("SELECT id FROM {id},{id},{id},{id},{id},{id},{id},{id},{id},{id};"))
.collect();
{
let mut group = get_group(c, GROUP_NAME, selects.len(), 10);
let id = format!("lookup len: {}", selects.len());
group.bench_function(id, |b| {
b.to_async(Runtime::new().unwrap())
.iter(|| knn_lookup_objects_db(&ds, session, &selects));
});
group.finish();
}
}
fn get_group<'a>(
c: &'a mut Criterion,
group_name: &str,
samples_len: usize,
measurement_secs: u64,
) -> BenchmarkGroup<'a, WallTime> {
let mut group = c.benchmark_group(group_name);
group.throughput(Throughput::Elements(samples_len as u64));
group.sample_size(10);
group.measurement_time(Duration::from_secs(measurement_secs));
group
}
fn new_vectors_from_file(path: &str) -> Vec<(Thing, Array)> {
// Open the gzip file
let file = File::open(path).unwrap();
// Create a GzDecoder to read the file
let gz = GzDecoder::new(file);
// Wrap the decoder in a BufReader
let reader = BufReader::new(gz);
let mut res = Vec::new();
// Iterate over each line in the file
for (i, line_result) in reader.lines().enumerate() {
let line = line_result.unwrap();
let value = value(&line).unwrap();
if let Value::Array(a) = value {
res.push((Thing::from(("e", Id::Number(i as i64))), a));
} else {
panic!("Wrong value");
}
}
res
}
async fn init_datastore(session: &Session, with_index: bool) -> Datastore {
let ds = Datastore::new("memory").await.unwrap();
if with_index {
let sql = format!("DEFINE INDEX ix ON e FIELDS r HNSW DIMENSION {DIMENSION} DIST EUCLIDEAN TYPE F32 EFC {EF_CONSTRUCTION} M {M};");
ds.execute(&sql, session, None).await.expect(&sql);
}
ds
}
fn hnsw() -> HnswIndex {
let p = HnswParams::new(
DIMENSION,
Distance::Euclidean,
VectorType::F32,
M,
M0,
(1.0 / (M as f64).ln()).into(),
EF_CONSTRUCTION,
false,
false,
);
HnswIndex::new(&p)
}
fn insert_objects(samples: &[(Thing, Vec<Value>)]) -> HnswIndex {
let mut h = hnsw();
for (id, content) in samples {
h.index_document(&id, content).unwrap();
}
h
}
async fn insert_objects_db(session: &Session, create_index: bool, inserts: &[String]) -> Datastore {
let ds = init_datastore(session, create_index).await;
for sql in inserts {
ds.execute(sql, session, None).await.expect(&sql);
}
ds
}
fn knn_lookup_objects(h: &HnswIndex, samples: &[Array]) {
for a in samples {
let r = h.knn_search(a, NN, EF_SEARCH).unwrap();
assert_eq!(r.len(), NN);
}
}
async fn knn_lookup_objects_db(ds: &Datastore, session: &Session, selects: &[String]) {
for sql in selects {
let mut res = ds.execute(sql, session, None).await.expect(&sql);
let res = res.remove(0).result.expect(&sql);
if let Value::Array(a) = &res {
assert_eq!(a.len(), NN);
} else {
panic!("{res:#}");
}
}
}
criterion_group!(benches, bench_hnsw_no_db, bench_hnsw_with_db, bench_db_without_index);
criterion_main!(benches);

View file

@ -110,7 +110,7 @@ fn random_object(rng: &mut ThreadRng, vector_size: usize) -> Vector {
for _ in 0..vector_size {
vec.push(rng.gen_range(-1.0..=1.0));
}
Vector::F32(vec)
Vector::F32(vec.into())
}
fn mtree() -> MTree {

View file

@ -3,14 +3,14 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;
const ITERATIONS: usize = 1_000_000;
const ITERATIONS: u32 = 1_000_000;
fn bench_move() {
let mut value = Arc::new(AtomicU32::new(0));
for _ in 0..ITERATIONS {
value = do_something_with_move(value);
}
assert_eq!(value.load(Ordering::Relaxed), ITERATIONS as u32);
assert_eq!(value.load(Ordering::Relaxed), ITERATIONS);
}
fn do_something_with_move(value: Arc<AtomicU32>) -> Arc<AtomicU32> {
@ -23,7 +23,7 @@ fn bench_clone() {
for _ in 0..ITERATIONS {
do_something_with_clone(value.clone());
}
assert_eq!(value.load(Ordering::Relaxed), ITERATIONS as u32);
assert_eq!(value.load(Ordering::Relaxed), ITERATIONS);
}
fn do_something_with_clone(value: Arc<AtomicU32>) {
@ -32,7 +32,7 @@ fn do_something_with_clone(value: Arc<AtomicU32>) {
fn bench_move_vs_clone(c: &mut Criterion) {
let mut group = c.benchmark_group("move_vs_clone");
group.throughput(Throughput::Elements(1));
group.throughput(Throughput::Elements(ITERATIONS as u64));
group.sample_size(10);
group.measurement_time(Duration::from_secs(10));

View file

@ -125,7 +125,7 @@ async fn remove_statement_index() -> Result<(), Error> {
}
// Every index store cache has been removed
assert!(dbs.index_store().is_empty());
assert!(dbs.index_store().is_empty().await);
Ok(())
}

View file

@ -1,6 +1,6 @@
mod helpers;
mod parse;
use crate::helpers::new_ds;
use crate::helpers::{new_ds, skip_ok};
use parse::Parse;
use surrealdb::dbs::Session;
use surrealdb::err::Error;
@ -155,42 +155,23 @@ async fn index_embedding() -> Result<(), Error> {
}
#[tokio::test]
async fn select_where_brut_force_knn() -> Result<(), Error> {
async fn select_where_brute_force_knn() -> Result<(), Error> {
let sql = r"
CREATE pts:1 SET point = [1,2,3,4];
CREATE pts:2 SET point = [4,5,6,7];
CREATE pts:3 SET point = [8,9,10,11];
LET $pt = [2,3,4,5];
SELECT id FROM pts WHERE point <|2,EUCLIDEAN|> $pt EXPLAIN;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt PARALLEL;
SELECT id FROM pts WHERE point <|2|> $pt EXPLAIN;
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
for _ in 0..4 {
let _ = res.remove(0).result?;
}
for _ in 0..2 {
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: pts:1,
dist: 2f
},
{
id: pts:2,
dist: 4f
}
]",
);
assert_eq!(format!("{:#}", tmp), format!("{:#}", val));
}
skip_ok(res, 4)?;
//
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
@ -215,5 +196,81 @@ async fn select_where_brut_force_knn() -> Result<(), Error> {
]",
);
assert_eq!(format!("{:#}", tmp), format!("{:#}", val));
//
for i in 0..2 {
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: pts:1,
dist: 2f
},
{
id: pts:2,
dist: 4f
}
]",
);
assert_eq!(format!("{:#}", tmp), format!("{:#}", val), "{i}");
}
Ok(())
}
#[tokio::test]
async fn select_where_hnsw_knn() -> Result<(), Error> {
let sql = r"
CREATE pts:1 SET point = [1,2,3,4];
CREATE pts:2 SET point = [4,5,6,7];
CREATE pts:3 SET point = [8,9,10,11];
DEFINE INDEX hnsw_pts ON pts FIELDS point HNSW DIMENSION 4 DIST EUCLIDEAN TYPE F32 EFC 500 M 12;
LET $pt = [2,3,4,5];
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,100|> $pt;
SELECT id FROM pts WHERE point <|2,100|> $pt EXPLAIN;
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
for _ in 0..5 {
let _ = res.remove(0).result?;
}
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
id: pts:1,
dist: 2f
},
{
id: pts:2,
dist: 4f
}
]",
);
assert_eq!(format!("{:#}", tmp), format!("{:#}", val));
let tmp = res.remove(0).result?;
let val = Value::parse(
"[
{
detail: {
plan: {
index: 'hnsw_pts',
operator: '<2,100>',
value: [2,3,4,5]
},
table: 'pts',
},
operation: 'Iterate Index'
},
{
detail: {
type: 'Memory'
},
operation: 'Collector'
}
]",
);
assert_eq!(format!("{:#}", tmp), format!("{:#}", val));
Ok(())
}

View file

@ -34,6 +34,12 @@ user-id = 145457 # Tobie Morgan Hitchcock (tobiemh)
start = "2022-01-27"
end = "2025-01-24"
[[trusted.hashbrown]]
criteria = "safe-to-deploy"
user-id = 2915 # Amanieu d'Antras (Amanieu)
start = "2019-04-02"
end = "2025-05-02"
[[trusted.indxdb]]
criteria = "safe-to-deploy"
user-id = 145457 # Tobie Morgan Hitchcock (tobiemh)

View file

@ -131,6 +131,10 @@ criteria = "safe-to-deploy"
version = "1.0.81"
criteria = "safe-to-deploy"
[[exemptions.approx]]
version = "0.4.0"
criteria = "safe-to-deploy"
[[exemptions.approx]]
version = "0.5.1"
criteria = "safe-to-deploy"
@ -751,10 +755,6 @@ criteria = "safe-to-deploy"
version = "0.3.1"
criteria = "safe-to-deploy"
[[exemptions.hashbrown]]
version = "0.14.3"
criteria = "safe-to-deploy"
[[exemptions.headers]]
version = "0.3.9"
criteria = "safe-to-deploy"
@ -923,10 +923,6 @@ criteria = "safe-to-deploy"
version = "0.7.4"
criteria = "safe-to-deploy"
[[exemptions.libloading]]
version = "0.8.3"
criteria = "safe-to-deploy"
[[exemptions.libm]]
version = "0.2.8"
criteria = "safe-to-deploy"
@ -951,6 +947,10 @@ criteria = "safe-to-deploy"
version = "1.1.16"
criteria = "safe-to-deploy"
[[exemptions.linfa-linalg]]
version = "0.1.0"
criteria = "safe-to-deploy"
[[exemptions.linux-raw-sys]]
version = "0.4.13"
criteria = "safe-to-deploy"
@ -1047,6 +1047,10 @@ criteria = "safe-to-deploy"
version = "0.15.6"
criteria = "safe-to-deploy"
[[exemptions.ndarray-stats]]
version = "0.5.1"
criteria = "safe-to-deploy"
[[exemptions.new_debug_unreachable]]
version = "1.0.6"
criteria = "safe-to-deploy"
@ -1063,6 +1067,10 @@ criteria = "safe-to-deploy"
version = "0.27.1"
criteria = "safe-to-deploy"
[[exemptions.noisy_float]]
version = "0.2.0"
criteria = "safe-to-deploy"
[[exemptions.num-bigint-dig]]
version = "0.8.4"
criteria = "safe-to-deploy"
@ -2187,14 +2195,6 @@ criteria = "safe-to-deploy"
version = "0.5.2"
criteria = "safe-to-run"
[[exemptions.zerocopy]]
version = "0.7.32"
criteria = "safe-to-deploy"
[[exemptions.zerocopy-derive]]
version = "0.7.32"
criteria = "safe-to-deploy"
[[exemptions.zeroize]]
version = "1.7.0"
criteria = "safe-to-deploy"

View file

@ -79,6 +79,13 @@ user-id = 4484
user-login = "hsivonen"
user-name = "Henri Sivonen"
[[publisher.hashbrown]]
version = "0.14.5"
when = "2024-04-28"
user-id = 2915
user-login = "Amanieu"
user-name = "Amanieu d'Antras"
[[publisher.indxdb]]
version = "0.4.0"
when = "2023-06-13"
@ -1333,6 +1340,12 @@ version = "1.4.0"
notes = "I have read over the macros, and audited the unsafe code."
aggregated-from = "https://raw.githubusercontent.com/mozilla/cargo-vet/main/supply-chain/audits.toml"
[[audits.mozilla.audits.libloading]]
who = "Erich Gubler <erichdongubler@gmail.com>"
criteria = "safe-to-deploy"
delta = "0.7.4 -> 0.8.3"
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"
[[audits.mozilla.audits.log]]
who = "Mike Hommey <mh+mozilla@glandium.org>"
criteria = "safe-to-deploy"
@ -1563,6 +1576,26 @@ criteria = "safe-to-deploy"
delta = "2.4.1 -> 2.5.0"
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"
[[audits.mozilla.audits.zerocopy]]
who = "Alex Franchuk <afranchuk@mozilla.com>"
criteria = "safe-to-deploy"
version = "0.7.32"
notes = """
This crate is `no_std` so doesn't use any side-effectful std functions. It
contains quite a lot of `unsafe` code, however. I verified portions of this. It
also has a large, thorough test suite. The project claims to run tests with
Miri to have stronger soundness checks, and also claims to use formal
verification tools to prove correctness.
"""
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"
[[audits.mozilla.audits.zerocopy-derive]]
who = "Alex Franchuk <afranchuk@mozilla.com>"
criteria = "safe-to-deploy"
version = "0.7.32"
notes = "Clean, safe macros for zerocopy."
aggregated-from = "https://hg.mozilla.org/mozilla-central/raw-file/tip/supply-chain/audits.toml"
[[audits.zcash.audits.base64]]
who = "Jack Grigg <jack@electriccoin.co>"
criteria = "safe-to-deploy"

Binary file not shown.

Binary file not shown.