Change KNN operator syntax. (#3617)

Co-authored-by: Emmanuel Keller <emmanuel.keller@surrealdb.com>
This commit is contained in:
Mees Delzenne 2024-03-05 17:01:17 +01:00 committed by GitHub
parent 957eff19a9
commit 807b4681fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 66 additions and 24 deletions

View file

@ -140,9 +140,9 @@ impl fmt::Display for Operator {
} }
Self::Knn(k, dist) => { Self::Knn(k, dist) => {
if let Some(d) = dist { if let Some(d) = dist {
write!(f, "<{k},{d}>") write!(f, "<|{k},{d}|>")
} else { } else {
write!(f, "<{k}>") write!(f, "<|{k}|>")
} }
} }
} }

View file

@ -150,12 +150,25 @@ pub fn knn_distance(i: &str) -> IResult<&str, Distance> {
} }
pub fn knn(i: &str) -> IResult<&str, Operator> { pub fn knn(i: &str) -> IResult<&str, Operator> {
alt((
|i| {
let (i, _) = opt(tag_no_case("knn"))(i)?; let (i, _) = opt(tag_no_case("knn"))(i)?;
let (i, _) = char('<')(i)?; let (i, _) = char('<')(i)?;
let (i, k) = u32(i)?; let (i, k) = u32(i)?;
let (i, dist) = opt(knn_distance)(i)?; let (i, dist) = opt(knn_distance)(i)?;
let (i, _) = char('>')(i)?; let (i, _) = char('>')(i)?;
Ok((i, Operator::Knn(k, dist))) Ok((i, Operator::Knn(k, dist)))
},
|i| {
let (i, _) = tag("<|")(i)?;
cut(|i| {
let (i, k) = u32(i)?;
let (i, dist) = opt(knn_distance)(i)?;
let (i, _) = tag("|>")(i)?;
Ok((i, Operator::Knn(k, dist)))
})(i)
},
))(i)
} }
pub fn dir(i: &str) -> IResult<&str, Dir> { pub fn dir(i: &str) -> IResult<&str, Dir> {
@ -218,7 +231,13 @@ mod tests {
let res = knn("<5>"); let res = knn("<5>");
assert!(res.is_ok()); assert!(res.is_ok());
let out = res.unwrap().1; let out = res.unwrap().1;
assert_eq!("<5>", format!("{}", out)); assert_eq!("<|5|>", format!("{}", out));
assert_eq!(out, Operator::Knn(5, None));
let res = knn("<|5|>");
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("<|5|>", format!("{}", out));
assert_eq!(out, Operator::Knn(5, None)); assert_eq!(out, Operator::Knn(5, None));
} }
@ -227,16 +246,22 @@ mod tests {
let res = knn("<3,EUCLIDEAN>"); let res = knn("<3,EUCLIDEAN>");
assert!(res.is_ok()); assert!(res.is_ok());
let out = res.unwrap().1; let out = res.unwrap().1;
assert_eq!("<3,EUCLIDEAN>", format!("{}", out)); assert_eq!("<|3,EUCLIDEAN|>", format!("{}", out));
assert_eq!(out, Operator::Knn(3, Some(Distance::Euclidean)));
let res = knn("<|3,EUCLIDEAN|>");
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("<|3,EUCLIDEAN|>", format!("{}", out));
assert_eq!(out, Operator::Knn(3, Some(Distance::Euclidean))); assert_eq!(out, Operator::Knn(3, Some(Distance::Euclidean)));
} }
#[test] #[test]
fn test_knn_with_prefix() { fn test_knn_with_prefix() {
let res = knn("knn<5>"); let res = knn("<|5|>");
assert!(res.is_ok()); assert!(res.is_ok());
let out = res.unwrap().1; let out = res.unwrap().1;
assert_eq!("<5>", format!("{}", out)); assert_eq!("<|5|>", format!("{}", out));
assert_eq!(out, Operator::Knn(5, None)); assert_eq!(out, Operator::Knn(5, None));
} }
} }

View file

@ -170,6 +170,10 @@ impl<'a> Lexer<'a> {
self.reader.next(); self.reader.next();
t!("||") t!("||")
} }
Some(b'>') => {
self.reader.next();
t!("|>")
}
_ => t!("|"), _ => t!("|"),
}, },
b'&' => match self.reader.peek() { b'&' => match self.reader.peek() {
@ -227,6 +231,10 @@ impl<'a> Lexer<'a> {
self.reader.next(); self.reader.next();
t!("<=") t!("<=")
} }
Some(b'|') => {
self.reader.next();
t!("<|")
}
Some(b'-') => { Some(b'-') => {
self.reader.next(); self.reader.next();
match self.reader.peek() { match self.reader.peek() {

View file

@ -84,7 +84,6 @@ pub(crate) static KEYWORDS: phf::Map<UniCase<&'static str>, TokenKind> = phf_map
UniCase::ascii("IS") => TokenKind::Keyword(Keyword::Is), UniCase::ascii("IS") => TokenKind::Keyword(Keyword::Is),
UniCase::ascii("KEY") => TokenKind::Keyword(Keyword::Key), UniCase::ascii("KEY") => TokenKind::Keyword(Keyword::Key),
UniCase::ascii("KILL") => TokenKind::Keyword(Keyword::Kill), UniCase::ascii("KILL") => TokenKind::Keyword(Keyword::Kill),
UniCase::ascii("KNN") => TokenKind::Keyword(Keyword::Knn),
UniCase::ascii("LET") => TokenKind::Keyword(Keyword::Let), UniCase::ascii("LET") => TokenKind::Keyword(Keyword::Let),
UniCase::ascii("LIMIT") => TokenKind::Keyword(Keyword::Limit), UniCase::ascii("LIMIT") => TokenKind::Keyword(Keyword::Limit),
UniCase::ascii("LIVE") => TokenKind::Keyword(Keyword::Live), UniCase::ascii("LIVE") => TokenKind::Keyword(Keyword::Live),

View file

@ -108,7 +108,7 @@ impl Parser<'_> {
| t!("INTERSECTS") | t!("INTERSECTS")
| t!("NOT") | t!("NOT")
| t!("IN") | t!("IN")
| t!("KNN") => Some((9, 10)), | t!("<|") => Some((9, 10)),
t!("+") | t!("-") => Some((11, 12)), t!("+") | t!("-") => Some((11, 12)),
t!("*") | t!("×") | t!("/") | t!("÷") => Some((13, 14)), t!("*") | t!("×") | t!("/") | t!("÷") => Some((13, 14)),
@ -253,11 +253,10 @@ impl Parser<'_> {
Operator::NotInside Operator::NotInside
} }
t!("IN") => Operator::Inside, t!("IN") => Operator::Inside,
t!("KNN") => { t!("<|") => {
let start = expected!(self, t!("<")).span;
let amount = self.next_token_value()?; let amount = self.next_token_value()?;
let dist = self.eat(t!(",")).then(|| self.parse_distance()).transpose()?; let dist = self.eat(t!(",")).then(|| self.parse_distance()).transpose()?;
self.expect_closing_delimiter(t!(">"), start)?; self.expect_closing_delimiter(t!("|>"), token.span)?;
Operator::Knn(amount, dist) Operator::Knn(amount, dist)
} }

View file

@ -92,7 +92,6 @@ keyword! {
Is => "IS", Is => "IS",
Key => "KEY", Key => "KEY",
Kill => "KILL", Kill => "KILL",
Knn => "KNN",
Let => "LET", Let => "LET",
Limit => "LIMIT", Limit => "LIMIT",
Live => "LIVE", Live => "LIVE",

View file

@ -53,6 +53,12 @@ macro_rules! t {
(">") => { (">") => {
$crate::syn::v2::token::TokenKind::RightChefron $crate::syn::v2::token::TokenKind::RightChefron
}; };
("<|") => {
$crate::syn::v2::token::TokenKind::Operator($crate::syn::v2::token::Operator::KnnOpen)
};
("|>") => {
$crate::syn::v2::token::TokenKind::Operator($crate::syn::v2::token::Operator::KnnClose)
};
(";") => { (";") => {
$crate::syn::v2::token::TokenKind::SemiColon $crate::syn::v2::token::TokenKind::SemiColon

View file

@ -128,6 +128,10 @@ pub enum Operator {
Tco, Tco,
/// `??` /// `??`
Nco, Nco,
/// `<|`
KnnOpen,
/// `|>`
KnnClose,
} }
impl Operator { impl Operator {
@ -169,6 +173,8 @@ impl Operator {
Operator::Ext => "+?=", Operator::Ext => "+?=",
Operator::Tco => "?:", Operator::Tco => "?:",
Operator::Nco => "??", Operator::Nco => "??",
Operator::KnnOpen => "<|",
Operator::KnnClose => "|>",
} }
} }
} }

View file

@ -14,8 +14,8 @@ async fn select_where_mtree_knn() -> Result<(), Error> {
CREATE pts:3 SET point = [8,9,10,11]; CREATE pts:3 SET point = [8,9,10,11];
DEFINE INDEX mt_pts ON pts FIELDS point MTREE DIMENSION 4; DEFINE INDEX mt_pts ON pts FIELDS point MTREE DIMENSION 4;
LET $pt = [2,3,4,5]; LET $pt = [2,3,4,5];
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<2,EUCLIDEAN> $pt; SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt;
SELECT id FROM pts WHERE point knn<2> $pt EXPLAIN; SELECT id FROM pts WHERE point <|2|> $pt EXPLAIN;
"; ";
let dbs = new_ds().await?; let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test"); let ses = Session::owner().with_ns("test").with_db("test");
@ -70,7 +70,7 @@ async fn delete_update_mtree_index() -> Result<(), Error> {
DELETE pts:2; DELETE pts:2;
UPDATE pts:3 SET point = [12,13,14,15]; UPDATE pts:3 SET point = [12,13,14,15];
LET $pt = [2,3,4,5]; LET $pt = [2,3,4,5];
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<5> $pt ORDER BY dist; SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|5|> $pt ORDER BY dist;
"; ";
let dbs = new_ds().await?; let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test"); let ses = Session::owner().with_ns("test").with_db("test");
@ -153,9 +153,9 @@ async fn select_where_brut_force_knn() -> Result<(), Error> {
CREATE pts:2 SET point = [4,5,6,7]; CREATE pts:2 SET point = [4,5,6,7];
CREATE pts:3 SET point = [8,9,10,11]; CREATE pts:3 SET point = [8,9,10,11];
LET $pt = [2,3,4,5]; LET $pt = [2,3,4,5];
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<2,EUCLIDEAN> $pt; SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt;
SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point knn<2,EUCLIDEAN> $pt PARALLEL; SELECT id, vector::distance::euclidean(point, $pt) AS dist FROM pts WHERE point <|2,EUCLIDEAN|> $pt PARALLEL;
SELECT id FROM pts WHERE point knn<2> $pt EXPLAIN; SELECT id FROM pts WHERE point <|2|> $pt EXPLAIN;
"; ";
let dbs = new_ds().await?; let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test"); let ses = Session::owner().with_ns("test").with_db("test");