Enable different output serialization formats in WebSocket RPC

This commit is contained in:
Tobie Morgan Hitchcock 2022-10-25 06:19:44 -07:00
parent d12384f3fb
commit a0d86248e2
2 changed files with 101 additions and 48 deletions

View file

@ -11,6 +11,7 @@ use crate::rpc::args::Take;
use crate::rpc::paths::{ID, METHOD, PARAMS};
use crate::rpc::res;
use crate::rpc::res::Failure;
use crate::rpc::res::Output;
use futures::{SinkExt, StreamExt};
use serde::Serialize;
use std::collections::BTreeMap;
@ -41,6 +42,7 @@ async fn socket(ws: WebSocket, session: Session) {
pub struct Rpc {
session: Session,
format: Output,
vars: BTreeMap<String, Value>,
}
@ -49,11 +51,14 @@ impl Rpc {
pub fn new(mut session: Session) -> Arc<RwLock<Rpc>> {
// Create a new RPC variables store
let vars = BTreeMap::new();
// Set the default output format
let format = Output::Json;
// Enable real-time live queries
session.rt = true;
// Create and store the Rpc connection
Arc::new(RwLock::new(Rpc {
session,
format,
vars,
}))
}
@ -132,17 +137,26 @@ impl Rpc {
// Call RPC methods from the WebSocket
async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
// Get the current output format
let out = { rpc.read().await.format.clone() };
// Clone the RPC
let rpc = rpc.clone();
// Convert the message
let str = match msg.to_str() {
Ok(v) => v,
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(chn).await,
};
// Parse the request
let req = match surrealdb::sql::json(str) {
Ok(v) if v.is_some() => v,
_ => return res::failure(None, Failure::PARSE_ERROR).send(chn).await,
let req = match msg {
// This is a text message
m if m.is_text() => {
// This won't panic due to the check above
let val = m.to_str().unwrap();
// Parse the SurrealQL object
match surrealdb::sql::json(val) {
// The SurrealQL message parsed ok
Ok(v) => v,
// The SurrealQL message failed to parse
_ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await,
}
}
// Unsupported message type
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await,
};
// Fetch the 'id' argument
let id = match req.pick(&*ID) {
@ -157,7 +171,7 @@ impl Rpc {
// Fetch the 'method' argument
let method = match req.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(),
_ => return res::failure(id, Failure::INVALID_REQUEST).send(chn).await,
_ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await,
};
// Fetch the 'params' argument
let params = match req.pick(&*PARAMS) {
@ -169,108 +183,115 @@ impl Rpc {
"ping" => Ok(Value::None),
"info" => match params.len() {
0 => rpc.read().await.info().await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"use" => match params.take_two() {
(Value::Strand(ns), Value::Strand(db)) => rpc.write().await.yuse(ns, db).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"signup" => match params.take_one() {
Value::Object(v) => rpc.write().await.signup(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"signin" => match params.take_one() {
Value::Object(v) => rpc.write().await.signin(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"invalidate" => match params.len() {
0 => rpc.write().await.invalidate().await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"authenticate" => match params.take_one() {
Value::None => rpc.write().await.invalidate().await,
Value::Strand(v) => rpc.write().await.authenticate(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"kill" => match params.take_one() {
v if v.is_uuid() => rpc.read().await.kill(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"live" => match params.take_one() {
v if v.is_strand() => rpc.read().await.live(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"let" => match params.take_two() {
(Value::Strand(s), v) => rpc.write().await.set(s, v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"set" => match params.take_two() {
(Value::Strand(s), v) => rpc.write().await.set(s, v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"query" => match params.take_two() {
(Value::Strand(s), o) if o.is_none() => {
let res = rpc.read().await.query(s).await;
return match res {
Ok(v) => res::success(id, v).send(chn).await,
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await,
return match rpc.read().await.query(s).await {
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => {
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
}
};
}
(Value::Strand(s), Value::Object(o)) => {
let res = rpc.read().await.query_with(s, o).await;
return match res {
Ok(v) => res::success(id, v).send(chn).await,
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await,
return match rpc.read().await.query_with(s, o).await {
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => {
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
}
};
}
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"select" => match params.take_one() {
v if v.is_thing() => rpc.read().await.select(v).await,
v if v.is_strand() => rpc.read().await.select(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"create" => match params.take_two() {
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.create(v, None).await,
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.create(v, None).await,
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.create(v, o).await,
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.create(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"update" => match params.take_two() {
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.update(v, None).await,
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.update(v, None).await,
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.update(v, o).await,
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.update(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"change" => match params.take_two() {
(v, o) if v.is_thing() && o.is_none() => rpc.read().await.change(v, None).await,
(v, o) if v.is_strand() && o.is_none() => rpc.read().await.change(v, None).await,
(v, o) if v.is_thing() && o.is_object() => rpc.read().await.change(v, o).await,
(v, o) if v.is_strand() && o.is_object() => rpc.read().await.change(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"modify" => match params.take_two() {
(v, o) if v.is_thing() && o.is_array() => rpc.read().await.modify(v, o).await,
(v, o) if v.is_strand() && o.is_array() => rpc.read().await.modify(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"delete" => match params.take_one() {
v if v.is_thing() => rpc.read().await.delete(v).await,
v if v.is_strand() => rpc.read().await.delete(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
// Specify the output format for text requests
"format" => match params.needs_one() {
Ok(Value::Strand(v)) => rpc.write().await.format(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
"version" => match params.len() {
0 => Ok(format!("{}-{}", PKG_NAME, *PKG_VERS).into()),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(chn).await,
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await,
};
// Return the final response
match res {
Ok(v) => res::success(id, v).send(chn).await,
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await,
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await,
}
}
@ -278,6 +299,16 @@ impl Rpc {
// Methods for authentication
// ------------------------------
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
match out.as_str() {
"json" | "application/json" => self.format = Output::Json,
"cbor" | "application/cbor" => self.format = Output::Cbor,
"msgpack" | "application/msgpack" => self.format = Output::Pack,
_ => return Err(Error::InvalidType),
};
Ok(Value::None)
}
async fn yuse(&mut self, ns: Strand, db: Strand) -> Result<Value, Error> {
self.session.ns = Some(ns.0);
self.session.db = Some(db.0);

View file

@ -4,12 +4,12 @@ use surrealdb::channel::Sender;
use surrealdb::sql::Value;
use warp::ws::Message;
#[derive(Serialize)]
enum Content<T> {
#[serde(rename = "result")]
Success(T),
#[serde(rename = "error")]
Failure(Failure),
#[derive(Clone)]
pub enum Output {
Json, // JSON
Cbor, // CBOR
Pack, // MsgPack
Full, // Full type serialization
}
#[derive(Serialize)]
@ -19,12 +19,34 @@ pub struct Response<T> {
content: Content<T>,
}
#[derive(Serialize)]
enum Content<T> {
#[serde(rename = "result")]
Success(T),
#[serde(rename = "error")]
Failure(Failure),
}
impl<T: Serialize> Response<T> {
// Send the response to the channel
pub async fn send(self, chn: Sender<Message>) {
let res = serde_json::to_string(&self).unwrap();
let res = Message::text(res);
let _ = chn.send(res).await;
pub async fn send(self, out: Output, chn: Sender<Message>) {
match out {
Output::Json => {
let res = serde_json::to_string(&self).unwrap();
let res = Message::text(res);
let _ = chn.send(res).await;
}
Output::Cbor => {
let res = serde_cbor::to_vec(&self).unwrap();
let res = Message::binary(res);
let _ = chn.send(res).await;
}
Output::Pack => {
let res = serde_pack::to_vec(&self).unwrap();
let res = Message::binary(res);
let _ = chn.send(res).await;
}
}
}
}