Enable different output serialization formats in WebSocket RPC
This commit is contained in:
parent
d12384f3fb
commit
a0d86248e2
2 changed files with 101 additions and 48 deletions
107
src/net/rpc.rs
107
src/net/rpc.rs
|
@ -11,6 +11,7 @@ use crate::rpc::args::Take;
|
|||
use crate::rpc::paths::{ID, METHOD, PARAMS};
|
||||
use crate::rpc::res;
|
||||
use crate::rpc::res::Failure;
|
||||
use crate::rpc::res::Output;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeMap;
|
||||
|
@ -41,6 +42,7 @@ async fn socket(ws: WebSocket, session: Session) {
|
|||
|
||||
pub struct Rpc {
|
||||
session: Session,
|
||||
format: Output,
|
||||
vars: BTreeMap<String, Value>,
|
||||
}
|
||||
|
||||
|
@ -49,11 +51,14 @@ impl Rpc {
|
|||
pub fn new(mut session: Session) -> Arc<RwLock<Rpc>> {
|
||||
// Create a new RPC variables store
|
||||
let vars = BTreeMap::new();
|
||||
// Set the default output format
|
||||
let format = Output::Json;
|
||||
// Enable real-time live queries
|
||||
session.rt = true;
|
||||
// Create and store the Rpc connection
|
||||
Arc::new(RwLock::new(Rpc {
|
||||
session,
|
||||
format,
|
||||
vars,
|
||||
}))
|
||||
}
|
||||
|
@ -132,17 +137,26 @@ impl Rpc {
|
|||
|
||||
// Call RPC methods from the WebSocket
|
||||
async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
|
||||
// Get the current output format
|
||||
let out = { rpc.read().await.format.clone() };
|
||||
// Clone the RPC
|
||||
let rpc = rpc.clone();
|
||||
// Convert the message
|
||||
let str = match msg.to_str() {
|
||||
Ok(v) => v,
|
||||
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(chn).await,
|
||||
};
|
||||
// Parse the request
|
||||
let req = match surrealdb::sql::json(str) {
|
||||
Ok(v) if v.is_some() => v,
|
||||
_ => return res::failure(None, Failure::PARSE_ERROR).send(chn).await,
|
||||
let req = match msg {
|
||||
// This is a text message
|
||||
m if m.is_text() => {
|
||||
// This won't panic due to the check above
|
||||
let val = m.to_str().unwrap();
|
||||
// Parse the SurrealQL object
|
||||
match surrealdb::sql::json(val) {
|
||||
// The SurrealQL message parsed ok
|
||||
Ok(v) => v,
|
||||
// The SurrealQL message failed to parse
|
||||
_ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await,
|
||||
}
|
||||
}
|
||||
// Unsupported message type
|
||||
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await,
|
||||
};
|
||||
// Fetch the 'id' argument
|
||||
let id = match req.pick(&*ID) {
|
||||
|
@ -157,7 +171,7 @@ impl Rpc {
|
|||
// Fetch the 'method' argument
|
||||
let method = match req.pick(&*METHOD) {
|
||||
Value::Strand(v) => v.to_raw(),
|
||||
_ => return res::failure(id, Failure::INVALID_REQUEST).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await,
|
||||
};
|
||||
// Fetch the 'params' argument
|
||||
let params = match req.pick(&*PARAMS) {
|
||||
|
@ -169,108 +183,115 @@ impl Rpc {
|
|||
"ping" => Ok(Value::None),
|
||||
"info" => match params.len() {
|
||||
0 => rpc.read().await.info().await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"use" => match params.take_two() {
|
||||
(Value::Strand(ns), Value::Strand(db)) => rpc.write().await.yuse(ns, db).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"signup" => match params.take_one() {
|
||||
Value::Object(v) => rpc.write().await.signup(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"signin" => match params.take_one() {
|
||||
Value::Object(v) => rpc.write().await.signin(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"invalidate" => match params.len() {
|
||||
0 => rpc.write().await.invalidate().await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"authenticate" => match params.take_one() {
|
||||
Value::None => rpc.write().await.invalidate().await,
|
||||
Value::Strand(v) => rpc.write().await.authenticate(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"kill" => match params.take_one() {
|
||||
v if v.is_uuid() => rpc.read().await.kill(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"live" => match params.take_one() {
|
||||
v if v.is_strand() => rpc.read().await.live(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"let" => match params.take_two() {
|
||||
(Value::Strand(s), v) => rpc.write().await.set(s, v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"set" => match params.take_two() {
|
||||
(Value::Strand(s), v) => rpc.write().await.set(s, v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"query" => match params.take_two() {
|
||||
(Value::Strand(s), o) if o.is_none() => {
|
||||
let res = rpc.read().await.query(s).await;
|
||||
return match res {
|
||||
Ok(v) => res::success(id, v).send(chn).await,
|
||||
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await,
|
||||
return match rpc.read().await.query(s).await {
|
||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
||||
Err(e) => {
|
||||
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
|
||||
}
|
||||
};
|
||||
}
|
||||
(Value::Strand(s), Value::Object(o)) => {
|
||||
let res = rpc.read().await.query_with(s, o).await;
|
||||
return match res {
|
||||
Ok(v) => res::success(id, v).send(chn).await,
|
||||
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await,
|
||||
return match rpc.read().await.query_with(s, o).await {
|
||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
||||
Err(e) => {
|
||||
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
|
||||
}
|
||||
};
|
||||
}
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"select" => match params.take_one() {
|
||||
v if v.is_thing() => rpc.read().await.select(v).await,
|
||||
v if v.is_strand() => rpc.read().await.select(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"create" => match params.take_two() {
|
||||
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.create(v, None).await,
|
||||
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.create(v, None).await,
|
||||
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.create(v, o).await,
|
||||
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.create(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"update" => match params.take_two() {
|
||||
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.update(v, None).await,
|
||||
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.update(v, None).await,
|
||||
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.update(v, o).await,
|
||||
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.update(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"change" => match params.take_two() {
|
||||
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.change(v, None).await,
|
||||
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.change(v, None).await,
|
||||
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.change(v, o).await,
|
||||
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.change(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"modify" => match params.take_two() {
|
||||
(v, o) if v.is_thing() && o.is_array() => rpc.read().await.modify(v, o).await,
|
||||
(v, o) if v.is_strand() && o.is_array() => rpc.read().await.modify(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"delete" => match params.take_one() {
|
||||
v if v.is_thing() => rpc.read().await.delete(v).await,
|
||||
v if v.is_strand() => rpc.read().await.delete(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
// Specify the output format for text requests
|
||||
"format" => match params.needs_one() {
|
||||
Ok(Value::Strand(v)) => rpc.write().await.format(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
"version" => match params.len() {
|
||||
0 => Ok(format!("{}-{}", PKG_NAME, *PKG_VERS).into()),
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(chn).await,
|
||||
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await,
|
||||
};
|
||||
// Return the final response
|
||||
match res {
|
||||
Ok(v) => res::success(id, v).send(chn).await,
|
||||
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await,
|
||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
||||
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -278,6 +299,16 @@ impl Rpc {
|
|||
// Methods for authentication
|
||||
// ------------------------------
|
||||
|
||||
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
||||
match out.as_str() {
|
||||
"json" | "application/json" => self.format = Output::Json,
|
||||
"cbor" | "application/cbor" => self.format = Output::Cbor,
|
||||
"msgpack" | "application/msgpack" => self.format = Output::Pack,
|
||||
_ => return Err(Error::InvalidType),
|
||||
};
|
||||
Ok(Value::None)
|
||||
}
|
||||
|
||||
async fn yuse(&mut self, ns: Strand, db: Strand) -> Result<Value, Error> {
|
||||
self.session.ns = Some(ns.0);
|
||||
self.session.db = Some(db.0);
|
||||
|
|
|
@ -4,12 +4,12 @@ use surrealdb::channel::Sender;
|
|||
use surrealdb::sql::Value;
|
||||
use warp::ws::Message;
|
||||
|
||||
#[derive(Serialize)]
|
||||
enum Content<T> {
|
||||
#[serde(rename = "result")]
|
||||
Success(T),
|
||||
#[serde(rename = "error")]
|
||||
Failure(Failure),
|
||||
#[derive(Clone)]
|
||||
pub enum Output {
|
||||
Json, // JSON
|
||||
Cbor, // CBOR
|
||||
Pack, // MsgPack
|
||||
Full, // Full type serialization
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -19,13 +19,35 @@ pub struct Response<T> {
|
|||
content: Content<T>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
enum Content<T> {
|
||||
#[serde(rename = "result")]
|
||||
Success(T),
|
||||
#[serde(rename = "error")]
|
||||
Failure(Failure),
|
||||
}
|
||||
|
||||
impl<T: Serialize> Response<T> {
|
||||
// Send the response to the channel
|
||||
pub async fn send(self, chn: Sender<Message>) {
|
||||
pub async fn send(self, out: Output, chn: Sender<Message>) {
|
||||
match out {
|
||||
Output::Json => {
|
||||
let res = serde_json::to_string(&self).unwrap();
|
||||
let res = Message::text(res);
|
||||
let _ = chn.send(res).await;
|
||||
}
|
||||
Output::Cbor => {
|
||||
let res = serde_cbor::to_vec(&self).unwrap();
|
||||
let res = Message::binary(res);
|
||||
let _ = chn.send(res).await;
|
||||
}
|
||||
Output::Pack => {
|
||||
let res = serde_pack::to_vec(&self).unwrap();
|
||||
let res = Message::binary(res);
|
||||
let _ = chn.send(res).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
|
|
Loading…
Reference in a new issue