diff --git a/src/net/rpc.rs b/src/net/rpc.rs index 42e95ef3..def13fac 100644 --- a/src/net/rpc.rs +++ b/src/net/rpc.rs @@ -11,6 +11,7 @@ use crate::rpc::args::Take; use crate::rpc::paths::{ID, METHOD, PARAMS}; use crate::rpc::res; use crate::rpc::res::Failure; +use crate::rpc::res::Output; use futures::{SinkExt, StreamExt}; use serde::Serialize; use std::collections::BTreeMap; @@ -41,6 +42,7 @@ async fn socket(ws: WebSocket, session: Session) { pub struct Rpc { session: Session, + format: Output, vars: BTreeMap, } @@ -49,11 +51,14 @@ impl Rpc { pub fn new(mut session: Session) -> Arc> { // Create a new RPC variables store let vars = BTreeMap::new(); + // Set the default output format + let format = Output::Json; // Enable real-time live queries session.rt = true; // Create and store the Rpc connection Arc::new(RwLock::new(Rpc { session, + format, vars, })) } @@ -132,17 +137,26 @@ impl Rpc { // Call RPC methods from the WebSocket async fn call(rpc: Arc>, msg: Message, chn: Sender) { + // Get the current output format + let out = { rpc.read().await.format.clone() }; // Clone the RPC let rpc = rpc.clone(); - // Convert the message - let str = match msg.to_str() { - Ok(v) => v, - _ => return res::failure(None, Failure::INTERNAL_ERROR).send(chn).await, - }; // Parse the request - let req = match surrealdb::sql::json(str) { - Ok(v) if v.is_some() => v, - _ => return res::failure(None, Failure::PARSE_ERROR).send(chn).await, + let req = match msg { + // This is a text message + m if m.is_text() => { + // This won't panic due to the check above + let val = m.to_str().unwrap(); + // Parse the SurrealQL object + match surrealdb::sql::json(val) { + // The SurrealQL message parsed ok + Ok(v) => v, + // The SurrealQL message failed to parse + _ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await, + } + } + // Unsupported message type + _ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await, }; // Fetch the 'id' argument let id = match req.pick(&*ID) { @@ -157,7 +171,7 @@ impl Rpc { // Fetch the 'method' argument let method = match req.pick(&*METHOD) { Value::Strand(v) => v.to_raw(), - _ => return res::failure(id, Failure::INVALID_REQUEST).send(chn).await, + _ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await, }; // Fetch the 'params' argument let params = match req.pick(&*PARAMS) { @@ -169,108 +183,115 @@ impl Rpc { "ping" => Ok(Value::None), "info" => match params.len() { 0 => rpc.read().await.info().await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "use" => match params.take_two() { (Value::Strand(ns), Value::Strand(db)) => rpc.write().await.yuse(ns, db).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "signup" => match params.take_one() { Value::Object(v) => rpc.write().await.signup(v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "signin" => match params.take_one() { Value::Object(v) => rpc.write().await.signin(v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "invalidate" => match params.len() { 0 => rpc.write().await.invalidate().await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "authenticate" => match params.take_one() { Value::None => rpc.write().await.invalidate().await, Value::Strand(v) => rpc.write().await.authenticate(v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "kill" => match params.take_one() { v if v.is_uuid() => rpc.read().await.kill(v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "live" => match params.take_one() { v if v.is_strand() => rpc.read().await.live(v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "let" => match params.take_two() { (Value::Strand(s), v) => rpc.write().await.set(s, v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "set" => match params.take_two() { (Value::Strand(s), v) => rpc.write().await.set(s, v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "query" => match params.take_two() { (Value::Strand(s), o) if o.is_none() => { - let res = rpc.read().await.query(s).await; - return match res { - Ok(v) => res::success(id, v).send(chn).await, - Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await, + return match rpc.read().await.query(s).await { + Ok(v) => res::success(id, v).send(out, chn).await, + Err(e) => { + res::failure(id, Failure::custom(e.to_string())).send(out, chn).await + } }; } (Value::Strand(s), Value::Object(o)) => { - let res = rpc.read().await.query_with(s, o).await; - return match res { - Ok(v) => res::success(id, v).send(chn).await, - Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await, + return match rpc.read().await.query_with(s, o).await { + Ok(v) => res::success(id, v).send(out, chn).await, + Err(e) => { + res::failure(id, Failure::custom(e.to_string())).send(out, chn).await + } }; } - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "select" => match params.take_one() { v if v.is_thing() => rpc.read().await.select(v).await, v if v.is_strand() => rpc.read().await.select(v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "create" => match params.take_two() { (v, o) if v.is_thing() && o.is_none() => rpc.read().await.create(v, None).await, (v, o) if v.is_strand() && o.is_none() => rpc.read().await.create(v, None).await, (v, o) if v.is_thing() && o.is_object() => rpc.read().await.create(v, o).await, (v, o) if v.is_strand() && o.is_object() => rpc.read().await.create(v, o).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "update" => match params.take_two() { (v, o) if v.is_thing() && o.is_none() => rpc.read().await.update(v, None).await, (v, o) if v.is_strand() && o.is_none() => rpc.read().await.update(v, None).await, (v, o) if v.is_thing() && o.is_object() => rpc.read().await.update(v, o).await, (v, o) if v.is_strand() && o.is_object() => rpc.read().await.update(v, o).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "change" => match params.take_two() { (v, o) if v.is_thing() && o.is_none() => rpc.read().await.change(v, None).await, (v, o) if v.is_strand() && o.is_none() => rpc.read().await.change(v, None).await, (v, o) if v.is_thing() && o.is_object() => rpc.read().await.change(v, o).await, (v, o) if v.is_strand() && o.is_object() => rpc.read().await.change(v, o).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "modify" => match params.take_two() { (v, o) if v.is_thing() && o.is_array() => rpc.read().await.modify(v, o).await, (v, o) if v.is_strand() && o.is_array() => rpc.read().await.modify(v, o).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "delete" => match params.take_one() { v if v.is_thing() => rpc.read().await.delete(v).await, v if v.is_strand() => rpc.read().await.delete(v).await, - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, + }, + // Specify the output format for text requests + "format" => match params.needs_one() { + Ok(Value::Strand(v)) => rpc.write().await.format(v).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, "version" => match params.len() { 0 => Ok(format!("{}-{}", PKG_NAME, *PKG_VERS).into()), - _ => return res::failure(id, Failure::INVALID_PARAMS).send(chn).await, + _ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }, - _ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(chn).await, + _ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await, }; // Return the final response match res { - Ok(v) => res::success(id, v).send(chn).await, - Err(e) => res::failure(id, Failure::custom(e.to_string())).send(chn).await, + Ok(v) => res::success(id, v).send(out, chn).await, + Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await, } } @@ -278,6 +299,16 @@ impl Rpc { // Methods for authentication // ------------------------------ + async fn format(&mut self, out: Strand) -> Result { + match out.as_str() { + "json" | "application/json" => self.format = Output::Json, + "cbor" | "application/cbor" => self.format = Output::Cbor, + "msgpack" | "application/msgpack" => self.format = Output::Pack, + _ => return Err(Error::InvalidType), + }; + Ok(Value::None) + } + async fn yuse(&mut self, ns: Strand, db: Strand) -> Result { self.session.ns = Some(ns.0); self.session.db = Some(db.0); diff --git a/src/rpc/res.rs b/src/rpc/res.rs index 64f381d1..3befa1cc 100644 --- a/src/rpc/res.rs +++ b/src/rpc/res.rs @@ -4,12 +4,12 @@ use surrealdb::channel::Sender; use surrealdb::sql::Value; use warp::ws::Message; -#[derive(Serialize)] -enum Content { - #[serde(rename = "result")] - Success(T), - #[serde(rename = "error")] - Failure(Failure), +#[derive(Clone)] +pub enum Output { + Json, // JSON + Cbor, // CBOR + Pack, // MsgPack + Full, // Full type serialization } #[derive(Serialize)] @@ -19,12 +19,34 @@ pub struct Response { content: Content, } +#[derive(Serialize)] +enum Content { + #[serde(rename = "result")] + Success(T), + #[serde(rename = "error")] + Failure(Failure), +} + impl Response { // Send the response to the channel - pub async fn send(self, chn: Sender) { - let res = serde_json::to_string(&self).unwrap(); - let res = Message::text(res); - let _ = chn.send(res).await; + pub async fn send(self, out: Output, chn: Sender) { + match out { + Output::Json => { + let res = serde_json::to_string(&self).unwrap(); + let res = Message::text(res); + let _ = chn.send(res).await; + } + Output::Cbor => { + let res = serde_cbor::to_vec(&self).unwrap(); + let res = Message::binary(res); + let _ = chn.send(res).await; + } + Output::Pack => { + let res = serde_pack::to_vec(&self).unwrap(); + let res = Message::binary(res); + let _ = chn.send(res).await; + } + } } }