surrealpatch/lib/src/sql/scoring.rs
2023-08-17 19:03:46 +01:00

135 lines
2.6 KiB
Rust

use crate::sql::common::{closeparentheses, commas, openparentheses};
use crate::sql::error::IResult;
use crate::sql::Error::Parser;
use nom::branch::alt;
use nom::bytes::complete::tag_no_case;
use nom::combinator::map;
use nom::number::complete::recognize_float;
use nom::Err::Failure;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::hash::{Hash, Hasher};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[revisioned(revision = 1)]
pub enum Scoring {
Bm {
k1: f32,
b: f32,
}, // BestMatching25
Vs, // VectorSearch
}
impl Eq for Scoring {}
impl PartialEq for Scoring {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
Scoring::Bm {
k1,
b,
},
Scoring::Bm {
k1: other_k1,
b: other_b,
},
) => k1.to_bits() == other_k1.to_bits() && b.to_bits() == other_b.to_bits(),
(Scoring::Vs, Scoring::Vs) => true,
_ => false,
}
}
}
impl Hash for Scoring {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Scoring::Bm {
k1,
b,
} => {
k1.to_bits().hash(state);
b.to_bits().hash(state);
}
Scoring::Vs => 0.hash(state),
}
}
}
impl Scoring {
pub(crate) fn bm25() -> Self {
Self::Bm {
k1: 1.2,
b: 0.75,
}
}
}
impl fmt::Display for Scoring {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Bm {
k1,
b,
} => write!(f, "BM25({},{})", k1, b),
Self::Vs => f.write_str("VS"),
}
}
}
pub fn scoring(i: &str) -> IResult<&str, Scoring> {
alt((
map(tag_no_case("VS"), |_| Scoring::Vs),
|i| {
let (i, _) = tag_no_case("BM25")(i)?;
let (i, _) = openparentheses(i)?;
let (i, k1) = recognize_float(i)?;
let k1 = k1.parse::<f32>().map_err(|_| Failure(Parser(i)))?;
let (i, _) = commas(i)?;
let (i, b) = recognize_float(i)?;
let b = b.parse::<f32>().map_err(|_| Failure(Parser(i)))?;
let (i, _) = closeparentheses(i)?;
Ok((
i,
Scoring::Bm {
k1,
b,
},
))
},
map(tag_no_case("BM25"), |_| Scoring::bm25()),
))(i)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scoring_bm_25_with_parameters() {
let sql = "BM25(1.0,0.6)";
let res = scoring(sql);
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("BM25(1,0.6)", format!("{}", out))
}
#[test]
fn scoring_bm_25_without_parameters() {
let sql = "BM25";
let res = scoring(sql);
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("BM25(1.2,0.75)", format!("{}", out))
}
#[test]
fn scoring_vs() {
let sql = "VS";
let res = scoring(sql);
assert!(res.is_ok());
let out = res.unwrap().1;
assert_eq!("VS", format!("{}", out))
}
}