Fixes wasm build (full-text indexing) ()

This commit is contained in:
Emmanuel Keller 2023-05-30 12:46:05 +01:00 committed by GitHub
parent fdec86ce3c
commit e1f8722b8c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 310 additions and 362 deletions

View file

@ -1,37 +1,23 @@
use crate::err::Error;
use crate::idx::btree::Payload;
use crate::kvs::{Key, Transaction};
use crate::kvs::Key;
use async_trait::async_trait;
use fst::{IntoStreamer, Map, MapBuilder, Streamer};
use radix_trie::{Trie, TrieCommon};
use radix_trie::{SubTrie, Trie, TrieCommon};
use serde::{de, ser, Deserialize, Serialize};
use std::collections::VecDeque;
use std::fmt::{Display, Formatter};
use std::io;
#[async_trait]
pub(super) trait KeyVisitor {
async fn visit(
&mut self,
tx: &mut Transaction,
key: &Key,
payload: Payload,
) -> Result<(), Error>;
}
#[async_trait]
pub(super) trait BKeys: Display + Sized {
fn with_key_val(key: Key, payload: Payload) -> Result<Self, Error>;
fn len(&self) -> usize;
fn get(&self, key: &Key) -> Option<Payload>;
async fn collect_with_prefix<V>(
&self,
tx: &mut Transaction,
prefix_key: &Key,
visitor: &mut V,
) -> Result<bool, Error>
where
V: KeyVisitor + Send;
// It is okay to return a owned Vec rather than an iterator,
// because BKeys are intended to be stored as Node in the BTree.
// The size of the Node should be small, therefore one instance of
// BKeys would never be store a large volume of keys.
fn collect_with_prefix(&self, prefix_key: &Key) -> VecDeque<(Key, Payload)>;
fn insert(&mut self, key: Key, payload: Payload);
fn append(&mut self, keys: Self);
fn remove(&mut self, key: &Key) -> Option<Payload>;
@ -65,7 +51,6 @@ pub(super) struct FstKeys {
len: usize,
}
#[async_trait]
impl BKeys for FstKeys {
fn with_key_val(key: Key, payload: Payload) -> Result<Self, Error> {
let mut builder = MapBuilder::memory();
@ -85,15 +70,7 @@ impl BKeys for FstKeys {
}
}
async fn collect_with_prefix<V>(
&self,
_tx: &mut Transaction,
_prefix_key: &Key,
_visitor: &mut V,
) -> Result<bool, Error>
where
V: KeyVisitor,
{
fn collect_with_prefix(&self, _prefix_key: &Key) -> VecDeque<(Key, Payload)> {
panic!("Not supported!")
}
@ -402,34 +379,13 @@ impl BKeys for TrieKeys {
self.keys.get(key).copied()
}
async fn collect_with_prefix<V>(
&self,
tx: &mut Transaction,
prefix: &Key,
visitor: &mut V,
) -> Result<bool, Error>
where
V: KeyVisitor + Send,
{
let mut node_queue = VecDeque::new();
if let Some(node) = self.keys.get_raw_descendant(prefix) {
node_queue.push_front(node);
fn collect_with_prefix(&self, prefix: &Key) -> VecDeque<(Key, Payload)> {
let mut i = KeysIterator::new(prefix, &self.keys);
let mut r = VecDeque::new();
while let Some((k, p)) = i.next() {
r.push_back((k.clone(), p))
}
let mut found = false;
while let Some(node) = node_queue.pop_front() {
if let Some(value) = node.value() {
if let Some(node_key) = node.key() {
if node_key.starts_with(prefix) {
found = true;
visitor.visit(tx, node_key, *value).await?;
}
}
}
for children in node.children() {
node_queue.push_front(children);
}
}
Ok(found)
r
}
fn insert(&mut self, key: Key, payload: Payload) {
@ -521,20 +477,59 @@ impl BKeys for TrieKeys {
}
}
impl From<Trie<Vec<u8>, u64>> for TrieKeys {
fn from(keys: Trie<Vec<u8>, u64>) -> Self {
impl From<Trie<Key, Payload>> for TrieKeys {
fn from(keys: Trie<Key, Payload>) -> Self {
Self {
keys,
}
}
}
struct KeysIterator<'a> {
prefix: &'a Key,
node_queue: VecDeque<SubTrie<'a, Key, Payload>>,
current_node: Option<SubTrie<'a, Key, Payload>>,
}
impl<'a> KeysIterator<'a> {
fn new(prefix: &'a Key, keys: &'a Trie<Key, Payload>) -> Self {
let start_node = keys.get_raw_descendant(prefix);
Self {
prefix,
node_queue: VecDeque::new(),
current_node: start_node,
}
}
fn next(&mut self) -> Option<(&Key, Payload)> {
loop {
if let Some(node) = self.current_node.take() {
for children in node.children() {
self.node_queue.push_front(children);
}
if let Some(value) = node.value() {
if let Some(node_key) = node.key() {
if node_key.starts_with(self.prefix) {
return Some((node_key, *value));
}
}
}
} else {
self.current_node = self.node_queue.pop_front();
if self.current_node.is_none() {
return None;
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::idx::bkeys::{BKeys, FstKeys, TrieKeys};
use crate::idx::tests::HashVisitor;
use crate::kvs::{Datastore, Key};
use std::collections::HashSet;
use crate::idx::btree::Payload;
use crate::kvs::Key;
use std::collections::{HashMap, HashSet, VecDeque};
#[test]
fn test_fst_keys_serde() {
@ -609,11 +604,19 @@ mod tests {
test_keys_deletions(TrieKeys::default())
}
fn check_keys(r: VecDeque<(Key, Payload)>, e: Vec<(Key, Payload)>) {
let mut map = HashMap::new();
for (k, p) in r {
map.insert(k, p);
}
assert_eq!(map.len(), e.len());
for (k, p) in e {
assert_eq!(map.get(&k), Some(&p));
}
}
#[tokio::test]
async fn test_tries_keys_collect_with_prefix() {
let ds = Datastore::new("memory").await.unwrap();
let mut tx = ds.transaction(true, false).await.unwrap();
let mut keys = TrieKeys::default();
keys.insert("apple".into(), 1);
keys.insert("applicant".into(), 2);
@ -629,18 +632,17 @@ mod tests {
keys.insert("there".into(), 10);
{
let mut visitor = HashVisitor::default();
keys.collect_with_prefix(&mut tx, &"appli".into(), &mut visitor).await.unwrap();
visitor.check(
let r = keys.collect_with_prefix(&"appli".into());
check_keys(
r,
vec![("applicant".into(), 2), ("application".into(), 3), ("applicative".into(), 4)],
"appli",
);
}
{
let mut visitor = HashVisitor::default();
keys.collect_with_prefix(&mut tx, &"the".into(), &mut visitor).await.unwrap();
visitor.check(
let r = keys.collect_with_prefix(&"the".into());
check_keys(
r,
vec![
("the".into(), 7),
("their".into(), 8),
@ -649,26 +651,22 @@ mod tests {
("these".into(), 11),
("theses".into(), 12),
],
"the",
);
}
{
let mut visitor = HashVisitor::default();
keys.collect_with_prefix(&mut tx, &"blue".into(), &mut visitor).await.unwrap();
visitor.check(vec![("blueberry".into(), 6)], "blue");
let r = keys.collect_with_prefix(&"blue".into());
check_keys(r, vec![("blueberry".into(), 6)]);
}
{
let mut visitor = HashVisitor::default();
keys.collect_with_prefix(&mut tx, &"apple".into(), &mut visitor).await.unwrap();
visitor.check(vec![("apple".into(), 1)], "apple");
let r = keys.collect_with_prefix(&"apple".into());
check_keys(r, vec![("apple".into(), 1)]);
}
{
let mut visitor = HashVisitor::default();
keys.collect_with_prefix(&mut tx, &"zz".into(), &mut visitor).await.unwrap();
visitor.check(vec![], "zz");
let r = keys.collect_with_prefix(&"zz".into());
check_keys(r, vec![]);
}
}

View file

@ -1,7 +1,8 @@
use crate::err::Error;
use crate::idx::bkeys::{BKeys, KeyVisitor};
use crate::idx::bkeys::BKeys;
use crate::idx::SerdeState;
use crate::kvs::{Key, Transaction};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
@ -12,16 +13,31 @@ use std::sync::Arc;
pub(crate) type NodeId = u64;
pub(super) type Payload = u64;
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
pub(super) trait KeyProvider {
fn get_node_key(&self, node_id: NodeId) -> Key;
fn get_state_key(&self) -> Key {
panic!("Not supported")
}
async fn load_node<BK>(&self, tx: &mut Transaction, id: NodeId) -> Result<StoredNode<BK>, Error>
where
BK: BKeys + Serialize + DeserializeOwned,
{
let key = self.get_node_key(id);
let (node, size) = Node::<BK>::read(tx, key.clone()).await?;
Ok(StoredNode {
node,
id,
key,
size,
})
}
}
pub(super) struct BTree<K>
where
K: KeyProvider,
K: KeyProvider + Clone,
{
keys: K,
state: State,
@ -66,9 +82,9 @@ where
Leaf(BK),
}
impl<BK> Node<BK>
impl<'a, BK> Node<BK>
where
BK: BKeys + Serialize + DeserializeOwned,
BK: BKeys + Serialize + DeserializeOwned + 'a,
{
async fn read(tx: &mut Transaction, key: Key) -> Result<(Self, usize), Error> {
if let Some(val) = tx.get(key).await? {
@ -154,7 +170,7 @@ where
impl<K> BTree<K>
where
K: KeyProvider,
K: KeyProvider + Clone + Sync,
{
pub(super) fn new(keys: K, state: State) -> Self {
Self {
@ -175,7 +191,7 @@ where
{
let mut next_node = self.state.root;
while let Some(node_id) = next_node.take() {
let current = self.load_node::<BK>(tx, node_id).await?;
let current = self.keys.load_node::<BK>(tx, node_id).await?;
if let Some(payload) = current.node.keys().get(searched_key) {
return Ok(Some(payload));
}
@ -187,39 +203,8 @@ where
Ok(None)
}
pub(super) async fn search_by_prefix<BK, V>(
&self,
tx: &mut Transaction,
prefix_key: &Key,
visitor: &mut V,
) -> Result<(), Error>
where
BK: BKeys + Serialize + DeserializeOwned,
V: KeyVisitor + Send,
{
let mut node_queue = VecDeque::new();
if let Some(node_id) = self.state.root {
node_queue.push_front((node_id, Arc::new(AtomicBool::new(false))));
}
while let Some((node_id, matches_found)) = node_queue.pop_front() {
let current = self.load_node::<BK>(tx, node_id).await?;
if current.node.keys().collect_with_prefix(tx, prefix_key, visitor).await? {
matches_found.fetch_and(true, Ordering::Relaxed);
} else if matches_found.load(Ordering::Relaxed) {
// If we have found matches in previous (lower) nodes,
// but we don't find matches anymore, there is no chance we can find new matches
// in upper child nodes, therefore we can stop the traversal.
break;
}
if let Node::Internal(keys, children) = current.node {
let same_level_matches_found = Arc::new(AtomicBool::new(false));
let child_idx = keys.get_child_idx(prefix_key);
for i in child_idx..children.len() {
node_queue.push_front((children[i], same_level_matches_found.clone()));
}
}
}
Ok(())
pub(super) fn search_by_prefix(&self, prefix_key: Key) -> BTreeIterator<K> {
BTreeIterator::new(self.keys.clone(), prefix_key, self.state.root)
}
pub(super) async fn insert<BK>(
@ -232,7 +217,7 @@ where
BK: BKeys + Serialize + DeserializeOwned + Default,
{
if let Some(root_id) = self.state.root {
let root = self.load_node::<BK>(tx, root_id).await?;
let root = self.keys.load_node::<BK>(tx, root_id).await?;
if root.node.keys().len() == self.full_size {
let new_root_id = self.new_node_id();
let new_root =
@ -278,7 +263,7 @@ where
return Ok(());
}
let child_idx = keys.get_child_idx(&key);
let child = self.load_node::<BK>(tx, children[child_idx]).await?;
let child = self.keys.load_node::<BK>(tx, children[child_idx]).await?;
let next = if child.node.keys().len() == self.full_size {
let split_result =
self.split_child::<BK>(tx, node, child_idx, child).await?;
@ -379,7 +364,7 @@ where
let mut deleted_payload = None;
if let Some(root_id) = self.state.root {
let node = self.load_node::<BK>(tx, root_id).await?;
let node = self.keys.load_node::<BK>(tx, root_id).await?;
let mut next_node = Some((true, key_to_delete, node));
while let Some((is_main_key, key_to_delete, mut node)) = next_node.take() {
@ -453,7 +438,7 @@ where
{
let left_idx = keys.get_child_idx(&key_to_delete);
let left_id = children[left_idx];
let mut left_node = self.load_node::<BK>(tx, left_id).await?;
let mut left_node = self.keys.load_node::<BK>(tx, left_id).await?;
if left_node.node.keys().len() >= self.state.minimum_degree {
// CLRS: 2a -> left_node is named `y` in the book
if let Some((key_prim, payload_prim)) = left_node.node.keys().get_last_key() {
@ -464,7 +449,7 @@ where
}
let right_idx = left_idx + 1;
let right_node = self.load_node::<BK>(tx, children[right_idx]).await?;
let right_node = self.keys.load_node::<BK>(tx, children[right_idx]).await?;
if right_node.node.keys().len() >= self.state.minimum_degree {
// CLRS: 2b -> right_node is name `z` in the book
if let Some((key_prim, payload_prim)) = right_node.node.keys().get_first_key() {
@ -498,7 +483,7 @@ where
{
// CLRS 3a
let child_idx = keys.get_child_idx(&key_to_delete);
let child_stored_node = self.load_node::<BK>(tx, children[child_idx]).await?;
let child_stored_node = self.keys.load_node::<BK>(tx, children[child_idx]).await?;
// TODO: Remove once everything is stable
// debug!("** delete_traversal");
// child_stored_node.node.keys().debug(|k| Ok(String::from_utf8(k)?))?;
@ -506,7 +491,7 @@ where
// right child (successor)
if child_idx < children.len() - 1 {
let right_child_id = children[child_idx + 1];
let right_child_stored_node = self.load_node::<BK>(tx, right_child_id).await?;
let right_child_stored_node = self.keys.load_node::<BK>(tx, right_child_id).await?;
return if right_child_stored_node.node.keys().len() >= self.state.minimum_degree {
Self::delete_adjust_successor(
tx,
@ -538,7 +523,7 @@ where
if child_idx > 0 {
let child_idx = child_idx - 1;
let left_child_id = children[child_idx];
let left_child_stored_node = self.load_node::<BK>(tx, left_child_id).await?;
let left_child_stored_node = self.keys.load_node::<BK>(tx, left_child_id).await?;
return if left_child_stored_node.node.keys().len() >= self.state.minimum_degree {
Self::delete_adjust_predecessor(
tx,
@ -684,7 +669,7 @@ where
}
let mut count = 0;
while let Some((node_id, depth)) = node_queue.pop_front() {
let stored_node = self.load_node::<BK>(tx, node_id).await?;
let stored_node = self.keys.load_node::<BK>(tx, node_id).await?;
if let Node::Internal(_, children) = &stored_node.node {
let depth = depth + 1;
for child_id in children {
@ -707,7 +692,7 @@ where
node_queue.push_front((node_id, 1));
}
while let Some((node_id, depth)) = node_queue.pop_front() {
let stored = self.load_node::<BK>(tx, node_id).await?;
let stored = self.keys.load_node::<BK>(tx, node_id).await?;
stats.keys_count += stored.node.keys().len();
if depth > stats.max_depth {
stats.max_depth = depth;
@ -743,23 +728,97 @@ where
size: 0,
}
}
}
async fn load_node<BK>(&self, tx: &mut Transaction, id: NodeId) -> Result<StoredNode<BK>, Error>
struct CurrentNode {
keys: VecDeque<(Key, Payload)>,
matches_found: bool,
level_matches_found: Arc<AtomicBool>,
}
pub(super) struct BTreeIterator<K>
where
K: KeyProvider,
{
key_provider: K,
prefix_key: Key,
node_queue: VecDeque<(NodeId, Arc<AtomicBool>)>,
current_node: Option<CurrentNode>,
}
impl<K> BTreeIterator<K>
where
K: KeyProvider + Sync,
{
fn new(key_provider: K, prefix_key: Key, start_node: Option<NodeId>) -> Self {
let mut node_queue = VecDeque::new();
if let Some(node_id) = start_node {
node_queue.push_front((node_id, Arc::new(AtomicBool::new(false))))
}
Self {
key_provider,
prefix_key,
node_queue,
current_node: None,
}
}
fn set_current_node<BK>(&mut self, node: Node<BK>, level_matches_found: Arc<AtomicBool>)
where
BK: BKeys + Serialize + DeserializeOwned,
{
let key = self.keys.get_node_key(id);
let (node, size) = Node::<BK>::read(tx, key.clone()).await?;
Ok(StoredNode {
node,
id,
key,
size,
})
if let Node::Internal(keys, children) = &node {
let same_level_matches_found = Arc::new(AtomicBool::new(false));
let child_idx = keys.get_child_idx(&self.prefix_key);
for i in child_idx..children.len() {
self.node_queue.push_front((children[i], same_level_matches_found.clone()));
}
}
let keys = node.keys().collect_with_prefix(&self.prefix_key);
let matches_found = !keys.is_empty();
if matches_found {
level_matches_found.fetch_and(true, Ordering::Relaxed);
}
self.current_node = Some(CurrentNode {
keys,
matches_found,
level_matches_found,
});
}
pub(super) async fn next<BK>(
&mut self,
tx: &mut Transaction,
) -> Result<Option<(Key, Payload)>, Error>
where
BK: BKeys + Serialize + DeserializeOwned,
{
loop {
if let Some(current) = &mut self.current_node {
if let Some((key, payload)) = current.keys.pop_front() {
return Ok(Some((key, payload)));
} else {
if !current.matches_found && current.level_matches_found.load(Ordering::Relaxed)
{
// If we have found matches in previous (lower) nodes,
// but we don't find matches anymore, there is no chance we can find new matches
// in upper child nodes, therefore we can stop the traversal.
break;
}
self.current_node = None;
}
} else if let Some((node_id, level_matches_found)) = self.node_queue.pop_front() {
let st = self.key_provider.load_node::<BK>(tx, node_id).await?;
self.set_current_node(st.node, level_matches_found);
} else {
break;
}
}
Ok(None)
}
}
struct StoredNode<BK>
pub(super) struct StoredNode<BK>
where
BK: BKeys,
{
@ -782,8 +841,9 @@ where
#[cfg(test)]
mod tests {
use crate::idx::bkeys::{BKeys, FstKeys, TrieKeys};
use crate::idx::btree::{BTree, KeyProvider, Node, NodeId, Payload, State, Statistics};
use crate::idx::tests::HashVisitor;
use crate::idx::btree::{
BTree, BTreeIterator, KeyProvider, Node, NodeId, Payload, State, Statistics,
};
use crate::idx::SerdeState;
use crate::kvs::{Datastore, Key, Transaction};
use rand::prelude::{SliceRandom, ThreadRng};
@ -793,6 +853,7 @@ mod tests {
use std::collections::HashMap;
use test_log::test;
#[derive(Clone)]
struct TestKeyProvider {}
impl KeyProvider for TestKeyProvider {
@ -836,7 +897,7 @@ mod tests {
) where
F: Fn(usize) -> (Key, Payload),
BK: BKeys + Serialize + DeserializeOwned + Default,
K: KeyProvider,
K: KeyProvider + Clone + Sync,
{
for i in 0..samples_size {
let (key, payload) = sample_provider(i);
@ -1058,6 +1119,21 @@ mod tests {
(t, s)
}
async fn check_results(
mut i: BTreeIterator<TestKeyProvider>,
tx: &mut Transaction,
e: Vec<(Key, Payload)>,
) {
let mut map = HashMap::new();
while let Some((k, p)) = i.next::<TrieKeys>(tx).await.unwrap() {
map.insert(k, p);
}
assert_eq!(map.len(), e.len());
for (k, p) in e {
assert_eq!(map.get(&k), Some(&p));
}
}
#[test(tokio::test)]
async fn test_btree_trie_keys_search_by_prefix() {
for _ in 0..50 {
@ -1086,17 +1162,18 @@ mod tests {
let mut tx = ds.transaction(false, false).await.unwrap();
// We should find all the keys prefixed with "bb"
let mut visitor = HashVisitor::default();
t.search_by_prefix::<TrieKeys, _>(&mut tx, &"bb".into(), &mut visitor).await.unwrap();
visitor.check(
let i = t.search_by_prefix("bb".into());
check_results(
i,
&mut tx,
vec![
("bb1".into(), 21),
("bb2".into(), 22),
("bb3".into(), 23),
("bb4".into(), 24),
],
"bb",
);
)
.await;
}
}
@ -1164,11 +1241,11 @@ mod tests {
("animals", 1),
("watching", 1),
] {
let mut visitor = HashVisitor::default();
t.search_by_prefix::<TrieKeys, _>(&mut tx, &prefix.into(), &mut visitor)
.await
.unwrap();
visitor.check_len(count, prefix);
let mut i = t.search_by_prefix(prefix.into());
for _ in 0..count {
assert!(i.next::<TrieKeys>(&mut tx).await.unwrap().is_some());
}
assert_eq!(i.next::<TrieKeys>(&mut tx).await.unwrap(), None);
}
}
}
@ -1460,7 +1537,7 @@ mod tests {
async fn print_tree<BK, K>(tx: &mut Transaction, t: &BTree<K>)
where
K: KeyProvider,
K: KeyProvider + Clone + Sync,
BK: BKeys + Serialize + DeserializeOwned,
{
debug!("----------------------------------");

View file

@ -169,6 +169,7 @@ impl Resolved {
}
}
#[derive(Clone)]
struct DocIdsKeyProvider {
index_key_base: IndexKeyBase,
}

View file

@ -70,6 +70,7 @@ impl DocLengths {
}
}
#[derive(Clone)]
struct DocLengthsKeyProvider {
index_key_base: IndexKeyBase,
}

View file

@ -7,12 +7,11 @@ use crate::err::Error;
use crate::error::Db::AnalyzerError;
use crate::idx::ft::docids::{DocId, DocIds};
use crate::idx::ft::doclength::{DocLength, DocLengths};
use crate::idx::ft::postings::{Postings, PostingsVisitor, TermFrequency};
use crate::idx::ft::postings::{Postings, TermFrequency};
use crate::idx::ft::terms::Terms;
use crate::idx::{btree, IndexKeyBase, SerdeState};
use crate::kvs::{Key, Transaction};
use crate::sql::error::IResult;
use async_trait::async_trait;
use nom::bytes::complete::take_while;
use nom::character::complete::multispace0;
use roaring::RoaringTreemap;
@ -127,7 +126,7 @@ impl FtIndex {
for term_id in term_list {
p.remove_posting(tx, term_id, doc_id).await?;
// if the term is not present in any document in the index, we can remove it
if p.count_postings(tx, term_id).await? == 0 {
if p.get_doc_count(tx, term_id).await? == 0 {
t.remove_term_id(tx, term_id).await?;
}
}
@ -190,7 +189,7 @@ impl FtIndex {
for old_term_id in old_term_ids {
p.remove_posting(tx, old_term_id, doc_id).await?;
// if the term does not have anymore postings, we can remove the term
if p.count_postings(tx, old_term_id).await? == 0 {
if p.get_doc_count(tx, old_term_id).await? == 0 {
t.remove_term_id(tx, old_term_id).await?;
}
}
@ -272,10 +271,13 @@ impl FtIndex {
doc_ids,
self.state.total_docs_lengths,
self.state.doc_count,
term_doc_count,
term_doc_count as u64,
self.bm25.clone(),
);
postings.collect_postings(tx, term_id, &mut scorer).await?;
let mut it = postings.new_postings_iterator(term_id);
while let Some((doc_id, term_freq)) = it.next(tx).await? {
scorer.visit(tx, doc_id, term_freq).await?;
}
}
}
Ok(())
@ -305,30 +307,6 @@ where
bm25: Bm25Params,
}
#[async_trait]
impl<'a, V> PostingsVisitor for BM25Scorer<'a, V>
where
V: HitVisitor + Send,
{
async fn visit(
&mut self,
tx: &mut Transaction,
doc_id: DocId,
term_frequency: TermFrequency,
) -> Result<(), Error> {
if let Some(doc_key) = self.doc_ids.get_doc_key(tx, doc_id).await? {
let doc_length = self.doc_lengths.get_doc_length(tx, doc_id).await?.unwrap_or(0);
let bm25_score = self.compute_bm25_score(
term_frequency as f32,
self.term_doc_count,
doc_length as f32,
);
self.visitor.visit(tx, doc_key, bm25_score);
}
Ok(())
}
}
impl<'a, V> BM25Scorer<'a, V>
where
V: HitVisitor,
@ -372,6 +350,24 @@ where
// numerator / (k1 * denominator + 1)
numerator / (self.bm25.k1 * denominator + 1.0)
}
async fn visit(
&mut self,
tx: &mut Transaction,
doc_id: DocId,
term_frequency: TermFrequency,
) -> Result<(), Error> {
if let Some(doc_key) = self.doc_ids.get_doc_key(tx, doc_id).await? {
let doc_length = self.doc_lengths.get_doc_length(tx, doc_id).await?.unwrap_or(0);
let bm25_score = self.compute_bm25_score(
term_frequency as f32,
self.term_doc_count,
doc_length as f32,
);
self.visitor.visit(tx, doc_key, bm25_score);
}
Ok(())
}
}
#[cfg(test)]

View file

@ -1,12 +1,11 @@
use crate::err::Error;
use crate::idx::bkeys::{KeyVisitor, TrieKeys};
use crate::idx::btree::{BTree, KeyProvider, NodeId, Payload, Statistics};
use crate::idx::bkeys::TrieKeys;
use crate::idx::btree::{BTree, BTreeIterator, KeyProvider, NodeId, Payload, Statistics};
use crate::idx::ft::docids::DocId;
use crate::idx::ft::terms::TermId;
use crate::idx::{btree, IndexKeyBase, SerdeState};
use crate::key::bf::Bf;
use crate::kvs::{Key, Transaction};
use async_trait::async_trait;
pub(super) type TermFrequency = u64;
@ -16,16 +15,6 @@ pub(super) struct Postings {
btree: BTree<PostingsKeyProvider>,
}
#[async_trait]
pub(super) trait PostingsVisitor {
async fn visit(
&mut self,
tx: &mut Transaction,
doc_id: DocId,
term_frequency: TermFrequency,
) -> Result<(), Error>;
}
impl Postings {
pub(super) async fn new(
tx: &mut Transaction,
@ -69,41 +58,23 @@ impl Postings {
self.btree.delete::<TrieKeys>(tx, key).await
}
pub(super) fn new_postings_iterator(&self, term_id: TermId) -> PostingsIterator {
let prefix_key = self.index_key_base.new_bf_prefix_key(term_id);
let i = self.btree.search_by_prefix(prefix_key);
PostingsIterator::new(i)
}
pub(super) async fn get_doc_count(
&self,
tx: &mut Transaction,
term_id: TermId,
) -> Result<u64, Error> {
let prefix_key = self.index_key_base.new_bf_prefix_key(term_id);
let mut counter = PostingsDocCount::default();
self.btree.search_by_prefix::<TrieKeys, _>(tx, &prefix_key, &mut counter).await?;
Ok(counter.doc_count)
}
pub(super) async fn collect_postings<V>(
&self,
tx: &mut Transaction,
term_id: TermId,
visitor: &mut V,
) -> Result<(), Error>
where
V: PostingsVisitor + Send,
{
let prefix_key = self.index_key_base.new_bf_prefix_key(term_id);
let mut key_visitor = PostingsAdapter {
visitor,
};
self.btree.search_by_prefix::<TrieKeys, _>(tx, &prefix_key, &mut key_visitor).await
}
pub(super) async fn count_postings(
&self,
tx: &mut Transaction,
term_id: TermId,
) -> Result<usize, Error> {
let mut counter = PostingCounter::default();
self.collect_postings(tx, term_id, &mut counter).await?;
Ok(counter.count)
let mut count = 0;
let mut it = self.new_postings_iterator(term_id);
while let Some((_, _)) = it.next(tx).await? {
count += 1;
}
Ok(count)
}
pub(super) async fn statistics(&self, tx: &mut Transaction) -> Result<Statistics, Error> {
@ -118,25 +89,8 @@ impl Postings {
}
}
#[derive(Default)]
struct PostingCounter {
count: usize,
}
#[async_trait]
impl PostingsVisitor for PostingCounter {
async fn visit(
&mut self,
_tx: &mut Transaction,
_doc_id: DocId,
_term_frequency: TermFrequency,
) -> Result<(), Error> {
self.count += 1;
Ok(())
}
}
struct PostingsKeyProvider {
#[derive(Clone)]
pub(super) struct PostingsKeyProvider {
index_key_base: IndexKeyBase,
}
@ -149,58 +103,53 @@ impl KeyProvider for PostingsKeyProvider {
}
}
struct PostingsAdapter<'a, V>
where
V: PostingsVisitor,
{
visitor: &'a mut V,
pub(super) struct PostingsIterator {
btree_iterator: BTreeIterator<PostingsKeyProvider>,
}
#[async_trait]
impl<'a, V> KeyVisitor for PostingsAdapter<'a, V>
where
V: PostingsVisitor + Send,
{
async fn visit(
impl PostingsIterator {
fn new(btree_iterator: BTreeIterator<PostingsKeyProvider>) -> Self {
Self {
btree_iterator,
}
}
pub(super) async fn next(
&mut self,
tx: &mut Transaction,
key: &Key,
payload: Payload,
) -> Result<(), Error> {
let posting_key: Bf = key.into();
self.visitor.visit(tx, posting_key.doc_id, payload).await
}
}
#[derive(Default)]
struct PostingsDocCount {
doc_count: u64,
}
#[async_trait]
impl KeyVisitor for PostingsDocCount {
async fn visit(
&mut self,
_tx: &mut Transaction,
_key: &Key,
_payload: Payload,
) -> Result<(), Error> {
self.doc_count += 1;
Ok(())
) -> Result<Option<(DocId, Payload)>, Error> {
Ok(self.btree_iterator.next::<TrieKeys>(tx).await?.map(|(k, p)| {
let posting_key: Bf = (&k).into();
(posting_key.doc_id, p)
}))
}
}
#[cfg(test)]
mod tests {
use crate::err::Error;
use crate::idx::btree::Payload;
use crate::idx::ft::docids::DocId;
use crate::idx::ft::postings::{Postings, PostingsVisitor, TermFrequency};
use crate::idx::ft::postings::{Postings, PostingsIterator};
use crate::idx::IndexKeyBase;
use crate::kvs::{Datastore, Transaction};
use async_trait::async_trait;
use std::collections::HashMap;
use test_log::test;
async fn check_postings(
mut i: PostingsIterator,
tx: &mut Transaction,
e: Vec<(DocId, Payload)>,
) {
let mut map = HashMap::new();
while let Some((d, p)) = i.next(tx).await.unwrap() {
map.insert(d, p);
}
assert_eq!(map.len(), e.len());
for (k, p) in e {
assert_eq!(map.get(&k), Some(&p));
}
}
#[test(tokio::test)]
async fn test_postings() {
const DEFAULT_BTREE_ORDER: usize = 5;
@ -225,52 +174,17 @@ mod tests {
Postings::new(&mut tx, IndexKeyBase::default(), DEFAULT_BTREE_ORDER).await.unwrap();
assert_eq!(p.statistics(&mut tx).await.unwrap().keys_count, 2);
let mut v = TestPostingVisitor::default();
p.collect_postings(&mut tx, 1, &mut v).await.unwrap();
v.check_len(2, "Postings");
v.check(vec![(2, 3), (4, 5)], "Postings");
let i = p.new_postings_iterator(1);
check_postings(i, &mut tx, vec![(2, 3), (4, 5)]).await;
// Check removal of doc 2
assert_eq!(p.remove_posting(&mut tx, 1, 2).await.unwrap(), Some(3));
assert_eq!(p.count_postings(&mut tx, 1).await.unwrap(), 1);
// Again the same
assert_eq!(p.remove_posting(&mut tx, 1, 2).await.unwrap(), None);
assert_eq!(p.count_postings(&mut tx, 1).await.unwrap(), 1);
// Remove doc 4
assert_eq!(p.remove_posting(&mut tx, 1, 4).await.unwrap(), Some(5));
assert_eq!(p.count_postings(&mut tx, 1).await.unwrap(), 0);
// The underlying b-tree should be empty now
assert_eq!(p.statistics(&mut tx).await.unwrap().keys_count, 0);
}
#[derive(Default)]
pub(super) struct TestPostingVisitor {
map: HashMap<DocId, TermFrequency>,
}
#[async_trait]
impl PostingsVisitor for TestPostingVisitor {
async fn visit(
&mut self,
_tx: &mut Transaction,
doc_id: DocId,
term_frequency: TermFrequency,
) -> Result<(), Error> {
assert_eq!(self.map.insert(doc_id, term_frequency), None);
Ok(())
}
}
impl TestPostingVisitor {
pub(super) fn check_len(&self, len: usize, info: &str) {
assert_eq!(self.map.len(), len, "len issue: {}", info);
}
pub(super) fn check(&self, res: Vec<(DocId, TermFrequency)>, info: &str) {
self.check_len(res.len(), info);
for (d, f) in res {
assert_eq!(self.map.get(&d), Some(&f));
}
}
}
}

View file

@ -153,6 +153,7 @@ impl State {
}
}
#[derive(Clone)]
struct TermsKeyProvider {
index_key_base: IndexKeyBase,
}

View file

@ -118,43 +118,3 @@ where
}
impl SerdeState for RoaringTreemap {}
#[cfg(test)]
mod tests {
use crate::err::Error;
use crate::idx::bkeys::KeyVisitor;
use crate::idx::btree::Payload;
use crate::kvs::{Key, Transaction};
use async_trait::async_trait;
use std::collections::HashMap;
#[derive(Default)]
pub(super) struct HashVisitor {
map: HashMap<Key, Payload>,
}
#[async_trait]
impl KeyVisitor for HashVisitor {
async fn visit(
&mut self,
_tx: &mut Transaction,
key: &Key,
payload: Payload,
) -> Result<(), Error> {
self.map.insert(key.clone(), payload);
Ok(())
}
}
impl HashVisitor {
pub(super) fn check_len(&self, len: usize, info: &str) {
assert_eq!(self.map.len(), len, "len issue: {}", info);
}
pub(super) fn check(&self, res: Vec<(Key, Payload)>, info: &str) {
self.check_len(res.len(), info);
for (k, p) in res {
assert_eq!(self.map.get(&k), Some(&p));
}
}
}
}