refactor rpc code (#3790)

Co-authored-by: Micha de Vries <micha@devrie.sh>
This commit is contained in:
Raphael Darley 2024-04-17 20:56:08 +02:00 committed by GitHub
parent b2b08e0ad1
commit cd653bdf7e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 319 additions and 414 deletions

2
Cargo.lock generated
View file

@ -5910,6 +5910,7 @@ dependencies = [
"bytes", "bytes",
"cedar-policy", "cedar-policy",
"chrono", "chrono",
"ciborium",
"criterion", "criterion",
"deunicode", "deunicode",
"dmp", "dmp",
@ -5947,6 +5948,7 @@ dependencies = [
"reqwest", "reqwest",
"revision", "revision",
"ring 0.17.8", "ring 0.17.8",
"rmpv",
"roaring", "roaring",
"rocksdb", "rocksdb",
"rquickjs", "rquickjs",

View file

@ -67,6 +67,7 @@ base64 = "0.21.5"
bcrypt = "0.15.0" bcrypt = "0.15.0"
bincode = "1.3.3" bincode = "1.3.3"
bytes = "1.5.0" bytes = "1.5.0"
ciborium = "0.2.1"
cedar-policy = "2.4.2" cedar-policy = "2.4.2"
channel = { version = "1.9.0", package = "async-channel" } channel = { version = "1.9.0", package = "async-channel" }
chrono = { version = "0.4.31", features = ["serde"] } chrono = { version = "0.4.31", features = ["serde"] }
@ -122,6 +123,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [
"multipart", "multipart",
], optional = true } ], optional = true }
revision = { version = "0.7.0", features = ["chrono", "geo", "roaring", "regex", "rust_decimal", "uuid"] } revision = { version = "0.7.0", features = ["chrono", "geo", "roaring", "regex", "rust_decimal", "uuid"] }
rmpv = "1.0.1"
roaring = { version = "0.10.2", features = ["serde"] } roaring = { version = "0.10.2", features = ["serde"] }
rocksdb = { version = "0.21.0", features = ["lz4", "snappy"], optional = true } rocksdb = { version = "0.21.0", features = ["lz4", "snappy"], optional = true }
rust_decimal = { version = "1.33.1", features = ["maths", "serde-str"] } rust_decimal = { version = "1.33.1", features = ["maths", "serde-str"] }

View file

@ -103,7 +103,7 @@ pub async fn del(file: &str) -> Result<(), Error> {
} }
/// Hashes the bytes of a file to a string for the storage of a file. /// Hashes the bytes of a file to a string for the storage of a file.
pub fn hash(data: &Vec<u8>) -> String { pub fn hash(data: &[u8]) -> String {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(data); hasher.update(data);
let result = hasher.finalize(); let result = hasher.finalize();

View file

@ -0,0 +1,14 @@
use crate::rpc::format::ResTrait;
use crate::rpc::request::Request;
use crate::rpc::RpcError;
use crate::sql::serde::{deserialize, serialize};
use crate::sql::Value;
pub fn req(val: &[u8]) -> Result<Request, RpcError> {
deserialize::<Value>(val).map_err(|_| RpcError::ParseError)?.try_into()
}
pub fn res(res: impl ResTrait) -> Result<Vec<u8>, RpcError> {
// Serialize the response with full internal type information
Ok(serialize(&res).unwrap())
}

View file

@ -3,14 +3,15 @@ use geo::{LineString, Point, Polygon};
use geo_types::{MultiLineString, MultiPoint, MultiPolygon}; use geo_types::{MultiLineString, MultiPoint, MultiPolygon};
use std::iter::once; use std::iter::once;
use std::ops::Deref; use std::ops::Deref;
use surrealdb::sql::Datetime;
use surrealdb::sql::Duration; use crate::sql::Datetime;
use surrealdb::sql::Geometry; use crate::sql::Duration;
use surrealdb::sql::Id; use crate::sql::Geometry;
use surrealdb::sql::Number; use crate::sql::Id;
use surrealdb::sql::Thing; use crate::sql::Number;
use surrealdb::sql::Uuid; use crate::sql::Thing;
use surrealdb::sql::Value; use crate::sql::Uuid;
use crate::sql::Value;
// Tags from the spec - https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml // Tags from the spec - https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml
const TAG_SPEC_DATETIME: u64 = 0; const TAG_SPEC_DATETIME: u64 = 0;
@ -327,7 +328,6 @@ impl TryFrom<Value> for Cbor {
Number::Decimal(v) => { Number::Decimal(v) => {
Ok(Cbor(Data::Tag(TAG_STRING_DECIMAL, Box::new(Data::Text(v.to_string()))))) Ok(Cbor(Data::Tag(TAG_STRING_DECIMAL, Box::new(Data::Text(v.to_string())))))
} }
_ => Err("Found an unsupported Number type being converted to CBOR"),
}, },
Value::Strand(v) => Ok(Cbor(Data::Text(v.0))), Value::Strand(v) => Ok(Cbor(Data::Text(v.0))),
Value::Duration(v) => { Value::Duration(v) => {
@ -390,7 +390,6 @@ impl TryFrom<Value> for Cbor {
Id::Generate(_) => { Id::Generate(_) => {
return Err("Cannot encode an ungenerated Record ID into CBOR") return Err("Cannot encode an ungenerated Record ID into CBOR")
} }
_ => return Err("Found an unsupported Id type being converted to CBOR"),
}, },
])), ])),
))), ))),
@ -459,6 +458,5 @@ fn encode_geometry(v: Geometry) -> Result<Data, &'static str> {
Ok(Data::Tag(TAG_GEOMETRY_COLLECTION, Box::new(Data::Array(data)))) Ok(Data::Tag(TAG_GEOMETRY_COLLECTION, Box::new(Data::Array(data))))
} }
_ => Err("Found an unsupported Geometry type being converted to CBOR"),
} }
} }

View file

@ -0,0 +1,29 @@
mod convert;
pub use convert::Cbor;
use crate::rpc::request::Request;
use crate::rpc::RpcError;
use crate::sql::Value;
use ciborium::Value as Data;
use super::ResTrait;
pub fn req(val: Vec<u8>) -> Result<Request, RpcError> {
ciborium::from_reader::<Data, _>(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Cbor)?
.try_into()
}
pub fn res(res: impl ResTrait) -> Result<Vec<u8>, RpcError> {
// Convert the response into a value
let val: Value = res.into();
let val: Cbor = val.try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into CBOR binary data
ciborium::into_writer(&val.0, &mut res).unwrap();
// Return the message length, and message as binary
Ok(res)
}

View file

@ -0,0 +1,22 @@
use crate::rpc::request::Request;
use crate::rpc::RpcError;
use crate::sql::Value;
use crate::syn;
use super::ResTrait;
pub fn req(val: &[u8]) -> Result<Request, RpcError> {
syn::value_legacy_strand(std::str::from_utf8(val).or(Err(RpcError::ParseError))?)
.or(Err(RpcError::ParseError))?
.try_into()
}
pub fn res(res: impl ResTrait) -> Result<Vec<u8>, RpcError> {
// Convert the response into simplified JSON
let val: Value = res.into();
let val = val.into_json();
// Serialize the response with simplified type information
let res = serde_json::to_string(&val).unwrap();
// Return the message length, and message as binary
Ok(res.into())
}

View file

@ -0,0 +1,82 @@
mod bincode;
pub mod cbor;
mod json;
pub mod msgpack;
mod revision;
use ::revision::Revisioned;
use serde::Serialize;
use super::{request::Request, RpcError};
use crate::sql::Value;
pub const PROTOCOLS: [&str; 5] = [
"json", // For basic JSON serialisation
"cbor", // For basic CBOR serialisation
"msgpack", // For basic Msgpack serialisation
"bincode", // For full internal serialisation
"revision", // For full versioned serialisation
];
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[non_exhaustive]
pub enum Format {
None, // No format is specified yet
Json, // For basic JSON serialisation
Cbor, // For basic CBOR serialisation
Msgpack, // For basic Msgpack serialisation
Bincode, // For full internal serialisation
Revision, // For full versioned serialisation
Unsupported, // Unsupported format
}
pub trait ResTrait: Serialize + Into<Value> + Revisioned {}
impl<T: Serialize + Into<Value> + Revisioned> ResTrait for T {}
impl From<&str> for Format {
fn from(v: &str) -> Self {
match v {
s if s == PROTOCOLS[0] => Format::Json,
s if s == PROTOCOLS[1] => Format::Cbor,
s if s == PROTOCOLS[2] => Format::Msgpack,
s if s == PROTOCOLS[3] => Format::Bincode,
s if s == PROTOCOLS[4] => Format::Revision,
_ => Format::None,
}
}
}
impl Format {
/// Check if this format has been set
pub fn is_none(&self) -> bool {
matches!(self, Format::None)
}
/// Process a request using the specified format
pub fn req(&self, val: impl Into<Vec<u8>>) -> Result<Request, RpcError> {
let val = val.into();
match self {
Self::None => Err(RpcError::InvalidRequest),
Self::Unsupported => Err(RpcError::InvalidRequest),
Self::Json => json::req(&val),
Self::Cbor => cbor::req(val),
Self::Msgpack => msgpack::req(val),
Self::Bincode => bincode::req(&val),
Self::Revision => revision::req(val),
}
.map_err(Into::into)
}
/// Process a response using the specified format
pub fn res(&self, val: impl ResTrait) -> Result<Vec<u8>, RpcError> {
match self {
Self::None => Err(RpcError::InvalidRequest),
Self::Unsupported => Err(RpcError::InvalidRequest),
Self::Json => json::res(val),
Self::Cbor => cbor::res(val),
Self::Msgpack => msgpack::res(val),
Self::Bincode => bincode::res(val),
Self::Revision => revision::res(val),
}
}
}

View file

@ -1,10 +1,10 @@
use crate::sql::Datetime;
use crate::sql::Duration;
use crate::sql::Number;
use crate::sql::Thing;
use crate::sql::Uuid;
use crate::sql::Value;
use rmpv::Value as Data; use rmpv::Value as Data;
use surrealdb::sql::Datetime;
use surrealdb::sql::Duration;
use surrealdb::sql::Number;
use surrealdb::sql::Thing;
use surrealdb::sql::Uuid;
use surrealdb::sql::Value;
const TAG_NONE: i8 = 1; const TAG_NONE: i8 = 1;
const TAG_UUID: i8 = 2; const TAG_UUID: i8 = 2;
@ -114,6 +114,7 @@ impl TryFrom<Value> for Pack {
Number::Decimal(v) => { Number::Decimal(v) => {
Ok(Pack(Data::Ext(TAG_DECIMAL, v.to_string().as_bytes().to_vec()))) Ok(Pack(Data::Ext(TAG_DECIMAL, v.to_string().as_bytes().to_vec())))
} }
#[allow(unreachable_patterns)]
_ => unreachable!(), _ => unreachable!(),
}, },
Value::Strand(v) => Ok(Pack(Data::String(v.0.into()))), Value::Strand(v) => Ok(Pack(Data::String(v.0.into()))),

View file

@ -0,0 +1,26 @@
mod convert;
use crate::rpc::format::ResTrait;
use crate::rpc::RpcError;
pub use convert::Pack;
use crate::rpc::request::Request;
use crate::sql::Value;
pub fn req(val: Vec<u8>) -> Result<Request, RpcError> {
rmpv::decode::read_value(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Pack)?
.try_into()
}
pub fn res(res: impl ResTrait) -> Result<Vec<u8>, RpcError> {
// Convert the response into a value
let val: Value = res.into();
let val: Pack = val.try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into MsgPack binary data
rmpv::encode::write_value(&mut res, &val.0).unwrap();
Ok(res)
}

View file

@ -0,0 +1,16 @@
use crate::rpc::format::ResTrait;
use crate::rpc::request::Request;
use crate::rpc::RpcError;
use crate::sql::Value;
use revision::Revisioned;
pub fn req(val: Vec<u8>) -> Result<Request, RpcError> {
Value::deserialize_revisioned(&mut val.as_slice()).map_err(|_| RpcError::ParseError)?.try_into()
}
pub fn res(res: impl ResTrait) -> Result<Vec<u8>, RpcError> {
// Serialize the response with full internal type information
let mut buf = Vec::new();
res.serialize_revisioned(&mut buf).unwrap();
Ok(buf)
}

View file

@ -1,6 +1,8 @@
pub mod args; pub mod args;
pub mod basic_context; pub mod basic_context;
pub mod format;
pub mod method; pub mod method;
pub mod request;
mod response; mod response;
pub mod rpc_context; pub mod rpc_context;
mod rpc_error; mod rpc_error;

View file

@ -1,9 +1,9 @@
use crate::rpc::format::cbor::Cbor; use crate::rpc::format::cbor::Cbor;
use crate::rpc::format::msgpack::Pack; use crate::rpc::format::msgpack::Pack;
use crate::rpc::RpcError;
use crate::sql::Part;
use crate::sql::{Array, Value};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use surrealdb::rpc::RpcError;
use surrealdb::sql::Part;
use surrealdb::sql::{Array, Value};
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]); pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]); pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);

View file

@ -5,8 +5,7 @@ use crate::cnf;
use crate::dbs::DB; use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::rpc::connection::Connection; use crate::rpc::connection::Connection;
use crate::rpc::format::Format; use crate::rpc::format::HttpFormat;
use crate::rpc::format::PROTOCOLS;
use crate::rpc::post_context::PostRpcContext; use crate::rpc::post_context::PostRpcContext;
use crate::rpc::response::IntoRpcResponse; use crate::rpc::response::IntoRpcResponse;
use crate::rpc::WEBSOCKETS; use crate::rpc::WEBSOCKETS;
@ -22,6 +21,8 @@ use bytes::Bytes;
use http::HeaderValue; use http::HeaderValue;
use http_body::Body as HttpBody; use http_body::Body as HttpBody;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::rpc::format::Format;
use surrealdb::rpc::format::PROTOCOLS;
use surrealdb::rpc::method::Method; use surrealdb::rpc::method::Method;
use tower_http::request_id::RequestId; use tower_http::request_id::RequestId;
use uuid::Uuid; use uuid::Uuid;

View file

@ -3,7 +3,7 @@ use crate::cnf::{
}; };
use crate::dbs::DB; use crate::dbs::DB;
use crate::rpc::failure::Failure; use crate::rpc::failure::Failure;
use crate::rpc::format::Format; use crate::rpc::format::WsFormat;
use crate::rpc::response::{failure, IntoRpcResponse}; use crate::rpc::response::{failure, IntoRpcResponse};
use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS}; use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
use crate::telemetry; use crate::telemetry;
@ -20,6 +20,7 @@ use surrealdb::channel::{self, Receiver, Sender};
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore; use surrealdb::kvs::Datastore;
use surrealdb::rpc::args::Take; use surrealdb::rpc::args::Take;
use surrealdb::rpc::format::Format;
use surrealdb::rpc::method::Method; use surrealdb::rpc::method::Method;
use surrealdb::rpc::RpcContext; use surrealdb::rpc::RpcContext;
use surrealdb::rpc::{Data, RpcError}; use surrealdb::rpc::{Data, RpcError};

83
src/rpc/format.rs Normal file
View file

@ -0,0 +1,83 @@
use crate::net::headers::{Accept, ContentType};
use crate::rpc::failure::Failure;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use axum::response::IntoResponse;
use axum::response::Response as AxumResponse;
use bytes::Bytes;
use http::StatusCode;
use surrealdb::rpc::format::Format;
use surrealdb::rpc::request::Request;
use surrealdb::rpc::RpcError;
impl From<&Accept> for Format {
fn from(value: &Accept) -> Self {
match value {
Accept::TextPlain => Format::None,
Accept::ApplicationJson => Format::Json,
Accept::ApplicationCbor => Format::Cbor,
Accept::ApplicationPack => Format::Msgpack,
Accept::ApplicationOctetStream => Format::Unsupported,
Accept::Surrealdb => Format::Bincode,
}
}
}
impl From<&ContentType> for Format {
fn from(value: &ContentType) -> Self {
match value {
ContentType::TextPlain => Format::None,
ContentType::ApplicationJson => Format::Json,
ContentType::ApplicationCbor => Format::Cbor,
ContentType::ApplicationPack => Format::Msgpack,
ContentType::ApplicationOctetStream => Format::Unsupported,
ContentType::Surrealdb => Format::Bincode,
}
}
}
pub trait WsFormat {
fn req_ws(&self, msg: Message) -> Result<Request, Failure>;
fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure>;
}
impl WsFormat for Format {
fn req_ws(&self, msg: Message) -> Result<Request, Failure> {
let val = msg.into_data();
self.req(val).map_err(Into::into)
}
fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure> {
let res = self.res(res).map_err(Failure::from)?;
if matches!(self, Format::Json) {
// If this has significant performance overhead it could be replaced with unsafe { String::from_utf8_unchecked(res) }
// This would be safe as in the case of JSON res come from a call to Into::<Vec<u8>> for String
Ok((res.len(), Message::Text(String::from_utf8(res).unwrap())))
} else {
Ok((res.len(), Message::Binary(res)))
}
}
}
pub trait HttpFormat {
fn req_http(&self, body: Bytes) -> Result<Request, RpcError>;
fn res_http(&self, res: Response) -> Result<AxumResponse, RpcError>;
}
impl HttpFormat for Format {
fn req_http(&self, body: Bytes) -> Result<Request, RpcError> {
self.req(body).map_err(Into::into)
}
fn res_http(&self, res: Response) -> Result<AxumResponse, RpcError> {
let val = res.into_value();
let res = self.res(val)?;
if matches!(self, Format::Json) {
// If this has significant performance overhead it could be replaced with unsafe { String::from_utf8_unchecked(res) }
// This would be safe as in the case of JSON res come from a call to Into::<Vec<u8>> for String
Ok((StatusCode::OK, String::from_utf8(res).unwrap()).into_response())
} else {
Ok((StatusCode::OK, res).into_response())
}
}
}

View file

@ -1,40 +0,0 @@
use crate::net::headers::ContentType;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use axum::response::IntoResponse;
use axum::response::Response as AxumResponse;
use bytes::Bytes;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use surrealdb::rpc::RpcError;
use surrealdb::sql::serde::deserialize;
use surrealdb::sql::Value;
pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg {
Message::Binary(val) => {
deserialize::<Value>(&val).map_err(|_| RpcError::ParseError)?.try_into()
}
_ => Err(RpcError::InvalidRequest),
}
}
pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Serialize the response with full internal type information
let res = surrealdb::sql::serde::serialize(&res).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}
pub fn req_http(val: &Bytes) -> Result<Request, RpcError> {
deserialize::<Value>(val).map_err(|_| RpcError::ParseError)?.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Serialize the response with full internal type information
let res = surrealdb::sql::serde::serialize(&res).unwrap();
// Return the message length, and message as binary
// TODO: Check what this header should be, I'm being consistent with /sql
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], res).into_response())
}

View file

@ -1,57 +0,0 @@
mod convert;
use bytes::Bytes;
pub use convert::Cbor;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use surrealdb::rpc::RpcError;
use crate::net::headers::ContentType;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use axum::response::{IntoResponse, Response as AxumResponse};
use ciborium::Value as Data;
pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into()
}
Message::Binary(val) => ciborium::from_reader::<Data, _>(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Cbor)?
.try_into(),
_ => Err(RpcError::InvalidRequest),
}
}
pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Convert the response into a value
let val: Cbor = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into CBOR binary data
ciborium::into_writer(&val.0, &mut res).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}
pub fn req_http(body: Bytes) -> Result<Request, RpcError> {
let val: Vec<u8> = body.into();
ciborium::from_reader::<Data, _>(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Cbor)?
.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Convert the response into a value
let val: Cbor = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into CBOR binary data
ciborium::into_writer(&val.0, &mut res).unwrap();
// Return the message length, and message as binary
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::ApplicationCbor))], res).into_response())
}

View file

@ -1,42 +0,0 @@
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use axum::response::IntoResponse;
use axum::response::Response as AxumResponse;
use bytes::Bytes;
use http::StatusCode;
use surrealdb::rpc::RpcError;
use surrealdb::sql;
pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg {
Message::Text(val) => {
surrealdb::syn::value_legacy_strand(&val).map_err(|_| RpcError::ParseError)?.try_into()
}
_ => Err(RpcError::InvalidRequest),
}
}
pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Convert the response into simplified JSON
let val = res.into_json();
// Serialize the response with simplified type information
let res = serde_json::to_string(&val).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Text(res)))
}
pub fn req_http(val: &Bytes) -> Result<Request, RpcError> {
sql::value(std::str::from_utf8(val).or(Err(RpcError::ParseError))?)
.or(Err(RpcError::ParseError))?
.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Convert the response into simplified JSON
let val = res.into_json();
// Serialize the response with simplified type information
let res = serde_json::to_string(&val).unwrap();
// Return the message length, and message as binary
Ok((StatusCode::OK, res).into_response())
}

View file

@ -1,129 +0,0 @@
mod bincode;
pub mod cbor;
mod json;
pub mod msgpack;
mod revision;
use crate::net::headers::{Accept, ContentType};
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use axum::response::Response as AxumResponse;
use bytes::Bytes;
use surrealdb::rpc::RpcError;
pub const PROTOCOLS: [&str; 5] = [
"json", // For basic JSON serialisation
"cbor", // For basic CBOR serialisation
"msgpack", // For basic Msgpack serialisation
"bincode", // For full internal serialisation
"revision", // For full versioned serialisation
];
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Format {
None, // No format is specified yet
Json, // For basic JSON serialisation
Cbor, // For basic CBOR serialisation
Msgpack, // For basic Msgpack serialisation
Bincode, // For full internal serialisation
Revision, // For full versioned serialisation
Unsupported, // Unsupported format
}
impl From<&Accept> for Format {
fn from(value: &Accept) -> Self {
match value {
Accept::TextPlain => Format::None,
Accept::ApplicationJson => Format::Json,
Accept::ApplicationCbor => Format::Cbor,
Accept::ApplicationPack => Format::Msgpack,
Accept::ApplicationOctetStream => Format::Unsupported,
Accept::Surrealdb => Format::Bincode,
}
}
}
impl From<&ContentType> for Format {
fn from(value: &ContentType) -> Self {
match value {
ContentType::TextPlain => Format::None,
ContentType::ApplicationJson => Format::Json,
ContentType::ApplicationCbor => Format::Cbor,
ContentType::ApplicationPack => Format::Msgpack,
ContentType::ApplicationOctetStream => Format::Unsupported,
ContentType::Surrealdb => Format::Bincode,
}
}
}
impl From<&str> for Format {
fn from(v: &str) -> Self {
match v {
s if s == PROTOCOLS[0] => Format::Json,
s if s == PROTOCOLS[1] => Format::Cbor,
s if s == PROTOCOLS[2] => Format::Msgpack,
s if s == PROTOCOLS[3] => Format::Bincode,
s if s == PROTOCOLS[4] => Format::Revision,
_ => Format::None,
}
}
}
impl Format {
/// Check if this format has been set
pub fn is_none(&self) -> bool {
matches!(self, Format::None)
}
/// Process a request using the specified format
pub fn req_ws(&self, msg: Message) -> Result<Request, Failure> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Json => json::req_ws(msg),
Self::Cbor => cbor::req_ws(msg),
Self::Msgpack => msgpack::req_ws(msg),
Self::Bincode => bincode::req_ws(msg),
Self::Revision => revision::req_ws(msg),
}
.map_err(Into::into)
}
/// Process a response using the specified format
pub fn res_ws(&self, res: Response) -> Result<(usize, Message), Failure> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Json => json::res_ws(res),
Self::Cbor => cbor::res_ws(res),
Self::Msgpack => msgpack::res_ws(res),
Self::Bincode => bincode::res_ws(res),
Self::Revision => revision::res_ws(res),
}
.map_err(Into::into)
}
/// Process a request using the specified format
pub fn req_http(&self, body: Bytes) -> Result<Request, RpcError> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Json => json::req_http(&body),
Self::Cbor => cbor::req_http(body),
Self::Msgpack => msgpack::req_http(body),
Self::Bincode => bincode::req_http(&body),
Self::Revision => revision::req_http(body),
}
}
/// Process a response using the specified format
pub fn res_http(&self, res: Response) -> Result<AxumResponse, RpcError> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Unsupported => unreachable!(), // We should never arrive at this code
Self::Json => json::res_http(res),
Self::Cbor => cbor::res_http(res),
Self::Msgpack => msgpack::res_http(res),
Self::Bincode => bincode::res_http(res),
Self::Revision => revision::res_http(res),
}
}
}

View file

@ -1,55 +0,0 @@
mod convert;
use bytes::Bytes;
pub use convert::Pack;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use surrealdb::rpc::RpcError;
use crate::net::headers::ContentType;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use axum::response::{IntoResponse, Response as AxumResponse};
pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| RpcError::ParseError)?.try_into()
}
Message::Binary(val) => rmpv::decode::read_value(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Pack)?
.try_into(),
_ => Err(RpcError::InvalidRequest),
}
}
pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Convert the response into a value
let val: Pack = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into MsgPack binary data
rmpv::encode::write_value(&mut res, &val.0).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}
pub fn req_http(body: Bytes) -> Result<Request, RpcError> {
let val: Vec<u8> = body.into();
rmpv::decode::read_value(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)
.map(Pack)?
.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Convert the response into a value
let val: Pack = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into MsgPack binary data
rmpv::encode::write_value(&mut res, &val.0).unwrap();
// Return the message length, and message as binary
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::ApplicationPack))], res).into_response())
}

View file

@ -1,42 +0,0 @@
use crate::net::headers::ContentType;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use axum::response::{IntoResponse, Response as AxumResponse};
use bytes::Bytes;
use http::header::CONTENT_TYPE;
use http::HeaderValue;
use revision::Revisioned;
use surrealdb::rpc::RpcError;
use surrealdb::sql::Value;
pub fn req_ws(msg: Message) -> Result<Request, RpcError> {
match msg {
Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice())
.map_err(|_| RpcError::ParseError)?
.try_into(),
_ => Err(RpcError::InvalidRequest),
}
}
pub fn res_ws(res: Response) -> Result<(usize, Message), RpcError> {
// Serialize the response with full internal type information
let mut buf = Vec::new();
res.serialize_revisioned(&mut buf).unwrap();
// Return the message length, and message as binary
Ok((buf.len(), Message::Binary(buf)))
}
pub fn req_http(body: Bytes) -> Result<Request, RpcError> {
let val: Vec<u8> = body.into();
Value::deserialize_revisioned(&mut val.as_slice()).map_err(|_| RpcError::ParseError)?.try_into()
}
pub fn res_http(res: Response) -> Result<AxumResponse, RpcError> {
// Serialize the response with full internal type information
let mut buf = Vec::new();
res.serialize_revisioned(&mut buf).unwrap();
// Return the message length, and message as binary
// TODO: Check what this header should be, new header needed for revisioned
Ok(([(CONTENT_TYPE, HeaderValue::from(ContentType::Surrealdb))], buf).into_response())
}

View file

@ -3,7 +3,6 @@ pub mod connection;
pub mod failure; pub mod failure;
pub mod format; pub mod format;
pub mod post_context; pub mod post_context;
pub mod request;
pub mod response; pub mod response;
use crate::dbs::DB; use crate::dbs::DB;

View file

@ -1,13 +1,13 @@
use crate::rpc::failure::Failure; use crate::rpc::failure::Failure;
use crate::rpc::format::Format; use crate::rpc::format::WsFormat;
use crate::telemetry::metrics::ws::record_rpc; use crate::telemetry::metrics::ws::record_rpc;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use opentelemetry::Context as TelemetryContext; use opentelemetry::Context as TelemetryContext;
use revision::revisioned; use revision::revisioned;
use serde::Serialize; use serde::Serialize;
use serde_json::Value as Json;
use std::sync::Arc; use std::sync::Arc;
use surrealdb::channel::Sender; use surrealdb::channel::Sender;
use surrealdb::rpc::format::Format;
use surrealdb::rpc::Data; use surrealdb::rpc::Data;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use tracing::Span; use tracing::Span;
@ -20,12 +20,6 @@ pub struct Response {
} }
impl Response { impl Response {
/// Convert and simplify the value into JSON
#[inline]
pub fn into_json(self) -> Json {
Json::from(self.into_value())
}
#[inline] #[inline]
pub fn into_value(self) -> Value { pub fn into_value(self) -> Value {
let mut value = match self.result { let mut value = match self.result {
@ -68,6 +62,12 @@ impl Response {
} }
} }
impl From<Response> for Value {
fn from(value: Response) -> Self {
value.into_value()
}
}
/// Create a JSON RPC result response /// Create a JSON RPC result response
pub fn success<T: Into<Data>>(id: Option<Value>, data: T) -> Response { pub fn success<T: Into<Data>>(id: Option<Value>, data: T) -> Response {
Response { Response {

View file

@ -108,9 +108,7 @@ impl Socket {
match format { match format {
Format::Json => Ok(Message::Text(serde_json::to_string(message)?)), Format::Json => Ok(Message::Text(serde_json::to_string(message)?)),
Format::Cbor => { Format::Cbor => {
pub mod try_from_impls { use surrealdb::rpc::format::cbor::Cbor;
include!("../../src/rpc/format/cbor/convert.rs");
}
// For tests we need to convert the serde_json::Value // For tests we need to convert the serde_json::Value
// to a SurrealQL value, so that record ids, uuids, // to a SurrealQL value, so that record ids, uuids,
// datetimes, and durations are stored properly. // datetimes, and durations are stored properly.
@ -119,7 +117,7 @@ impl Socket {
// Then we parse the JSON in to SurrealQL. // Then we parse the JSON in to SurrealQL.
let surrealql = surrealdb::syn::value_legacy_strand(&json)?; let surrealql = surrealdb::syn::value_legacy_strand(&json)?;
// Then we convert the SurrealQL in to CBOR. // Then we convert the SurrealQL in to CBOR.
let cbor = try_from_impls::Cbor::try_from(surrealql)?; let cbor = Cbor::try_from(surrealql)?;
// Then serialize the CBOR as binary data. // Then serialize the CBOR as binary data.
let mut output = Vec::new(); let mut output = Vec::new();
ciborium::into_writer(&cbor.0, &mut output).unwrap(); ciborium::into_writer(&cbor.0, &mut output).unwrap();
@ -127,9 +125,7 @@ impl Socket {
Ok(Message::Binary(output)) Ok(Message::Binary(output))
} }
Format::Pack => { Format::Pack => {
pub mod try_from_impls { use surrealdb::rpc::format::msgpack::Pack;
include!("../../src/rpc/format/msgpack/convert.rs");
}
// For tests we need to convert the serde_json::Value // For tests we need to convert the serde_json::Value
// to a SurrealQL value, so that record ids, uuids, // to a SurrealQL value, so that record ids, uuids,
// datetimes, and durations are stored properly. // datetimes, and durations are stored properly.
@ -138,7 +134,7 @@ impl Socket {
// Then we parse the JSON in to SurrealQL. // Then we parse the JSON in to SurrealQL.
let surrealql = surrealdb::syn::value_legacy_strand(&json)?; let surrealql = surrealdb::syn::value_legacy_strand(&json)?;
// Then we convert the SurrealQL in to MessagePack. // Then we convert the SurrealQL in to MessagePack.
let pack = try_from_impls::Pack::try_from(surrealql)?; let pack = Pack::try_from(surrealql)?;
// Then serialize the MessagePack as binary data. // Then serialize the MessagePack as binary data.
let mut output = Vec::new(); let mut output = Vec::new();
rmpv::encode::write_value(&mut output, &pack.0).unwrap(); rmpv::encode::write_value(&mut output, &pack.0).unwrap();
@ -165,15 +161,13 @@ impl Socket {
debug!("Response {msg:?}"); debug!("Response {msg:?}");
match format { match format {
Format::Cbor => { Format::Cbor => {
pub mod try_from_impls { use surrealdb::rpc::format::cbor::Cbor;
include!("../../src/rpc/format/cbor/convert.rs");
}
// For tests we need to convert the binary data to // For tests we need to convert the binary data to
// a serde_json::Value so that test assertions work. // a serde_json::Value so that test assertions work.
// First of all we deserialize the CBOR data. // First of all we deserialize the CBOR data.
let msg: ciborium::Value = ciborium::from_reader(&mut msg.as_slice())?; let msg: ciborium::Value = ciborium::from_reader(&mut msg.as_slice())?;
// Then we convert it to a SurrealQL Value. // Then we convert it to a SurrealQL Value.
let msg: Value = try_from_impls::Cbor(msg).try_into()?; let msg: Value = Cbor(msg).try_into()?;
// Then we convert the SurrealQL to JSON. // Then we convert the SurrealQL to JSON.
let msg = msg.into_json(); let msg = msg.into_json();
// Then output the response. // Then output the response.
@ -181,15 +175,13 @@ impl Socket {
Ok(Some(msg)) Ok(Some(msg))
} }
Format::Pack => { Format::Pack => {
pub mod try_from_impls { use surrealdb::rpc::format::msgpack::Pack;
include!("../../src/rpc/format/msgpack/convert.rs");
}
// For tests we need to convert the binary data to // For tests we need to convert the binary data to
// a serde_json::Value so that test assertions work. // a serde_json::Value so that test assertions work.
// First of all we deserialize the MessagePack data. // First of all we deserialize the MessagePack data.
let msg: rmpv::Value = rmpv::decode::read_value(&mut msg.as_slice())?; let msg: rmpv::Value = rmpv::decode::read_value(&mut msg.as_slice())?;
// Then we convert it to a SurrealQL Value. // Then we convert it to a SurrealQL Value.
let msg: Value = try_from_impls::Pack(msg).try_into()?; let msg: Value = Pack(msg).try_into()?;
// Then we convert the SurrealQL to JSON. // Then we convert the SurrealQL to JSON.
let msg = msg.into_json(); let msg = msg.into_json();
// Then output the response. // Then output the response.