MTree bench improvements (replaces hashbrown with ahash) (#4408)

This commit is contained in:
Emmanuel Keller 2024-07-23 09:24:00 +01:00 committed by GitHub
parent 07610d9411
commit 08f4ad6c82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 105 additions and 139 deletions

2
Cargo.lock generated
View file

@ -2493,7 +2493,6 @@ checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash 0.8.11",
"allocator-api2",
"serde",
]
[[package]]
@ -6018,7 +6017,6 @@ dependencies = [
"fuzzy-matcher",
"geo 0.27.0",
"geo-types",
"hashbrown 0.14.5",
"hex",
"indxdb",
"ipnet",

View file

@ -81,7 +81,6 @@ futures = "0.3.29"
fuzzy-matcher = "0.3.7"
geo = { version = "0.27.0", features = ["use-serde"] }
geo-types = { version = "0.7.12", features = ["arbitrary"] }
hashbrown = { version = "0.14.5", features = ["serde"] }
hex = { version = "0.4.3" }
indxdb = { version = "0.5.0", optional = true }
ipnet = "2.9.0"

View file

@ -7,10 +7,10 @@ use crate::idx::planner::iterators::KnnIteratorResult;
use crate::idx::trees::hnsw::docs::HnswDocs;
use crate::idx::trees::knn::Ids64;
use crate::sql::{Cond, Thing, Value};
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
use ahash::HashMap;
use reblessive::tree::Stk;
use std::borrow::Cow;
use std::collections::hash_map::Entry;
use std::collections::VecDeque;
use std::sync::Arc;

View file

@ -1,5 +1,5 @@
use crate::sql::{Expression, Number, Thing};
use hashbrown::{HashMap, HashSet};
use ahash::{HashMap, HashMapExt, HashSet, HashSetExt};
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::sync::Arc;
@ -16,7 +16,7 @@ impl KnnPriorityList {
pub(super) fn new(knn: usize) -> Self {
Self(Arc::new(Mutex::new(Inner {
knn,
docs: HashSet::new(),
docs: HashSet::with_capacity(knn),
priority_list: BTreeMap::default(),
})))
}

View file

@ -5,7 +5,7 @@ use crate::idx::VersionedSerdeState;
use crate::kvs::{Key, Transaction, Val};
use crate::sql::{Object, Value};
#[cfg(debug_assertions)]
use hashbrown::HashSet;
use ahash::HashSet;
use revision::{revisioned, Revisioned};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
@ -954,7 +954,7 @@ where
) -> Result<BStatistics, Error> {
let mut stats = BStatistics::default();
#[cfg(debug_assertions)]
let mut keys = HashSet::new();
let mut keys = HashSet::default();
let mut node_queue = VecDeque::new();
if let Some(node_id) = self.state.root {
node_queue.push_front((node_id, 1));

View file

@ -1,4 +1,4 @@
use hashbrown::HashSet;
use ahash::{HashSet, HashSetExt};
use std::fmt::Debug;
use std::hash::Hash;
@ -126,11 +126,11 @@ where
#[cfg(test)]
mod tests {
use crate::idx::trees::dynamicset::{ArraySet, DynamicSet, HashBrownSet};
use hashbrown::HashSet;
use ahash::HashSet;
fn test_dynamic_set<S: DynamicSet<usize>>(capacity: usize) {
let mut dyn_set = S::with_capacity(capacity);
let mut control = HashSet::new();
let mut control = HashSet::default();
// Test insertions
for sample in 0..capacity {
assert_eq!(dyn_set.len(), control.len(), "{capacity} - {sample}");

View file

@ -1,8 +1,8 @@
use crate::idx::trees::dynamicset::DynamicSet;
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
use ahash::HashMap;
#[cfg(test)]
use hashbrown::HashSet;
use ahash::HashSet;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::hash::Hash;
@ -24,7 +24,7 @@ where
pub(super) fn new(capacity: usize) -> Self {
Self {
capacity,
nodes: HashMap::new(),
nodes: HashMap::default(),
}
}

View file

@ -1,7 +1,7 @@
use crate::idx::trees::hnsw::ElementId;
use crate::idx::trees::vector::SharedVector;
use crate::sql::index::Distance;
use hashbrown::HashMap;
use ahash::HashMap;
pub(super) struct HnswElements {
elements: HashMap<ElementId, SharedVector>,

View file

@ -10,9 +10,9 @@ use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::sql::index::{HnswParams, VectorType};
use crate::sql::{Number, Thing, Value};
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
use ahash::HashMap;
use reblessive::tree::Stk;
use std::collections::hash_map::Entry;
use std::collections::VecDeque;
pub struct HnswIndex {
@ -197,7 +197,7 @@ impl HnswIndex {
}
builder.build(
#[cfg(debug_assertions)]
HashMap::new(),
HashMap::default(),
)
}

View file

@ -7,7 +7,7 @@ use crate::idx::trees::hnsw::index::HnswCheckedSearchContext;
use crate::idx::trees::hnsw::{ElementId, HnswElements};
use crate::idx::trees::knn::DoublePriorityQueue;
use crate::idx::trees::vector::SharedVector;
use hashbrown::HashSet;
use ahash::HashSet;
use reblessive::tree::Stk;
#[derive(Debug)]
@ -49,7 +49,7 @@ where
ep_id: ElementId,
ef: usize,
) -> DoublePriorityQueue {
let visited = HashSet::from([ep_id]);
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(ep_dist, ep_id);
let w = candidates.clone();
self.search(elements, pt, candidates, visited, w, ef)
@ -64,7 +64,7 @@ where
stk: &mut Stk,
chk: &mut HnswConditionChecker<'_>,
) -> Result<DoublePriorityQueue, Error> {
let visited = HashSet::from([ep_id]);
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(ep_dist, ep_id);
let mut w = DoublePriorityQueue::default();
Self::add_if_truthy(search, &mut w, ep_pt, ep_dist, ep_id, stk, chk).await?;
@ -89,7 +89,7 @@ where
pt: &SharedVector,
ep_id: ElementId,
) -> Option<(f64, ElementId)> {
let visited = HashSet::from([ep_id]);
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(0.0, ep_id);
let w = candidates.clone();
let q = self.search(elements, pt, candidates, visited, w, 1);
@ -103,7 +103,7 @@ where
ep_id: ElementId,
efc: usize,
) -> DoublePriorityQueue {
let visited = HashSet::from([ep_id]);
let visited = HashSet::from_iter([ep_id]);
let candidates = DoublePriorityQueue::from(0.0, ep_id);
let w = DoublePriorityQueue::default();
self.search(elements, pt, candidates, visited, w, efc)

View file

@ -302,10 +302,11 @@ mod tests {
use crate::idx::trees::knn::{Ids64, KnnResult, KnnResultBuilder};
use crate::idx::trees::vector::{SharedVector, Vector};
use crate::sql::index::{Distance, HnswParams, VectorType};
use hashbrown::{hash_map::Entry, HashMap, HashSet};
use ahash::{HashMap, HashSet};
use ndarray::Array1;
use reblessive::tree::Stk;
use roaring::RoaringTreemap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use test_log::test;
@ -313,7 +314,7 @@ mod tests {
h: &mut HnswFlavor,
collection: &TestCollection,
) -> HashSet<SharedVector> {
let mut set = HashSet::new();
let mut set = HashSet::default();
for (_, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone();
h.insert(obj.clone());
@ -445,7 +446,7 @@ mod tests {
h: &mut HnswIndex,
collection: &TestCollection,
) -> HashMap<SharedVector, HashSet<DocId>> {
let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::new();
let mut map: HashMap<SharedVector, HashSet<DocId>> = HashMap::default();
for (doc_id, obj) in collection.to_vec_ref() {
let obj: SharedVector = obj.clone();
h.insert(obj.clone(), *doc_id);
@ -454,7 +455,7 @@ mod tests {
e.get_mut().insert(*doc_id);
}
Entry::Vacant(e) => {
e.insert(HashSet::from([*doc_id]));
e.insert(HashSet::from_iter([*doc_id]));
}
}
h.check_hnsw_properties(map.len());
@ -726,7 +727,7 @@ mod tests {
}
b.build(
#[cfg(debug_assertions)]
HashMap::new(),
HashMap::default(),
)
}
}

View file

@ -3,8 +3,8 @@ use crate::idx::trees::dynamicset::DynamicSet;
use crate::idx::trees::hnsw::ElementId;
use crate::idx::trees::store::NodeId;
#[cfg(debug_assertions)]
use hashbrown::HashMap;
use hashbrown::HashSet;
use ahash::HashMap;
use ahash::{HashSet, HashSetExt};
use roaring::RoaringTreemap;
use std::cmp::{Ordering, Reverse};
use std::collections::btree_map::Entry;
@ -619,10 +619,10 @@ pub(super) mod tests {
use crate::sql::index::{Distance, VectorType};
use crate::sql::{Array, Number, Value};
use crate::syn::Parse;
use flate2::read::GzDecoder;
#[cfg(debug_assertions)]
use hashbrown::HashMap;
use hashbrown::HashSet;
use ahash::HashMap;
use ahash::HashSet;
use flate2::read::GzDecoder;
use rand::prelude::SmallRng;
use rand::{Rng, SeedableRng};
use roaring::RoaringTreemap;
@ -755,7 +755,7 @@ pub(super) mod tests {
gen: &RandomItemGenerator,
rng: &mut SmallRng,
) -> Self {
let mut vector_set = HashSet::new();
let mut vector_set = HashSet::default();
let mut attempts = collection_size * 2;
while vector_set.len() < collection_size {
vector_set.insert(new_random_vec(rng, vector_type, dimension, gen));
@ -821,7 +821,7 @@ pub(super) mod tests {
b.add(0.2, &Ids64::Vec2([6, 8]));
let res = b.build(
#[cfg(debug_assertions)]
HashMap::new(),
HashMap::default(),
);
assert_eq!(
res.docs,

View file

@ -1,10 +1,10 @@
use crate::ctx::Context;
use hashbrown::hash_map::Entry;
use hashbrown::{HashMap, HashSet};
use ahash::{HashMap, HashMapExt, HashSet};
use reblessive::tree::Stk;
use revision::revisioned;
use roaring::RoaringTreemap;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::Entry;
use std::collections::{BinaryHeap, VecDeque};
use std::fmt::{Debug, Display, Formatter};
use std::io::Cursor;
@ -217,7 +217,7 @@ impl MTree {
queue.push(PriorityNode::new(0.0, root_id));
}
#[cfg(debug_assertions)]
let mut visited_nodes = HashMap::new();
let mut visited_nodes = HashMap::default();
while let Some(e) = queue.pop() {
let id = e.id();
let node = search.store.get_node_txn(search.ctx, id).await?;
@ -330,7 +330,7 @@ impl MTree {
) -> Result<(), Error> {
let new_root_id = self.new_node_id();
let p = ObjectProperties::new_root(id);
let mut objects = LeafMap::new();
let mut objects = LeafMap::with_capacity(1);
objects.insert(obj, p);
let new_root_node = store.new_node(new_root_id, MTreeNode::Leaf(objects))?;
store.set_node(new_root_node, true).await?;
@ -1486,7 +1486,7 @@ mod tests {
use crate::kvs::Transaction;
use crate::kvs::{Datastore, TransactionType};
use crate::sql::index::{Distance, VectorType};
use hashbrown::{HashMap, HashSet};
use ahash::{HashMap, HashMapExt, HashSet};
use reblessive::tree::Stk;
use std::collections::VecDeque;
use test_log::test;
@ -2080,13 +2080,13 @@ mod tests {
t: &MTree,
) -> Result<CheckedProperties, Error> {
debug!("CheckTreeProperties");
let mut node_ids = HashSet::new();
let mut node_ids = HashSet::default();
let mut checks = CheckedProperties::default();
let mut nodes: VecDeque<(NodeId, f64, Option<SharedVector>, usize)> = VecDeque::new();
if let Some(root_id) = t.state.root {
nodes.push_back((root_id, 0.0, None, 1));
}
let mut leaf_objects = HashSet::new();
let mut leaf_objects = HashSet::default();
while let Some((node_id, radius, center, depth)) = nodes.pop_front() {
assert!(node_ids.insert(node_id), "Node already exist: {}", node_id);
checks.node_count += 1;

View file

@ -2,9 +2,9 @@ use crate::err::Error;
use crate::idx::trees::store::lru::{CacheKey, ConcurrentLru};
use crate::idx::trees::store::{NodeId, StoreGeneration, StoredNode, TreeNode, TreeNodeProvider};
use crate::kvs::{Key, Transaction};
use ahash::{HashMap, HashSet};
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use hashbrown::{HashMap, HashSet};
use std::cmp::Ordering;
use std::fmt::{Debug, Display};
use std::sync::Arc;
@ -117,7 +117,7 @@ where
if cache_size == 0 {
Self::Full(cache_key, generation, TreeFullCache::new(keys))
} else {
Self::Lru(cache_key, generation, TreeLruCache::new(keys, cache_size))
Self::Lru(cache_key, generation, TreeLruCache::with_capacity(keys, cache_size))
}
}
@ -198,8 +198,8 @@ impl<N> TreeLruCache<N>
where
N: TreeNode + Debug + Clone,
{
fn new(keys: TreeNodeProvider, size: usize) -> Self {
let lru = ConcurrentLru::new(size);
fn with_capacity(keys: TreeNodeProvider, size: usize) -> Self {
let lru = ConcurrentLru::with_capacity(size);
Self {
keys,
lru,

View file

@ -1,5 +1,5 @@
use ahash::{HashMap, HashMapExt};
use futures::future::join_all;
use hashbrown::HashMap;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use tokio::sync::Mutex;
@ -26,8 +26,9 @@ impl<V> ConcurrentLru<V>
where
V: Clone,
{
pub(super) fn new(capacity: usize) -> Self {
let shards_count = num_cpus::get().min(capacity);
pub(super) fn with_capacity(capacity: usize) -> Self {
// slightly more than the number of CPU cores
let shards_count = (num_cpus::get() * 4 / 3).min(capacity);
let mut shards = Vec::with_capacity(shards_count);
let mut lengths = Vec::with_capacity(shards_count);
for _ in 0..shards_count {
@ -47,10 +48,7 @@ where
// Locate the shard
let n = key as usize % self.shards_count;
// Get and promote the key
let mut shard = self.shards[n].lock().await;
let v = shard.get_and_promote(key);
drop(shard);
v
self.shards[n].lock().await.get_and_promote(key)
}
pub(super) async fn insert<K: Into<CacheKey>>(&self, key: K, val: V) {
@ -58,9 +56,7 @@ where
// Locate the shard
let shard = key as usize % self.shards_count;
// Insert the key/object in the shard and get the new length
let mut s = self.shards[shard].lock().await;
let new_length = s.insert(key, val, self.full.load(Relaxed));
drop(s);
let new_length = self.shards[shard].lock().await.insert(key, val, self.full.load(Relaxed));
// Update lengths
self.check_length(new_length, shard);
}
@ -70,9 +66,7 @@ where
// Locate the shard
let shard = key as usize % self.shards_count;
// Remove the key
let mut s = self.shards[shard].lock().await;
let new_length = s.remove(key);
drop(s);
let new_length = self.shards[shard].lock().await.remove(key);
// Update lengths
self.check_length(new_length, shard);
}
@ -101,9 +95,7 @@ where
.shards
.iter()
.map(|s| async {
let s = s.lock().await;
let shard = s.duplicate(filter);
drop(s);
let shard = s.lock().await.duplicate(filter);
(shard.map.len(), Mutex::new(shard))
})
.collect();
@ -139,7 +131,7 @@ where
{
fn new() -> Self {
Self {
map: HashMap::new(),
map: HashMap::default(),
vec: Vec::new(),
}
}
@ -242,7 +234,7 @@ mod tests {
#[test(tokio::test)]
async fn test_minimal_tree_lru() {
let lru = ConcurrentLru::new(1);
let lru = ConcurrentLru::with_capacity(1);
assert_eq!(lru.len(), 0);
//
lru.insert(1u64, 'a').await;
@ -270,7 +262,7 @@ mod tests {
#[test(tokio::test)]
async fn test_tree_lru() {
let lru = ConcurrentLru::new(4);
let lru = ConcurrentLru::with_capacity(4);
//
lru.insert(1u64, 'a').await;
lru.insert(2u64, 'b').await;
@ -302,7 +294,7 @@ mod tests {
#[test(tokio::test(flavor = "multi_thread"))]
async fn concurrent_lru_test() {
let num_threads = 4;
let lru = ConcurrentLru::new(100);
let lru = ConcurrentLru::with_capacity(100);
let futures: Vec<_> = (0..num_threads)
.map(|_| async {

View file

@ -2,7 +2,7 @@ use crate::err::Error;
use crate::idx::trees::store::cache::TreeCache;
use crate::idx::trees::store::{NodeId, StoredNode, TreeNode, TreeNodeProvider};
use crate::kvs::{Key, Transaction};
use hashbrown::{HashMap, HashSet};
use ahash::{HashMap, HashSet};
use std::fmt::{Debug, Display};
use std::mem;
use std::sync::Arc;
@ -30,12 +30,12 @@ where
Self {
np,
cache,
cached: HashSet::new(),
nodes: HashMap::new(),
updated: HashSet::new(),
removed: HashMap::new(),
cached: Default::default(),
nodes: Default::default(),
updated: Default::default(),
removed: Default::default(),
#[cfg(debug_assertions)]
out: HashSet::new(),
out: Default::default(),
}
}

View file

@ -3,7 +3,7 @@ use crate::fnc::util::math::ToFloat;
use crate::sql::index::{Distance, VectorType};
use crate::sql::{Number, Value};
use ahash::AHasher;
use hashbrown::HashSet;
use ahash::HashSet;
use linfa_linalg::norm::Norm;
use ndarray::{Array1, LinalgScalar, Zip};
use ndarray_stats::DeviationExt;

View file

@ -1,6 +1,7 @@
use criterion::measurement::WallTime;
use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion, Throughput};
use futures::executor::block_on;
use futures::future::join_all;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use reblessive::TreeStack;
@ -19,36 +20,26 @@ use surrealdb_core::sql::{Id, Number, Thing, Value};
use tokio::runtime::{Builder, Runtime};
use tokio::task;
fn bench_index_mtree_dim_3(c: &mut Criterion) {
bench_index_mtree(c, 250, 2500, 3, 100);
}
fn bench_index_mtree_dim_3_full_cache(c: &mut Criterion) {
bench_index_mtree(c, 250, 2500, 3, 0);
}
fn bench_index_mtree_dim_50(c: &mut Criterion) {
bench_index_mtree(c, 100, 1000, 50, 100);
}
fn bench_index_mtree_dim_50_full_cache(c: &mut Criterion) {
bench_index_mtree(c, 100, 1000, 50, 0);
}
fn bench_index_mtree_dim_300(c: &mut Criterion) {
bench_index_mtree(c, 50, 500, 300, 100);
}
fn bench_index_mtree_dim_300_full_cache(c: &mut Criterion) {
bench_index_mtree(c, 50, 500, 300, 0);
}
fn bench_index_mtree_dim_2048(c: &mut Criterion) {
bench_index_mtree(c, 10, 100, 2048, 100);
}
fn bench_index_mtree_dim_2048_full_cache(c: &mut Criterion) {
bench_index_mtree(c, 10, 100, 2048, 0);
fn bench_index_mtree_combinations(c: &mut Criterion) {
for (samples, dimension, cache) in [
(2500, 3, 100),
(2500, 3, 2500),
(2500, 3, 0),
(1000, 50, 100),
(1000, 50, 1000),
(1000, 50, 0),
(500, 300, 100),
(500, 300, 500),
(500, 300, 0),
(250, 1024, 75),
(250, 1024, 250),
(250, 1024, 0),
(100, 2048, 50),
(100, 2048, 100),
(100, 2048, 0),
] {
bench_index_mtree(c, samples, dimension, cache);
}
}
async fn mtree_index(
@ -76,15 +67,14 @@ fn runtime() -> Runtime {
fn bench_index_mtree(
c: &mut Criterion,
debug_samples_len: usize,
release_samples_len: usize,
samples_len: usize,
vector_dimension: usize,
cache_size: usize,
) {
let samples_len = if cfg!(debug_assertions) {
debug_samples_len // Debug is slow
samples_len / 10 // Debug is slow
} else {
release_samples_len // Release is fast
samples_len // Release is fast
};
// Both benchmark groups are sharing the same datastore
@ -111,7 +101,7 @@ fn bench_index_mtree(
);
group.bench_function(id, |b| {
b.to_async(runtime()).iter(|| {
knn_lookup_objects(&ds, samples_len, vector_dimension, cache_size, knn)
knn_lookup_objects(&ds, samples_len / 5, vector_dimension, cache_size, knn)
});
});
}
@ -180,39 +170,25 @@ async fn knn_lookup_objects(
let (ctx, mt, counter) = (ctx.clone(), mt.clone(), counter.clone());
let c = task::spawn(async move {
let mut rng = StdRng::from_entropy();
while counter.fetch_add(1, Ordering::Relaxed) < samples_size {
let object = random_object(&mut rng, vector_size);
knn_lookup_object(mt.as_ref(), &ctx, object, knn).await;
}
});
consumers.push(c);
}
for c in consumers {
c.await.unwrap();
}
}
async fn knn_lookup_object(mt: &MTreeIndex, ctx: &Context<'_>, object: Vec<Number>, knn: usize) {
let mut stack = TreeStack::new();
stack
.enter(|stk| async {
let chk = MTreeConditionChecker::new(ctx);
let r = mt.knn_search(stk, ctx, &object, knn, chk).await.unwrap();
while counter.fetch_add(1, Ordering::Relaxed) < samples_size {
let object = random_object(&mut rng, vector_size);
let chk = MTreeConditionChecker::new(ctx.as_ref());
let r = mt.knn_search(stk, ctx.as_ref(), &object, knn, chk).await.unwrap();
assert_eq!(r.len(), knn);
}
})
.finish()
.await;
});
consumers.push(c);
}
for c in join_all(consumers).await {
c.unwrap();
}
}
criterion_group!(
benches,
bench_index_mtree_dim_3,
bench_index_mtree_dim_3_full_cache,
bench_index_mtree_dim_50,
bench_index_mtree_dim_50_full_cache,
bench_index_mtree_dim_300,
bench_index_mtree_dim_300_full_cache,
bench_index_mtree_dim_2048,
bench_index_mtree_dim_2048_full_cache
);
criterion_group!(benches, bench_index_mtree_combinations);
criterion_main!(benches);