use crate::ctx::Context; use crate::dbs::{Options, Transaction}; use crate::doc::CursorDoc; use crate::err::Error; use crate::sql::value::Value; use derive::Store; use revision::revisioned; use serde::{Deserialize, Serialize}; use std::fmt; #[cfg(any(feature = "ml", feature = "ml2"))] use crate::iam::Action; #[cfg(any(feature = "ml", feature = "ml2"))] use crate::ml::execution::compute::ModelComputation; #[cfg(any(feature = "ml", feature = "ml2"))] use crate::ml::storage::surml_file::SurMlFile; #[cfg(any(feature = "ml", feature = "ml2"))] use crate::sql::Permission; #[cfg(any(feature = "ml", feature = "ml2"))] use futures::future::try_join_all; #[cfg(any(feature = "ml", feature = "ml2"))] use std::collections::HashMap; #[cfg(any(feature = "ml", feature = "ml2"))] const ARGUMENTS: &str = "The model expects 1 argument. The argument can be either a number, an object, or an array of numbers."; #[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[revisioned(revision = 1)] pub struct Model { pub name: String, pub version: String, pub args: Vec, } impl fmt::Display for Model { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ml::{}<{}>(", self.name, self.version)?; for (idx, p) in self.args.iter().enumerate() { if idx != 0 { write!(f, ",")?; } write!(f, "{}", p)?; } write!(f, ")") } } impl Model { #[cfg(any(feature = "ml", feature = "ml2"))] pub(crate) async fn compute( &self, ctx: &Context<'_>, opt: &Options, txn: &Transaction, doc: Option<&CursorDoc<'_>>, ) -> Result { // Ensure futures are run let opt = &opt.new_with_futures(true); // Get the full name of this model let name = format!("ml::{}", self.name); // Check this function is allowed ctx.check_allowed_function(name.as_str())?; // Get the model definition let val = { // Claim transaction let mut run = txn.lock().await; // Get the function definition run.get_and_cache_db_model(opt.ns(), opt.db(), &self.name, &self.version).await? }; // Calculate the model path let path = format!( "ml/{}/{}/{}-{}-{}.surml", opt.ns(), opt.db(), self.name, self.version, val.hash ); // Check permissions if opt.check_perms(Action::View) { match &val.permissions { Permission::Full => (), Permission::None => { return Err(Error::FunctionPermissions { name: self.name.to_owned(), }) } Permission::Specific(e) => { // Disable permissions let opt = &opt.new_with_perms(false); // Process the PERMISSION clause if !e.compute(ctx, opt, txn, doc).await?.is_truthy() { return Err(Error::FunctionPermissions { name: self.name.to_owned(), }); } } } } // Compute the function arguments let mut args = try_join_all(self.args.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?; // Check the minimum argument length if args.len() != 1 { return Err(Error::InvalidArguments { name: format!("ml::{}<{}>", self.name, self.version), message: ARGUMENTS.into(), }); } // Take the first and only specified argument match args.swap_remove(0) { // Perform bufferered compute Value::Object(v) => { // Compute the model function arguments let mut args = v .into_iter() .map(|(k, v)| Ok((k, Value::try_into(v)?))) .collect::, Error>>() .map_err(|_| Error::InvalidArguments { name: format!("ml::{}<{}>", self.name, self.version), message: ARGUMENTS.into(), })?; // Get the model file as bytes let bytes = crate::obs::get(&path).await?; // Run the compute in a blocking task let outcome = tokio::task::spawn_blocking(move || { let mut file = SurMlFile::from_bytes(bytes).unwrap(); let compute_unit = ModelComputation { surml_file: &mut file, }; compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation) }) .await .unwrap()?; // Convert the output to a value Ok(outcome[0].into()) } // Perform raw compute Value::Number(v) => { // Compute the model function arguments let args: f32 = v.try_into().map_err(|_| Error::InvalidArguments { name: format!("ml::{}<{}>", self.name, self.version), message: ARGUMENTS.into(), })?; // Get the model file as bytes let bytes = crate::obs::get(&path).await?; // Convert the argument to a tensor let tensor = ndarray::arr1::(&[args]).into_dyn(); // Run the compute in a blocking task let outcome = tokio::task::spawn_blocking(move || { let mut file = SurMlFile::from_bytes(bytes).unwrap(); let compute_unit = ModelComputation { surml_file: &mut file, }; compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) }) .await .unwrap()?; // Convert the output to a value Ok(outcome[0].into()) } // Perform raw compute Value::Array(v) => { // Compute the model function arguments let args = v .into_iter() .map(Value::try_into) .collect::, Error>>() .map_err(|_| Error::InvalidArguments { name: format!("ml::{}<{}>", self.name, self.version), message: ARGUMENTS.into(), })?; // Get the model file as bytes let bytes = crate::obs::get(&path).await?; // Convert the argument to a tensor let tensor = ndarray::arr1::(&args).into_dyn(); // Run the compute in a blocking task let outcome = tokio::task::spawn_blocking(move || { let mut file = SurMlFile::from_bytes(bytes).unwrap(); let compute_unit = ModelComputation { surml_file: &mut file, }; compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation) }) .await .unwrap()?; // Convert the output to a value Ok(outcome[0].into()) } // _ => Err(Error::InvalidArguments { name: format!("ml::{}<{}>", self.name, self.version), message: ARGUMENTS.into(), }), } } #[cfg(not(any(feature = "ml", feature = "ml2")))] pub(crate) async fn compute( &self, _ctx: &Context<'_>, _opt: &Options, _txn: &Transaction, _doc: Option<&CursorDoc<'_>>, ) -> Result { Err(Error::InvalidModel { message: String::from("Machine learning computation is not enabled."), }) } }