Implement parsing for ML models. (#2691)

Co-authored-by: Tobie Morgan Hitchcock <tobie@surrealdb.com>
This commit is contained in:
Mees Delzenne 2023-09-13 08:06:28 +02:00 committed by GitHub
parent 538ad50c65
commit 178e2a0d4a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 365 additions and 161 deletions

View file

@ -394,7 +394,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
Value::Param(param) => json!(param), Value::Param(param) => json!(param),
Value::Idiom(idiom) => json!(idiom), Value::Idiom(idiom) => json!(idiom),
Value::Table(table) => json!(table), Value::Table(table) => json!(table),
Value::Model(model) => json!(model), Value::Mock(mock) => json!(mock),
Value::Regex(regex) => json!(regex), Value::Regex(regex) => json!(regex),
Value::Block(block) => json!(block), Value::Block(block) => json!(block),
Value::Range(range) => json!(range), Value::Range(range) => json!(range),
@ -409,6 +409,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
}, },
Value::Cast(cast) => json!(cast), Value::Cast(cast) => json!(cast),
Value::Function(function) => json!(function), Value::Function(function) => json!(function),
Value::MlModel(model) => json!(model),
Value::Query(query) => json!(query), Value::Query(query) => json!(query),
Value::Subquery(subquery) => json!(subquery), Value::Subquery(subquery) => json!(subquery),
Value::Expression(expression) => json!(expression), Value::Expression(expression) => json!(expression),

View file

@ -142,7 +142,7 @@ impl Iterator {
// Add the record to the iterator // Add the record to the iterator
self.ingest(Iterable::Thing(v)); self.ingest(Iterable::Thing(v));
} }
Value::Model(v) => { Value::Mock(v) => {
// Check if there is a data clause // Check if there is a data clause
if let Some(data) = stm.data() { if let Some(data) = stm.data() {
// Check if there is an id field specified // Check if there is an id field specified

130
lib/src/sql/mock.rs Normal file
View file

@ -0,0 +1,130 @@
use crate::sql::common::take_u64;
use crate::sql::error::IResult;
use crate::sql::escape::escape_ident;
use crate::sql::id::Id;
use crate::sql::ident::ident_raw;
use crate::sql::thing::Thing;
use nom::character::complete::char;
use nom::combinator::map;
use nom::{branch::alt, combinator::value};
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Mock";
pub struct IntoIter {
model: Mock,
index: u64,
}
impl Iterator for IntoIter {
type Item = Thing;
fn next(&mut self) -> Option<Thing> {
match &self.model {
Mock::Count(tb, c) => {
if self.index < *c {
self.index += 1;
Some(Thing {
tb: tb.to_string(),
id: Id::rand(),
})
} else {
None
}
}
Mock::Range(tb, b, e) => {
if self.index == 0 {
self.index = *b - 1;
}
if self.index < *e {
self.index += 1;
Some(Thing {
tb: tb.to_string(),
id: Id::from(self.index),
})
} else {
None
}
}
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Mock")]
#[revisioned(revision = 1)]
pub enum Mock {
Count(String, u64),
Range(String, u64, u64),
// Add new variants here
}
impl IntoIterator for Mock {
type Item = Thing;
type IntoIter = IntoIter;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
model: self,
index: 0,
}
}
}
impl fmt::Display for Mock {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Mock::Count(tb, c) => {
write!(f, "|{}:{}|", escape_ident(tb), c)
}
Mock::Range(tb, b, e) => {
write!(f, "|{}:{}..{}|", escape_ident(tb), b, e)
}
}
}
}
pub fn mock(i: &str) -> IResult<&str, Mock> {
let (i, _) = char('|')(i)?;
let (i, t) = ident_raw(i)?;
let (i, _) = char(':')(i)?;
let (i, c) = take_u64(i)?;
let (i, e) = alt((value(None, char('|')), map(mock_range, Some)))(i)?;
if let Some(e) = e {
Ok((i, Mock::Range(t, c, e)))
} else {
Ok((i, Mock::Count(t, c)))
}
}
fn mock_range(i: &str) -> IResult<&str, u64> {
let (i, _) = char('.')(i)?;
let (i, _) = char('.')(i)?;
let (i, e) = take_u64(i)?;
let (i, _) = char('|')(i)?;
Ok((i, e))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mock_count() {
let sql = "|test:1000|";
let res = mock(sql);
let out = res.unwrap().1;
assert_eq!("|test:1000|", format!("{}", out));
assert_eq!(out, Mock::Count(String::from("test"), 1000));
}
#[test]
fn mock_range() {
let sql = "|test:1..1000|";
let res = mock(sql);
let out = res.unwrap().1;
assert_eq!("|test:1..1000|", format!("{}", out));
assert_eq!(out, Mock::Range(String::from("test"), 1, 1000));
}
}

View file

@ -38,6 +38,7 @@ pub(crate) mod index;
pub(crate) mod kind; pub(crate) mod kind;
pub(crate) mod language; pub(crate) mod language;
pub(crate) mod limit; pub(crate) mod limit;
pub(crate) mod mock;
pub(crate) mod model; pub(crate) mod model;
pub(crate) mod number; pub(crate) mod number;
pub(crate) mod object; pub(crate) mod object;
@ -114,7 +115,7 @@ pub use self::idiom::Idioms;
pub use self::index::Index; pub use self::index::Index;
pub use self::kind::Kind; pub use self::kind::Kind;
pub use self::limit::Limit; pub use self::limit::Limit;
pub use self::model::Model; pub use self::mock::Mock;
pub use self::number::Number; pub use self::number::Number;
pub use self::object::Object; pub use self::object::Object;
pub use self::operation::Operation; pub use self::operation::Operation;

View file

@ -1,131 +1,133 @@
use crate::sql::common::take_u64; use async_recursion::async_recursion;
use crate::sql::error::IResult; use derive::Store;
use crate::sql::escape::escape_ident; use nom::{
use crate::sql::id::Id; bytes::complete::{tag, take_while1},
use crate::sql::ident::ident_raw; character::complete::i64,
use crate::sql::thing::Thing; combinator::{cut, recognize},
use nom::character::complete::char; multi::separated_list1,
use nom::combinator::map; };
use nom::{branch::alt, combinator::value};
use revision::revisioned; use revision::revisioned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Model"; use crate::{
ctx::Context,
dbs::{Options, Transaction},
doc::CursorDoc,
err::Error,
sql::{error::IResult, value::Value},
};
pub struct IntoIter { use super::{
model: Model, common::{closechevron, closeparentheses, openchevron, openparentheses, val_char},
index: u64, error::{expect_tag_no_case, expected},
} util::expect_delimited,
value::value,
};
impl Iterator for IntoIter { #[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
type Item = Thing;
fn next(&mut self) -> Option<Thing> {
match &self.model {
Model::Count(tb, c) => {
if self.index < *c {
self.index += 1;
Some(Thing {
tb: tb.to_string(),
id: Id::rand(),
})
} else {
None
}
}
Model::Range(tb, b, e) => {
if self.index == 0 {
self.index = *b - 1;
}
if self.index < *e {
self.index += 1;
Some(Thing {
tb: tb.to_string(),
id: Id::from(self.index),
})
} else {
None
}
}
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Model")]
#[revisioned(revision = 1)] #[revisioned(revision = 1)]
pub enum Model { pub struct Model {
Count(String, u64), pub name: String,
Range(String, u64, u64), pub version: String,
// Add new variants here pub parameters: Value,
}
impl IntoIterator for Model {
type Item = Thing;
type IntoIter = IntoIter;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
model: self,
index: 0,
}
}
} }
impl fmt::Display for Model { impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { write!(f, "ml::{}<{}>({})", self.name, self.version, self.parameters)
Model::Count(tb, c) => {
write!(f, "|{}:{}|", escape_ident(tb), c)
}
Model::Range(tb, b, e) => {
write!(f, "|{}:{}..{}|", escape_ident(tb), b, e)
}
} }
}
impl Model {
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
_opt: &Options,
_txn: &Transaction,
_doc: Option<&'async_recursion CursorDoc<'_>>,
) -> Result<Value, Error> {
Err(Error::Unimplemented("ML model evaluation not yet implemented".to_string()))
} }
} }
pub fn model(i: &str) -> IResult<&str, Model> { pub fn model(i: &str) -> IResult<&str, Model> {
let (i, _) = char('|')(i)?; let (i, _) = tag("ml::")(i)?;
let (i, t) = ident_raw(i)?;
let (i, _) = char(':')(i)?; cut(|i| {
let (i, c) = take_u64(i)?; let (i, name) = recognize(separated_list1(tag("::"), take_while1(val_char)))(i)?;
let (i, e) = alt((value(None, char('|')), map(model_range, Some)))(i)?;
if let Some(e) = e { let (i, version) =
Ok((i, Model::Range(t, c, e))) expected("a version", expect_delimited(openchevron, version, closechevron))(i)?;
} else {
Ok((i, Model::Count(t, c))) let (i, parameters) = expected(
} "model parameters",
expect_delimited(openparentheses, value, closeparentheses),
)(i)?;
Ok((
i,
Model {
name: name.to_owned(),
version,
parameters,
},
))
})(i)
} }
fn model_range(i: &str) -> IResult<&str, u64> { pub fn version(i: &str) -> IResult<&str, String> {
let (i, _) = char('.')(i)?; use std::fmt::Write;
let (i, _) = char('.')(i)?;
let (i, e) = take_u64(i)?; let (i, major) = expected("a version number", i64)(i)?;
let (i, _) = char('|')(i)?; let (i, _) = expect_tag_no_case(".")(i)?;
//Ok((i, Model::Range(t, b, e))) let (i, minor) = expected("a version number", i64)(i)?;
Ok((i, e)) let (i, _) = expect_tag_no_case(".")(i)?;
let (i, patch) = expected("a version number", i64)(i)?;
let mut res = String::new();
// Writing into a string can never error.
write!(&mut res, "{major}.{minor}.{patch}").unwrap();
Ok((i, res))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod test {
use super::*; use super::*;
use crate::sql::query;
#[test] #[test]
fn model_count() { fn ml_model_example() {
let sql = "|test:1000|"; let sql = r#"ml::insurance::prediction<1.0.0>({
age: 18,
disposable_income: "yes",
purchased_before: true
})
"#;
let res = model(sql); let res = model(sql);
let out = res.unwrap().1; let out = res.unwrap().1.to_string();
assert_eq!("|test:1000|", format!("{}", out)); assert_eq!("ml::insurance::prediction<1.0.0>({ age: 18, disposable_income: 'yes', purchased_before: true })",out);
assert_eq!(out, Model::Count(String::from("test"), 1000));
} }
#[test] #[test]
fn model_range() { fn ml_model_example_in_select() {
let sql = "|test:1..1000|"; let sql = r"
let res = model(sql); SELECT
let out = res.unwrap().1; name,
assert_eq!("|test:1..1000|", format!("{}", out)); age,
assert_eq!(out, Model::Range(String::from("test"), 1, 1000)); ml::insurance::prediction<1.0.0>({
age: age,
disposable_income: math::round(income),
purchased_before: array::len(->purchased->property) > 0,
}) AS likely_to_buy FROM person:tobie;
";
let res = query::query(sql);
let out = res.unwrap().1.to_string();
assert_eq!(
"SELECT name, age, ml::insurance::prediction<1.0.0>({ age: age, disposable_income: math::round(income), purchased_before: array::len(->purchased->property) > 0 }) AS likely_to_buy FROM person:tobie;",
out,
);
} }
} }

View file

@ -4,6 +4,7 @@ mod event;
mod field; mod field;
mod function; mod function;
mod index; mod index;
mod model;
mod namespace; mod namespace;
mod param; mod param;
mod scope; mod scope;
@ -17,6 +18,7 @@ pub use event::{event, DefineEventStatement};
pub use field::{field, DefineFieldStatement}; pub use field::{field, DefineFieldStatement};
pub use function::{function, DefineFunctionStatement}; pub use function::{function, DefineFunctionStatement};
pub use index::{index, DefineIndexStatement}; pub use index::{index, DefineIndexStatement};
pub use model::DefineModelStatement;
pub use namespace::{namespace, DefineNamespaceStatement}; pub use namespace::{namespace, DefineNamespaceStatement};
use nom::bytes::complete::tag_no_case; use nom::bytes::complete::tag_no_case;
pub use param::{param, DefineParamStatement}; pub use param::{param, DefineParamStatement};
@ -55,6 +57,7 @@ pub enum DefineStatement {
Field(DefineFieldStatement), Field(DefineFieldStatement),
Index(DefineIndexStatement), Index(DefineIndexStatement),
User(DefineUserStatement), User(DefineUserStatement),
MlModel(DefineModelStatement),
} }
impl DefineStatement { impl DefineStatement {
@ -83,6 +86,7 @@ impl DefineStatement {
Self::Index(ref v) => v.compute(ctx, opt, txn, doc).await, Self::Index(ref v) => v.compute(ctx, opt, txn, doc).await,
Self::Analyzer(ref v) => v.compute(ctx, opt, txn, doc).await, Self::Analyzer(ref v) => v.compute(ctx, opt, txn, doc).await,
Self::User(ref v) => v.compute(ctx, opt, txn, doc).await, Self::User(ref v) => v.compute(ctx, opt, txn, doc).await,
Self::MlModel(ref v) => v.compute(ctx, opt, txn, doc).await,
} }
} }
} }
@ -102,6 +106,7 @@ impl Display for DefineStatement {
Self::Field(v) => Display::fmt(v, f), Self::Field(v) => Display::fmt(v, f),
Self::Index(v) => Display::fmt(v, f), Self::Index(v) => Display::fmt(v, f),
Self::Analyzer(v) => Display::fmt(v, f), Self::Analyzer(v) => Display::fmt(v, f),
Self::MlModel(v) => Display::fmt(v, f),
} }
} }
} }

View file

@ -0,0 +1,59 @@
use crate::sql::fmt::is_pretty;
use crate::sql::fmt::pretty_indent;
use crate::sql::permission::Permission;
use async_recursion::async_recursion;
use derive::Store;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::fmt::Write;
use crate::{
ctx::Context,
dbs::{Options, Transaction},
doc::CursorDoc,
err::Error,
sql::{Ident, Strand, Value},
};
#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[revisioned(revision = 1)]
pub struct DefineModelStatement {
pub name: Ident,
pub version: String,
pub comment: Option<Strand>,
pub permissions: Permission,
}
impl fmt::Display for DefineModelStatement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DEFINE MODEL ml::{}<{}>", self.name, self.version)?;
if let Some(comment) = self.comment.as_ref() {
write!(f, "COMMENT {}", comment)?;
}
if !self.permissions.is_full() {
let _indent = if is_pretty() {
Some(pretty_indent())
} else {
f.write_char(' ')?;
None
};
write!(f, "PERMISSIONS {}", self.permissions)?;
}
Ok(())
}
}
impl DefineModelStatement {
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
_opt: &Options,
_txn: &Transaction,
_doc: Option<&'async_recursion CursorDoc<'_>>,
) -> Result<Value, Error> {
Err(Error::Unimplemented("Ml model definition not yet implemented".to_string()))
}
}

View file

@ -105,7 +105,7 @@ impl SelectStatement {
Value::Thing(v) => i.ingest(Iterable::Thing(v)), Value::Thing(v) => i.ingest(Iterable::Thing(v)),
Value::Range(v) => i.ingest(Iterable::Range(*v)), Value::Range(v) => i.ingest(Iterable::Range(*v)),
Value::Edges(v) => i.ingest(Iterable::Edges(*v)), Value::Edges(v) => i.ingest(Iterable::Edges(*v)),
Value::Model(v) => { Value::Mock(v) => {
for v in v { for v in v {
i.ingest(Iterable::Thing(v)); i.ingest(Iterable::Thing(v));
} }
@ -118,7 +118,7 @@ impl SelectStatement {
} }
Value::Thing(v) => i.ingest(Iterable::Thing(v)), Value::Thing(v) => i.ingest(Iterable::Thing(v)),
Value::Edges(v) => i.ingest(Iterable::Edges(*v)), Value::Edges(v) => i.ingest(Iterable::Edges(*v)),
Value::Model(v) => { Value::Mock(v) => {
for v in v { for v in v {
i.ingest(Iterable::Thing(v)); i.ingest(Iterable::Thing(v));
} }

View file

@ -1,6 +1,6 @@
use crate::err::Error; use crate::err::Error;
use crate::sql::value::serde::ser; use crate::sql::value::serde::ser;
use crate::sql::Model; use crate::sql::Mock;
use ser::Serializer as _; use ser::Serializer as _;
use serde::ser::Error as _; use serde::ser::Error as _;
use serde::ser::Impossible; use serde::ser::Impossible;
@ -9,18 +9,18 @@ use serde::ser::Serialize;
pub(super) struct Serializer; pub(super) struct Serializer;
impl ser::Serializer for Serializer { impl ser::Serializer for Serializer {
type Ok = Model; type Ok = Mock;
type Error = Error; type Error = Error;
type SerializeSeq = Impossible<Model, Error>; type SerializeSeq = Impossible<Mock, Error>;
type SerializeTuple = Impossible<Model, Error>; type SerializeTuple = Impossible<Mock, Error>;
type SerializeTupleStruct = Impossible<Model, Error>; type SerializeTupleStruct = Impossible<Mock, Error>;
type SerializeTupleVariant = SerializeModel; type SerializeTupleVariant = SerializeMock;
type SerializeMap = Impossible<Model, Error>; type SerializeMap = Impossible<Mock, Error>;
type SerializeStruct = Impossible<Model, Error>; type SerializeStruct = Impossible<Mock, Error>;
type SerializeStructVariant = Impossible<Model, Error>; type SerializeStructVariant = Impossible<Mock, Error>;
const EXPECTED: &'static str = "an enum `Model`"; const EXPECTED: &'static str = "an enum `Mock`";
fn serialize_tuple_variant( fn serialize_tuple_variant(
self, self,
@ -36,14 +36,14 @@ impl ser::Serializer for Serializer {
return Err(Error::custom(format!("unexpected tuple variant `{name}::{variant}`"))); return Err(Error::custom(format!("unexpected tuple variant `{name}::{variant}`")));
} }
}; };
Ok(SerializeModel { Ok(SerializeMock {
inner, inner,
index: 0, index: 0,
}) })
} }
} }
pub(super) struct SerializeModel { pub(super) struct SerializeMock {
index: usize, index: usize,
inner: Inner, inner: Inner,
} }
@ -53,8 +53,8 @@ enum Inner {
Range(Option<String>, Option<u64>, Option<u64>), Range(Option<String>, Option<u64>, Option<u64>),
} }
impl serde::ser::SerializeTupleVariant for SerializeModel { impl serde::ser::SerializeTupleVariant for SerializeMock {
type Ok = Model; type Ok = Mock;
type Error = Error; type Error = Error;
fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error> fn serialize_field<T>(&mut self, value: &T) -> Result<(), Self::Error>
@ -76,9 +76,7 @@ impl serde::ser::SerializeTupleVariant for SerializeModel {
Inner::Count(..) => "Count", Inner::Count(..) => "Count",
Inner::Range(..) => "Range", Inner::Range(..) => "Range",
}; };
return Err(Error::custom(format!( return Err(Error::custom(format!("unexpected `Mock::{variant}` index `{index}`")));
"unexpected `Model::{variant}` index `{index}`"
)));
} }
} }
self.index += 1; self.index += 1;
@ -87,9 +85,9 @@ impl serde::ser::SerializeTupleVariant for SerializeModel {
fn end(self) -> Result<Self::Ok, Self::Error> { fn end(self) -> Result<Self::Ok, Self::Error> {
match self.inner { match self.inner {
Inner::Count(Some(one), Some(two)) => Ok(Model::Count(one, two)), Inner::Count(Some(one), Some(two)) => Ok(Mock::Count(one, two)),
Inner::Range(Some(one), Some(two), Some(three)) => Ok(Model::Range(one, two, three)), Inner::Range(Some(one), Some(two), Some(three)) => Ok(Mock::Range(one, two, three)),
_ => Err(Error::custom("`Model` missing required value(s)")), _ => Err(Error::custom("`Mock` missing required value(s)")),
} }
} }
} }
@ -101,14 +99,14 @@ mod tests {
#[test] #[test]
fn count() { fn count() {
let model = Model::Count(Default::default(), Default::default()); let model = Mock::Count(Default::default(), Default::default());
let serialized = model.serialize(Serializer.wrap()).unwrap(); let serialized = model.serialize(Serializer.wrap()).unwrap();
assert_eq!(model, serialized); assert_eq!(model, serialized);
} }
#[test] #[test]
fn range() { fn range() {
let model = Model::Range(Default::default(), 1, 2); let model = Mock::Range(Default::default(), 1, 2);
let serialized = model.serialize(Serializer.wrap()).unwrap(); let serialized = model.serialize(Serializer.wrap()).unwrap();
assert_eq!(model, serialized); assert_eq!(model, serialized);
} }

View file

@ -30,7 +30,7 @@ mod index;
mod kind; mod kind;
mod language; mod language;
mod limit; mod limit;
mod model; mod mock;
mod number; mod number;
mod operator; mod operator;
mod order; mod order;

View file

@ -27,7 +27,7 @@ use ser::cast::SerializeCast;
use ser::edges::SerializeEdges; use ser::edges::SerializeEdges;
use ser::expression::SerializeExpression; use ser::expression::SerializeExpression;
use ser::function::SerializeFunction; use ser::function::SerializeFunction;
use ser::model::SerializeModel; use ser::mock::SerializeMock;
use ser::range::SerializeRange; use ser::range::SerializeRange;
use ser::thing::SerializeThing; use ser::thing::SerializeThing;
use ser::Serializer as _; use ser::Serializer as _;
@ -323,14 +323,9 @@ impl ser::Serializer for Serializer {
len: usize, len: usize,
) -> Result<Self::SerializeTupleVariant, Error> { ) -> Result<Self::SerializeTupleVariant, Error> {
Ok(match name { Ok(match name {
sql::model::TOKEN => { sql::mock::TOKEN => SerializeTupleVariant::Model(
SerializeTupleVariant::Model(ser::model::Serializer.serialize_tuple_variant( ser::mock::Serializer.serialize_tuple_variant(name, variant_index, variant, len)?,
name, ),
variant_index,
variant,
len,
)?)
}
sql::function::TOKEN => { sql::function::TOKEN => {
SerializeTupleVariant::Function(ser::function::Serializer.serialize_tuple_variant( SerializeTupleVariant::Function(ser::function::Serializer.serialize_tuple_variant(
name, name,
@ -462,7 +457,7 @@ impl serde::ser::SerializeMap for SerializeMap {
} }
pub(super) enum SerializeTupleVariant { pub(super) enum SerializeTupleVariant {
Model(SerializeModel), Model(SerializeMock),
Function(SerializeFunction), Function(SerializeFunction),
Unknown { Unknown {
variant: &'static str, variant: &'static str,
@ -517,7 +512,7 @@ impl serde::ser::SerializeTupleVariant for SerializeTupleVariant {
fn end(self) -> Result<Value, Error> { fn end(self) -> Result<Value, Error> {
match self { match self {
Self::Model(model) => Ok(Value::Model(model.end()?)), Self::Model(model) => Ok(Value::Mock(model.end()?)),
Self::Function(function) => Ok(Value::Function(Box::new(function.end()?))), Self::Function(function) => Ok(Value::Function(Box::new(function.end()?))),
Self::Unknown { Self::Unknown {
variant, variant,
@ -789,9 +784,9 @@ mod tests {
#[test] #[test]
fn model() { fn model() {
let model = Model::Count("foo".to_owned(), Default::default()); let model = Mock::Count("foo".to_owned(), Default::default());
let value = to_value(&model).unwrap(); let value = to_value(&model).unwrap();
let expected = Value::Model(model); let expected = Value::Mock(model);
assert_eq!(value, expected); assert_eq!(value, expected);
assert_eq!(expected, to_value(&expected).unwrap()); assert_eq!(expected, to_value(&expected).unwrap());
} }

View file

@ -27,6 +27,7 @@ use crate::sql::geometry::{geometry, Geometry};
use crate::sql::id::{Gen, Id}; use crate::sql::id::{Gen, Id};
use crate::sql::idiom::{self, reparse_idiom_start, Idiom}; use crate::sql::idiom::{self, reparse_idiom_start, Idiom};
use crate::sql::kind::Kind; use crate::sql::kind::Kind;
use crate::sql::mock::{mock, Mock};
use crate::sql::model::{model, Model}; use crate::sql::model::{model, Model};
use crate::sql::number::{number, Number}; use crate::sql::number::{number, Number};
use crate::sql::object::{key, object, Object}; use crate::sql::object::{key, object, Object};
@ -140,7 +141,7 @@ pub enum Value {
Param(Param), Param(Param),
Idiom(Idiom), Idiom(Idiom),
Table(Table), Table(Table),
Model(Model), Mock(Mock),
Regex(Regex), Regex(Regex),
Cast(Box<Cast>), Cast(Box<Cast>),
Block(Box<Block>), Block(Box<Block>),
@ -153,6 +154,7 @@ pub enum Value {
Subquery(Box<Subquery>), Subquery(Box<Subquery>),
Expression(Box<Expression>), Expression(Box<Expression>),
Query(Query), Query(Query),
MlModel(Box<Model>),
// Add new variants here // Add new variants here
} }
@ -189,9 +191,9 @@ impl From<Idiom> for Value {
} }
} }
impl From<Model> for Value { impl From<Mock> for Value {
fn from(v: Model) -> Self { fn from(v: Mock) -> Self {
Value::Model(v) Value::Mock(v)
} }
} }
@ -303,6 +305,12 @@ impl From<Function> for Value {
} }
} }
impl From<Model> for Value {
fn from(v: Model) -> Self {
Value::MlModel(Box::new(v))
}
}
impl From<Subquery> for Value { impl From<Subquery> for Value {
fn from(v: Subquery) -> Self { fn from(v: Subquery) -> Self {
Value::Subquery(Box::new(v)) Value::Subquery(Box::new(v))
@ -854,9 +862,9 @@ impl Value {
matches!(self, Value::Thing(_)) matches!(self, Value::Thing(_))
} }
/// Check if this Value is a Model /// Check if this Value is a Mock
pub fn is_model(&self) -> bool { pub fn is_mock(&self) -> bool {
matches!(self, Value::Model(_)) matches!(self, Value::Mock(_))
} }
/// Check if this Value is a Range /// Check if this Value is a Range
@ -1055,7 +1063,8 @@ impl Value {
pub fn can_start_idiom(&self) -> bool { pub fn can_start_idiom(&self) -> bool {
match self { match self {
Value::Function(x) => !x.is_script(), Value::Function(x) => !x.is_script(),
Value::Subquery(_) Value::MlModel(_)
| Value::Subquery(_)
| Value::Constant(_) | Value::Constant(_)
| Value::Datetime(_) | Value::Datetime(_)
| Value::Duration(_) | Value::Duration(_)
@ -2526,10 +2535,11 @@ impl fmt::Display for Value {
Value::Edges(v) => write!(f, "{v}"), Value::Edges(v) => write!(f, "{v}"),
Value::Expression(v) => write!(f, "{v}"), Value::Expression(v) => write!(f, "{v}"),
Value::Function(v) => write!(f, "{v}"), Value::Function(v) => write!(f, "{v}"),
Value::MlModel(v) => write!(f, "{v}"),
Value::Future(v) => write!(f, "{v}"), Value::Future(v) => write!(f, "{v}"),
Value::Geometry(v) => write!(f, "{v}"), Value::Geometry(v) => write!(f, "{v}"),
Value::Idiom(v) => write!(f, "{v}"), Value::Idiom(v) => write!(f, "{v}"),
Value::Model(v) => write!(f, "{v}"), Value::Mock(v) => write!(f, "{v}"),
Value::Number(v) => write!(f, "{v}"), Value::Number(v) => write!(f, "{v}"),
Value::Object(v) => write!(f, "{v}"), Value::Object(v) => write!(f, "{v}"),
Value::Param(v) => write!(f, "{v}"), Value::Param(v) => write!(f, "{v}"),
@ -2556,6 +2566,7 @@ impl Value {
Value::Function(v) => { Value::Function(v) => {
v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable) v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable)
} }
Value::MlModel(m) => m.parameters.writeable(),
Value::Subquery(v) => v.writeable(), Value::Subquery(v) => v.writeable(),
Value::Expression(v) => v.writeable(), Value::Expression(v) => v.writeable(),
_ => false, _ => false,
@ -2586,6 +2597,7 @@ impl Value {
Value::Future(v) => v.compute(ctx, opt, txn, doc).await, Value::Future(v) => v.compute(ctx, opt, txn, doc).await,
Value::Constant(v) => v.compute(ctx, opt, txn, doc).await, Value::Constant(v) => v.compute(ctx, opt, txn, doc).await,
Value::Function(v) => v.compute(ctx, opt, txn, doc).await, Value::Function(v) => v.compute(ctx, opt, txn, doc).await,
Value::MlModel(v) => v.compute(ctx, opt, txn, doc).await,
Value::Subquery(v) => v.compute(ctx, opt, txn, doc).await, Value::Subquery(v) => v.compute(ctx, opt, txn, doc).await,
Value::Expression(v) => v.compute(ctx, opt, txn, doc).await, Value::Expression(v) => v.compute(ctx, opt, txn, doc).await,
_ => Ok(self.to_owned()), _ => Ok(self.to_owned()),
@ -2741,7 +2753,7 @@ pub fn single(i: &str) -> IResult<&str, Value> {
alt(( alt((
into(future), into(future),
into(cast), into(cast),
function_or_const, path_like,
into(geometry), into(geometry),
into(subquery), into(subquery),
into(datetime), into(datetime),
@ -2756,7 +2768,7 @@ pub fn single(i: &str) -> IResult<&str, Value> {
into(block), into(block),
into(param), into(param),
into(regex), into(regex),
into(model), into(mock),
into(edges), into(edges),
into(range), into(range),
into(thing), into(thing),
@ -2780,7 +2792,7 @@ pub fn select_start(i: &str) -> IResult<&str, Value> {
alt(( alt((
into(future), into(future),
into(cast), into(cast),
function_or_const, path_like,
into(geometry), into(geometry),
into(subquery), into(subquery),
into(datetime), into(datetime),
@ -2792,7 +2804,7 @@ pub fn select_start(i: &str) -> IResult<&str, Value> {
into(block), into(block),
into(param), into(param),
into(regex), into(regex),
into(model), into(mock),
into(edges), into(edges),
into(range), into(range),
into(thing), into(thing),
@ -2803,8 +2815,9 @@ pub fn select_start(i: &str) -> IResult<&str, Value> {
reparse_idiom_start(v, i) reparse_idiom_start(v, i)
} }
pub fn function_or_const(i: &str) -> IResult<&str, Value> { /// A path like production: Constants, predefined functions, user defined functions and ml models.
alt((into(defined_function), |i| { pub fn path_like(i: &str) -> IResult<&str, Value> {
alt((into(defined_function), into(model), |i| {
let (i, v) = builtin_name(i)?; let (i, v) = builtin_name(i)?;
match v { match v {
builtin::BuiltinName::Constant(x) => Ok((i, x.into())), builtin::BuiltinName::Constant(x) => Ok((i, x.into())),
@ -2841,14 +2854,14 @@ pub fn what(i: &str) -> IResult<&str, Value> {
let _diving = crate::sql::parser::depth::dive(i)?; let _diving = crate::sql::parser::depth::dive(i)?;
let (i, v) = alt(( let (i, v) = alt((
into(idiom::multi_without_start), into(idiom::multi_without_start),
function_or_const, path_like,
into(subquery), into(subquery),
into(datetime), into(datetime),
into(duration), into(duration),
into(future), into(future),
into(block), into(block),
into(param), into(param),
into(model), into(mock),
into(edges), into(edges),
into(range), into(range),
into(thing), into(thing),
@ -2995,7 +3008,7 @@ mod tests {
assert_eq!(24, std::mem::size_of::<crate::sql::idiom::Idiom>()); assert_eq!(24, std::mem::size_of::<crate::sql::idiom::Idiom>());
assert_eq!(24, std::mem::size_of::<crate::sql::table::Table>()); assert_eq!(24, std::mem::size_of::<crate::sql::table::Table>());
assert_eq!(56, std::mem::size_of::<crate::sql::thing::Thing>()); assert_eq!(56, std::mem::size_of::<crate::sql::thing::Thing>());
assert_eq!(40, std::mem::size_of::<crate::sql::model::Model>()); assert_eq!(40, std::mem::size_of::<crate::sql::mock::Mock>());
assert_eq!(32, std::mem::size_of::<crate::sql::regex::Regex>()); assert_eq!(32, std::mem::size_of::<crate::sql::regex::Regex>());
assert_eq!(8, std::mem::size_of::<Box<crate::sql::range::Range>>()); assert_eq!(8, std::mem::size_of::<Box<crate::sql::range::Range>>());
assert_eq!(8, std::mem::size_of::<Box<crate::sql::edges::Edges>>()); assert_eq!(8, std::mem::size_of::<Box<crate::sql::edges::Edges>>());