From 66810fd36bac5198e2935f912b8eff7546f0aff1 Mon Sep 17 00:00:00 2001 From: Mees Delzenne Date: Wed, 25 Oct 2023 13:38:03 +0200 Subject: [PATCH] Allow a model to take multiple arguments (#2890) --- lib/src/sql/model.rs | 31 +++++++++++++++++++++++-------- lib/src/sql/value/value.rs | 2 +- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/lib/src/sql/model.rs b/lib/src/sql/model.rs index be056ad0..3d5802b5 100644 --- a/lib/src/sql/model.rs +++ b/lib/src/sql/model.rs @@ -19,9 +19,9 @@ use crate::{ }; use super::{ - common::{closechevron, closeparentheses, openchevron, openparentheses, val_char}, + common::{closechevron, closeparentheses, commas, openchevron, openparentheses, val_char}, error::{expect_tag_no_case, expected}, - util::expect_delimited, + util::{delimited_list1, expect_delimited}, value::value, }; @@ -30,12 +30,19 @@ use super::{ pub struct Model { pub name: String, pub version: String, - pub parameters: Value, + pub args: Vec, } impl fmt::Display for Model { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ml::{}<{}>({})", self.name, self.version, self.parameters) + write!(f, "ml::{}<{}>(", self.name, self.version)?; + for (idx, p) in self.args.iter().enumerate() { + if idx != 0 { + write!(f, ",")?; + } + write!(f, "{}", p)?; + } + write!(f, ")") } } @@ -62,9 +69,9 @@ pub fn model(i: &str) -> IResult<&str, Model> { let (i, version) = expected("a version", expect_delimited(openchevron, version, closechevron))(i)?; - let (i, parameters) = expected( - "model parameters", - expect_delimited(openparentheses, value, closeparentheses), + let (i, args) = expected( + "model arguments", + delimited_list1(openparentheses, commas, value, closeparentheses), )(i)?; Ok(( @@ -72,7 +79,7 @@ pub fn model(i: &str) -> IResult<&str, Model> { Model { name: name.to_owned(), version, - parameters, + args, }, )) })(i) @@ -130,4 +137,12 @@ mod test { out, ); } + + #[test] + fn ml_model_with_mutiple_arguments() { + let sql = "ml::insurance::prediction<1.0.0>(1,2,3,4,);"; + let res = query::query(sql); + let out = res.unwrap().1.to_string(); + assert_eq!("ml::insurance::prediction<1.0.0>(1,2,3,4);", out,); + } } diff --git a/lib/src/sql/value/value.rs b/lib/src/sql/value/value.rs index 95db24c7..839b5a38 100644 --- a/lib/src/sql/value/value.rs +++ b/lib/src/sql/value/value.rs @@ -2582,7 +2582,7 @@ impl Value { Value::Function(v) => { v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable) } - Value::MlModel(m) => m.parameters.writeable(), + Value::MlModel(m) => m.args.iter().any(Value::writeable), Value::Subquery(v) => v.writeable(), Value::Expression(v) => v.writeable(), _ => false,