run rpc method (#3766)
Co-authored-by: Micha de Vries <mt.dev@hotmail.com>
This commit is contained in:
parent
5f1b55f7d2
commit
e93649503c
6 changed files with 144 additions and 5 deletions
|
@ -6,6 +6,7 @@ use super::rpc_error::RpcError;
|
|||
pub trait Take {
|
||||
fn needs_one(self) -> Result<Value, RpcError>;
|
||||
fn needs_two(self) -> Result<(Value, Value), RpcError>;
|
||||
fn needs_three(self) -> Result<(Value, Value, Value), RpcError>;
|
||||
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>;
|
||||
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>;
|
||||
}
|
||||
|
@ -34,9 +35,20 @@ impl Take for Array {
|
|||
(_, _) => Ok((Value::None, Value::None)),
|
||||
}
|
||||
}
|
||||
/// Convert the array to three arguments
|
||||
fn needs_three(self) -> Result<(Value, Value, Value), RpcError> {
|
||||
if self.len() != 3 {
|
||||
return Err(RpcError::InvalidParams);
|
||||
}
|
||||
let mut x = self.into_iter();
|
||||
match (x.next(), x.next(), x.next()) {
|
||||
(Some(a), Some(b), Some(c)) => Ok((a, b, c)),
|
||||
_ => Err(RpcError::InvalidParams),
|
||||
}
|
||||
}
|
||||
/// Convert the array to two arguments
|
||||
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError> {
|
||||
if self.is_empty() && self.len() > 2 {
|
||||
if self.is_empty() || self.len() > 2 {
|
||||
return Err(RpcError::InvalidParams);
|
||||
}
|
||||
let mut x = self.into_iter();
|
||||
|
@ -48,7 +60,7 @@ impl Take for Array {
|
|||
}
|
||||
/// Convert the array to three arguments
|
||||
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError> {
|
||||
if self.is_empty() && self.len() > 3 {
|
||||
if self.is_empty() || self.len() > 3 {
|
||||
return Err(RpcError::InvalidParams);
|
||||
}
|
||||
let mut x = self.into_iter();
|
||||
|
|
|
@ -22,6 +22,7 @@ pub enum Method {
|
|||
Version,
|
||||
Query,
|
||||
Relate,
|
||||
Run,
|
||||
}
|
||||
|
||||
impl Method {
|
||||
|
@ -51,6 +52,7 @@ impl Method {
|
|||
"version" => Self::Version,
|
||||
"query" => Self::Query,
|
||||
"relate" => Self::Relate,
|
||||
"run" => Self::Run,
|
||||
_ => Self::Unknown,
|
||||
}
|
||||
}
|
||||
|
@ -81,6 +83,7 @@ impl Method {
|
|||
Self::Version => "version",
|
||||
Self::Query => "query",
|
||||
Self::Relate => "relate",
|
||||
Self::Run => "run",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -105,7 +108,7 @@ impl Method {
|
|||
| Method::Patch | Method::Delete
|
||||
| Method::Version
|
||||
| Method::Query | Method::Relate
|
||||
| Method::Unknown
|
||||
| Method::Run | Method::Unknown
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
dbs::{QueryType, Response, Session},
|
||||
kvs::Datastore,
|
||||
rpc::args::Take,
|
||||
sql::{Array, Value},
|
||||
sql::{Array, Function, Model, Statement, Strand, Value},
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -59,6 +59,7 @@ pub trait RpcContext {
|
|||
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Run => self.run(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Unknown => Err(RpcError::MethodNotFound),
|
||||
}
|
||||
}
|
||||
|
@ -77,6 +78,7 @@ pub trait RpcContext {
|
|||
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Run => self.run(params).await.map(Into::into).map_err(Into::into),
|
||||
Method::Unknown => Err(RpcError::MethodNotFound),
|
||||
_ => Err(RpcError::MethodNotFound),
|
||||
}
|
||||
|
@ -481,7 +483,7 @@ pub trait RpcContext {
|
|||
}
|
||||
|
||||
// ------------------------------
|
||||
// Methods for querying
|
||||
// Methods for relating
|
||||
// ------------------------------
|
||||
|
||||
async fn relate(&self, _params: Array) -> Result<impl Into<Data>, RpcError> {
|
||||
|
@ -489,6 +491,48 @@ pub trait RpcContext {
|
|||
out
|
||||
}
|
||||
|
||||
// ------------------------------
|
||||
// Methods for running functions
|
||||
// ------------------------------
|
||||
|
||||
async fn run(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
|
||||
let Ok((Value::Strand(Strand(func_name)), version, args)) = params.needs_one_two_or_three()
|
||||
else {
|
||||
return Err(RpcError::InvalidParams);
|
||||
};
|
||||
|
||||
let version = match version {
|
||||
Value::Strand(Strand(v)) => Some(v),
|
||||
Value::None | Value::Null => None,
|
||||
_ => return Err(RpcError::InvalidParams),
|
||||
};
|
||||
|
||||
let args = match args {
|
||||
Value::Array(Array(arr)) => arr,
|
||||
Value::None | Value::Null => vec![],
|
||||
_ => return Err(RpcError::InvalidParams),
|
||||
};
|
||||
|
||||
let func: Value = match &func_name[0..4] {
|
||||
"fn::" => Function::Custom(func_name.chars().skip(4).collect(), args).into(),
|
||||
"ml::" => Model {
|
||||
name: func_name.chars().skip(4).collect(),
|
||||
version: version.ok_or(RpcError::InvalidParams)?,
|
||||
args,
|
||||
}
|
||||
.into(),
|
||||
_ => Function::Normal(func_name, args).into(),
|
||||
};
|
||||
|
||||
let mut res = self
|
||||
.kvs()
|
||||
.process(Statement::Value(func).into(), self.session(), Some(self.vars().clone()))
|
||||
.await?;
|
||||
let out = res.remove(0).result?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// ------------------------------
|
||||
// Private methods
|
||||
// ------------------------------
|
||||
|
|
|
@ -29,6 +29,12 @@ impl From<RemoveStatement> for Query {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Statement> for Query {
|
||||
fn from(s: Statement) -> Self {
|
||||
Query(Statements(vec![s]))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Query {
|
||||
type Target = Vec<Statement>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
|
|
|
@ -429,4 +429,32 @@ impl Socket {
|
|||
}
|
||||
}
|
||||
}
|
||||
pub async fn send_message_run(
|
||||
&mut self,
|
||||
fn_name: &str,
|
||||
version: Option<&str>,
|
||||
args: Vec<serde_json::Value>,
|
||||
) -> Result<serde_json::Value> {
|
||||
// Send message and receive response
|
||||
let msg = self.send_request("run", json!([fn_name, version, args])).await?;
|
||||
// Check response message structure
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||
.get("result")
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!(
|
||||
"expected a result from the received object, got this instead: {:?}",
|
||||
obj
|
||||
),
|
||||
})?
|
||||
.to_owned()),
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1463,3 +1463,49 @@ async fn session_reauthentication_expired() {
|
|||
// Test passed
|
||||
server.finish().unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn run_functions() {
|
||||
// Setup database server
|
||||
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
|
||||
// Connect to WebSocket
|
||||
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
|
||||
// Authenticate the connection
|
||||
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
|
||||
// Specify a namespace and database
|
||||
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
|
||||
// Define function
|
||||
socket
|
||||
.send_message_query("DEFINE FUNCTION fn::foo() {RETURN 'fn::foo called';}")
|
||||
.await
|
||||
.unwrap();
|
||||
socket
|
||||
.send_message_query(
|
||||
"DEFINE FUNCTION fn::bar($val: string) {RETURN 'fn::bar called with: ' + $val;}",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// call functions
|
||||
let res = socket.send_message_run("fn::foo", None, vec![]).await.unwrap();
|
||||
assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::foo called"));
|
||||
let res = socket.send_message_run("fn::bar", None, vec![]).await;
|
||||
assert!(res.is_err());
|
||||
let res = socket.send_message_run("fn::bar", None, vec![42.into()]).await;
|
||||
assert!(res.is_err());
|
||||
let res = socket.send_message_run("fn::bar", None, vec!["first".into(), "second".into()]).await;
|
||||
assert!(res.is_err());
|
||||
let res = socket.send_message_run("fn::bar", None, vec!["string_val".into()]).await.unwrap();
|
||||
assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::bar called with: string_val"));
|
||||
|
||||
// normal functions
|
||||
let res = socket.send_message_run("math::abs", None, vec![(-42).into()]).await.unwrap();
|
||||
assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(42)));
|
||||
let res = socket
|
||||
.send_message_run("math::max", None, vec![vec![1, 2, 3, 4, 5, 6].into()])
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(6)));
|
||||
|
||||
// Test passed
|
||||
server.finish().unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue