From 178e2a0d4a1bf2371396368d96a3733e3aed6c36 Mon Sep 17 00:00:00 2001 From: Mees Delzenne Date: Wed, 13 Sep 2023 08:06:28 +0200 Subject: [PATCH] Implement parsing for ML models. (#2691) Co-authored-by: Tobie Morgan Hitchcock --- lib/src/api/opt/mod.rs | 3 +- lib/src/dbs/iterator.rs | 2 +- lib/src/sql/mock.rs | 130 +++++++++++ lib/src/sql/mod.rs | 3 +- lib/src/sql/model.rs | 206 +++++++++--------- lib/src/sql/statements/define/mod.rs | 5 + lib/src/sql/statements/define/model.rs | 59 +++++ lib/src/sql/statements/select.rs | 4 +- .../value/serde/ser/{model => mock}/mod.rs | 42 ++-- lib/src/sql/value/serde/ser/mod.rs | 2 +- lib/src/sql/value/serde/ser/value/mod.rs | 21 +- lib/src/sql/value/value.rs | 49 +++-- 12 files changed, 365 insertions(+), 161 deletions(-) create mode 100644 lib/src/sql/mock.rs create mode 100644 lib/src/sql/statements/define/model.rs rename lib/src/sql/value/serde/ser/{model => mock}/mod.rs (66%) diff --git a/lib/src/api/opt/mod.rs b/lib/src/api/opt/mod.rs index fde91240..8b1f38a4 100644 --- a/lib/src/api/opt/mod.rs +++ b/lib/src/api/opt/mod.rs @@ -394,7 +394,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue { Value::Param(param) => json!(param), Value::Idiom(idiom) => json!(idiom), Value::Table(table) => json!(table), - Value::Model(model) => json!(model), + Value::Mock(mock) => json!(mock), Value::Regex(regex) => json!(regex), Value::Block(block) => json!(block), Value::Range(range) => json!(range), @@ -409,6 +409,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue { }, Value::Cast(cast) => json!(cast), Value::Function(function) => json!(function), + Value::MlModel(model) => json!(model), Value::Query(query) => json!(query), Value::Subquery(subquery) => json!(subquery), Value::Expression(expression) => json!(expression), diff --git a/lib/src/dbs/iterator.rs b/lib/src/dbs/iterator.rs index 3bda9deb..704bcc34 100644 --- a/lib/src/dbs/iterator.rs +++ b/lib/src/dbs/iterator.rs @@ -142,7 +142,7 @@ impl Iterator { // Add the record to the iterator self.ingest(Iterable::Thing(v)); } - Value::Model(v) => { + Value::Mock(v) => { // Check if there is a data clause if let Some(data) = stm.data() { // Check if there is an id field specified diff --git a/lib/src/sql/mock.rs b/lib/src/sql/mock.rs new file mode 100644 index 00000000..c11e75d5 --- /dev/null +++ b/lib/src/sql/mock.rs @@ -0,0 +1,130 @@ +use crate::sql::common::take_u64; +use crate::sql::error::IResult; +use crate::sql::escape::escape_ident; +use crate::sql::id::Id; +use crate::sql::ident::ident_raw; +use crate::sql::thing::Thing; +use nom::character::complete::char; +use nom::combinator::map; +use nom::{branch::alt, combinator::value}; +use revision::revisioned; +use serde::{Deserialize, Serialize}; +use std::fmt; + +pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Mock"; + +pub struct IntoIter { + model: Mock, + index: u64, +} + +impl Iterator for IntoIter { + type Item = Thing; + fn next(&mut self) -> Option { + match &self.model { + Mock::Count(tb, c) => { + if self.index < *c { + self.index += 1; + Some(Thing { + tb: tb.to_string(), + id: Id::rand(), + }) + } else { + None + } + } + Mock::Range(tb, b, e) => { + if self.index == 0 { + self.index = *b - 1; + } + if self.index < *e { + self.index += 1; + Some(Thing { + tb: tb.to_string(), + id: Id::from(self.index), + }) + } else { + None + } + } + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] +#[serde(rename = "$surrealdb::private::sql::Mock")] +#[revisioned(revision = 1)] +pub enum Mock { + Count(String, u64), + Range(String, u64, u64), + // Add new variants here +} + +impl IntoIterator for Mock { + type Item = Thing; + type IntoIter = IntoIter; + fn into_iter(self) -> Self::IntoIter { + IntoIter { + model: self, + index: 0, + } + } +} + +impl fmt::Display for Mock { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Mock::Count(tb, c) => { + write!(f, "|{}:{}|", escape_ident(tb), c) + } + Mock::Range(tb, b, e) => { + write!(f, "|{}:{}..{}|", escape_ident(tb), b, e) + } + } + } +} + +pub fn mock(i: &str) -> IResult<&str, Mock> { + let (i, _) = char('|')(i)?; + let (i, t) = ident_raw(i)?; + let (i, _) = char(':')(i)?; + let (i, c) = take_u64(i)?; + let (i, e) = alt((value(None, char('|')), map(mock_range, Some)))(i)?; + if let Some(e) = e { + Ok((i, Mock::Range(t, c, e))) + } else { + Ok((i, Mock::Count(t, c))) + } +} + +fn mock_range(i: &str) -> IResult<&str, u64> { + let (i, _) = char('.')(i)?; + let (i, _) = char('.')(i)?; + let (i, e) = take_u64(i)?; + let (i, _) = char('|')(i)?; + Ok((i, e)) +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn mock_count() { + let sql = "|test:1000|"; + let res = mock(sql); + let out = res.unwrap().1; + assert_eq!("|test:1000|", format!("{}", out)); + assert_eq!(out, Mock::Count(String::from("test"), 1000)); + } + + #[test] + fn mock_range() { + let sql = "|test:1..1000|"; + let res = mock(sql); + let out = res.unwrap().1; + assert_eq!("|test:1..1000|", format!("{}", out)); + assert_eq!(out, Mock::Range(String::from("test"), 1, 1000)); + } +} diff --git a/lib/src/sql/mod.rs b/lib/src/sql/mod.rs index 3cce2d94..b8fc4844 100644 --- a/lib/src/sql/mod.rs +++ b/lib/src/sql/mod.rs @@ -38,6 +38,7 @@ pub(crate) mod index; pub(crate) mod kind; pub(crate) mod language; pub(crate) mod limit; +pub(crate) mod mock; pub(crate) mod model; pub(crate) mod number; pub(crate) mod object; @@ -114,7 +115,7 @@ pub use self::idiom::Idioms; pub use self::index::Index; pub use self::kind::Kind; pub use self::limit::Limit; -pub use self::model::Model; +pub use self::mock::Mock; pub use self::number::Number; pub use self::object::Object; pub use self::operation::Operation; diff --git a/lib/src/sql/model.rs b/lib/src/sql/model.rs index 37f2786e..be056ad0 100644 --- a/lib/src/sql/model.rs +++ b/lib/src/sql/model.rs @@ -1,131 +1,133 @@ -use crate::sql::common::take_u64; -use crate::sql::error::IResult; -use crate::sql::escape::escape_ident; -use crate::sql::id::Id; -use crate::sql::ident::ident_raw; -use crate::sql::thing::Thing; -use nom::character::complete::char; -use nom::combinator::map; -use nom::{branch::alt, combinator::value}; +use async_recursion::async_recursion; +use derive::Store; +use nom::{ + bytes::complete::{tag, take_while1}, + character::complete::i64, + combinator::{cut, recognize}, + multi::separated_list1, +}; use revision::revisioned; use serde::{Deserialize, Serialize}; use std::fmt; -pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Model"; +use crate::{ + ctx::Context, + dbs::{Options, Transaction}, + doc::CursorDoc, + err::Error, + sql::{error::IResult, value::Value}, +}; -pub struct IntoIter { - model: Model, - index: u64, -} +use super::{ + common::{closechevron, closeparentheses, openchevron, openparentheses, val_char}, + error::{expect_tag_no_case, expected}, + util::expect_delimited, + value::value, +}; -impl Iterator for IntoIter { - type Item = Thing; - fn next(&mut self) -> Option { - match &self.model { - Model::Count(tb, c) => { - if self.index < *c { - self.index += 1; - Some(Thing { - tb: tb.to_string(), - id: Id::rand(), - }) - } else { - None - } - } - Model::Range(tb, b, e) => { - if self.index == 0 { - self.index = *b - 1; - } - if self.index < *e { - self.index += 1; - Some(Thing { - tb: tb.to_string(), - id: Id::from(self.index), - }) - } else { - None - } - } - } - } -} - -#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Hash)] -#[serde(rename = "$surrealdb::private::sql::Model")] +#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[revisioned(revision = 1)] -pub enum Model { - Count(String, u64), - Range(String, u64, u64), - // Add new variants here -} - -impl IntoIterator for Model { - type Item = Thing; - type IntoIter = IntoIter; - fn into_iter(self) -> Self::IntoIter { - IntoIter { - model: self, - index: 0, - } - } +pub struct Model { + pub name: String, + pub version: String, + pub parameters: Value, } impl fmt::Display for Model { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Model::Count(tb, c) => { - write!(f, "|{}:{}|", escape_ident(tb), c) - } - Model::Range(tb, b, e) => { - write!(f, "|{}:{}..{}|", escape_ident(tb), b, e) - } - } + write!(f, "ml::{}<{}>({})", self.name, self.version, self.parameters) + } +} + +impl Model { + #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] + #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))] + pub(crate) async fn compute( + &self, + _ctx: &Context<'_>, + _opt: &Options, + _txn: &Transaction, + _doc: Option<&'async_recursion CursorDoc<'_>>, + ) -> Result { + Err(Error::Unimplemented("ML model evaluation not yet implemented".to_string())) } } pub fn model(i: &str) -> IResult<&str, Model> { - let (i, _) = char('|')(i)?; - let (i, t) = ident_raw(i)?; - let (i, _) = char(':')(i)?; - let (i, c) = take_u64(i)?; - let (i, e) = alt((value(None, char('|')), map(model_range, Some)))(i)?; - if let Some(e) = e { - Ok((i, Model::Range(t, c, e))) - } else { - Ok((i, Model::Count(t, c))) - } + let (i, _) = tag("ml::")(i)?; + + cut(|i| { + let (i, name) = recognize(separated_list1(tag("::"), take_while1(val_char)))(i)?; + + let (i, version) = + expected("a version", expect_delimited(openchevron, version, closechevron))(i)?; + + let (i, parameters) = expected( + "model parameters", + expect_delimited(openparentheses, value, closeparentheses), + )(i)?; + + Ok(( + i, + Model { + name: name.to_owned(), + version, + parameters, + }, + )) + })(i) } -fn model_range(i: &str) -> IResult<&str, u64> { - let (i, _) = char('.')(i)?; - let (i, _) = char('.')(i)?; - let (i, e) = take_u64(i)?; - let (i, _) = char('|')(i)?; - //Ok((i, Model::Range(t, b, e))) - Ok((i, e)) +pub fn version(i: &str) -> IResult<&str, String> { + use std::fmt::Write; + + let (i, major) = expected("a version number", i64)(i)?; + let (i, _) = expect_tag_no_case(".")(i)?; + let (i, minor) = expected("a version number", i64)(i)?; + let (i, _) = expect_tag_no_case(".")(i)?; + let (i, patch) = expected("a version number", i64)(i)?; + + let mut res = String::new(); + // Writing into a string can never error. + write!(&mut res, "{major}.{minor}.{patch}").unwrap(); + Ok((i, res)) } #[cfg(test)] -mod tests { - +mod test { use super::*; + use crate::sql::query; #[test] - fn model_count() { - let sql = "|test:1000|"; + fn ml_model_example() { + let sql = r#"ml::insurance::prediction<1.0.0>({ + age: 18, + disposable_income: "yes", + purchased_before: true + }) + "#; let res = model(sql); - let out = res.unwrap().1; - assert_eq!("|test:1000|", format!("{}", out)); - assert_eq!(out, Model::Count(String::from("test"), 1000)); + let out = res.unwrap().1.to_string(); + assert_eq!("ml::insurance::prediction<1.0.0>({ age: 18, disposable_income: 'yes', purchased_before: true })",out); } #[test] - fn model_range() { - let sql = "|test:1..1000|"; - let res = model(sql); - let out = res.unwrap().1; - assert_eq!("|test:1..1000|", format!("{}", out)); - assert_eq!(out, Model::Range(String::from("test"), 1, 1000)); + fn ml_model_example_in_select() { + let sql = r" + SELECT + name, + age, + ml::insurance::prediction<1.0.0>({ + age: age, + disposable_income: math::round(income), + purchased_before: array::len(->purchased->property) > 0, + }) AS likely_to_buy FROM person:tobie; + "; + let res = query::query(sql); + let out = res.unwrap().1.to_string(); + assert_eq!( + "SELECT name, age, ml::insurance::prediction<1.0.0>({ age: age, disposable_income: math::round(income), purchased_before: array::len(->purchased->property) > 0 }) AS likely_to_buy FROM person:tobie;", + out, + ); } } diff --git a/lib/src/sql/statements/define/mod.rs b/lib/src/sql/statements/define/mod.rs index ba9d1ca6..d2bdc1c3 100644 --- a/lib/src/sql/statements/define/mod.rs +++ b/lib/src/sql/statements/define/mod.rs @@ -4,6 +4,7 @@ mod event; mod field; mod function; mod index; +mod model; mod namespace; mod param; mod scope; @@ -17,6 +18,7 @@ pub use event::{event, DefineEventStatement}; pub use field::{field, DefineFieldStatement}; pub use function::{function, DefineFunctionStatement}; pub use index::{index, DefineIndexStatement}; +pub use model::DefineModelStatement; pub use namespace::{namespace, DefineNamespaceStatement}; use nom::bytes::complete::tag_no_case; pub use param::{param, DefineParamStatement}; @@ -55,6 +57,7 @@ pub enum DefineStatement { Field(DefineFieldStatement), Index(DefineIndexStatement), User(DefineUserStatement), + MlModel(DefineModelStatement), } impl DefineStatement { @@ -83,6 +86,7 @@ impl DefineStatement { Self::Index(ref v) => v.compute(ctx, opt, txn, doc).await, Self::Analyzer(ref v) => v.compute(ctx, opt, txn, doc).await, Self::User(ref v) => v.compute(ctx, opt, txn, doc).await, + Self::MlModel(ref v) => v.compute(ctx, opt, txn, doc).await, } } } @@ -102,6 +106,7 @@ impl Display for DefineStatement { Self::Field(v) => Display::fmt(v, f), Self::Index(v) => Display::fmt(v, f), Self::Analyzer(v) => Display::fmt(v, f), + Self::MlModel(v) => Display::fmt(v, f), } } } diff --git a/lib/src/sql/statements/define/model.rs b/lib/src/sql/statements/define/model.rs new file mode 100644 index 00000000..d60ff784 --- /dev/null +++ b/lib/src/sql/statements/define/model.rs @@ -0,0 +1,59 @@ +use crate::sql::fmt::is_pretty; +use crate::sql::fmt::pretty_indent; +use crate::sql::permission::Permission; +use async_recursion::async_recursion; +use derive::Store; +use revision::revisioned; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::fmt::Write; + +use crate::{ + ctx::Context, + dbs::{Options, Transaction}, + doc::CursorDoc, + err::Error, + sql::{Ident, Strand, Value}, +}; + +#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] +#[revisioned(revision = 1)] +pub struct DefineModelStatement { + pub name: Ident, + pub version: String, + pub comment: Option, + pub permissions: Permission, +} + +impl fmt::Display for DefineModelStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DEFINE MODEL ml::{}<{}>", self.name, self.version)?; + if let Some(comment) = self.comment.as_ref() { + write!(f, "COMMENT {}", comment)?; + } + if !self.permissions.is_full() { + let _indent = if is_pretty() { + Some(pretty_indent()) + } else { + f.write_char(' ')?; + None + }; + write!(f, "PERMISSIONS {}", self.permissions)?; + } + Ok(()) + } +} + +impl DefineModelStatement { + #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] + #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))] + pub(crate) async fn compute( + &self, + _ctx: &Context<'_>, + _opt: &Options, + _txn: &Transaction, + _doc: Option<&'async_recursion CursorDoc<'_>>, + ) -> Result { + Err(Error::Unimplemented("Ml model definition not yet implemented".to_string())) + } +} diff --git a/lib/src/sql/statements/select.rs b/lib/src/sql/statements/select.rs index 7e6aef70..03dfbda9 100644 --- a/lib/src/sql/statements/select.rs +++ b/lib/src/sql/statements/select.rs @@ -105,7 +105,7 @@ impl SelectStatement { Value::Thing(v) => i.ingest(Iterable::Thing(v)), Value::Range(v) => i.ingest(Iterable::Range(*v)), Value::Edges(v) => i.ingest(Iterable::Edges(*v)), - Value::Model(v) => { + Value::Mock(v) => { for v in v { i.ingest(Iterable::Thing(v)); } @@ -118,7 +118,7 @@ impl SelectStatement { } Value::Thing(v) => i.ingest(Iterable::Thing(v)), Value::Edges(v) => i.ingest(Iterable::Edges(*v)), - Value::Model(v) => { + Value::Mock(v) => { for v in v { i.ingest(Iterable::Thing(v)); } diff --git a/lib/src/sql/value/serde/ser/model/mod.rs b/lib/src/sql/value/serde/ser/mock/mod.rs similarity index 66% rename from lib/src/sql/value/serde/ser/model/mod.rs rename to lib/src/sql/value/serde/ser/mock/mod.rs index 8d5fc3ac..5f5ed6da 100644 --- a/lib/src/sql/value/serde/ser/model/mod.rs +++ b/lib/src/sql/value/serde/ser/mock/mod.rs @@ -1,6 +1,6 @@ use crate::err::Error; use crate::sql::value::serde::ser; -use crate::sql::Model; +use crate::sql::Mock; use ser::Serializer as _; use serde::ser::Error as _; use serde::ser::Impossible; @@ -9,18 +9,18 @@ use serde::ser::Serialize; pub(super) struct Serializer; impl ser::Serializer for Serializer { - type Ok = Model; + type Ok = Mock; type Error = Error; - type SerializeSeq = Impossible; - type SerializeTuple = Impossible; - type SerializeTupleStruct = Impossible; - type SerializeTupleVariant = SerializeModel; - type SerializeMap = Impossible; - type SerializeStruct = Impossible; - type SerializeStructVariant = Impossible; + type SerializeSeq = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = SerializeMock; + type SerializeMap = Impossible; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; - const EXPECTED: &'static str = "an enum `Model`"; + const EXPECTED: &'static str = "an enum `Mock`"; fn serialize_tuple_variant( self, @@ -36,14 +36,14 @@ impl ser::Serializer for Serializer { return Err(Error::custom(format!("unexpected tuple variant `{name}::{variant}`"))); } }; - Ok(SerializeModel { + Ok(SerializeMock { inner, index: 0, }) } } -pub(super) struct SerializeModel { +pub(super) struct SerializeMock { index: usize, inner: Inner, } @@ -53,8 +53,8 @@ enum Inner { Range(Option, Option, Option), } -impl serde::ser::SerializeTupleVariant for SerializeModel { - type Ok = Model; +impl serde::ser::SerializeTupleVariant for SerializeMock { + type Ok = Mock; type Error = Error; fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> @@ -76,9 +76,7 @@ impl serde::ser::SerializeTupleVariant for SerializeModel { Inner::Count(..) => "Count", Inner::Range(..) => "Range", }; - return Err(Error::custom(format!( - "unexpected `Model::{variant}` index `{index}`" - ))); + return Err(Error::custom(format!("unexpected `Mock::{variant}` index `{index}`"))); } } self.index += 1; @@ -87,9 +85,9 @@ impl serde::ser::SerializeTupleVariant for SerializeModel { fn end(self) -> Result { match self.inner { - Inner::Count(Some(one), Some(two)) => Ok(Model::Count(one, two)), - Inner::Range(Some(one), Some(two), Some(three)) => Ok(Model::Range(one, two, three)), - _ => Err(Error::custom("`Model` missing required value(s)")), + Inner::Count(Some(one), Some(two)) => Ok(Mock::Count(one, two)), + Inner::Range(Some(one), Some(two), Some(three)) => Ok(Mock::Range(one, two, three)), + _ => Err(Error::custom("`Mock` missing required value(s)")), } } } @@ -101,14 +99,14 @@ mod tests { #[test] fn count() { - let model = Model::Count(Default::default(), Default::default()); + let model = Mock::Count(Default::default(), Default::default()); let serialized = model.serialize(Serializer.wrap()).unwrap(); assert_eq!(model, serialized); } #[test] fn range() { - let model = Model::Range(Default::default(), 1, 2); + let model = Mock::Range(Default::default(), 1, 2); let serialized = model.serialize(Serializer.wrap()).unwrap(); assert_eq!(model, serialized); } diff --git a/lib/src/sql/value/serde/ser/mod.rs b/lib/src/sql/value/serde/ser/mod.rs index 0fd71482..4e2db038 100644 --- a/lib/src/sql/value/serde/ser/mod.rs +++ b/lib/src/sql/value/serde/ser/mod.rs @@ -30,7 +30,7 @@ mod index; mod kind; mod language; mod limit; -mod model; +mod mock; mod number; mod operator; mod order; diff --git a/lib/src/sql/value/serde/ser/value/mod.rs b/lib/src/sql/value/serde/ser/value/mod.rs index f647f76d..dc2f0c73 100644 --- a/lib/src/sql/value/serde/ser/value/mod.rs +++ b/lib/src/sql/value/serde/ser/value/mod.rs @@ -27,7 +27,7 @@ use ser::cast::SerializeCast; use ser::edges::SerializeEdges; use ser::expression::SerializeExpression; use ser::function::SerializeFunction; -use ser::model::SerializeModel; +use ser::mock::SerializeMock; use ser::range::SerializeRange; use ser::thing::SerializeThing; use ser::Serializer as _; @@ -323,14 +323,9 @@ impl ser::Serializer for Serializer { len: usize, ) -> Result { Ok(match name { - sql::model::TOKEN => { - SerializeTupleVariant::Model(ser::model::Serializer.serialize_tuple_variant( - name, - variant_index, - variant, - len, - )?) - } + sql::mock::TOKEN => SerializeTupleVariant::Model( + ser::mock::Serializer.serialize_tuple_variant(name, variant_index, variant, len)?, + ), sql::function::TOKEN => { SerializeTupleVariant::Function(ser::function::Serializer.serialize_tuple_variant( name, @@ -462,7 +457,7 @@ impl serde::ser::SerializeMap for SerializeMap { } pub(super) enum SerializeTupleVariant { - Model(SerializeModel), + Model(SerializeMock), Function(SerializeFunction), Unknown { variant: &'static str, @@ -517,7 +512,7 @@ impl serde::ser::SerializeTupleVariant for SerializeTupleVariant { fn end(self) -> Result { match self { - Self::Model(model) => Ok(Value::Model(model.end()?)), + Self::Model(model) => Ok(Value::Mock(model.end()?)), Self::Function(function) => Ok(Value::Function(Box::new(function.end()?))), Self::Unknown { variant, @@ -789,9 +784,9 @@ mod tests { #[test] fn model() { - let model = Model::Count("foo".to_owned(), Default::default()); + let model = Mock::Count("foo".to_owned(), Default::default()); let value = to_value(&model).unwrap(); - let expected = Value::Model(model); + let expected = Value::Mock(model); assert_eq!(value, expected); assert_eq!(expected, to_value(&expected).unwrap()); } diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 478b6782..114fcba0 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -27,6 +27,7 @@ use crate::sql::geometry::{geometry, Geometry}; use crate::sql::id::{Gen, Id}; use crate::sql::idiom::{self, reparse_idiom_start, Idiom}; use crate::sql::kind::Kind; +use crate::sql::mock::{mock, Mock}; use crate::sql::model::{model, Model}; use crate::sql::number::{number, Number}; use crate::sql::object::{key, object, Object}; @@ -140,7 +141,7 @@ pub enum Value { Param(Param), Idiom(Idiom), Table(Table), - Model(Model), + Mock(Mock), Regex(Regex), Cast(Box), Block(Box), @@ -153,6 +154,7 @@ pub enum Value { Subquery(Box), Expression(Box), Query(Query), + MlModel(Box), // Add new variants here } @@ -189,9 +191,9 @@ impl From for Value { } } -impl From for Value { - fn from(v: Model) -> Self { - Value::Model(v) +impl From for Value { + fn from(v: Mock) -> Self { + Value::Mock(v) } } @@ -303,6 +305,12 @@ impl From for Value { } } +impl From for Value { + fn from(v: Model) -> Self { + Value::MlModel(Box::new(v)) + } +} + impl From for Value { fn from(v: Subquery) -> Self { Value::Subquery(Box::new(v)) @@ -854,9 +862,9 @@ impl Value { matches!(self, Value::Thing(_)) } - /// Check if this Value is a Model - pub fn is_model(&self) -> bool { - matches!(self, Value::Model(_)) + /// Check if this Value is a Mock + pub fn is_mock(&self) -> bool { + matches!(self, Value::Mock(_)) } /// Check if this Value is a Range @@ -1055,7 +1063,8 @@ impl Value { pub fn can_start_idiom(&self) -> bool { match self { Value::Function(x) => !x.is_script(), - Value::Subquery(_) + Value::MlModel(_) + | Value::Subquery(_) | Value::Constant(_) | Value::Datetime(_) | Value::Duration(_) @@ -2526,10 +2535,11 @@ impl fmt::Display for Value { Value::Edges(v) => write!(f, "{v}"), Value::Expression(v) => write!(f, "{v}"), Value::Function(v) => write!(f, "{v}"), + Value::MlModel(v) => write!(f, "{v}"), Value::Future(v) => write!(f, "{v}"), Value::Geometry(v) => write!(f, "{v}"), Value::Idiom(v) => write!(f, "{v}"), - Value::Model(v) => write!(f, "{v}"), + Value::Mock(v) => write!(f, "{v}"), Value::Number(v) => write!(f, "{v}"), Value::Object(v) => write!(f, "{v}"), Value::Param(v) => write!(f, "{v}"), @@ -2556,6 +2566,7 @@ impl Value { Value::Function(v) => { v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable) } + Value::MlModel(m) => m.parameters.writeable(), Value::Subquery(v) => v.writeable(), Value::Expression(v) => v.writeable(), _ => false, @@ -2586,6 +2597,7 @@ impl Value { Value::Future(v) => v.compute(ctx, opt, txn, doc).await, Value::Constant(v) => v.compute(ctx, opt, txn, doc).await, Value::Function(v) => v.compute(ctx, opt, txn, doc).await, + Value::MlModel(v) => v.compute(ctx, opt, txn, doc).await, Value::Subquery(v) => v.compute(ctx, opt, txn, doc).await, Value::Expression(v) => v.compute(ctx, opt, txn, doc).await, _ => Ok(self.to_owned()), @@ -2741,7 +2753,7 @@ pub fn single(i: &str) -> IResult<&str, Value> { alt(( into(future), into(cast), - function_or_const, + path_like, into(geometry), into(subquery), into(datetime), @@ -2756,7 +2768,7 @@ pub fn single(i: &str) -> IResult<&str, Value> { into(block), into(param), into(regex), - into(model), + into(mock), into(edges), into(range), into(thing), @@ -2780,7 +2792,7 @@ pub fn select_start(i: &str) -> IResult<&str, Value> { alt(( into(future), into(cast), - function_or_const, + path_like, into(geometry), into(subquery), into(datetime), @@ -2792,7 +2804,7 @@ pub fn select_start(i: &str) -> IResult<&str, Value> { into(block), into(param), into(regex), - into(model), + into(mock), into(edges), into(range), into(thing), @@ -2803,8 +2815,9 @@ pub fn select_start(i: &str) -> IResult<&str, Value> { reparse_idiom_start(v, i) } -pub fn function_or_const(i: &str) -> IResult<&str, Value> { - alt((into(defined_function), |i| { +/// A path like production: Constants, predefined functions, user defined functions and ml models. +pub fn path_like(i: &str) -> IResult<&str, Value> { + alt((into(defined_function), into(model), |i| { let (i, v) = builtin_name(i)?; match v { builtin::BuiltinName::Constant(x) => Ok((i, x.into())), @@ -2841,14 +2854,14 @@ pub fn what(i: &str) -> IResult<&str, Value> { let _diving = crate::sql::parser::depth::dive(i)?; let (i, v) = alt(( into(idiom::multi_without_start), - function_or_const, + path_like, into(subquery), into(datetime), into(duration), into(future), into(block), into(param), - into(model), + into(mock), into(edges), into(range), into(thing), @@ -2995,7 +3008,7 @@ mod tests { assert_eq!(24, std::mem::size_of::()); assert_eq!(24, std::mem::size_of::()); assert_eq!(56, std::mem::size_of::()); - assert_eq!(40, std::mem::size_of::()); + assert_eq!(40, std::mem::size_of::()); assert_eq!(32, std::mem::size_of::()); assert_eq!(8, std::mem::size_of::>()); assert_eq!(8, std::mem::size_of::>());