diff --git a/core/src/rpc/args.rs b/core/src/rpc/args.rs index f446a5bf..8e6221e0 100644 --- a/core/src/rpc/args.rs +++ b/core/src/rpc/args.rs @@ -6,6 +6,7 @@ use super::rpc_error::RpcError; pub trait Take { fn needs_one(self) -> Result; fn needs_two(self) -> Result<(Value, Value), RpcError>; + fn needs_three(self) -> Result<(Value, Value, Value), RpcError>; fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>; fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>; } @@ -34,9 +35,20 @@ impl Take for Array { (_, _) => Ok((Value::None, Value::None)), } } + /// Convert the array to three arguments + fn needs_three(self) -> Result<(Value, Value, Value), RpcError> { + if self.len() != 3 { + return Err(RpcError::InvalidParams); + } + let mut x = self.into_iter(); + match (x.next(), x.next(), x.next()) { + (Some(a), Some(b), Some(c)) => Ok((a, b, c)), + _ => Err(RpcError::InvalidParams), + } + } /// Convert the array to two arguments fn needs_one_or_two(self) -> Result<(Value, Value), RpcError> { - if self.is_empty() && self.len() > 2 { + if self.is_empty() || self.len() > 2 { return Err(RpcError::InvalidParams); } let mut x = self.into_iter(); @@ -48,7 +60,7 @@ impl Take for Array { } /// Convert the array to three arguments fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError> { - if self.is_empty() && self.len() > 3 { + if self.is_empty() || self.len() > 3 { return Err(RpcError::InvalidParams); } let mut x = self.into_iter(); diff --git a/core/src/rpc/method.rs b/core/src/rpc/method.rs index a7e7f41c..89de5bb7 100644 --- a/core/src/rpc/method.rs +++ b/core/src/rpc/method.rs @@ -22,6 +22,7 @@ pub enum Method { Version, Query, Relate, + Run, } impl Method { @@ -51,6 +52,7 @@ impl Method { "version" => Self::Version, "query" => Self::Query, "relate" => Self::Relate, + "run" => Self::Run, _ => Self::Unknown, } } @@ -81,6 +83,7 @@ impl Method { Self::Version => "version", Self::Query => "query", Self::Relate => "relate", + Self::Run => "run", } } } @@ -105,7 +108,7 @@ impl Method { | Method::Patch | Method::Delete | Method::Version | Method::Query | Method::Relate - | Method::Unknown + | Method::Run | Method::Unknown ) } } diff --git a/core/src/rpc/rpc_context.rs b/core/src/rpc/rpc_context.rs index 86298577..33b55aa9 100644 --- a/core/src/rpc/rpc_context.rs +++ b/core/src/rpc/rpc_context.rs @@ -4,7 +4,7 @@ use crate::{ dbs::{QueryType, Response, Session}, kvs::Datastore, rpc::args::Take, - sql::{Array, Value}, + sql::{Array, Function, Model, Statement, Strand, Value}, }; use uuid::Uuid; @@ -59,6 +59,7 @@ pub trait RpcContext { Method::Version => self.version(params).await.map(Into::into).map_err(Into::into), Method::Query => self.query(params).await.map(Into::into).map_err(Into::into), Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into), + Method::Run => self.run(params).await.map(Into::into).map_err(Into::into), Method::Unknown => Err(RpcError::MethodNotFound), } } @@ -77,6 +78,7 @@ pub trait RpcContext { Method::Version => self.version(params).await.map(Into::into).map_err(Into::into), Method::Query => self.query(params).await.map(Into::into).map_err(Into::into), Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into), + Method::Run => self.run(params).await.map(Into::into).map_err(Into::into), Method::Unknown => Err(RpcError::MethodNotFound), _ => Err(RpcError::MethodNotFound), } @@ -481,7 +483,7 @@ pub trait RpcContext { } // ------------------------------ - // Methods for querying + // Methods for relating // ------------------------------ async fn relate(&self, _params: Array) -> Result, RpcError> { @@ -489,6 +491,48 @@ pub trait RpcContext { out } + // ------------------------------ + // Methods for running functions + // ------------------------------ + + async fn run(&self, params: Array) -> Result, RpcError> { + let Ok((Value::Strand(Strand(func_name)), version, args)) = params.needs_one_two_or_three() + else { + return Err(RpcError::InvalidParams); + }; + + let version = match version { + Value::Strand(Strand(v)) => Some(v), + Value::None | Value::Null => None, + _ => return Err(RpcError::InvalidParams), + }; + + let args = match args { + Value::Array(Array(arr)) => arr, + Value::None | Value::Null => vec![], + _ => return Err(RpcError::InvalidParams), + }; + + let func: Value = match &func_name[0..4] { + "fn::" => Function::Custom(func_name.chars().skip(4).collect(), args).into(), + "ml::" => Model { + name: func_name.chars().skip(4).collect(), + version: version.ok_or(RpcError::InvalidParams)?, + args, + } + .into(), + _ => Function::Normal(func_name, args).into(), + }; + + let mut res = self + .kvs() + .process(Statement::Value(func).into(), self.session(), Some(self.vars().clone())) + .await?; + let out = res.remove(0).result?; + + Ok(out) + } + // ------------------------------ // Private methods // ------------------------------ diff --git a/core/src/sql/query.rs b/core/src/sql/query.rs index 6293fd13..7b18309c 100644 --- a/core/src/sql/query.rs +++ b/core/src/sql/query.rs @@ -29,6 +29,12 @@ impl From for Query { } } +impl From for Query { + fn from(s: Statement) -> Self { + Query(Statements(vec![s])) + } +} + impl Deref for Query { type Target = Vec; fn deref(&self) -> &Self::Target { diff --git a/tests/common/socket.rs b/tests/common/socket.rs index 303dad9f..89dd6a20 100644 --- a/tests/common/socket.rs +++ b/tests/common/socket.rs @@ -429,4 +429,32 @@ impl Socket { } } } + pub async fn send_message_run( + &mut self, + fn_name: &str, + version: Option<&str>, + args: Vec, + ) -> Result { + // Send message and receive response + let msg = self.send_request("run", json!([fn_name, version, args])).await?; + // Check response message structure + match msg.as_object() { + Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => { + Err(format!("unexpected error from query request: {:?}", obj.get("error")).into()) + } + Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj + .get("result") + .ok_or(TestError::AssertionError { + message: format!( + "expected a result from the received object, got this instead: {:?}", + obj + ), + })? + .to_owned()), + _ => { + error!("{:?}", msg.as_object().unwrap().keys().collect::>()); + Err(format!("unexpected response: {:?}", msg).into()) + } + } + } } diff --git a/tests/common/tests.rs b/tests/common/tests.rs index eada6be1..7030c9e1 100644 --- a/tests/common/tests.rs +++ b/tests/common/tests.rs @@ -1463,3 +1463,49 @@ async fn session_reauthentication_expired() { // Test passed server.finish().unwrap(); } + +#[test(tokio::test)] +async fn run_functions() { + // Setup database server + let (addr, mut server) = common::start_server_with_defaults().await.unwrap(); + // Connect to WebSocket + let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap(); + // Authenticate the connection + socket.send_message_signin(USER, PASS, None, None, None).await.unwrap(); + // Specify a namespace and database + socket.send_message_use(Some(NS), Some(DB)).await.unwrap(); + // Define function + socket + .send_message_query("DEFINE FUNCTION fn::foo() {RETURN 'fn::foo called';}") + .await + .unwrap(); + socket + .send_message_query( + "DEFINE FUNCTION fn::bar($val: string) {RETURN 'fn::bar called with: ' + $val;}", + ) + .await + .unwrap(); + // call functions + let res = socket.send_message_run("fn::foo", None, vec![]).await.unwrap(); + assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::foo called")); + let res = socket.send_message_run("fn::bar", None, vec![]).await; + assert!(res.is_err()); + let res = socket.send_message_run("fn::bar", None, vec![42.into()]).await; + assert!(res.is_err()); + let res = socket.send_message_run("fn::bar", None, vec!["first".into(), "second".into()]).await; + assert!(res.is_err()); + let res = socket.send_message_run("fn::bar", None, vec!["string_val".into()]).await.unwrap(); + assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::bar called with: string_val")); + + // normal functions + let res = socket.send_message_run("math::abs", None, vec![(-42).into()]).await.unwrap(); + assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(42))); + let res = socket + .send_message_run("math::max", None, vec![vec![1, 2, 3, 4, 5, 6].into()]) + .await + .unwrap(); + assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(6))); + + // Test passed + server.finish().unwrap(); +}