run rpc method (#3766)
Co-authored-by: Micha de Vries <mt.dev@hotmail.com>
This commit is contained in:
parent
5f1b55f7d2
commit
e93649503c
6 changed files with 144 additions and 5 deletions
|
@ -6,6 +6,7 @@ use super::rpc_error::RpcError;
|
||||||
pub trait Take {
|
pub trait Take {
|
||||||
fn needs_one(self) -> Result<Value, RpcError>;
|
fn needs_one(self) -> Result<Value, RpcError>;
|
||||||
fn needs_two(self) -> Result<(Value, Value), RpcError>;
|
fn needs_two(self) -> Result<(Value, Value), RpcError>;
|
||||||
|
fn needs_three(self) -> Result<(Value, Value, Value), RpcError>;
|
||||||
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>;
|
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>;
|
||||||
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>;
|
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>;
|
||||||
}
|
}
|
||||||
|
@ -34,9 +35,20 @@ impl Take for Array {
|
||||||
(_, _) => Ok((Value::None, Value::None)),
|
(_, _) => Ok((Value::None, Value::None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// Convert the array to three arguments
|
||||||
|
fn needs_three(self) -> Result<(Value, Value, Value), RpcError> {
|
||||||
|
if self.len() != 3 {
|
||||||
|
return Err(RpcError::InvalidParams);
|
||||||
|
}
|
||||||
|
let mut x = self.into_iter();
|
||||||
|
match (x.next(), x.next(), x.next()) {
|
||||||
|
(Some(a), Some(b), Some(c)) => Ok((a, b, c)),
|
||||||
|
_ => Err(RpcError::InvalidParams),
|
||||||
|
}
|
||||||
|
}
|
||||||
/// Convert the array to two arguments
|
/// Convert the array to two arguments
|
||||||
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError> {
|
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError> {
|
||||||
if self.is_empty() && self.len() > 2 {
|
if self.is_empty() || self.len() > 2 {
|
||||||
return Err(RpcError::InvalidParams);
|
return Err(RpcError::InvalidParams);
|
||||||
}
|
}
|
||||||
let mut x = self.into_iter();
|
let mut x = self.into_iter();
|
||||||
|
@ -48,7 +60,7 @@ impl Take for Array {
|
||||||
}
|
}
|
||||||
/// Convert the array to three arguments
|
/// Convert the array to three arguments
|
||||||
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError> {
|
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError> {
|
||||||
if self.is_empty() && self.len() > 3 {
|
if self.is_empty() || self.len() > 3 {
|
||||||
return Err(RpcError::InvalidParams);
|
return Err(RpcError::InvalidParams);
|
||||||
}
|
}
|
||||||
let mut x = self.into_iter();
|
let mut x = self.into_iter();
|
||||||
|
|
|
@ -22,6 +22,7 @@ pub enum Method {
|
||||||
Version,
|
Version,
|
||||||
Query,
|
Query,
|
||||||
Relate,
|
Relate,
|
||||||
|
Run,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Method {
|
impl Method {
|
||||||
|
@ -51,6 +52,7 @@ impl Method {
|
||||||
"version" => Self::Version,
|
"version" => Self::Version,
|
||||||
"query" => Self::Query,
|
"query" => Self::Query,
|
||||||
"relate" => Self::Relate,
|
"relate" => Self::Relate,
|
||||||
|
"run" => Self::Run,
|
||||||
_ => Self::Unknown,
|
_ => Self::Unknown,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,6 +83,7 @@ impl Method {
|
||||||
Self::Version => "version",
|
Self::Version => "version",
|
||||||
Self::Query => "query",
|
Self::Query => "query",
|
||||||
Self::Relate => "relate",
|
Self::Relate => "relate",
|
||||||
|
Self::Run => "run",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,7 +108,7 @@ impl Method {
|
||||||
| Method::Patch | Method::Delete
|
| Method::Patch | Method::Delete
|
||||||
| Method::Version
|
| Method::Version
|
||||||
| Method::Query | Method::Relate
|
| Method::Query | Method::Relate
|
||||||
| Method::Unknown
|
| Method::Run | Method::Unknown
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
||||||
dbs::{QueryType, Response, Session},
|
dbs::{QueryType, Response, Session},
|
||||||
kvs::Datastore,
|
kvs::Datastore,
|
||||||
rpc::args::Take,
|
rpc::args::Take,
|
||||||
sql::{Array, Value},
|
sql::{Array, Function, Model, Statement, Strand, Value},
|
||||||
};
|
};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ pub trait RpcContext {
|
||||||
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
|
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
|
||||||
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
|
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
|
||||||
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
|
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
|
||||||
|
Method::Run => self.run(params).await.map(Into::into).map_err(Into::into),
|
||||||
Method::Unknown => Err(RpcError::MethodNotFound),
|
Method::Unknown => Err(RpcError::MethodNotFound),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -77,6 +78,7 @@ pub trait RpcContext {
|
||||||
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
|
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
|
||||||
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
|
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
|
||||||
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
|
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
|
||||||
|
Method::Run => self.run(params).await.map(Into::into).map_err(Into::into),
|
||||||
Method::Unknown => Err(RpcError::MethodNotFound),
|
Method::Unknown => Err(RpcError::MethodNotFound),
|
||||||
_ => Err(RpcError::MethodNotFound),
|
_ => Err(RpcError::MethodNotFound),
|
||||||
}
|
}
|
||||||
|
@ -481,7 +483,7 @@ pub trait RpcContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
// Methods for querying
|
// Methods for relating
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
async fn relate(&self, _params: Array) -> Result<impl Into<Data>, RpcError> {
|
async fn relate(&self, _params: Array) -> Result<impl Into<Data>, RpcError> {
|
||||||
|
@ -489,6 +491,48 @@ pub trait RpcContext {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for running functions
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn run(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
|
||||||
|
let Ok((Value::Strand(Strand(func_name)), version, args)) = params.needs_one_two_or_three()
|
||||||
|
else {
|
||||||
|
return Err(RpcError::InvalidParams);
|
||||||
|
};
|
||||||
|
|
||||||
|
let version = match version {
|
||||||
|
Value::Strand(Strand(v)) => Some(v),
|
||||||
|
Value::None | Value::Null => None,
|
||||||
|
_ => return Err(RpcError::InvalidParams),
|
||||||
|
};
|
||||||
|
|
||||||
|
let args = match args {
|
||||||
|
Value::Array(Array(arr)) => arr,
|
||||||
|
Value::None | Value::Null => vec![],
|
||||||
|
_ => return Err(RpcError::InvalidParams),
|
||||||
|
};
|
||||||
|
|
||||||
|
let func: Value = match &func_name[0..4] {
|
||||||
|
"fn::" => Function::Custom(func_name.chars().skip(4).collect(), args).into(),
|
||||||
|
"ml::" => Model {
|
||||||
|
name: func_name.chars().skip(4).collect(),
|
||||||
|
version: version.ok_or(RpcError::InvalidParams)?,
|
||||||
|
args,
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
_ => Function::Normal(func_name, args).into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut res = self
|
||||||
|
.kvs()
|
||||||
|
.process(Statement::Value(func).into(), self.session(), Some(self.vars().clone()))
|
||||||
|
.await?;
|
||||||
|
let out = res.remove(0).result?;
|
||||||
|
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
// Private methods
|
// Private methods
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
|
@ -29,6 +29,12 @@ impl From<RemoveStatement> for Query {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Statement> for Query {
|
||||||
|
fn from(s: Statement) -> Self {
|
||||||
|
Query(Statements(vec![s]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Deref for Query {
|
impl Deref for Query {
|
||||||
type Target = Vec<Statement>;
|
type Target = Vec<Statement>;
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
|
|
|
@ -429,4 +429,32 @@ impl Socket {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub async fn send_message_run(
|
||||||
|
&mut self,
|
||||||
|
fn_name: &str,
|
||||||
|
version: Option<&str>,
|
||||||
|
args: Vec<serde_json::Value>,
|
||||||
|
) -> Result<serde_json::Value> {
|
||||||
|
// Send message and receive response
|
||||||
|
let msg = self.send_request("run", json!([fn_name, version, args])).await?;
|
||||||
|
// Check response message structure
|
||||||
|
match msg.as_object() {
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||||
|
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||||
|
}
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||||
|
.get("result")
|
||||||
|
.ok_or(TestError::AssertionError {
|
||||||
|
message: format!(
|
||||||
|
"expected a result from the received object, got this instead: {:?}",
|
||||||
|
obj
|
||||||
|
),
|
||||||
|
})?
|
||||||
|
.to_owned()),
|
||||||
|
_ => {
|
||||||
|
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||||
|
Err(format!("unexpected response: {:?}", msg).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1463,3 +1463,49 @@ async fn session_reauthentication_expired() {
|
||||||
// Test passed
|
// Test passed
|
||||||
server.finish().unwrap();
|
server.finish().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test(tokio::test)]
|
||||||
|
async fn run_functions() {
|
||||||
|
// Setup database server
|
||||||
|
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
|
||||||
|
// Connect to WebSocket
|
||||||
|
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
|
||||||
|
// Authenticate the connection
|
||||||
|
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
|
||||||
|
// Specify a namespace and database
|
||||||
|
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
|
||||||
|
// Define function
|
||||||
|
socket
|
||||||
|
.send_message_query("DEFINE FUNCTION fn::foo() {RETURN 'fn::foo called';}")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
socket
|
||||||
|
.send_message_query(
|
||||||
|
"DEFINE FUNCTION fn::bar($val: string) {RETURN 'fn::bar called with: ' + $val;}",
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// call functions
|
||||||
|
let res = socket.send_message_run("fn::foo", None, vec![]).await.unwrap();
|
||||||
|
assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::foo called"));
|
||||||
|
let res = socket.send_message_run("fn::bar", None, vec![]).await;
|
||||||
|
assert!(res.is_err());
|
||||||
|
let res = socket.send_message_run("fn::bar", None, vec![42.into()]).await;
|
||||||
|
assert!(res.is_err());
|
||||||
|
let res = socket.send_message_run("fn::bar", None, vec!["first".into(), "second".into()]).await;
|
||||||
|
assert!(res.is_err());
|
||||||
|
let res = socket.send_message_run("fn::bar", None, vec!["string_val".into()]).await.unwrap();
|
||||||
|
assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::bar called with: string_val"));
|
||||||
|
|
||||||
|
// normal functions
|
||||||
|
let res = socket.send_message_run("math::abs", None, vec![(-42).into()]).await.unwrap();
|
||||||
|
assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(42)));
|
||||||
|
let res = socket
|
||||||
|
.send_message_run("math::max", None, vec![vec![1, 2, 3, 4, 5, 6].into()])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(6)));
|
||||||
|
|
||||||
|
// Test passed
|
||||||
|
server.finish().unwrap();
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue