diff --git a/core/src/idx/planner/executor.rs b/core/src/idx/planner/executor.rs index 758a1242..d6dc696c 100644 --- a/core/src/idx/planner/executor.rs +++ b/core/src/idx/planner/executor.rs @@ -13,8 +13,8 @@ use crate::idx::planner::checker::{HnswConditionChecker, MTreeConditionChecker}; use crate::idx::planner::iterators::{ IndexEqualThingIterator, IndexJoinThingIterator, IndexRangeThingIterator, IndexUnionThingIterator, IteratorRecord, IteratorRef, KnnIterator, KnnIteratorResult, - MatchesThingIterator, ThingIterator, UniqueEqualThingIterator, UniqueJoinThingIterator, - UniqueRangeThingIterator, UniqueUnionThingIterator, + MatchesThingIterator, MultipleIterators, ThingIterator, UniqueEqualThingIterator, + UniqueJoinThingIterator, UniqueRangeThingIterator, UniqueUnionThingIterator, }; use crate::idx::planner::knn::{KnnBruteForceResult, KnnPriorityList}; use crate::idx::planner::plan::IndexOperator::Matches; @@ -28,7 +28,9 @@ use crate::kvs::{Key, TransactionType}; use crate::sql::index::{Distance, Index}; use crate::sql::statements::DefineIndexStatement; use crate::sql::{Cond, Expression, Idiom, Number, Object, Table, Thing, Value}; +use num_traits::{FromPrimitive, ToPrimitive}; use reblessive::tree::Stk; +use rust_decimal::Decimal; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; use std::sync::Arc; @@ -388,14 +390,16 @@ impl QueryExecutor { ) -> Result, Error> { Ok(match io.op() { IndexOperator::Equality(value) | IndexOperator::Exactness(value) => { - Some(ThingIterator::IndexEqual(IndexEqualThingIterator::new( - irf, - opt.ns()?, - opt.db()?, - &ix.what, - &ix.name, - value, - ))) + if let Value::Number(n) = value.as_ref() { + let values = Self::get_number_variants(n); + if values.len() == 1 { + Some(Self::new_index_equal_iterator(irf, opt, ix, &values[0])?) + } else { + Some(Self::new_multiple_index_equal_iterators(irf, opt, ix, values)?) + } + } else { + Some(Self::new_index_equal_iterator(irf, opt, ix, value)?) + } } IndexOperator::Union(value) => Some(ThingIterator::IndexUnion( IndexUnionThingIterator::new(irf, opt.ns()?, opt.db()?, &ix.what, &ix.name, value), @@ -422,6 +426,76 @@ impl QueryExecutor { }) } + fn new_index_equal_iterator( + irf: IteratorRef, + opt: &Options, + ix: &DefineIndexStatement, + value: &Value, + ) -> Result { + Ok(ThingIterator::IndexEqual(IndexEqualThingIterator::new( + irf, + opt.ns()?, + opt.db()?, + &ix.what, + &ix.name, + value, + ))) + } + + fn new_multiple_index_equal_iterators( + irf: IteratorRef, + opt: &Options, + ix: &DefineIndexStatement, + values: Vec, + ) -> Result { + let mut iterators = VecDeque::with_capacity(values.len()); + for value in values { + iterators.push_back(Self::new_index_equal_iterator(irf, opt, ix, &value)?); + } + Ok(ThingIterator::Multiples(Box::new(MultipleIterators::new(iterators)))) + } + + /// This function takes a reference to a `Number` enum and returns a vector of `Value` enum. + /// The `Number` enum can be either an `Int`, `Float`, or `Decimal`. + /// The function first initializes an empty vector with a capacity of 3 to store the converted values. + /// It then matches on the input number and performs the appropriate conversions. + /// For `Int`, it pushes the original `Int` value, the equivalent `Float` value, and if possible, the equivalent `Decimal` value. + /// For `Float`, it pushes the original `Float` value, the truncated `Int` value if it is a whole number, and if possible, the equivalent `Decimal` value. + /// For `Decimal`, it pushes the equivalent `Int` value if it is representable as an `i64`, and the equivalent `Float` value if it is representable as an `f64`. + /// Finally, it returns the vector of converted values. + fn get_number_variants(n: &Number) -> Vec { + let mut values = Vec::with_capacity(3); + match n { + Number::Int(i) => { + values.push(Number::Int(*i).into()); + values.push(Number::Float(*i as f64).into()); + if let Some(d) = Decimal::from_i64(*i) { + values.push(Number::Decimal(d.normalize()).into()); + } + } + Number::Float(f) => { + values.push(Number::Float(*f).into()); + if f.trunc().eq(f) { + values.push(Number::Int(*f as i64).into()); + } + if let Some(d) = Decimal::from_f64(*f) { + values.push(Number::Decimal(d.normalize()).into()); + } + } + Number::Decimal(d) => { + values.push(Number::Decimal(d.normalize()).into()); + if let Some(i) = d.to_i64() { + values.push(Number::Int(i).into()); + } + if let Some(f) = d.to_f64() { + values.push(Number::Float(f).into()); + } + } + }; + println!("VALUES: {:?}", values); + values + } + fn new_range_iterator( &self, opt: &Options, @@ -468,14 +542,16 @@ impl QueryExecutor { ) -> Result, Error> { Ok(match io.op() { IndexOperator::Equality(value) | IndexOperator::Exactness(value) => { - Some(ThingIterator::UniqueEqual(UniqueEqualThingIterator::new( - irf, - opt.ns()?, - opt.db()?, - &ix.what, - &ix.name, - value, - ))) + if let Value::Number(n) = value.as_ref() { + let values = Self::get_number_variants(n); + if values.len() == 1 { + Some(Self::new_unique_equal_iterator(irf, opt, ix, &values[0])?) + } else { + Some(Self::new_multiple_unique_equal_iterators(irf, opt, ix, values)?) + } + } else { + Some(Self::new_unique_equal_iterator(irf, opt, ix, value)?) + } } IndexOperator::Union(value) => Some(ThingIterator::UniqueUnion( UniqueUnionThingIterator::new(irf, opt, ix, value)?, @@ -502,6 +578,35 @@ impl QueryExecutor { }) } + fn new_unique_equal_iterator( + irf: IteratorRef, + opt: &Options, + ix: &DefineIndexStatement, + value: &Value, + ) -> Result { + Ok(ThingIterator::UniqueEqual(UniqueEqualThingIterator::new( + irf, + opt.ns()?, + opt.db()?, + &ix.what, + &ix.name, + value, + ))) + } + + fn new_multiple_unique_equal_iterators( + irf: IteratorRef, + opt: &Options, + ix: &DefineIndexStatement, + values: Vec, + ) -> Result { + let mut iterators = VecDeque::with_capacity(values.len()); + for value in values { + iterators.push_back(Self::new_unique_equal_iterator(irf, opt, ix, &value)?); + } + Ok(ThingIterator::Multiples(Box::new(MultipleIterators::new(iterators)))) + } + async fn new_search_index_iterator( &self, irf: IteratorRef, diff --git a/core/src/idx/planner/iterators.rs b/core/src/idx/planner/iterators.rs index e56e034e..05c54ddf 100644 --- a/core/src/idx/planner/iterators.rs +++ b/core/src/idx/planner/iterators.rs @@ -113,6 +113,7 @@ pub(crate) enum ThingIterator { UniqueJoin(Box), Matches(MatchesThingIterator), Knn(KnnIterator), + Multiples(Box), } impl ThingIterator { @@ -133,6 +134,7 @@ impl ThingIterator { Self::Knn(i) => i.next_batch(ctx, size).await, Self::IndexJoin(i) => Box::pin(i.next_batch(ctx, txn, size)).await, Self::UniqueJoin(i) => Box::pin(i.next_batch(ctx, txn, size)).await, + Self::Multiples(i) => Box::pin(i.next_batch(ctx, txn, size)).await, } } } @@ -828,3 +830,42 @@ impl KnnIterator { Ok(records) } } + +pub(crate) struct MultipleIterators { + iterators: VecDeque, + current: Option, +} + +impl MultipleIterators { + pub(super) fn new(iterators: VecDeque) -> Self { + Self { + iterators, + current: None, + } + } + + async fn next_batch( + &mut self, + ctx: &Context, + txn: &Transaction, + limit: u32, + ) -> Result { + loop { + // Do we have an iterator + if let Some(i) = &mut self.current { + // If so, take the next batch + let b: B = i.next_batch(ctx, txn, limit).await?; + // Return the batch if it is not empty + if !b.is_empty() { + return Ok(b); + } + } + // Otherwise check if there is another iterator + self.current = self.iterators.pop_front(); + if self.current.is_none() { + // If none, we are done + return Ok(B::empty()); + } + } + } +} diff --git a/sdk/tests/define.rs b/sdk/tests/define.rs index 885e7b2b..e3acf7ba 100644 --- a/sdk/tests/define.rs +++ b/sdk/tests/define.rs @@ -715,6 +715,88 @@ async fn define_statement_index_single() -> Result<(), Error> { Ok(()) } +#[tokio::test] +async fn define_statement_index_numbers() -> Result<(), Error> { + let sql = " + DEFINE INDEX index ON TABLE test COLUMNS number; + CREATE test:int SET number = 0; + CREATE test:float SET number = 0.0; + -- TODO: CREATE test:dec_int SET number = 0dec; + -- TODO: CREATE test:dec_dec SET number = 0.0dec; + SELECT * FROM test WITH NOINDEX WHERE number = 0 ORDER BY id; + SELECT * FROM test WHERE number = 0 ORDER BY id; + SELECT * FROM test WHERE number = 0.0 ORDER BY id; + -- TODO: SELECT * FROM test WHERE number = 0dec ORDER BY id; + -- TODO: SELECT * FROM test WHERE number = 0.0dec ORDER BY id; + "; + let mut t = Test::new(sql).await?; + t.skip_ok(3)?; + for _ in 0..3 { + t.expect_val( + "[ + // { + // id: test:dec, + // number: 0.0dec + // }, + // { + // id: test:int, + // number: 0dec + // }, + { + id: test:float, + number: 0f + }, + { + id: test:int, + number: 0 + } + ]", + )?; + } + Ok(()) +} + +#[tokio::test] +async fn define_statement_unique_index_numbers() -> Result<(), Error> { + let sql = " + DEFINE INDEX index ON TABLE test COLUMNS number UNIQUE; + CREATE test:int SET number = 0; + CREATE test:float SET number = 0.0; + -- TODO: CREATE test:dec_int SET number = 0dec; + -- TODO: CREATE test:dec_dec SET number = 0.0dec; + SELECT * FROM test WITH NOINDEX WHERE number = 0 ORDER BY id; + SELECT * FROM test WHERE number = 0 ORDER BY id; + SELECT * FROM test WHERE number = 0.0 ORDER BY id; + -- TODO: SELECT * FROM test WHERE number = 0dec ORDER BY id; + -- TODO: SELECT * FROM test WHERE number = 0.0dec ORDER BY id; + "; + let mut t = Test::new(sql).await?; + t.skip_ok(3)?; + for _ in 0..3 { + t.expect_val( + "[ + // { + // id: test:dec, + // number: 0.0dec + // }, + // { + // id: test:int, + // number: 0dec + // }, + { + id: test:float, + number: 0f + }, + { + id: test:int, + number: 0 + } + ]", + )?; + } + Ok(()) +} + #[tokio::test] async fn define_statement_index_concurrently() -> Result<(), Error> { let sql = "