surrealpatch/lib/src/sql/model.rs

149 lines
3.7 KiB
Rust
Raw Normal View History

use async_recursion::async_recursion;
use derive::Store;
use nom::{
bytes::complete::{tag, take_while1},
character::complete::i64,
combinator::{cut, recognize},
multi::separated_list1,
};
use revision::revisioned;
2020-06-29 15:36:01 +00:00
use serde::{Deserialize, Serialize};
use std::fmt;
use crate::{
ctx::Context,
dbs::{Options, Transaction},
doc::CursorDoc,
err::Error,
sql::{error::IResult, value::Value},
};
use super::{
common::{closechevron, closeparentheses, commas, openchevron, openparentheses, val_char},
error::{expect_tag_no_case, expected},
util::{delimited_list1, expect_delimited},
value::value,
};
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[revisioned(revision = 1)]
pub struct Model {
pub name: String,
pub version: String,
pub args: Vec<Value>,
2020-06-29 15:36:01 +00:00
}
impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ml::{}<{}>(", self.name, self.version)?;
for (idx, p) in self.args.iter().enumerate() {
if idx != 0 {
write!(f, ",")?;
}
write!(f, "{}", p)?;
}
write!(f, ")")
}
}
impl Model {
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
_opt: &Options,
_txn: &Transaction,
_doc: Option<&'async_recursion CursorDoc<'_>>,
) -> Result<Value, Error> {
Err(Error::Unimplemented("ML model evaluation not yet implemented".to_string()))
2020-06-29 15:36:01 +00:00
}
}
pub fn model(i: &str) -> IResult<&str, Model> {
let (i, _) = tag("ml::")(i)?;
cut(|i| {
let (i, name) = recognize(separated_list1(tag("::"), take_while1(val_char)))(i)?;
let (i, version) =
expected("a version", expect_delimited(openchevron, version, closechevron))(i)?;
let (i, args) = expected(
"model arguments",
delimited_list1(openparentheses, commas, value, closeparentheses),
)(i)?;
Ok((
i,
Model {
name: name.to_owned(),
version,
args,
},
))
})(i)
2020-06-29 15:36:01 +00:00
}
pub fn version(i: &str) -> IResult<&str, String> {
use std::fmt::Write;
let (i, major) = expected("a version number", i64)(i)?;
let (i, _) = expect_tag_no_case(".")(i)?;
let (i, minor) = expected("a version number", i64)(i)?;
let (i, _) = expect_tag_no_case(".")(i)?;
let (i, patch) = expected("a version number", i64)(i)?;
let mut res = String::new();
// Writing into a string can never error.
write!(&mut res, "{major}.{minor}.{patch}").unwrap();
Ok((i, res))
2020-06-29 15:36:01 +00:00
}
#[cfg(test)]
mod test {
2020-06-29 15:36:01 +00:00
use super::*;
use crate::sql::query;
2020-06-29 15:36:01 +00:00
#[test]
fn ml_model_example() {
let sql = r#"ml::insurance::prediction<1.0.0>({
age: 18,
disposable_income: "yes",
purchased_before: true
})
"#;
2020-06-29 15:36:01 +00:00
let res = model(sql);
let out = res.unwrap().1.to_string();
assert_eq!("ml::insurance::prediction<1.0.0>({ age: 18, disposable_income: 'yes', purchased_before: true })",out);
2020-06-29 15:36:01 +00:00
}
#[test]
fn ml_model_example_in_select() {
let sql = r"
SELECT
name,
age,
ml::insurance::prediction<1.0.0>({
age: age,
disposable_income: math::round(income),
purchased_before: array::len(->purchased->property) > 0,
}) AS likely_to_buy FROM person:tobie;
";
let res = query::query(sql);
let out = res.unwrap().1.to_string();
assert_eq!(
"SELECT name, age, ml::insurance::prediction<1.0.0>({ age: age, disposable_income: math::round(income), purchased_before: array::len(->purchased->property) > 0 }) AS likely_to_buy FROM person:tobie;",
out,
);
2020-06-29 15:36:01 +00:00
}
#[test]
fn ml_model_with_mutiple_arguments() {
let sql = "ml::insurance::prediction<1.0.0>(1,2,3,4,);";
let res = query::query(sql);
let out = res.unwrap().1.to_string();
assert_eq!("ml::insurance::prediction<1.0.0>(1,2,3,4);", out,);
}
2020-06-29 15:36:01 +00:00
}