Enable different output serialization formats in WebSocket RPC

This commit is contained in:
Tobie Morgan Hitchcock 2022-10-25 06:19:44 -07:00
parent d12384f3fb
commit a0d86248e2
2 changed files with 101 additions and 48 deletions

View file

@ -11,6 +11,7 @@ use crate::rpc::args::Take;
use crate::rpc::paths::{ID, METHOD, PARAMS}; use crate::rpc::paths::{ID, METHOD, PARAMS};
use crate::rpc::res; use crate::rpc::res;
use crate::rpc::res::Failure; use crate::rpc::res::Failure;
use crate::rpc::res::Output;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use serde::Serialize; use serde::Serialize;
use std::collections::BTreeMap; use std::collections::BTreeMap;
@ -41,6 +42,7 @@ async fn socket(ws: WebSocket, session: Session) {
pub struct Rpc { pub struct Rpc {
session: Session, session: Session,
format: Output,
vars: BTreeMap<String, Value>, vars: BTreeMap<String, Value>,
} }
@ -49,11 +51,14 @@ impl Rpc {
pub fn new(mut session: Session) -> Arc<RwLock<Rpc>> { pub fn new(mut session: Session) -> Arc<RwLock<Rpc>> {
// Create a new RPC variables store // Create a new RPC variables store
let vars = BTreeMap::new(); let vars = BTreeMap::new();
// Set the default output format
let format = Output::Json;
// Enable real-time live queries // Enable real-time live queries
session.rt = true; session.rt = true;
// Create and store the Rpc connection // Create and store the Rpc connection
Arc::new(RwLock::new(Rpc { Arc::new(RwLock::new(Rpc {
session, session,
format,
vars, vars,
})) }))
} }
@ -132,17 +137,26 @@ impl Rpc {
// Call RPC methods from the WebSocket // Call RPC methods from the WebSocket
async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) { async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
// Get the current output format
let out = { rpc.read().await.format.clone() };
// Clone the RPC // Clone the RPC
let rpc = rpc.clone(); let rpc = rpc.clone();
// Convert the message
let str = match msg.to_str() {
Ok(v) => v,
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(chn).await,
};
// Parse the request // Parse the request
let req = match surrealdb::sql::json(str) { let req = match msg {
Ok(v) if v.is_some() => v, // This is a text message
_ => return res::failure(None, Failure::PARSE_ERROR).send(chn).await, m if m.is_text() => {
// This won't panic due to the check above
let val = m.to_str().unwrap();
// Parse the SurrealQL object
match surrealdb::sql::json(val) {
// The SurrealQL message parsed ok
Ok(v) => v,
// The SurrealQL message failed to parse
_ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await,
}
}
// Unsupported message type
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await,
}; };
// Fetch the 'id' argument // Fetch the 'id' argument
let id = match req.pick(&*ID) { let id = match req.pick(&*ID) {
@ -157,7 +171,7 @@ impl Rpc {
// Fetch the 'method' argument // Fetch the 'method' argument
let method = match req.pick(&*METHOD) { let method = match req.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(), Value::Strand(v) => v.to_raw(),
_ => return res::failure(id, Failure::INVALID_REQUEST).send(chn).await, _ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await,
}; };
// Fetch the 'params' argument // Fetch the 'params' argument
let params = match req.pick(&*PARAMS) { let params = match req.pick(&*PARAMS) {
@ -169,108 +183,115 @@ impl Rpc {
"ping" => Ok(Value::None), "ping" => Ok(Value::None),
"info" => match params.len() { "info" => match params.len() {
0 => rpc.read().await.info().await, 0 => rpc.read().await.info().await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"use" => match params.take_two() { "use" => match params.take_two() {
(Value::Strand(ns), Value::Strand(db)) => rpc.write().await.yuse(ns, db).await, (Value::Strand(ns), Value::Strand(db)) => rpc.write().await.yuse(ns, db).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"signup" => match params.take_one() { "signup" => match params.take_one() {
Value::Object(v) => rpc.write().await.signup(v).await, Value::Object(v) => rpc.write().await.signup(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"signin" => match params.take_one() { "signin" => match params.take_one() {
Value::Object(v) => rpc.write().await.signin(v).await, Value::Object(v) => rpc.write().await.signin(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"invalidate" => match params.len() { "invalidate" => match params.len() {
0 => rpc.write().await.invalidate().await, 0 => rpc.write().await.invalidate().await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"authenticate" => match params.take_one() { "authenticate" => match params.take_one() {
Value::None => rpc.write().await.invalidate().await, Value::None => rpc.write().await.invalidate().await,
Value::Strand(v) => rpc.write().await.authenticate(v).await, Value::Strand(v) => rpc.write().await.authenticate(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"kill" => match params.take_one() { "kill" => match params.take_one() {
v if v.is_uuid() => rpc.read().await.kill(v).await, v if v.is_uuid() => rpc.read().await.kill(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"live" => match params.take_one() { "live" => match params.take_one() {
v if v.is_strand() => rpc.read().await.live(v).await, v if v.is_strand() => rpc.read().await.live(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"let" => match params.take_two() { "let" => match params.take_two() {
(Value::Strand(s), v) => rpc.write().await.set(s, v).await, (Value::Strand(s), v) => rpc.write().await.set(s, v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"set" => match params.take_two() { "set" => match params.take_two() {
(Value::Strand(s), v) => rpc.write().await.set(s, v).await, (Value::Strand(s), v) => rpc.write().await.set(s, v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"query" => match params.take_two() { "query" => match params.take_two() {
(Value::Strand(s), o) if o.is_none() => { (Value::Strand(s), o) if o.is_none() => {
let res = rpc.read().await.query(s).await; return match rpc.read().await.query(s).await {
return match res { Ok(v) => res::success(id, v).send(out, chn).await,
Ok(v) => res::success(id, v).send(chn).await, Err(e) => {
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await, res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
}
}; };
} }
(Value::Strand(s), Value::Object(o)) => { (Value::Strand(s), Value::Object(o)) => {
let res = rpc.read().await.query_with(s, o).await; return match rpc.read().await.query_with(s, o).await {
return match res { Ok(v) => res::success(id, v).send(out, chn).await,
Ok(v) => res::success(id, v).send(chn).await, Err(e) => {
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await, res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
}
}; };
} }
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"select" => match params.take_one() { "select" => match params.take_one() {
v if v.is_thing() => rpc.read().await.select(v).await, v if v.is_thing() => rpc.read().await.select(v).await,
v if v.is_strand() => rpc.read().await.select(v).await, v if v.is_strand() => rpc.read().await.select(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"create" => match params.take_two() { "create" => match params.take_two() {
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.create(v, None).await, (v, o) if v.is_thing() && o.is_none() => rpc.read().await.create(v, None).await,
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.create(v, None).await, (v, o) if v.is_strand() && o.is_none() => rpc.read().await.create(v, None).await,
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.create(v, o).await, (v, o) if v.is_thing() && o.is_object() => rpc.read().await.create(v, o).await,
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.create(v, o).await, (v, o) if v.is_strand() && o.is_object() => rpc.read().await.create(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"update" => match params.take_two() { "update" => match params.take_two() {
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.update(v, None).await, (v, o) if v.is_thing() && o.is_none() => rpc.read().await.update(v, None).await,
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.update(v, None).await, (v, o) if v.is_strand() && o.is_none() => rpc.read().await.update(v, None).await,
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.update(v, o).await, (v, o) if v.is_thing() && o.is_object() => rpc.read().await.update(v, o).await,
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.update(v, o).await, (v, o) if v.is_strand() && o.is_object() => rpc.read().await.update(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"change" => match params.take_two() { "change" => match params.take_two() {
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.change(v, None).await, (v, o) if v.is_thing() && o.is_none() => rpc.read().await.change(v, None).await,
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.change(v, None).await, (v, o) if v.is_strand() && o.is_none() => rpc.read().await.change(v, None).await,
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.change(v, o).await, (v, o) if v.is_thing() && o.is_object() => rpc.read().await.change(v, o).await,
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.change(v, o).await, (v, o) if v.is_strand() && o.is_object() => rpc.read().await.change(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"modify" => match params.take_two() { "modify" => match params.take_two() {
(v, o) if v.is_thing() && o.is_array() => rpc.read().await.modify(v, o).await, (v, o) if v.is_thing() && o.is_array() => rpc.read().await.modify(v, o).await,
(v, o) if v.is_strand() && o.is_array() => rpc.read().await.modify(v, o).await, (v, o) if v.is_strand() && o.is_array() => rpc.read().await.modify(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"delete" => match params.take_one() { "delete" => match params.take_one() {
v if v.is_thing() => rpc.read().await.delete(v).await, v if v.is_thing() => rpc.read().await.delete(v).await,
v if v.is_strand() => rpc.read().await.delete(v).await, v if v.is_strand() => rpc.read().await.delete(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
// Specify the output format for text requests
"format" => match params.needs_one() {
Ok(Value::Strand(v)) => rpc.write().await.format(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
"version" => match params.len() { "version" => match params.len() {
0 => Ok(format!("{}-{}", PKG_NAME, *PKG_VERS).into()), 0 => Ok(format!("{}-{}", PKG_NAME, *PKG_VERS).into()),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(chn).await, _ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await,
}; };
// Return the final response // Return the final response
match res { match res {
Ok(v) => res::success(id, v).send(chn).await, Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await, Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await,
} }
} }
@ -278,6 +299,16 @@ impl Rpc {
// Methods for authentication // Methods for authentication
// ------------------------------ // ------------------------------
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
match out.as_str() {
"json" | "application/json" => self.format = Output::Json,
"cbor" | "application/cbor" => self.format = Output::Cbor,
"msgpack" | "application/msgpack" => self.format = Output::Pack,
_ => return Err(Error::InvalidType),
};
Ok(Value::None)
}
async fn yuse(&mut self, ns: Strand, db: Strand) -> Result<Value, Error> { async fn yuse(&mut self, ns: Strand, db: Strand) -> Result<Value, Error> {
self.session.ns = Some(ns.0); self.session.ns = Some(ns.0);
self.session.db = Some(db.0); self.session.db = Some(db.0);

View file

@ -4,12 +4,12 @@ use surrealdb::channel::Sender;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use warp::ws::Message; use warp::ws::Message;
#[derive(Serialize)] #[derive(Clone)]
enum Content<T> { pub enum Output {
#[serde(rename = "result")] Json, // JSON
Success(T), Cbor, // CBOR
#[serde(rename = "error")] Pack, // MsgPack
Failure(Failure), Full, // Full type serialization
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -19,12 +19,34 @@ pub struct Response<T> {
content: Content<T>, content: Content<T>,
} }
#[derive(Serialize)]
enum Content<T> {
#[serde(rename = "result")]
Success(T),
#[serde(rename = "error")]
Failure(Failure),
}
impl<T: Serialize> Response<T> { impl<T: Serialize> Response<T> {
// Send the response to the channel // Send the response to the channel
pub async fn send(self, chn: Sender<Message>) { pub async fn send(self, out: Output, chn: Sender<Message>) {
let res = serde_json::to_string(&self).unwrap(); match out {
let res = Message::text(res); Output::Json => {
let _ = chn.send(res).await; let res = serde_json::to_string(&self).unwrap();
let res = Message::text(res);
let _ = chn.send(res).await;
}
Output::Cbor => {
let res = serde_cbor::to_vec(&self).unwrap();
let res = Message::binary(res);
let _ = chn.send(res).await;
}
Output::Pack => {
let res = serde_pack::to_vec(&self).unwrap();
let res = Message::binary(res);
let _ = chn.send(res).await;
}
}
} }
} }