Allow a model to take multiple arguments (#2890)
This commit is contained in:
parent
cfdd7c195c
commit
66810fd36b
2 changed files with 24 additions and 9 deletions
|
@ -19,9 +19,9 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::{
|
||||
common::{closechevron, closeparentheses, openchevron, openparentheses, val_char},
|
||||
common::{closechevron, closeparentheses, commas, openchevron, openparentheses, val_char},
|
||||
error::{expect_tag_no_case, expected},
|
||||
util::expect_delimited,
|
||||
util::{delimited_list1, expect_delimited},
|
||||
value::value,
|
||||
};
|
||||
|
||||
|
@ -30,12 +30,19 @@ use super::{
|
|||
pub struct Model {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub parameters: Value,
|
||||
pub args: Vec<Value>,
|
||||
}
|
||||
|
||||
impl fmt::Display for Model {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "ml::{}<{}>({})", self.name, self.version, self.parameters)
|
||||
write!(f, "ml::{}<{}>(", self.name, self.version)?;
|
||||
for (idx, p) in self.args.iter().enumerate() {
|
||||
if idx != 0 {
|
||||
write!(f, ",")?;
|
||||
}
|
||||
write!(f, "{}", p)?;
|
||||
}
|
||||
write!(f, ")")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,9 +69,9 @@ pub fn model(i: &str) -> IResult<&str, Model> {
|
|||
let (i, version) =
|
||||
expected("a version", expect_delimited(openchevron, version, closechevron))(i)?;
|
||||
|
||||
let (i, parameters) = expected(
|
||||
"model parameters",
|
||||
expect_delimited(openparentheses, value, closeparentheses),
|
||||
let (i, args) = expected(
|
||||
"model arguments",
|
||||
delimited_list1(openparentheses, commas, value, closeparentheses),
|
||||
)(i)?;
|
||||
|
||||
Ok((
|
||||
|
@ -72,7 +79,7 @@ pub fn model(i: &str) -> IResult<&str, Model> {
|
|||
Model {
|
||||
name: name.to_owned(),
|
||||
version,
|
||||
parameters,
|
||||
args,
|
||||
},
|
||||
))
|
||||
})(i)
|
||||
|
@ -130,4 +137,12 @@ mod test {
|
|||
out,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ml_model_with_mutiple_arguments() {
|
||||
let sql = "ml::insurance::prediction<1.0.0>(1,2,3,4,);";
|
||||
let res = query::query(sql);
|
||||
let out = res.unwrap().1.to_string();
|
||||
assert_eq!("ml::insurance::prediction<1.0.0>(1,2,3,4);", out,);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2582,7 +2582,7 @@ impl Value {
|
|||
Value::Function(v) => {
|
||||
v.is_custom() || v.is_script() || v.args().iter().any(Value::writeable)
|
||||
}
|
||||
Value::MlModel(m) => m.parameters.writeable(),
|
||||
Value::MlModel(m) => m.args.iter().any(Value::writeable),
|
||||
Value::Subquery(v) => v.writeable(),
|
||||
Value::Expression(v) => v.writeable(),
|
||||
_ => false,
|
||||
|
|
Loading…
Reference in a new issue