Support CBOR and MessagePack binary serialisation in WebSocket (#3251)

Co-authored-by: Rushmore Mushambi <rushmore@surrealdb.com>
This commit is contained in:
Tobie Morgan Hitchcock 2024-01-09 21:50:27 +00:00 committed by GitHub
parent 32dae92a29
commit f7e6e028a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 473 additions and 6 deletions

View file

@ -0,0 +1,182 @@
use ciborium::Value as Data;
use std::collections::BTreeMap;
use surrealdb::sql::Datetime;
use surrealdb::sql::Duration;
use surrealdb::sql::Id;
use surrealdb::sql::Number;
use surrealdb::sql::Thing;
use surrealdb::sql::Uuid;
use surrealdb::sql::Value;
const TAG_NONE: u64 = 78_773_250;
const TAG_UUID: u64 = 78_773_251;
const TAG_DECIMAL: u64 = 78_773_252;
const TAG_DURATION: u64 = 78_773_253;
const TAG_DATETIME: u64 = 78_773_254;
const TAG_RECORDID: u64 = 78_773_255;
#[derive(Debug)]
pub struct Cbor(pub Data);
impl TryFrom<Cbor> for Value {
type Error = &'static str;
fn try_from(val: Cbor) -> Result<Self, &'static str> {
match val.0 {
Data::Null => Ok(Value::Null),
Data::Bool(v) => Ok(Value::from(v)),
Data::Integer(v) => Ok(Value::from(i128::from(v))),
Data::Float(v) => Ok(Value::from(v)),
Data::Bytes(v) => Ok(Value::Bytes(v.into())),
Data::Text(v) => Ok(Value::from(v)),
Data::Array(v) => {
v.into_iter().map(|v| Value::try_from(Cbor(v))).collect::<Result<Value, &str>>()
}
Data::Map(v) => v
.into_iter()
.map(|(k, v)| {
let k = Value::try_from(Cbor(k)).map(|k| k.as_raw_string());
let v = Value::try_from(Cbor(v));
Ok((k?, v?))
})
.collect::<Result<Value, &str>>(),
Data::Tag(t, v) => {
match t {
// A literal NONE
TAG_NONE => Ok(Value::None),
// A literal uuid
TAG_UUID => match *v {
Data::Text(v) => match Uuid::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid UUID value"),
},
_ => Err("Expected a CBOR text data type"),
},
// A literal decimal
TAG_DECIMAL => match *v {
Data::Text(v) => match Number::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid Decimal value"),
},
_ => Err("Expected a CBOR text data type"),
},
// A literal duration
TAG_DURATION => match *v {
Data::Text(v) => match Duration::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid Duration value"),
},
_ => Err("Expected a CBOR text data type"),
},
// A literal datetime
TAG_DATETIME => match *v {
Data::Text(v) => match Datetime::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid Datetime value"),
},
_ => Err("Expected a CBOR text data type"),
},
// A literal recordid
TAG_RECORDID => match *v {
Data::Text(v) => match Thing::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid RecordID value"),
},
Data::Array(mut v) if v.len() == 2 => match (v.remove(0), v.remove(0)) {
(Data::Text(tb), Data::Text(id)) => {
Ok(Value::from(Thing::from((tb, id))))
}
(Data::Text(tb), Data::Integer(id)) => {
Ok(Value::from(Thing::from((tb, Id::from(i128::from(id) as i64)))))
}
(Data::Text(tb), Data::Array(id)) => Ok(Value::from(Thing::from((
tb,
Id::from(
id.into_iter()
.map(|v| Value::try_from(Cbor(v)))
.collect::<Result<Vec<Value>, &str>>()?,
),
)))),
(Data::Text(tb), Data::Map(id)) => Ok(Value::from(Thing::from((
tb,
Id::from(
id.into_iter()
.map(|(k, v)| {
let k =
Value::try_from(Cbor(k)).map(|k| k.as_raw_string());
let v = Value::try_from(Cbor(v));
Ok((k?, v?))
})
.collect::<Result<BTreeMap<String, Value>, &str>>()?,
),
)))),
_ => Err("Expected a CBOR array with 2 elements, a text data type, and a valid ID type"),
},
_ => Err("Expected a CBOR text data type, or a CBOR array with 2 elements"),
},
// An unknown tag
_ => Err("Encountered an unknown CBOR tag"),
}
}
_ => Err("Encountered an unknown CBOR data type"),
}
}
}
impl TryFrom<Value> for Cbor {
type Error = &'static str;
fn try_from(val: Value) -> Result<Self, &'static str> {
match val {
Value::None => Ok(Cbor(Data::Tag(TAG_NONE, Box::new(Data::Null)))),
Value::Null => Ok(Cbor(Data::Null)),
Value::Bool(v) => Ok(Cbor(Data::Bool(v))),
Value::Number(v) => match v {
Number::Int(v) => Ok(Cbor(Data::Integer(v.into()))),
Number::Float(v) => Ok(Cbor(Data::Float(v))),
Number::Decimal(v) => {
Ok(Cbor(Data::Tag(TAG_DECIMAL, Box::new(Data::Text(v.to_string())))))
}
},
Value::Strand(v) => Ok(Cbor(Data::Text(v.0))),
Value::Duration(v) => {
Ok(Cbor(Data::Tag(TAG_DURATION, Box::new(Data::Text(v.to_raw())))))
}
Value::Datetime(v) => {
Ok(Cbor(Data::Tag(TAG_DATETIME, Box::new(Data::Text(v.to_raw())))))
}
Value::Uuid(v) => Ok(Cbor(Data::Tag(TAG_UUID, Box::new(Data::Text(v.to_raw()))))),
Value::Array(v) => Ok(Cbor(Data::Array(
v.into_iter()
.map(|v| {
let v = Cbor::try_from(v)?.0;
Ok(v)
})
.collect::<Result<Vec<Data>, &str>>()?,
))),
Value::Object(v) => Ok(Cbor(Data::Map(
v.into_iter()
.map(|(k, v)| {
let k = Data::Text(k);
let v = Cbor::try_from(v)?.0;
Ok((k, v))
})
.collect::<Result<Vec<(Data, Data)>, &str>>()?,
))),
Value::Bytes(v) => Ok(Cbor(Data::Bytes(v.into_inner()))),
Value::Thing(v) => Ok(Cbor(Data::Tag(
TAG_RECORDID,
Box::new(Data::Array(vec![
Data::Text(v.tb),
match v.id {
Id::Number(v) => Data::Integer(v.into()),
Id::String(v) => Data::Text(v),
Id::Array(v) => Cbor::try_from(Value::from(v))?.0,
Id::Object(v) => Cbor::try_from(Value::from(v))?.0,
Id::Generate(_) => unreachable!(),
},
])),
))),
// We shouldn't reach here
_ => unreachable!(),
}
}
}

View file

@ -1,24 +1,33 @@
mod convert;
pub use convert::Cbor;
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use ciborium::Value as Data;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
Message::Binary(val) => ciborium::from_reader::<Data, _>(&mut val.as_slice())
.map_err(|_| Failure::PARSE_ERROR)
.map(Cbor)?
.try_into(),
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Convert the response into a value
let val: Cbor = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into CBOR binary data
ciborium::into_writer(&val, &mut res).unwrap();
ciborium::into_writer(&val.0, &mut res).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

View file

@ -0,0 +1,145 @@
use rmpv::Value as Data;
use surrealdb::sql::Datetime;
use surrealdb::sql::Duration;
use surrealdb::sql::Number;
use surrealdb::sql::Thing;
use surrealdb::sql::Uuid;
use surrealdb::sql::Value;
const TAG_NONE: i8 = 1;
const TAG_UUID: i8 = 2;
const TAG_DECIMAL: i8 = 3;
const TAG_DURATION: i8 = 4;
const TAG_DATETIME: i8 = 5;
const TAG_RECORDID: i8 = 6;
#[derive(Debug)]
pub struct Pack(pub Data);
impl TryFrom<Pack> for Value {
type Error = &'static str;
fn try_from(val: Pack) -> Result<Self, &'static str> {
match val.0 {
Data::Nil => Ok(Value::Null),
Data::Boolean(v) => Ok(Value::from(v)),
Data::Integer(v) if v.is_i64() => match v.as_i64() {
Some(v) => Ok(Value::from(v)),
None => Ok(Value::Null),
},
Data::Integer(v) if v.is_u64() => match v.as_u64() {
Some(v) => Ok(Value::from(v)),
None => Ok(Value::Null),
},
Data::F32(v) => Ok(Value::from(v)),
Data::F64(v) => Ok(Value::from(v)),
Data::String(v) => match v.into_str() {
Some(v) => Ok(Value::from(v)),
None => Ok(Value::Null),
},
Data::Binary(v) => Ok(Value::Bytes(v.into())),
Data::Array(v) => {
v.into_iter().map(|v| Value::try_from(Pack(v))).collect::<Result<Value, &str>>()
}
Data::Map(v) => v
.into_iter()
.map(|(k, v)| {
let k = Value::try_from(Pack(k)).map(|k| k.as_raw_string());
let v = Value::try_from(Pack(v));
Ok((k?, v?))
})
.collect::<Result<Value, &str>>(),
Data::Ext(t, v) => {
match t {
// A literal uuid
TAG_NONE => Ok(Value::None),
// A literal uuid
TAG_UUID => match std::str::from_utf8(&v) {
Ok(v) => match Uuid::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid UUID value"),
},
_ => Err("Expected a valid UTF-8 string"),
},
// A literal decimal
TAG_DECIMAL => match std::str::from_utf8(&v) {
Ok(v) => match Number::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid Decimal value"),
},
_ => Err("Expected a valid UTF-8 string"),
},
// A literal duration
TAG_DURATION => match std::str::from_utf8(&v) {
Ok(v) => match Duration::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid Duration value"),
},
_ => Err("Expected a valid UTF-8 string"),
},
// A literal datetime
TAG_DATETIME => match std::str::from_utf8(&v) {
Ok(v) => match Datetime::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid Datetime value"),
},
_ => Err("Expected a valid UTF-8 string"),
},
// A literal recordid
TAG_RECORDID => match std::str::from_utf8(&v) {
Ok(v) => match Thing::try_from(v) {
Ok(v) => Ok(v.into()),
_ => Err("Expected a valid RecordID value"),
},
_ => Err("Expected a valid UTF-8 string"),
},
// An unknown tag
_ => Err("Encountered an unknown MessagePack tag"),
}
}
_ => Err("Encountered an unknown MessagePack data type"),
}
}
}
impl TryFrom<Value> for Pack {
type Error = &'static str;
fn try_from(val: Value) -> Result<Self, &'static str> {
match val {
Value::None => Ok(Pack(Data::Ext(TAG_NONE, vec![]))),
Value::Null => Ok(Pack(Data::Nil)),
Value::Bool(v) => Ok(Pack(Data::Boolean(v))),
Value::Number(v) => match v {
Number::Int(v) => Ok(Pack(Data::Integer(v.into()))),
Number::Float(v) => Ok(Pack(Data::F64(v))),
Number::Decimal(v) => {
Ok(Pack(Data::Ext(TAG_DECIMAL, v.to_string().as_bytes().to_vec())))
}
},
Value::Strand(v) => Ok(Pack(Data::String(v.0.into()))),
Value::Duration(v) => Ok(Pack(Data::Ext(TAG_DURATION, v.to_raw().as_bytes().to_vec()))),
Value::Datetime(v) => Ok(Pack(Data::Ext(TAG_DATETIME, v.to_raw().as_bytes().to_vec()))),
Value::Uuid(v) => Ok(Pack(Data::Ext(TAG_UUID, v.to_raw().as_bytes().to_vec()))),
Value::Array(v) => Ok(Pack(Data::Array(
v.into_iter()
.map(|v| {
let v = Pack::try_from(v)?.0;
Ok(v)
})
.collect::<Result<Vec<Data>, &str>>()?,
))),
Value::Object(v) => Ok(Pack(Data::Map(
v.into_iter()
.map(|(k, v)| {
let k = Data::String(k.into());
let v = Pack::try_from(v)?.0;
Ok((k, v))
})
.collect::<Result<Vec<(Data, Data)>, &str>>()?,
))),
Value::Bytes(v) => Ok(Pack(Data::Binary(v.into_inner()))),
Value::Thing(v) => Ok(Pack(Data::Ext(TAG_RECORDID, v.to_raw().as_bytes().to_vec()))),
// We shouldn't reach here
_ => unreachable!(),
}
}
}

View file

@ -1,3 +1,7 @@
mod convert;
pub use convert::Pack;
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
@ -8,15 +12,21 @@ pub fn req(msg: Message) -> Result<Request, Failure> {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
Message::Binary(val) => rmpv::decode::read_value(&mut val.as_slice())
.map_err(|_| Failure::PARSE_ERROR)
.map(Pack)?
.try_into(),
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Convert the response into a value
let val: Pack = res.into_value().try_into()?;
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into MsgPack binary data
let res = serde_pack::to_vec(&val).unwrap();
rmpv::encode::write_value(&mut res, &val.0).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

View file

@ -1,4 +1,6 @@
use crate::rpc::failure::Failure;
use crate::rpc::format::cbor::Cbor;
use crate::rpc::format::msgpack::Pack;
use once_cell::sync::Lazy;
use surrealdb::sql::Part;
use surrealdb::sql::{Array, Value};
@ -13,6 +15,20 @@ pub struct Request {
pub params: Array,
}
impl TryFrom<Cbor> for Request {
type Error = Failure;
fn try_from(val: Cbor) -> Result<Self, Failure> {
<Cbor as TryInto<Value>>::try_into(val).map_err(|_| Failure::INVALID_REQUEST)?.try_into()
}
}
impl TryFrom<Pack> for Request {
type Error = Failure;
fn try_from(val: Pack) -> Result<Self, Failure> {
<Pack as TryInto<Value>>::try_into(val).map_err(|_| Failure::INVALID_REQUEST)?.try_into()
}
}
impl TryFrom<Value> for Request {
type Error = Failure;
fn try_from(val: Value) -> Result<Self, Failure> {

View file

@ -3,12 +3,16 @@ use std::string::ToString;
#[derive(Debug, Copy, Clone)]
pub enum Format {
Json,
Cbor,
Pack,
}
impl ToString for Format {
fn to_string(&self) -> String {
match self {
Self::Json => "json".to_owned(),
Self::Cbor => "cbor".to_owned(),
Self::Pack => "msgpack".to_owned(),
}
}
}

View file

@ -69,6 +69,44 @@ impl Socket {
// Format the message
let msg = match format {
Format::Json => Message::Text(serde_json::to_string(&message)?),
Format::Cbor => {
pub mod try_from_impls {
include!("../../src/rpc/format/cbor/convert.rs");
}
// For tests we need to convert the serde_json::Value
// to a SurrealQL value, so that record ids, uuids,
// datetimes, and durations are stored properly.
// First of all we convert the JSON type to a string.
let json = message.to_string();
// Then we parse the JSON in to SurrealQL.
let surrealql = surrealdb::sql::value(&json)?;
// Then we convert the SurrealQL in to CBOR.
let cbor = try_from_impls::Cbor::try_from(surrealql)?;
// Then serialize the CBOR as binary data.
let mut output = Vec::new();
ciborium::into_writer(&cbor.0, &mut output).unwrap();
// THen output the message.
Message::Binary(output)
}
Format::Pack => {
pub mod try_from_impls {
include!("../../src/rpc/format/msgpack/convert.rs");
}
// For tests we need to convert the serde_json::Value
// to a SurrealQL value, so that record ids, uuids,
// datetimes, and durations are stored properly.
// First of all we convert the JSON type to a string.
let json = message.to_string();
// Then we parse the JSON in to SurrealQL.
let surrealql = surrealdb::sql::value(&json)?;
// Then we convert the SurrealQL in to MessagePack.
let pack = try_from_impls::Pack::try_from(surrealql)?;
// Then serialize the MessagePack as binary data.
let mut output = Vec::new();
rmpv::encode::write_value(&mut output, &pack.0).unwrap();
// THen output the message.
Message::Binary(output)
}
};
// Send the message
tokio::select! {
@ -108,6 +146,49 @@ impl Socket {
debug!("Received message: {msg}");
return Ok(msg);
},
_ => {
return Err("Expected to receive a binary message".to_string().into());
}
}
},
Some(Message::Binary(msg)) => {
debug!("Response {msg:?} received in {:?}", now.elapsed());
match format {
Format::Cbor => {
pub mod try_from_impls {
include!("../../src/rpc/format/cbor/convert.rs");
}
// For tests we need to convert the binary data to
// a serde_json::Value so that test assertions work.
// First of all we deserialize the CBOR data.
let msg: ciborium::Value = ciborium::from_reader(&mut msg.as_slice())?;
// Then we convert it to a SurrealQL Value.
let msg: Value = try_from_impls::Cbor(msg).try_into()?;
// Then we convert the SurrealQL to JSON.
let msg = msg.into_json();
// Then output the response.
debug!("Received message: {msg:?}");
return Ok(msg);
},
Format::Pack => {
pub mod try_from_impls {
include!("../../src/rpc/format/msgpack/convert.rs");
}
// For tests we need to convert the binary data to
// a serde_json::Value so that test assertions work.
// First of all we deserialize the MessagePack data.
let msg: rmpv::Value = rmpv::decode::read_value(&mut msg.as_slice())?;
// Then we convert it to a SurrealQL Value.
let msg: Value = try_from_impls::Pack(msg).try_into()?;
// Then we convert the SurrealQL to JSON.
let msg = msg.into_json();
// Then output the response.
debug!("Received message: {msg:?}");
return Ok(msg);
},
_ => {
return Err("Expected to receive a text message".to_string().into());
}
}
},
Some(_) => {

View file

@ -25,4 +25,24 @@ mod ws_integration {
// Run all of the common tests
include!("common/tests.rs");
}
/// Tests for the CBOR protocol format
mod cbor {
// The WebSocket protocol header
const SERVER: Option<Format> = Some(Format::Cbor);
// The format to send messages
const FORMAT: Format = Format::Cbor;
// Run all of the common tests
include!("common/tests.rs");
}
/// Tests for the MessagePack protocol format
mod pack {
// The WebSocket protocol header
const SERVER: Option<Format> = Some(Format::Pack);
// The format to send messages
const FORMAT: Format = Format::Pack;
// Run all of the common tests
include!("common/tests.rs");
}
}