Allow a model to take multiple arguments (#2890)

This commit is contained in:
Mees Delzenne 2023-10-25 13:38:03 +02:00 committed by GitHub
parent cfdd7c195c
commit 66810fd36b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 9 deletions

View file

@ -19,9 +19,9 @@ use crate::{
};
use super::{
common::{closechevron, closeparentheses, openchevron, openparentheses, val_char},
common::{closechevron, closeparentheses, commas, openchevron, openparentheses, val_char},
error::{expect_tag_no_case, expected},
util::expect_delimited,
util::{delimited_list1, expect_delimited},
value::value,
};
@ -30,12 +30,19 @@ use super::{
pub struct Model {
pub name: String,
pub version: String,
pub parameters: Value,
pub args: Vec<Value>,
}
impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ml::{}<{}>({})", self.name, self.version, self.parameters)
write!(f, "ml::{}<{}>(", self.name, self.version)?;
for (idx, p) in self.args.iter().enumerate() {
if idx != 0 {
write!(f, ",")?;
}
write!(f, "{}", p)?;
}
write!(f, ")")
}
}
@ -62,9 +69,9 @@ pub fn model(i: &str) -> IResult<&str, Model> {
let (i, version) =
expected("a version", expect_delimited(openchevron, version, closechevron))(i)?;
let (i, parameters) = expected(
"model parameters",
expect_delimited(openparentheses, value, closeparentheses),
let (i, args) = expected(
"model arguments",
delimited_list1(openparentheses, commas, value, closeparentheses),
)(i)?;
Ok((
@ -72,7 +79,7 @@ pub fn model(i: &str) -> IResult<&str, Model> {
Model {
name: name.to_owned(),
version,
parameters,
args,
},
))
})(i)
@ -130,4 +137,12 @@ mod test {
out,
);
}
#[test]
fn ml_model_with_mutiple_arguments() {
let sql = "ml::insurance::prediction<1.0.0>(1,2,3,4,);";
let res = query::query(sql);
let out = res.unwrap().1.to_string();
assert_eq!("ml::insurance::prediction<1.0.0>(1,2,3,4);", out,);
}
}

View file

@ -2582,7 +2582,7 @@ impl Value {
Value::Function(v) => {
v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable)
}
Value::MlModel(m) => m.parameters.writeable(),
Value::MlModel(m) => m.args.iter().any(Value::writeable),
Value::Subquery(v) => v.writeable(),
Value::Expression(v) => v.writeable(),
_ => false,