Change KNN operator syntax. (#3617)

Co-authored-by: Emmanuel Keller <emmanuel.keller@surrealdb.com>
This commit is contained in:
Mees Delzenne 2024-03-05 17:01:17 +01:00 committed by GitHub
parent 957eff19a9
commit 807b4681fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 66 additions and 24 deletions

View file

@ -140,9 +140,9 @@ impl fmt::Display for Operator {
}
Self::Knn(k, dist) => {
if let Some(d) = dist {
write!(f, "<{k},{d}>")
write!(f, "<|{k},{d}|>")
} else {
write!(f, "<{k}>")
write!(f, "<|{k}|>")
}
}
}

View file

@ -150,12 +150,25 @@ pub fn knn_distance(i: &str) -> IResult<&str, Distance> {
}
pub fn knn(i: &str) -> IResult<&str, Operator> {
let (i, _) = opt(tag_no_case("knn"))(i)?;
let (i, _) = char('<')(i)?;
let (i, k) = u32(i)?;
let (i, dist) = opt(knn_distance)(i)?;
let (i, _) = char('>')(i)?;
Ok((i, Operator::Knn(k, dist)))
alt((
|i| {
let (i, _) = opt(tag_no_case("knn"))(i)?;
let (i, _) = char('<')(i)?;
let (i, k) = u32(i)?;
let (i, dist) = opt(knn_distance)(i)?;
let (i, _) = char('>')(i)?;
Ok((i, Operator::Knn(k, dist)))
},
|i| {
let (i, _) = tag("<|")(i)?;
cut(|i| {
let (i, k) = u32(i)?;
let (i, dist) = opt(knn_distance)(i)?;
let (i, _) = tag("|>")(i)?;
Ok((i, Operator::Knn(k, dist)))
})(i)
},
))(i)
}
pub fn dir(i: &str) -> IResult<&str, Dir> {
@ -218,7 +231,13 @@ mod tests {
let res = knn("<5>");
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("<5>", format!("{}", out));
assert_eq!("<|5|>", format!("{}", out));
assert_eq!(out, Operator::Knn(5, None));
let res = knn("<|5|>");
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("<|5|>", format!("{}", out));
assert_eq!(out, Operator::Knn(5, None));
}
@ -227,16 +246,22 @@ mod tests {
let res = knn("<3,EUCLIDEAN>");
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("<3,EUCLIDEAN>", format!("{}", out));
assert_eq!("<|3,EUCLIDEAN|>", format!("{}", out));
assert_eq!(out, Operator::Knn(3, Some(Distance::Euclidean)));
let res = knn("<|3,EUCLIDEAN|>");
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("<|3,EUCLIDEAN|>", format!("{}", out));
assert_eq!(out, Operator::Knn(3, Some(Distance::Euclidean)));
}
#[test]
fn test_knn_with_prefix() {
let res = knn("knn<5>");
let res = knn("<|5|>");
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("<5>", format!("{}", out));
assert_eq!("<|5|>", format!("{}", out));
assert_eq!(out, Operator::Knn(5, None));
}
}

View file

@ -170,6 +170,10 @@ impl<'a> Lexer<'a> {
self.reader.next();
t!("||")
}
Some(b'>') => {
self.reader.next();
t!("|>")
}
_ => t!("|"),
},
b'&' => match self.reader.peek() {
@ -227,6 +231,10 @@ impl<'a> Lexer<'a> {
self.reader.next();
t!("<=")
}
Some(b'|') => {
self.reader.next();
t!("<|")
}
Some(b'-') => {
self.reader.next();
match self.reader.peek() {

View file

@ -84,7 +84,6 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("IS") => TokenKind::Keyword(Keyword::Is),
UniCase::ascii("KEY") => TokenKind::Keyword(Keyword::Key),
UniCase::ascii("KILL") => TokenKind::Keyword(Keyword::Kill),
UniCase::ascii("KNN") => TokenKind::Keyword(Keyword::Knn),
UniCase::ascii("LET") => TokenKind::Keyword(Keyword::Let),
UniCase::ascii("LIMIT") => TokenKind::Keyword(Keyword::Limit),
UniCase::ascii("LIVE") => TokenKind::Keyword(Keyword::Live),

View file

@ -108,7 +108,7 @@ impl Parser<'_> {
| t!("INTERSECTS")
| t!("NOT")
| t!("IN")
| t!("KNN") => Some((9, 10)),
| t!("<|") => Some((9, 10)),
t!("+") | t!("-") => Some((11, 12)),
t!("*") | t!("×") | t!("/") | t!("÷") => Some((13, 14)),
@ -253,11 +253,10 @@ impl Parser<'_> {
Operator::NotInside
}
t!("IN") => Operator::Inside,
t!("KNN") => {
let start = expected!(self, t!("<")).span;
t!("<|") => {
let amount = self.next_token_value()?;
let dist = self.eat(t!(",")).then(|| self.parse_distance()).transpose()?;
self.expect_closing_delimiter(t!(">"), start)?;
self.expect_closing_delimiter(t!("|>"), token.span)?;
Operator::Knn(amount, dist)
}

View file

@ -92,7 +92,6 @@ keyword! {
Is => "IS",
Key => "KEY",
Kill => "KILL",
Knn => "KNN",
Let => "LET",
Limit => "LIMIT",
Live => "LIVE",

View file

@ -53,6 +53,12 @@ macro_rules! t {
(">") => {
$crate::syn::v2::token::TokenKind::RightChefron
};
("<|") => {
$crate::syn::v2::token::TokenKind::Operator($crate::syn::v2::token::Operator::KnnOpen)
};
("|>") => {
$crate::syn::v2::token::TokenKind::Operator($crate::syn::v2::token::Operator::KnnClose)
};
(";") => {
$crate::syn::v2::token::TokenKind::SemiColon

View file

@ -128,6 +128,10 @@ pub enum Operator {
Tco,
/// `??`
Nco,
/// `<|`
KnnOpen,
/// `|>`
KnnClose,
}
impl Operator {
@ -169,6 +173,8 @@ impl Operator {
Operator::Ext => "+?=",
Operator::Tco => "?:",
Operator::Nco => "??",
Operator::KnnOpen => "<|",
Operator::KnnClose => "|>",
}
}
}

View file

@ -14,8 +14,8 @@ async fn select_where_mtree_knn() -> Result<(), Error> {
CREATE pts:3 SET point = [8,9,10,11];
DEFINE INDEX mt_pts ON pts FIELDS point MTREE DIMENSION 4;
LET $pt = [2,3,4,5];
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<2,EUCLIDEAN> $pt;
SELECT id FROM pts WHERE point knn<2> $pt EXPLAIN;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt;
SELECT id FROM pts WHERE point <|2|> $pt EXPLAIN;
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
@ -70,7 +70,7 @@ async fn delete_update_mtree_index() -> Result<(), Error> {
DELETE pts:2;
UPDATE pts:3 SET point = [12,13,14,15];
LET $pt = [2,3,4,5];
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<5> $pt ORDER BY dist;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|5|> $pt ORDER BY dist;
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
@ -153,9 +153,9 @@ async fn select_where_brut_force_knn() -> Result<(), Error> {
CREATE pts:2 SET point = [4,5,6,7];
CREATE pts:3 SET point = [8,9,10,11];
LET $pt = [2,3,4,5];
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<2,EUCLIDEAN> $pt;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<2,EUCLIDEAN> $pt PARALLEL;
SELECT id FROM pts WHERE point knn<2> $pt EXPLAIN;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt PARALLEL;
SELECT id FROM pts WHERE point <|2|> $pt EXPLAIN;
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");