run rpc method (#3766)

Co-authored-by: Micha de Vries <mt.dev@hotmail.com>
This commit is contained in:
Raphael Darley 2024-03-26 15:27:08 +00:00 committed by GitHub
parent 5f1b55f7d2
commit e93649503c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 144 additions and 5 deletions

View file

@ -6,6 +6,7 @@ use super::rpc_error::RpcError;
pub trait Take {
fn needs_one(self) -> Result<Value, RpcError>;
fn needs_two(self) -> Result<(Value, Value), RpcError>;
fn needs_three(self) -> Result<(Value, Value, Value), RpcError>;
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError>;
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError>;
}
@ -34,9 +35,20 @@ impl Take for Array {
(_, _) => Ok((Value::None, Value::None)),
}
}
/// Convert the array to three arguments
fn needs_three(self) -> Result<(Value, Value, Value), RpcError> {
if self.len() != 3 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();
match (x.next(), x.next(), x.next()) {
(Some(a), Some(b), Some(c)) => Ok((a, b, c)),
_ => Err(RpcError::InvalidParams),
}
}
/// Convert the array to two arguments
fn needs_one_or_two(self) -> Result<(Value, Value), RpcError> {
if self.is_empty() && self.len() > 2 {
if self.is_empty() || self.len() > 2 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();
@ -48,7 +60,7 @@ impl Take for Array {
}
/// Convert the array to three arguments
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), RpcError> {
if self.is_empty() && self.len() > 3 {
if self.is_empty() || self.len() > 3 {
return Err(RpcError::InvalidParams);
}
let mut x = self.into_iter();

View file

@ -22,6 +22,7 @@ pub enum Method {
Version,
Query,
Relate,
Run,
}
impl Method {
@ -51,6 +52,7 @@ impl Method {
"version" => Self::Version,
"query" => Self::Query,
"relate" => Self::Relate,
"run" => Self::Run,
_ => Self::Unknown,
}
}
@ -81,6 +83,7 @@ impl Method {
Self::Version => "version",
Self::Query => "query",
Self::Relate => "relate",
Self::Run => "run",
}
}
}
@ -105,7 +108,7 @@ impl Method {
| Method::Patch | Method::Delete
| Method::Version
| Method::Query | Method::Relate
| Method::Unknown
| Method::Run | Method::Unknown
)
}
}

View file

@ -4,7 +4,7 @@ use crate::{
dbs::{QueryType, Response, Session},
kvs::Datastore,
rpc::args::Take,
sql::{Array, Value},
sql::{Array, Function, Model, Statement, Strand, Value},
};
use uuid::Uuid;
@ -59,6 +59,7 @@ pub trait RpcContext {
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
Method::Run => self.run(params).await.map(Into::into).map_err(Into::into),
Method::Unknown => Err(RpcError::MethodNotFound),
}
}
@ -77,6 +78,7 @@ pub trait RpcContext {
Method::Version => self.version(params).await.map(Into::into).map_err(Into::into),
Method::Query => self.query(params).await.map(Into::into).map_err(Into::into),
Method::Relate => self.relate(params).await.map(Into::into).map_err(Into::into),
Method::Run => self.run(params).await.map(Into::into).map_err(Into::into),
Method::Unknown => Err(RpcError::MethodNotFound),
_ => Err(RpcError::MethodNotFound),
}
@ -481,7 +483,7 @@ pub trait RpcContext {
}
// ------------------------------
// Methods for querying
// Methods for relating
// ------------------------------
async fn relate(&self, _params: Array) -> Result<impl Into<Data>, RpcError> {
@ -489,6 +491,48 @@ pub trait RpcContext {
out
}
// ------------------------------
// Methods for running functions
// ------------------------------
async fn run(&self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok((Value::Strand(Strand(func_name)), version, args)) = params.needs_one_two_or_three()
else {
return Err(RpcError::InvalidParams);
};
let version = match version {
Value::Strand(Strand(v)) => Some(v),
Value::None | Value::Null => None,
_ => return Err(RpcError::InvalidParams),
};
let args = match args {
Value::Array(Array(arr)) => arr,
Value::None | Value::Null => vec![],
_ => return Err(RpcError::InvalidParams),
};
let func: Value = match &func_name[0..4] {
"fn::" => Function::Custom(func_name.chars().skip(4).collect(), args).into(),
"ml::" => Model {
name: func_name.chars().skip(4).collect(),
version: version.ok_or(RpcError::InvalidParams)?,
args,
}
.into(),
_ => Function::Normal(func_name, args).into(),
};
let mut res = self
.kvs()
.process(Statement::Value(func).into(), self.session(), Some(self.vars().clone()))
.await?;
let out = res.remove(0).result?;
Ok(out)
}
// ------------------------------
// Private methods
// ------------------------------

View file

@ -29,6 +29,12 @@ impl From<RemoveStatement> for Query {
}
}
impl From<Statement> for Query {
fn from(s: Statement) -> Self {
Query(Statements(vec![s]))
}
}
impl Deref for Query {
type Target = Vec<Statement>;
fn deref(&self) -> &Self::Target {

View file

@ -429,4 +429,32 @@ impl Socket {
}
}
}
pub async fn send_message_run(
&mut self,
fn_name: &str,
version: Option<&str>,
args: Vec<serde_json::Value>,
) -> Result<serde_json::Value> {
// Send message and receive response
let msg = self.send_request("run", json!([fn_name, version, args])).await?;
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!(
"expected a result from the received object, got this instead: {:?}",
obj
),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
}

View file

@ -1463,3 +1463,49 @@ async fn session_reauthentication_expired() {
// Test passed
server.finish().unwrap();
}
#[test(tokio::test)]
async fn run_functions() {
// Setup database server
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
// Authenticate the connection
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
// Specify a namespace and database
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
// Define function
socket
.send_message_query("DEFINE FUNCTION fn::foo() {RETURN 'fn::foo called';}")
.await
.unwrap();
socket
.send_message_query(
"DEFINE FUNCTION fn::bar($val: string) {RETURN 'fn::bar called with: ' + $val;}",
)
.await
.unwrap();
// call functions
let res = socket.send_message_run("fn::foo", None, vec![]).await.unwrap();
assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::foo called"));
let res = socket.send_message_run("fn::bar", None, vec![]).await;
assert!(res.is_err());
let res = socket.send_message_run("fn::bar", None, vec![42.into()]).await;
assert!(res.is_err());
let res = socket.send_message_run("fn::bar", None, vec!["first".into(), "second".into()]).await;
assert!(res.is_err());
let res = socket.send_message_run("fn::bar", None, vec!["string_val".into()]).await.unwrap();
assert!(matches!(res, serde_json::Value::String(s) if &s == "fn::bar called with: string_val"));
// normal functions
let res = socket.send_message_run("math::abs", None, vec![(-42).into()]).await.unwrap();
assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(42)));
let res = socket
.send_message_run("math::max", None, vec![vec![1, 2, 3, 4, 5, 6].into()])
.await
.unwrap();
assert!(matches!(res, serde_json::Value::Number(n) if n.as_u64() == Some(6)));
// Test passed
server.finish().unwrap();
}