diff --git a/Cargo.lock b/Cargo.lock index 12361505..99023689 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5910,6 +5910,7 @@ dependencies = [ "bytes", "cedar-policy", "chrono", + "ciborium", "criterion", "deunicode", "dmp", @@ -5947,6 +5948,7 @@ dependencies = [ "reqwest", "revision", "ring 0.17.8", + "rmpv", "roaring", "rocksdb", "rquickjs", diff --git a/core/Cargo.toml b/core/Cargo.toml index 16857e1e..f4b4ff1f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -67,6 +67,7 @@ base64 = "0.21.5" bcrypt = "0.15.0" bincode = "1.3.3" bytes = "1.5.0" +ciborium = "0.2.1" cedar-policy = "2.4.2" channel = { version = "1.9.0", package = "async-channel" } chrono = { version = "0.4.31", features = ["serde"] } @@ -122,6 +123,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [ "multipart", ], optional = true } revision = { version = "0.7.0", features = ["chrono", "geo", "roaring", "regex", "rust_decimal", "uuid"] } +rmpv = "1.0.1" roaring = { version = "0.10.2", features = ["serde"] } rocksdb = { version = "0.21.0", features = ["lz4", "snappy"], optional = true } rust_decimal = { version = "1.33.1", features = ["maths", "serde-str"] } diff --git a/core/src/obs/mod.rs b/core/src/obs/mod.rs index 4c88b6db..c70b048c 100644 --- a/core/src/obs/mod.rs +++ b/core/src/obs/mod.rs @@ -103,7 +103,7 @@ pub async fn del(file: &str) -> Result<(), Error> { } /// Hashes the bytes of a file to a string for the storage of a file. -pub fn hash(data: &Vec) -> String { +pub fn hash(data: &[u8]) -> String { let mut hasher = Sha1::new(); hasher.update(data); let result = hasher.finalize(); diff --git a/core/src/rpc/format/bincode.rs b/core/src/rpc/format/bincode.rs new file mode 100644 index 00000000..d31afe85 --- /dev/null +++ b/core/src/rpc/format/bincode.rs @@ -0,0 +1,14 @@ +use crate::rpc::format::ResTrait; +use crate::rpc::request::Request; +use crate::rpc::RpcError; +use crate::sql::serde::{deserialize, serialize}; +use crate::sql::Value; + +pub fn req(val: &[u8]) -> Result { + deserialize::(val).map_err(|_| RpcError::ParseError)?.try_into() +} + +pub fn res(res: impl ResTrait) -> Result, RpcError> { + // Serialize the response with full internal type information + Ok(serialize(&res).unwrap()) +} diff --git a/src/rpc/format/cbor/convert.rs b/core/src/rpc/format/cbor/convert.rs similarity index 97% rename from src/rpc/format/cbor/convert.rs rename to core/src/rpc/format/cbor/convert.rs index e272fa41..a1056dd0 100644 --- a/src/rpc/format/cbor/convert.rs +++ b/core/src/rpc/format/cbor/convert.rs @@ -3,14 +3,15 @@ use geo::{LineString, Point, Polygon}; use geo_types::{MultiLineString, MultiPoint, MultiPolygon}; use std::iter::once; use std::ops::Deref; -use surrealdb::sql::Datetime; -use surrealdb::sql::Duration; -use surrealdb::sql::Geometry; -use surrealdb::sql::Id; -use surrealdb::sql::Number; -use surrealdb::sql::Thing; -use surrealdb::sql::Uuid; -use surrealdb::sql::Value; + +use crate::sql::Datetime; +use crate::sql::Duration; +use crate::sql::Geometry; +use crate::sql::Id; +use crate::sql::Number; +use crate::sql::Thing; +use crate::sql::Uuid; +use crate::sql::Value; // Tags from the spec - https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml const TAG_SPEC_DATETIME: u64 = 0; @@ -327,7 +328,6 @@ impl TryFrom for Cbor { Number::Decimal(v) => { Ok(Cbor(Data::Tag(TAG_STRING_DECIMAL, Box::new(Data::Text(v.to_string()))))) } - _ => Err("Found an unsupported Number type being converted to CBOR"), }, Value::Strand(v) => Ok(Cbor(Data::Text(v.0))), Value::Duration(v) => { @@ -390,7 +390,6 @@ impl TryFrom for Cbor { Id::Generate(_) => { return Err("Cannot encode an ungenerated Record ID into CBOR") } - _ => return Err("Found an unsupported Id type being converted to CBOR"), }, ])), ))), @@ -459,6 +458,5 @@ fn encode_geometry(v: Geometry) -> Result { Ok(Data::Tag(TAG_GEOMETRY_COLLECTION, Box::new(Data::Array(data)))) } - _ => Err("Found an unsupported Geometry type being converted to CBOR"), } } diff --git a/core/src/rpc/format/cbor/mod.rs b/core/src/rpc/format/cbor/mod.rs new file mode 100644 index 00000000..911ad005 --- /dev/null +++ b/core/src/rpc/format/cbor/mod.rs @@ -0,0 +1,29 @@ +mod convert; + +pub use convert::Cbor; + +use crate::rpc::request::Request; +use crate::rpc::RpcError; +use crate::sql::Value; +use ciborium::Value as Data; + +use super::ResTrait; + +pub fn req(val: Vec) -> Result { + ciborium::from_reader::(&mut val.as_slice()) + .map_err(|_| RpcError::ParseError) + .map(Cbor)? + .try_into() +} + +pub fn res(res: impl ResTrait) -> Result, RpcError> { + // Convert the response into a value + let val: Value = res.into(); + let val: Cbor = val.try_into()?; + // Create a new vector for encoding output + let mut res = Vec::new(); + // Serialize the value into CBOR binary data + ciborium::into_writer(&val.0, &mut res).unwrap(); + // Return the message length, and message as binary + Ok(res) +} diff --git a/core/src/rpc/format/json.rs b/core/src/rpc/format/json.rs new file mode 100644 index 00000000..5906b471 --- /dev/null +++ b/core/src/rpc/format/json.rs @@ -0,0 +1,22 @@ +use crate::rpc::request::Request; +use crate::rpc::RpcError; +use crate::sql::Value; +use crate::syn; + +use super::ResTrait; + +pub fn req(val: &[u8]) -> Result { + syn::value_legacy_strand(std::str::from_utf8(val).or(Err(RpcError::ParseError))?) + .or(Err(RpcError::ParseError))? + .try_into() +} + +pub fn res(res: impl ResTrait) -> Result, RpcError> { + // Convert the response into simplified JSON + let val: Value = res.into(); + let val = val.into_json(); + // Serialize the response with simplified type information + let res = serde_json::to_string(&val).unwrap(); + // Return the message length, and message as binary + Ok(res.into()) +} diff --git a/core/src/rpc/format/mod.rs b/core/src/rpc/format/mod.rs new file mode 100644 index 00000000..da9042be --- /dev/null +++ b/core/src/rpc/format/mod.rs @@ -0,0 +1,82 @@ +mod bincode; +pub mod cbor; +mod json; +pub mod msgpack; +mod revision; + +use ::revision::Revisioned; +use serde::Serialize; + +use super::{request::Request, RpcError}; +use crate::sql::Value; + +pub const PROTOCOLS: [&str; 5] = [ + "json", // For basic JSON serialisation + "cbor", // For basic CBOR serialisation + "msgpack", // For basic Msgpack serialisation + "bincode", // For full internal serialisation + "revision", // For full versioned serialisation +]; + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[non_exhaustive] +pub enum Format { + None, // No format is specified yet + Json, // For basic JSON serialisation + Cbor, // For basic CBOR serialisation + Msgpack, // For basic Msgpack serialisation + Bincode, // For full internal serialisation + Revision, // For full versioned serialisation + Unsupported, // Unsupported format +} + +pub trait ResTrait: Serialize + Into + Revisioned {} +impl + Revisioned> ResTrait for T {} + +impl From<&str> for Format { + fn from(v: &str) -> Self { + match v { + s if s == PROTOCOLS[0] => Format::Json, + s if s == PROTOCOLS[1] => Format::Cbor, + s if s == PROTOCOLS[2] => Format::Msgpack, + s if s == PROTOCOLS[3] => Format::Bincode, + s if s == PROTOCOLS[4] => Format::Revision, + _ => Format::None, + } + } +} + +impl Format { + /// Check if this format has been set + pub fn is_none(&self) -> bool { + matches!(self, Format::None) + } + + /// Process a request using the specified format + pub fn req(&self, val: impl Into>) -> Result { + let val = val.into(); + match self { + Self::None => Err(RpcError::InvalidRequest), + Self::Unsupported => Err(RpcError::InvalidRequest), + Self::Json => json::req(&val), + Self::Cbor => cbor::req(val), + Self::Msgpack => msgpack::req(val), + Self::Bincode => bincode::req(&val), + Self::Revision => revision::req(val), + } + .map_err(Into::into) + } + + /// Process a response using the specified format + pub fn res(&self, val: impl ResTrait) -> Result, RpcError> { + match self { + Self::None => Err(RpcError::InvalidRequest), + Self::Unsupported => Err(RpcError::InvalidRequest), + Self::Json => json::res(val), + Self::Cbor => cbor::res(val), + Self::Msgpack => msgpack::res(val), + Self::Bincode => bincode::res(val), + Self::Revision => revision::res(val), + } + } +} diff --git a/src/rpc/format/msgpack/convert.rs b/core/src/rpc/format/msgpack/convert.rs similarity index 96% rename from src/rpc/format/msgpack/convert.rs rename to core/src/rpc/format/msgpack/convert.rs index 3802d429..e241737c 100644 --- a/src/rpc/format/msgpack/convert.rs +++ b/core/src/rpc/format/msgpack/convert.rs @@ -1,10 +1,10 @@ +use crate::sql::Datetime; +use crate::sql::Duration; +use crate::sql::Number; +use crate::sql::Thing; +use crate::sql::Uuid; +use crate::sql::Value; use rmpv::Value as Data; -use surrealdb::sql::Datetime; -use surrealdb::sql::Duration; -use surrealdb::sql::Number; -use surrealdb::sql::Thing; -use surrealdb::sql::Uuid; -use surrealdb::sql::Value; const TAG_NONE: i8 = 1; const TAG_UUID: i8 = 2; @@ -114,6 +114,7 @@ impl TryFrom for Pack { Number::Decimal(v) => { Ok(Pack(Data::Ext(TAG_DECIMAL, v.to_string().as_bytes().to_vec()))) } + #[allow(unreachable_patterns)] _ => unreachable!(), }, Value::Strand(v) => Ok(Pack(Data::String(v.0.into()))), diff --git a/core/src/rpc/format/msgpack/mod.rs b/core/src/rpc/format/msgpack/mod.rs new file mode 100644 index 00000000..e4b0e6c4 --- /dev/null +++ b/core/src/rpc/format/msgpack/mod.rs @@ -0,0 +1,26 @@ +mod convert; + +use crate::rpc::format::ResTrait; +use crate::rpc::RpcError; +pub use convert::Pack; + +use crate::rpc::request::Request; +use crate::sql::Value; + +pub fn req(val: Vec) -> Result { + rmpv::decode::read_value(&mut val.as_slice()) + .map_err(|_| RpcError::ParseError) + .map(Pack)? + .try_into() +} + +pub fn res(res: impl ResTrait) -> Result, RpcError> { + // Convert the response into a value + let val: Value = res.into(); + let val: Pack = val.try_into()?; + // Create a new vector for encoding output + let mut res = Vec::new(); + // Serialize the value into MsgPack binary data + rmpv::encode::write_value(&mut res, &val.0).unwrap(); + Ok(res) +} diff --git a/core/src/rpc/format/revision.rs b/core/src/rpc/format/revision.rs new file mode 100644 index 00000000..a22b99e4 --- /dev/null +++ b/core/src/rpc/format/revision.rs @@ -0,0 +1,16 @@ +use crate::rpc::format::ResTrait; +use crate::rpc::request::Request; +use crate::rpc::RpcError; +use crate::sql::Value; +use revision::Revisioned; + +pub fn req(val: Vec) -> Result { + Value::deserialize_revisioned(&mut val.as_slice()).map_err(|_| RpcError::ParseError)?.try_into() +} + +pub fn res(res: impl ResTrait) -> Result, RpcError> { + // Serialize the response with full internal type information + let mut buf = Vec::new(); + res.serialize_revisioned(&mut buf).unwrap(); + Ok(buf) +} diff --git a/core/src/rpc/mod.rs b/core/src/rpc/mod.rs index 2b6a8875..b4cc0925 100644 --- a/core/src/rpc/mod.rs +++ b/core/src/rpc/mod.rs @@ -1,6 +1,8 @@ pub mod args; pub mod basic_context; +pub mod format; pub mod method; +pub mod request; mod response; pub mod rpc_context; mod rpc_error; diff --git a/src/rpc/request.rs b/core/src/rpc/request.rs similarity index 94% rename from src/rpc/request.rs rename to core/src/rpc/request.rs index b621cb91..4a742508 100644 --- a/src/rpc/request.rs +++ b/core/src/rpc/request.rs @@ -1,9 +1,9 @@ use crate::rpc::format::cbor::Cbor; use crate::rpc::format::msgpack::Pack; +use crate::rpc::RpcError; +use crate::sql::Part; +use crate::sql::{Array, Value}; use once_cell::sync::Lazy; -use surrealdb::rpc::RpcError; -use surrealdb::sql::Part; -use surrealdb::sql::{Array, Value}; pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]); pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]); diff --git a/src/net/rpc.rs b/src/net/rpc.rs index dc74325f..ddf8f86f 100644 --- a/src/net/rpc.rs +++ b/src/net/rpc.rs @@ -5,8 +5,7 @@ use crate::cnf; use crate::dbs::DB; use crate::err::Error; use crate::rpc::connection::Connection; -use crate::rpc::format::Format; -use crate::rpc::format::PROTOCOLS; +use crate::rpc::format::HttpFormat; use crate::rpc::post_context::PostRpcContext; use crate::rpc::response::IntoRpcResponse; use crate::rpc::WEBSOCKETS; @@ -22,6 +21,8 @@ use bytes::Bytes; use http::HeaderValue; use http_body::Body as HttpBody; use surrealdb::dbs::Session; +use surrealdb::rpc::format::Format; +use surrealdb::rpc::format::PROTOCOLS; use surrealdb::rpc::method::Method; use tower_http::request_id::RequestId; use uuid::Uuid; diff --git a/src/rpc/connection.rs b/src/rpc/connection.rs index 85cd4872..c3f145a5 100644 --- a/src/rpc/connection.rs +++ b/src/rpc/connection.rs @@ -3,7 +3,7 @@ use crate::cnf::{ }; use crate::dbs::DB; use crate::rpc::failure::Failure; -use crate::rpc::format::Format; +use crate::rpc::format::WsFormat; use crate::rpc::response::{failure, IntoRpcResponse}; use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS}; use crate::telemetry; @@ -20,6 +20,7 @@ use surrealdb::channel::{self, Receiver, Sender}; use surrealdb::dbs::Session; use surrealdb::kvs::Datastore; use surrealdb::rpc::args::Take; +use surrealdb::rpc::format::Format; use surrealdb::rpc::method::Method; use surrealdb::rpc::RpcContext; use surrealdb::rpc::{Data, RpcError}; diff --git a/src/rpc/format.rs b/src/rpc/format.rs new file mode 100644 index 00000000..defc00e6 --- /dev/null +++ b/src/rpc/format.rs @@ -0,0 +1,83 @@ +use crate::net::headers::{Accept, ContentType}; +use crate::rpc::failure::Failure; +use crate::rpc::response::Response; +use axum::extract::ws::Message; +use axum::response::IntoResponse; +use axum::response::Response as AxumResponse; +use bytes::Bytes; +use http::StatusCode; +use surrealdb::rpc::format::Format; +use surrealdb::rpc::request::Request; +use surrealdb::rpc::RpcError; + +impl From<&Accept> for Format { + fn from(value: &Accept) -> Self { + match value { + Accept::TextPlain => Format::None, + Accept::ApplicationJson => Format::Json, + Accept::ApplicationCbor => Format::Cbor, + Accept::ApplicationPack => Format::Msgpack, + Accept::ApplicationOctetStream => Format::Unsupported, + Accept::Surrealdb => Format::Bincode, + } + } +} + +impl From<&ContentType> for Format { + fn from(value: &ContentType) -> Self { + match value { + ContentType::TextPlain => Format::None, + ContentType::ApplicationJson => Format::Json, + ContentType::ApplicationCbor => Format::Cbor, + ContentType::ApplicationPack => Format::Msgpack, + ContentType::ApplicationOctetStream => Format::Unsupported, + ContentType::Surrealdb => Format::Bincode, + } + } +} + +pub trait WsFormat { + fn req_ws(&self, msg: Message) -> Result; + fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure>; +} + +impl WsFormat for Format { + fn req_ws(&self, msg: Message) -> Result { + let val = msg.into_data(); + self.req(val).map_err(Into::into) + } + + fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure> { + let res = self.res(res).map_err(Failure::from)?; + if matches!(self, Format::Json) { + // If this has significant performance overhead it could be replaced with unsafe { String::from_utf8_unchecked(res) } + // This would be safe as in the case of JSON res come from a call to Into::> for String + Ok((res.len(), Message::Text(String::from_utf8(res).unwrap()))) + } else { + Ok((res.len(), Message::Binary(res))) + } + } +} + +pub trait HttpFormat { + fn req_http(&self, body: Bytes) -> Result; + fn res_http(&self, res: Response) -> Result; +} + +impl HttpFormat for Format { + fn req_http(&self, body: Bytes) -> Result { + self.req(body).map_err(Into::into) + } + + fn res_http(&self, res: Response) -> Result { + let val = res.into_value(); + let res = self.res(val)?; + if matches!(self, Format::Json) { + // If this has significant performance overhead it could be replaced with unsafe { String::from_utf8_unchecked(res) } + // This would be safe as in the case of JSON res come from a call to Into::> for String + Ok((StatusCode::OK, String::from_utf8(res).unwrap()).into_response()) + } else { + Ok((StatusCode::OK, res).into_response()) + } + } +} diff --git a/src/rpc/format/bincode.rs b/src/rpc/format/bincode.rs deleted file mode 100644 index 4760a89f..00000000 --- a/src/rpc/format/bincode.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::net::headers::ContentType; -use crate::rpc::request::Request; -use crate::rpc::response::Response; -use axum::extract::ws::Message; -use axum::response::IntoResponse; -use axum::response::Response as AxumResponse; -use bytes::Bytes; -use http::header::CONTENT_TYPE; -use http::HeaderValue; -use surrealdb::rpc::RpcError; -use surrealdb::sql::serde::deserialize; -use surrealdb::sql::Value; - -pub fn req_ws(msg: Message) -> Result { - match msg { - Message::Binary(val) => { - deserialize::(&val).map_err(|_| RpcError::ParseError)?.try_into() - } - _ => Err(RpcError::InvalidRequest), - } -} - -pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { - // Serialize the response with full internal type information - let res = surrealdb::sql::serde::serialize(&res).unwrap(); - // Return the message length, and message as binary - Ok((res.len(), Message::Binary(res))) -} - -pub fn req_http(val: &Bytes) -> Result { - deserialize::(val).map_err(|_| RpcError::ParseError)?.try_into() -} - -pub fn res_http(res: Response) -> Result { - // Serialize the response with full internal type information - let res = surrealdb::sql::serde::serialize(&res).unwrap(); - // Return the message length, and message as binary - // TODO: Check what this header should be, I'm being consistent with /sql - Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], res).into_response()) -} diff --git a/src/rpc/format/cbor/mod.rs b/src/rpc/format/cbor/mod.rs deleted file mode 100644 index 498ecff1..00000000 --- a/src/rpc/format/cbor/mod.rs +++ /dev/null @@ -1,57 +0,0 @@ -mod convert; - -use bytes::Bytes; -pub use convert::Cbor; -use http::header::CONTENT_TYPE; -use http::HeaderValue; -use surrealdb::rpc::RpcError; - -use crate::net::headers::ContentType; -use crate::rpc::request::Request; -use crate::rpc::response::Response; -use axum::extract::ws::Message; -use axum::response::{IntoResponse, Response as AxumResponse}; -use ciborium::Value as Data; - -pub fn req_ws(msg: Message) -> Result { - match msg { - Message::Text(val) => { - surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into() - } - Message::Binary(val) => ciborium::from_reader::(&mut val.as_slice()) - .map_err(|_| RpcError::ParseError) - .map(Cbor)? - .try_into(), - _ => Err(RpcError::InvalidRequest), - } -} - -pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { - // Convert the response into a value - let val: Cbor = res.into_value().try_into()?; - // Create a new vector for encoding output - let mut res = Vec::new(); - // Serialize the value into CBOR binary data - ciborium::into_writer(&val.0, &mut res).unwrap(); - // Return the message length, and message as binary - Ok((res.len(), Message::Binary(res))) -} - -pub fn req_http(body: Bytes) -> Result { - let val: Vec = body.into(); - ciborium::from_reader::(&mut val.as_slice()) - .map_err(|_| RpcError::ParseError) - .map(Cbor)? - .try_into() -} - -pub fn res_http(res: Response) -> Result { - // Convert the response into a value - let val: Cbor = res.into_value().try_into()?; - // Create a new vector for encoding output - let mut res = Vec::new(); - // Serialize the value into CBOR binary data - ciborium::into_writer(&val.0, &mut res).unwrap(); - // Return the message length, and message as binary - Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::ApplicationCbor))], res).into_response()) -} diff --git a/src/rpc/format/json.rs b/src/rpc/format/json.rs deleted file mode 100644 index 2099c9b4..00000000 --- a/src/rpc/format/json.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::rpc::request::Request; -use crate::rpc::response::Response; -use axum::extract::ws::Message; -use axum::response::IntoResponse; -use axum::response::Response as AxumResponse; -use bytes::Bytes; -use http::StatusCode; -use surrealdb::rpc::RpcError; -use surrealdb::sql; - -pub fn req_ws(msg: Message) -> Result { - match msg { - Message::Text(val) => { - surrealdb::syn::value_legacy_strand(&val).map_err(|_| RpcError::ParseError)?.try_into() - } - _ => Err(RpcError::InvalidRequest), - } -} - -pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { - // Convert the response into simplified JSON - let val = res.into_json(); - // Serialize the response with simplified type information - let res = serde_json::to_string(&val).unwrap(); - // Return the message length, and message as binary - Ok((res.len(), Message::Text(res))) -} - -pub fn req_http(val: &Bytes) -> Result { - sql::value(std::str::from_utf8(val).or(Err(RpcError::ParseError))?) - .or(Err(RpcError::ParseError))? - .try_into() -} - -pub fn res_http(res: Response) -> Result { - // Convert the response into simplified JSON - let val = res.into_json(); - // Serialize the response with simplified type information - let res = serde_json::to_string(&val).unwrap(); - // Return the message length, and message as binary - Ok((StatusCode::OK, res).into_response()) -} diff --git a/src/rpc/format/mod.rs b/src/rpc/format/mod.rs deleted file mode 100644 index a562afc9..00000000 --- a/src/rpc/format/mod.rs +++ /dev/null @@ -1,129 +0,0 @@ -mod bincode; -pub mod cbor; -mod json; -pub mod msgpack; -mod revision; - -use crate::net::headers::{Accept, ContentType}; -use crate::rpc::failure::Failure; -use crate::rpc::request::Request; -use crate::rpc::response::Response; -use axum::extract::ws::Message; -use axum::response::Response as AxumResponse; -use bytes::Bytes; -use surrealdb::rpc::RpcError; - -pub const PROTOCOLS: [&str; 5] = [ - "json", // For basic JSON serialisation - "cbor", // For basic CBOR serialisation - "msgpack", // For basic Msgpack serialisation - "bincode", // For full internal serialisation - "revision", // For full versioned serialisation -]; - -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub enum Format { - None, // No format is specified yet - Json, // For basic JSON serialisation - Cbor, // For basic CBOR serialisation - Msgpack, // For basic Msgpack serialisation - Bincode, // For full internal serialisation - Revision, // For full versioned serialisation - Unsupported, // Unsupported format -} - -impl From<&Accept> for Format { - fn from(value: &Accept) -> Self { - match value { - Accept::TextPlain => Format::None, - Accept::ApplicationJson => Format::Json, - Accept::ApplicationCbor => Format::Cbor, - Accept::ApplicationPack => Format::Msgpack, - Accept::ApplicationOctetStream => Format::Unsupported, - Accept::Surrealdb => Format::Bincode, - } - } -} - -impl From<&ContentType> for Format { - fn from(value: &ContentType) -> Self { - match value { - ContentType::TextPlain => Format::None, - ContentType::ApplicationJson => Format::Json, - ContentType::ApplicationCbor => Format::Cbor, - ContentType::ApplicationPack => Format::Msgpack, - ContentType::ApplicationOctetStream => Format::Unsupported, - ContentType::Surrealdb => Format::Bincode, - } - } -} - -impl From<&str> for Format { - fn from(v: &str) -> Self { - match v { - s if s == PROTOCOLS[0] => Format::Json, - s if s == PROTOCOLS[1] => Format::Cbor, - s if s == PROTOCOLS[2] => Format::Msgpack, - s if s == PROTOCOLS[3] => Format::Bincode, - s if s == PROTOCOLS[4] => Format::Revision, - _ => Format::None, - } - } -} - -impl Format { - /// Check if this format has been set - pub fn is_none(&self) -> bool { - matches!(self, Format::None) - } - /// Process a request using the specified format - pub fn req_ws(&self, msg: Message) -> Result { - match self { - Self::None => unreachable!(), // We should never arrive at this code - Self::Unsupported => unreachable!(), // We should never arrive at this code - Self::Json => json::req_ws(msg), - Self::Cbor => cbor::req_ws(msg), - Self::Msgpack => msgpack::req_ws(msg), - Self::Bincode => bincode::req_ws(msg), - Self::Revision => revision::req_ws(msg), - } - .map_err(Into::into) - } - /// Process a response using the specified format - pub fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure> { - match self { - Self::None => unreachable!(), // We should never arrive at this code - Self::Unsupported => unreachable!(), // We should never arrive at this code - Self::Json => json::res_ws(res), - Self::Cbor => cbor::res_ws(res), - Self::Msgpack => msgpack::res_ws(res), - Self::Bincode => bincode::res_ws(res), - Self::Revision => revision::res_ws(res), - } - .map_err(Into::into) - } - /// Process a request using the specified format - pub fn req_http(&self, body: Bytes) -> Result { - match self { - Self::None => unreachable!(), // We should never arrive at this code - Self::Unsupported => unreachable!(), // We should never arrive at this code - Self::Json => json::req_http(&body), - Self::Cbor => cbor::req_http(body), - Self::Msgpack => msgpack::req_http(body), - Self::Bincode => bincode::req_http(&body), - Self::Revision => revision::req_http(body), - } - } - /// Process a response using the specified format - pub fn res_http(&self, res: Response) -> Result { - match self { - Self::None => unreachable!(), // We should never arrive at this code - Self::Unsupported => unreachable!(), // We should never arrive at this code - Self::Json => json::res_http(res), - Self::Cbor => cbor::res_http(res), - Self::Msgpack => msgpack::res_http(res), - Self::Bincode => bincode::res_http(res), - Self::Revision => revision::res_http(res), - } - } -} diff --git a/src/rpc/format/msgpack/mod.rs b/src/rpc/format/msgpack/mod.rs deleted file mode 100644 index 66a54608..00000000 --- a/src/rpc/format/msgpack/mod.rs +++ /dev/null @@ -1,55 +0,0 @@ -mod convert; - -use bytes::Bytes; -pub use convert::Pack; -use http::header::CONTENT_TYPE; -use http::HeaderValue; -use surrealdb::rpc::RpcError; - -use crate::net::headers::ContentType; -use crate::rpc::request::Request; -use crate::rpc::response::Response; -use axum::extract::ws::Message; -use axum::response::{IntoResponse, Response as AxumResponse}; - -pub fn req_ws(msg: Message) -> Result { - match msg { - Message::Text(val) => { - surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into() - } - Message::Binary(val) => rmpv::decode::read_value(&mut val.as_slice()) - .map_err(|_| RpcError::ParseError) - .map(Pack)? - .try_into(), - _ => Err(RpcError::InvalidRequest), - } -} - -pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { - // Convert the response into a value - let val: Pack = res.into_value().try_into()?; - // Create a new vector for encoding output - let mut res = Vec::new(); - // Serialize the value into MsgPack binary data - rmpv::encode::write_value(&mut res, &val.0).unwrap(); - // Return the message length, and message as binary - Ok((res.len(), Message::Binary(res))) -} -pub fn req_http(body: Bytes) -> Result { - let val: Vec = body.into(); - rmpv::decode::read_value(&mut val.as_slice()) - .map_err(|_| RpcError::ParseError) - .map(Pack)? - .try_into() -} - -pub fn res_http(res: Response) -> Result { - // Convert the response into a value - let val: Pack = res.into_value().try_into()?; - // Create a new vector for encoding output - let mut res = Vec::new(); - // Serialize the value into MsgPack binary data - rmpv::encode::write_value(&mut res, &val.0).unwrap(); - // Return the message length, and message as binary - Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::ApplicationPack))], res).into_response()) -} diff --git a/src/rpc/format/revision.rs b/src/rpc/format/revision.rs deleted file mode 100644 index bbc650cf..00000000 --- a/src/rpc/format/revision.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::net::headers::ContentType; -use crate::rpc::request::Request; -use crate::rpc::response::Response; -use axum::extract::ws::Message; -use axum::response::{IntoResponse, Response as AxumResponse}; -use bytes::Bytes; -use http::header::CONTENT_TYPE; -use http::HeaderValue; -use revision::Revisioned; -use surrealdb::rpc::RpcError; -use surrealdb::sql::Value; - -pub fn req_ws(msg: Message) -> Result { - match msg { - Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice()) - .map_err(|_| RpcError::ParseError)? - .try_into(), - _ => Err(RpcError::InvalidRequest), - } -} - -pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> { - // Serialize the response with full internal type information - let mut buf = Vec::new(); - res.serialize_revisioned(&mut buf).unwrap(); - // Return the message length, and message as binary - Ok((buf.len(), Message::Binary(buf))) -} - -pub fn req_http(body: Bytes) -> Result { - let val: Vec = body.into(); - Value::deserialize_revisioned(&mut val.as_slice()).map_err(|_| RpcError::ParseError)?.try_into() -} - -pub fn res_http(res: Response) -> Result { - // Serialize the response with full internal type information - let mut buf = Vec::new(); - res.serialize_revisioned(&mut buf).unwrap(); - // Return the message length, and message as binary - // TODO: Check what this header should be, new header needed for revisioned - Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], buf).into_response()) -} diff --git a/src/rpc/mod.rs b/src/rpc/mod.rs index 0e417411..adac30e0 100644 --- a/src/rpc/mod.rs +++ b/src/rpc/mod.rs @@ -3,7 +3,6 @@ pub mod connection; pub mod failure; pub mod format; pub mod post_context; -pub mod request; pub mod response; use crate::dbs::DB; diff --git a/src/rpc/response.rs b/src/rpc/response.rs index 3bf62e25..07cbc32b 100644 --- a/src/rpc/response.rs +++ b/src/rpc/response.rs @@ -1,13 +1,13 @@ use crate::rpc::failure::Failure; -use crate::rpc::format::Format; +use crate::rpc::format::WsFormat; use crate::telemetry::metrics::ws::record_rpc; use axum::extract::ws::Message; use opentelemetry::Context as TelemetryContext; use revision::revisioned; use serde::Serialize; -use serde_json::Value as Json; use std::sync::Arc; use surrealdb::channel::Sender; +use surrealdb::rpc::format::Format; use surrealdb::rpc::Data; use surrealdb::sql::Value; use tracing::Span; @@ -20,12 +20,6 @@ pub struct Response { } impl Response { - /// Convert and simplify the value into JSON - #[inline] - pub fn into_json(self) -> Json { - Json::from(self.into_value()) - } - #[inline] pub fn into_value(self) -> Value { let mut value = match self.result { @@ -68,6 +62,12 @@ impl Response { } } +impl From for Value { + fn from(value: Response) -> Self { + value.into_value() + } +} + /// Create a JSON RPC result response pub fn success>(id: Option, data: T) -> Response { Response { diff --git a/tests/common/socket.rs b/tests/common/socket.rs index 7d76c4d8..13d71a38 100644 --- a/tests/common/socket.rs +++ b/tests/common/socket.rs @@ -108,9 +108,7 @@ impl Socket { match format { Format::Json => Ok(Message::Text(serde_json::to_string(message)?)), Format::Cbor => { - pub mod try_from_impls { - include!("../../src/rpc/format/cbor/convert.rs"); - } + use surrealdb::rpc::format::cbor::Cbor; // For tests we need to convert the serde_json::Value // to a SurrealQL value, so that record ids, uuids, // datetimes, and durations are stored properly. @@ -119,7 +117,7 @@ impl Socket { // Then we parse the JSON in to SurrealQL. let surrealql = surrealdb::syn::value_legacy_strand(&json)?; // Then we convert the SurrealQL in to CBOR. - let cbor = try_from_impls::Cbor::try_from(surrealql)?; + let cbor = Cbor::try_from(surrealql)?; // Then serialize the CBOR as binary data. let mut output = Vec::new(); ciborium::into_writer(&cbor.0, &mut output).unwrap(); @@ -127,9 +125,7 @@ impl Socket { Ok(Message::Binary(output)) } Format::Pack => { - pub mod try_from_impls { - include!("../../src/rpc/format/msgpack/convert.rs"); - } + use surrealdb::rpc::format::msgpack::Pack; // For tests we need to convert the serde_json::Value // to a SurrealQL value, so that record ids, uuids, // datetimes, and durations are stored properly. @@ -138,7 +134,7 @@ impl Socket { // Then we parse the JSON in to SurrealQL. let surrealql = surrealdb::syn::value_legacy_strand(&json)?; // Then we convert the SurrealQL in to MessagePack. - let pack = try_from_impls::Pack::try_from(surrealql)?; + let pack = Pack::try_from(surrealql)?; // Then serialize the MessagePack as binary data. let mut output = Vec::new(); rmpv::encode::write_value(&mut output, &pack.0).unwrap(); @@ -165,15 +161,13 @@ impl Socket { debug!("Response {msg:?}"); match format { Format::Cbor => { - pub mod try_from_impls { - include!("../../src/rpc/format/cbor/convert.rs"); - } + use surrealdb::rpc::format::cbor::Cbor; // For tests we need to convert the binary data to // a serde_json::Value so that test assertions work. // First of all we deserialize the CBOR data. let msg: ciborium::Value = ciborium::from_reader(&mut msg.as_slice())?; // Then we convert it to a SurrealQL Value. - let msg: Value = try_from_impls::Cbor(msg).try_into()?; + let msg: Value = Cbor(msg).try_into()?; // Then we convert the SurrealQL to JSON. let msg = msg.into_json(); // Then output the response. @@ -181,15 +175,13 @@ impl Socket { Ok(Some(msg)) } Format::Pack => { - pub mod try_from_impls { - include!("../../src/rpc/format/msgpack/convert.rs"); - } + use surrealdb::rpc::format::msgpack::Pack; // For tests we need to convert the binary data to // a serde_json::Value so that test assertions work. // First of all we deserialize the MessagePack data. let msg: rmpv::Value = rmpv::decode::read_value(&mut msg.as_slice())?; // Then we convert it to a SurrealQL Value. - let msg: Value = try_from_impls::Pack(msg).try_into()?; + let msg: Value = Pack(msg).try_into()?; // Then we convert the SurrealQL to JSON. let msg = msg.into_json(); // Then output the response.