Implement revision types for client/server communication (#3474)

This commit is contained in:
Rushmore Mushambi 2024-02-12 13:52:36 +02:00 committed by GitHub
parent 2d2b3a40bb
commit d55d1a3b6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 321 additions and 150 deletions

2
Cargo.lock generated
View file

@ -5278,6 +5278,7 @@ dependencies = [
"rand 0.8.5", "rand 0.8.5",
"rcgen", "rcgen",
"reqwest", "reqwest",
"revision",
"rmp-serde", "rmp-serde",
"rmpv", "rmpv",
"rustyline", "rustyline",
@ -5330,6 +5331,7 @@ dependencies = [
"rand 0.8.5", "rand 0.8.5",
"regex", "regex",
"reqwest", "reqwest",
"revision",
"ring 0.17.7", "ring 0.17.7",
"rust_decimal", "rust_decimal",
"rustls", "rustls",

View file

@ -65,6 +65,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [
"blocking", "blocking",
"gzip", "gzip",
] } ] }
revision = "0.5.0"
rmpv = "1.0.1" rmpv = "1.0.1"
rustyline = { version = "12.0.0", features = ["derive"] } rustyline = { version = "12.0.0", features = ["derive"] }
serde = { version = "1.0.193", features = ["derive"] } serde = { version = "1.0.193", features = ["derive"] }

View file

@ -1,9 +1,11 @@
use crate::sql::{Object, Uuid, Value}; use crate::sql::{Object, Uuid, Value};
use revision::revisioned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::{self, Debug, Display}; use std::fmt::{self, Debug, Display};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")] #[serde(rename_all = "UPPERCASE")]
#[revisioned(revision = 1)]
pub enum Action { pub enum Action {
Create, Create,
Update, Update,
@ -21,6 +23,7 @@ impl Display for Action {
} }
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
#[revisioned(revision = 1)]
pub struct Notification { pub struct Notification {
/// The id of the LIVE query to which this notification belongs /// The id of the LIVE query to which this notification belongs
pub id: Uuid, pub id: Uuid,

View file

@ -1,5 +1,7 @@
use crate::err::Error; use crate::err::Error;
use crate::sql::value::Value; use crate::sql::value::Value;
use revision::revisioned;
use revision::Revisioned;
use serde::ser::SerializeStruct; use serde::ser::SerializeStruct;
use serde::Deserialize; use serde::Deserialize;
use serde::Serialize; use serde::Serialize;
@ -40,6 +42,7 @@ impl Response {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")] #[serde(rename_all = "UPPERCASE")]
#[revisioned(revision = 1)]
#[doc(hidden)] #[doc(hidden)]
pub enum Status { pub enum Status {
Ok, Ok,
@ -66,3 +69,49 @@ impl Serialize for Response {
val.end() val.end()
} }
} }
#[derive(Debug, Serialize, Deserialize)]
#[revisioned(revision = 1)]
#[doc(hidden)]
pub struct QueryMethodResponse {
pub time: String,
pub status: Status,
pub result: Value,
}
impl From<&Response> for QueryMethodResponse {
fn from(res: &Response) -> Self {
let time = res.speed();
let (status, result) = match &res.result {
Ok(value) => (Status::Ok, value.clone()),
Err(error) => (Status::Err, Value::from(error.to_string())),
};
Self {
status,
result,
time,
}
}
}
#[doc(hidden)]
impl Revisioned for Response {
#[inline]
fn serialize_revisioned<W: std::io::Write>(
&self,
writer: &mut W,
) -> std::result::Result<(), revision::Error> {
QueryMethodResponse::from(self).serialize_revisioned(writer)
}
#[inline]
fn deserialize_revisioned<R: std::io::Read>(
_reader: &mut R,
) -> std::result::Result<Self, revision::Error> {
unreachable!("deserialising `Response` directly is not supported")
}
fn revision() -> u16 {
1
}
}

View file

@ -92,6 +92,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [
"stream", "stream",
"multipart", "multipart",
], optional = true } ], optional = true }
revision = "0.5.0"
rust_decimal = { version = "1.33.1", features = ["maths", "serde-str"] } rust_decimal = { version = "1.33.1", features = ["maths", "serde-str"] }
rustls = { version = "0.21.10", optional = true } rustls = { version = "0.21.10", optional = true }
semver = { version = "1.0.20", features = ["serde"] } semver = { version = "1.0.20", features = ["serde"] }

View file

@ -134,11 +134,9 @@ impl IntoEndpoint for &str {
) )
} }
}; };
Ok(Endpoint { let mut endpoint = Endpoint::new(url);
url, endpoint.path = path;
path, Ok(endpoint)
config: Default::default(),
})
} }
} }

View file

@ -175,9 +175,10 @@ impl Connection for Any {
#[cfg(feature = "protocol-ws")] #[cfg(feature = "protocol-ws")]
{ {
features.insert(ExtraFeatures::LiveQueries); features.insert(ExtraFeatures::LiveQueries);
let url = address.url.join(engine::remote::ws::PATH)?; let mut endpoint = address;
endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?;
#[cfg(any(feature = "native-tls", feature = "rustls"))] #[cfg(any(feature = "native-tls", feature = "rustls"))]
let maybe_connector = address.config.tls_config.map(Connector::from); let maybe_connector = endpoint.config.tls_config.clone().map(Connector::from);
#[cfg(not(any(feature = "native-tls", feature = "rustls")))] #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
let maybe_connector = None; let maybe_connector = None;
@ -188,13 +189,13 @@ impl Connection for Any {
..Default::default() ..Default::default()
}; };
let socket = engine::remote::ws::native::connect( let socket = engine::remote::ws::native::connect(
&url, &endpoint,
Some(config), Some(config),
maybe_connector.clone(), maybe_connector.clone(),
) )
.await?; .await?;
engine::remote::ws::native::router( engine::remote::ws::native::router(
url, endpoint,
maybe_connector, maybe_connector,
capacity, capacity,
config, config,

View file

@ -151,9 +151,9 @@ impl Connection for Any {
#[cfg(feature = "protocol-ws")] #[cfg(feature = "protocol-ws")]
{ {
features.insert(ExtraFeatures::LiveQueries); features.insert(ExtraFeatures::LiveQueries);
let mut address = address; let mut endpoint = address;
address.url = address.url.join(engine::remote::ws::PATH)?; endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?;
engine::remote::ws::wasm::router(address, capacity, conn_tx, route_rx); engine::remote::ws::wasm::router(endpoint, capacity, conn_tx, route_rx);
conn_rx.into_recv_async().await??; conn_rx.into_recv_async().await??;
} }

View file

@ -15,18 +15,24 @@ use crate::api::Connect;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::dbs::Notification; use crate::dbs::Notification;
use crate::dbs::QueryMethodResponse;
use crate::dbs::Status; use crate::dbs::Status;
use crate::method::Stats; use crate::method::Stats;
use crate::opt::IntoEndpoint; use crate::opt::IntoEndpoint;
use crate::sql::Value; use crate::sql::Value;
use indexmap::IndexMap; use indexmap::IndexMap;
use revision::revisioned;
use revision::Revisioned;
use serde::de::DeserializeOwned;
use serde::Deserialize; use serde::Deserialize;
use std::io::Read;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::time::Duration; use std::time::Duration;
pub(crate) const PATH: &str = "rpc"; pub(crate) const PATH: &str = "rpc";
const PING_INTERVAL: Duration = Duration::from_secs(5); const PING_INTERVAL: Duration = Duration::from_secs(5);
const PING_METHOD: &str = "ping"; const PING_METHOD: &str = "ping";
const REVISION_HEADER: &str = "revision";
/// The WS scheme used to connect to `ws://` endpoints /// The WS scheme used to connect to `ws://` endpoints
#[derive(Debug)] #[derive(Debug)]
@ -78,12 +84,14 @@ impl Surreal<Client> {
} }
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
#[revisioned(revision = 1)]
pub(crate) struct Failure { pub(crate) struct Failure {
pub(crate) code: i64, pub(crate) code: i64,
pub(crate) message: String, pub(crate) message: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[revisioned(revision = 1)]
pub(crate) enum Data { pub(crate) enum Data {
Other(Value), Other(Value),
Query(Vec<QueryMethodResponse>), Query(Vec<QueryMethodResponse>),
@ -104,13 +112,6 @@ impl From<Failure> for Error {
} }
} }
#[derive(Debug, Deserialize)]
pub(crate) struct QueryMethodResponse {
time: String,
status: Status,
result: Value,
}
impl DbResponse { impl DbResponse {
fn from(result: ServerResult) -> Result<Self> { fn from(result: ServerResult) -> Result<Self> {
match result.map_err(Error::from)? { match result.map_err(Error::from)? {
@ -148,7 +149,30 @@ impl DbResponse {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[revisioned(revision = 1)]
pub(crate) struct Response { pub(crate) struct Response {
id: Option<Value>, id: Option<Value>,
pub(crate) result: ServerResult, pub(crate) result: ServerResult,
} }
fn serialize(value: &Value, revisioned: bool) -> Result<Vec<u8>> {
if revisioned {
let mut buf = Vec::new();
value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?;
return Ok(buf);
}
crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into()))
}
fn deserialize<A, T>(bytes: &mut A, revisioned: bool) -> Result<T>
where
A: Read,
T: Revisioned + DeserializeOwned,
{
if revisioned {
return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into()));
}
let mut buf = Vec::new();
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
}

View file

@ -1,4 +1,5 @@
use super::PATH; use super::PATH;
use super::{deserialize, serialize};
use crate::api::conn::Connection; use crate::api::conn::Connection;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Method; use crate::api::conn::Method;
@ -19,7 +20,6 @@ use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::engine::remote::ws::Data; use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream; use crate::engine::IntervalStream;
use crate::sql::serde::{deserialize, serialize};
use crate::sql::Strand; use crate::sql::Strand;
use crate::sql::Value; use crate::sql::Value;
use flume::Receiver; use flume::Receiver;
@ -28,6 +28,7 @@ use futures::SinkExt;
use futures::StreamExt; use futures::StreamExt;
use futures_concurrency::stream::Merge as _; use futures_concurrency::stream::Merge as _;
use indexmap::IndexMap; use indexmap::IndexMap;
use revision::revisioned;
use serde::Deserialize; use serde::Deserialize;
use std::collections::hash_map::Entry; use std::collections::hash_map::Entry;
use std::collections::BTreeMap; use std::collections::BTreeMap;
@ -42,14 +43,16 @@ use std::sync::OnceLock;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time; use tokio::time;
use tokio::time::MissedTickBehavior; use tokio::time::MissedTickBehavior;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::error::Error as WsError; use tokio_tungstenite::tungstenite::error::Error as WsError;
use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::Connector; use tokio_tungstenite::Connector;
use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::WebSocketStream;
use trice::Instant; use trice::Instant;
use url::Url;
type WsResult<T> = std::result::Result<T, WsError>; type WsResult<T> = std::result::Result<T, WsError>;
@ -78,17 +81,29 @@ impl From<Tls> for Connector {
} }
pub(crate) async fn connect( pub(crate) async fn connect(
url: &Url, endpoint: &Endpoint,
config: Option<WebSocketConfig>, config: Option<WebSocketConfig>,
#[allow(unused_variables)] maybe_connector: Option<Connector>, #[allow(unused_variables)] maybe_connector: Option<Connector>,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> { ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let mut request = (&endpoint.url).into_client_request()?;
if endpoint.supports_revision {
request
.headers_mut()
.insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static(super::REVISION_HEADER));
}
#[cfg(any(feature = "native-tls", feature = "rustls"))] #[cfg(any(feature = "native-tls", feature = "rustls"))]
let (socket, _) = let (socket, _) = tokio_tungstenite::connect_async_tls_with_config(
tokio_tungstenite::connect_async_tls_with_config(url, config, NAGLE_ALG, maybe_connector) request,
config,
NAGLE_ALG,
maybe_connector,
)
.await?; .await?;
#[cfg(not(any(feature = "native-tls", feature = "rustls")))] #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
let (socket, _) = tokio_tungstenite::connect_async_with_config(url, config, NAGLE_ALG).await?; let (socket, _) = tokio_tungstenite::connect_async_with_config(request, config, NAGLE_ALG).await?;
Ok(socket) Ok(socket)
} }
@ -104,13 +119,13 @@ impl Connection for Client {
} }
fn connect( fn connect(
address: Endpoint, mut address: Endpoint,
capacity: usize, capacity: usize,
) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
Box::pin(async move { Box::pin(async move {
let url = address.url.join(PATH)?; address.url = address.url.join(PATH)?;
#[cfg(any(feature = "native-tls", feature = "rustls"))] #[cfg(any(feature = "native-tls", feature = "rustls"))]
let maybe_connector = address.config.tls_config.map(Connector::from); let maybe_connector = address.config.tls_config.clone().map(Connector::from);
#[cfg(not(any(feature = "native-tls", feature = "rustls")))] #[cfg(not(any(feature = "native-tls", feature = "rustls")))]
let maybe_connector = None; let maybe_connector = None;
@ -121,14 +136,14 @@ impl Connection for Client {
..Default::default() ..Default::default()
}; };
let socket = connect(&url, Some(config), maybe_connector.clone()).await?; let socket = connect(&address, Some(config), maybe_connector.clone()).await?;
let (route_tx, route_rx) = match capacity { let (route_tx, route_rx) = match capacity {
0 => flume::unbounded(), 0 => flume::unbounded(),
capacity => flume::bounded(capacity), capacity => flume::bounded(capacity),
}; };
router(url, maybe_connector, capacity, config, socket, route_rx); router(address, maybe_connector, capacity, config, socket, route_rx);
let mut features = HashSet::new(); let mut features = HashSet::new();
features.insert(ExtraFeatures::LiveQueries); features.insert(ExtraFeatures::LiveQueries);
@ -164,7 +179,7 @@ impl Connection for Client {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub(crate) fn router( pub(crate) fn router(
url: Url, endpoint: Endpoint,
maybe_connector: Option<Connector>, maybe_connector: Option<Connector>,
capacity: usize, capacity: usize,
config: WebSocketConfig, config: WebSocketConfig,
@ -176,7 +191,7 @@ pub(crate) fn router(
let mut request = BTreeMap::new(); let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into()); request.insert("method".to_owned(), PING_METHOD.into());
let value = Value::from(request); let value = Value::from(request);
let value = serialize(&value).unwrap(); let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value) Message::Binary(value)
}; };
@ -270,7 +285,8 @@ pub(crate) fn router(
} }
let payload = Value::from(request); let payload = Value::from(request);
trace!("Request {payload}"); trace!("Request {payload}");
let payload = serialize(&payload).unwrap(); let payload =
serialize(&payload, endpoint.supports_revision).unwrap();
Message::Binary(payload) Message::Binary(payload)
}; };
if let Method::Authenticate if let Method::Authenticate
@ -314,7 +330,7 @@ pub(crate) fn router(
last_activity = Instant::now(); last_activity = Instant::now();
match result { match result {
Ok(message) => { Ok(message) => {
match Response::try_from(&message) { match Response::try_from(&message, endpoint.supports_revision) {
Ok(option) => { Ok(option) => {
// We are only interested in responses that are not empty // We are only interested in responses that are not empty
if let Some(response) = option { if let Some(response) = option {
@ -379,8 +395,11 @@ pub(crate) fn router(
); );
let value = let value =
Value::from(request); Value::from(request);
let value = let value = serialize(
serialize(&value) &value,
endpoint
.supports_revision,
)
.unwrap(); .unwrap();
Message::Binary(value) Message::Binary(value)
}; };
@ -402,6 +421,7 @@ pub(crate) fn router(
} }
Err(error) => { Err(error) => {
#[derive(Deserialize)] #[derive(Deserialize)]
#[revisioned(revision = 1)]
struct Response { struct Response {
id: Option<Value>, id: Option<Value>,
} }
@ -410,8 +430,10 @@ pub(crate) fn router(
if let Message::Binary(binary) = message { if let Message::Binary(binary) = message {
if let Ok(Response { if let Ok(Response {
id, id,
}) = deserialize(&binary) }) = deserialize(
{ &mut &binary[..],
endpoint.supports_revision,
) {
// Return an error if an ID was returned // Return an error if an ID was returned
if let Some(Ok(id)) = if let Some(Ok(id)) =
id.map(Value::coerce_to_i64) id.map(Value::coerce_to_i64)
@ -473,7 +495,7 @@ pub(crate) fn router(
'reconnect: loop { 'reconnect: loop {
trace!("Reconnecting..."); trace!("Reconnecting...");
match connect(&url, Some(config), maybe_connector.clone()).await { match connect(&endpoint, Some(config), maybe_connector.clone()).await {
Ok(s) => { Ok(s) => {
socket = s; socket = s;
for (_, message) in &replay { for (_, message) in &replay {
@ -512,19 +534,21 @@ pub(crate) fn router(
} }
impl Response { impl Response {
fn try_from(message: &Message) -> Result<Option<Self>> { fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
match message { match message {
Message::Text(text) => { Message::Text(text) => {
trace!("Received an unexpected text message; {text}"); trace!("Received an unexpected text message; {text}");
Ok(None) Ok(None)
} }
Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| { Message::Binary(binary) => {
deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
Error::ResponseFromBinary { Error::ResponseFromBinary {
binary: binary.clone(), binary: binary.clone(),
error, error: bincode::ErrorKind::Custom(error.to_string()).into(),
} }
.into() .into()
}), })
}
Message::Ping(..) => { Message::Ping(..) => {
trace!("Received a ping from the server"); trace!("Received a ping from the server");
Ok(None) Ok(None)

View file

@ -1,4 +1,5 @@
use super::PATH; use super::PATH;
use super::{deserialize, serialize};
use crate::api::conn::Connection; use crate::api::conn::Connection;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Method; use crate::api::conn::Method;
@ -17,7 +18,6 @@ use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::engine::remote::ws::Data; use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream; use crate::engine::IntervalStream;
use crate::sql::serde::{deserialize, serialize};
use crate::sql::Strand; use crate::sql::Strand;
use crate::sql::Value; use crate::sql::Value;
use flume::Receiver; use flume::Receiver;
@ -29,6 +29,7 @@ use indexmap::IndexMap;
use pharos::Channel; use pharos::Channel;
use pharos::Observable; use pharos::Observable;
use pharos::ObserveConfig; use pharos::ObserveConfig;
use revision::revisioned;
use serde::Deserialize; use serde::Deserialize;
use std::collections::hash_map::Entry; use std::collections::hash_map::Entry;
use std::collections::BTreeMap; use std::collections::BTreeMap;
@ -117,13 +118,17 @@ impl Connection for Client {
} }
pub(crate) fn router( pub(crate) fn router(
address: Endpoint, endpoint: Endpoint,
capacity: usize, capacity: usize,
conn_tx: Sender<Result<()>>, conn_tx: Sender<Result<()>>,
route_rx: Receiver<Option<Route>>, route_rx: Receiver<Option<Route>>,
) { ) {
spawn_local(async move { spawn_local(async move {
let (mut ws, mut socket) = match WsMeta::connect(&address.url, None).await { let connect = match endpoint.supports_revision {
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
false => WsMeta::connect(&endpoint.url, None).await,
};
let (mut ws, mut socket) = match connect {
Ok(pair) => pair, Ok(pair) => pair,
Err(error) => { Err(error) => {
let _ = conn_tx.into_send_async(Err(error.into())).await; let _ = conn_tx.into_send_async(Err(error.into())).await;
@ -151,7 +156,7 @@ pub(crate) fn router(
let mut request = BTreeMap::new(); let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into()); request.insert("method".to_owned(), PING_METHOD.into());
let value = Value::from(request); let value = Value::from(request);
let value = serialize(&value).unwrap(); let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value) Message::Binary(value)
}; };
@ -244,7 +249,7 @@ pub(crate) fn router(
} }
let payload = Value::from(request); let payload = Value::from(request);
trace!("Request {payload}"); trace!("Request {payload}");
let payload = serialize(&payload).unwrap(); let payload = serialize(&payload, endpoint.supports_revision).unwrap();
Message::Binary(payload) Message::Binary(payload)
}; };
if let Method::Authenticate if let Method::Authenticate
@ -285,7 +290,7 @@ pub(crate) fn router(
} }
Either::Response(message) => { Either::Response(message) => {
last_activity = Instant::now(); last_activity = Instant::now();
match Response::try_from(&message) { match Response::try_from(&message, endpoint.supports_revision) {
Ok(option) => { Ok(option) => {
// We are only interested in responses that are not empty // We are only interested in responses that are not empty
if let Some(response) = option { if let Some(response) = option {
@ -335,7 +340,11 @@ pub(crate) fn router(
.into(), .into(),
); );
let value = Value::from(request); let value = Value::from(request);
let value = serialize(&value).unwrap(); let value = serialize(
&value,
endpoint.supports_revision,
)
.unwrap();
Message::Binary(value) Message::Binary(value)
}; };
if let Err(error) = if let Err(error) =
@ -355,6 +364,7 @@ pub(crate) fn router(
} }
Err(error) => { Err(error) => {
#[derive(Deserialize)] #[derive(Deserialize)]
#[revisioned(revision = 1)]
struct Response { struct Response {
id: Option<Value>, id: Option<Value>,
} }
@ -363,7 +373,7 @@ pub(crate) fn router(
if let Message::Binary(binary) = message { if let Message::Binary(binary) = message {
if let Ok(Response { if let Ok(Response {
id, id,
}) = deserialize(&binary) }) = deserialize(&mut &binary[..], endpoint.supports_revision)
{ {
// Return an error if an ID was returned // Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) { if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
@ -418,7 +428,11 @@ pub(crate) fn router(
'reconnect: loop { 'reconnect: loop {
trace!("Reconnecting..."); trace!("Reconnecting...");
match WsMeta::connect(&address.url, None).await { let connect = match endpoint.supports_revision {
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
false => WsMeta::connect(&endpoint.url, None).await,
};
match connect {
Ok((mut meta, stream)) => { Ok((mut meta, stream)) => {
socket = stream; socket = stream;
events = { events = {
@ -471,19 +485,21 @@ pub(crate) fn router(
} }
impl Response { impl Response {
fn try_from(message: &Message) -> Result<Option<Self>> { fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
match message { match message {
Message::Text(text) => { Message::Text(text) => {
trace!("Received an unexpected text message; {text}"); trace!("Received an unexpected text message; {text}");
Ok(None) Ok(None)
} }
Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| { Message::Binary(binary) => {
deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
Error::ResponseFromBinary { Error::ResponseFromBinary {
binary: binary.clone(), binary: binary.clone(),
error, error: bincode::ErrorKind::Custom(error.to_string()).into(),
} }
.into() .into()
}), })
}
} }
} }
} }

View file

@ -29,11 +29,7 @@ impl IntoEndpoint<Test> for () {
type Client = Client; type Client = Client;
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
Ok(Endpoint { Ok(Endpoint::new(Url::parse("test://")?))
url: Url::parse("test://")?,
path: String::new(),
config: Default::default(),
})
} }
} }

View file

@ -10,6 +10,7 @@ pub mod opt;
mod conn; mod conn;
pub use method::query::Response; pub use method::query::Response;
use semver::Version;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Router; use crate::api::conn::Router;
@ -32,6 +33,7 @@ use self::opt::EndpointKind;
pub type Result<T> = std::result::Result<T, crate::Error>; pub type Result<T> = std::result::Result<T, crate::Error>;
const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <2.0.0", "20230701.55918b7c"); const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <2.0.0", "20230701.55918b7c");
const REVISION_SUPPORTED_SERVER_VERSION: Version = Version::new(1, 2, 0);
/// Connection trait implemented by supported engines /// Connection trait implemented by supported engines
pub trait Connection: conn::Connection {} pub trait Connection: conn::Connection {}
@ -93,11 +95,19 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let endpoint = self.address?; let mut endpoint = self.address?;
let endpoint_kind = EndpointKind::from(endpoint.url.scheme()); let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
let client = Client::connect(endpoint, self.capacity).await?; let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
if endpoint_kind.is_remote() { if endpoint_kind.is_remote() {
client.check_server_version().await?; let mut version = client.version().await?;
// we would like to be able to connect to pre-releases too
version.pre = Default::default();
client.check_server_version(&version).await?;
if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
// Switch to revision based serialisation
endpoint.supports_revision = true;
client = Client::connect(endpoint, self.capacity).await?;
}
} }
Ok(client) Ok(client)
}) })
@ -117,19 +127,27 @@ where
if self.router.get().is_some() { if self.router.get().is_some() {
return Err(Error::AlreadyConnected.into()); return Err(Error::AlreadyConnected.into());
} }
let endpoint = self.address?; let mut endpoint = self.address?;
let endpoint_kind = EndpointKind::from(endpoint.url.scheme()); let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
let arc = Client::connect(endpoint, self.capacity).await?.router; let mut client = Surreal {
let cell = Arc::into_inner(arc).expect("new connection to have no references"); router: Client::connect(endpoint.clone(), self.capacity).await?.router,
let router = cell.into_inner().expect("router to be set");
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
let client = Surreal {
router: self.router,
engine: PhantomData::<Client>, engine: PhantomData::<Client>,
}; };
if endpoint_kind.is_remote() { if endpoint_kind.is_remote() {
client.check_server_version().await?; let mut version = client.version().await?;
// we would like to be able to connect to pre-releases too
version.pre = Default::default();
client.check_server_version(&version).await?;
if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
// Switch to revision based serialisation
endpoint.supports_revision = true;
client = Client::connect(endpoint, self.capacity).await?;
} }
}
let cell =
Arc::into_inner(client.router).expect("new connection to have no references");
let router = cell.into_inner().expect("router to be set");
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
Ok(()) Ok(())
}) })
} }
@ -151,18 +169,15 @@ impl<C> Surreal<C>
where where
C: Connection, C: Connection,
{ {
async fn check_server_version(&self) -> Result<()> { async fn check_server_version(&self, version: &Version) -> Result<()> {
let (versions, build_meta) = SUPPORTED_VERSIONS; let (versions, build_meta) = SUPPORTED_VERSIONS;
// invalid version requirements should be caught during development // invalid version requirements should be caught during development
let req = VersionReq::parse(versions).expect("valid supported versions"); let req = VersionReq::parse(versions).expect("valid supported versions");
let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata"); let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata");
let mut version = self.version().await?;
// we would like to be able to connect to pre-releases too
version.pre = Default::default();
let server_build = &version.build; let server_build = &version.build;
if !req.matches(&version) { if !req.matches(version) {
return Err(Error::VersionMismatch { return Err(Error::VersionMismatch {
server_version: version, server_version: version.clone(),
supported_versions: versions.to_owned(), supported_versions: versions.to_owned(),
} }
.into()); .into());

View file

@ -16,11 +16,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "fdb://"; let protocol = "fdb://";
Ok(Endpoint { let url = Url::parse(protocol)
url: Url::parse(protocol).unwrap(), .unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
path: super::path_to_string(protocol, self), let mut endpoint = Endpoint::new(url);
config: Default::default(), endpoint.path = super::path_to_string(protocol, self);
}) Ok(endpoint)
} }
} }

View file

@ -17,11 +17,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("http://{self}"); let url = format!("http://{self}");
Ok(Endpoint { Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
} }
} }
@ -40,11 +36,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("https://{self}"); let url = format!("https://{self}");
Ok(Endpoint { Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
} }
} }

View file

@ -14,11 +14,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "indxdb://"; let protocol = "indxdb://";
Ok(Endpoint { let url = Url::parse(protocol)
url: Url::parse(protocol).unwrap(), .unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
path: super::path_to_string(protocol, self), let mut endpoint = Endpoint::new(url);
config: Default::default(), endpoint.path = super::path_to_string(protocol, self);
}) Ok(endpoint)
} }
} }

View file

@ -10,11 +10,12 @@ impl IntoEndpoint<Mem> for () {
type Client = Db; type Client = Db;
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
Ok(Endpoint { let protocol = "mem://";
url: Url::parse("mem://").unwrap(), let url = Url::parse(protocol)
path: "memory".to_owned(), .unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
config: Default::default(), let mut endpoint = Endpoint::new(url);
}) endpoint.path = "memory".to_owned();
Ok(endpoint)
} }
} }

View file

@ -24,7 +24,7 @@ use url::Url;
use super::Config; use super::Config;
/// A server address used to connect to the server /// A server address used to connect to the server
#[derive(Debug)] #[derive(Debug, Clone)]
#[allow(dead_code)] // used by the embedded and remote connections #[allow(dead_code)] // used by the embedded and remote connections
pub struct Endpoint { pub struct Endpoint {
#[doc(hidden)] #[doc(hidden)]
@ -32,9 +32,20 @@ pub struct Endpoint {
#[doc(hidden)] #[doc(hidden)]
pub path: String, pub path: String,
pub(crate) config: Config, pub(crate) config: Config,
// Whether or not the remote server supports revision based serialisation
pub(crate) supports_revision: bool,
} }
impl Endpoint { impl Endpoint {
pub(crate) fn new(url: Url) -> Self {
Self {
url,
path: String::new(),
config: Default::default(),
supports_revision: false,
}
}
#[doc(hidden)] #[doc(hidden)]
pub fn parse_kind(&self) -> Result<EndpointKind> { pub fn parse_kind(&self) -> Result<EndpointKind> {
match EndpointKind::from(self.url.scheme()) { match EndpointKind::from(self.url.scheme()) {
@ -151,6 +162,10 @@ impl EndpointKind {
) )
} }
pub(crate) fn is_ws(&self) -> bool {
matches!(self, EndpointKind::Ws | EndpointKind::Wss)
}
pub fn is_local(&self) -> bool { pub fn is_local(&self) -> bool {
!self.is_remote() !self.is_remote()
} }

View file

@ -17,11 +17,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "rocksdb://"; let protocol = "rocksdb://";
Ok(Endpoint { let url = Url::parse(protocol)
url: Url::parse(protocol).unwrap(), .unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
path: super::path_to_string(protocol, self), let mut endpoint = Endpoint::new(url);
config: Default::default(), endpoint.path = super::path_to_string(protocol, self);
}) Ok(endpoint)
} }
} }
@ -40,11 +40,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "file://"; let protocol = "file://";
Ok(Endpoint { let url = Url::parse(protocol)
url: Url::parse(protocol).unwrap(), .unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
path: super::path_to_string(protocol, self), let mut endpoint = Endpoint::new(url);
config: Default::default(), endpoint.path = super::path_to_string(protocol, self);
}) Ok(endpoint)
} }
} }

View file

@ -16,11 +16,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "speedb://"; let protocol = "speedb://";
Ok(Endpoint { let url = Url::parse(protocol)
url: Url::parse(protocol).unwrap(), .unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
path: super::path_to_string(protocol, self), let mut endpoint = Endpoint::new(url);
config: Default::default(), endpoint.path = super::path_to_string(protocol, self);
}) Ok(endpoint)
} }
} }

View file

@ -1,5 +1,6 @@
use crate::api::engine::local::Db; use crate::api::engine::local::Db;
use crate::api::engine::local::TiKv; use crate::api::engine::local::TiKv;
use crate::api::err::Error;
use crate::api::opt::Config; use crate::api::opt::Config;
use crate::api::opt::Endpoint; use crate::api::opt::Endpoint;
use crate::api::opt::IntoEndpoint; use crate::api::opt::IntoEndpoint;
@ -15,11 +16,9 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("tikv://{self}"); let url = format!("tikv://{self}");
Ok(Endpoint { let mut endpoint = Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url.clone()))?);
url: Url::parse(&url).unwrap(), endpoint.path = url;
path: url, Ok(endpoint)
config: Default::default(),
})
} }
} }

View file

@ -17,11 +17,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("ws://{self}"); let url = format!("ws://{self}");
Ok(Endpoint { Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
} }
} }
@ -40,11 +36,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> { fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("wss://{self}"); let url = format!("wss://{self}");
Ok(Endpoint { Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
} }
} }

View file

@ -1,4 +1,6 @@
use crate::err::Error; use crate::err::Error;
use revision::revisioned;
use revision::Revisioned;
use serde::Serialize; use serde::Serialize;
use std::borrow::Cow; use std::borrow::Cow;
use surrealdb::sql::Value; use surrealdb::sql::Value;
@ -9,6 +11,34 @@ pub struct Failure {
pub(crate) message: Cow<'static, str>, pub(crate) message: Cow<'static, str>,
} }
#[derive(Clone, Debug, Serialize)]
#[revisioned(revision = 1)]
struct Inner {
code: i64,
message: String,
}
impl Revisioned for Failure {
fn serialize_revisioned<W: std::io::Write>(
&self,
writer: &mut W,
) -> Result<(), revision::Error> {
let inner = Inner {
code: self.code,
message: self.message.as_ref().to_owned(),
};
inner.serialize_revisioned(writer)
}
fn deserialize_revisioned<R: std::io::Read>(_reader: &mut R) -> Result<Self, revision::Error> {
unreachable!("deserialization not supported for this type")
}
fn revision() -> u16 {
1
}
}
impl From<&str> for Failure { impl From<&str> for Failure {
fn from(err: &str) -> Self { fn from(err: &str) -> Self {
Failure::custom(err.to_string()) Failure::custom(err.to_string())

View file

@ -2,13 +2,22 @@ use crate::rpc::failure::Failure;
use crate::rpc::request::Request; use crate::rpc::request::Request;
use crate::rpc::response::Response; use crate::rpc::response::Response;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use revision::Revisioned;
use surrealdb::sql::Value;
pub fn req(_msg: Message) -> Result<Request, Failure> { pub fn req(msg: Message) -> Result<Request, Failure> {
// This format is not yet implemented match msg {
Err(Failure::INTERNAL_ERROR) Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice())
.map_err(|_| Failure::PARSE_ERROR)?
.try_into(),
_ => Err(Failure::INVALID_REQUEST),
}
} }
pub fn res(_res: Response) -> Result<(usize, Message), Failure> { pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// This format is not yet implemented // Serialize the response with full internal type information
Err(Failure::INTERNAL_ERROR) let mut buf = Vec::new();
res.serialize_revisioned(&mut buf).unwrap();
// Return the message length, and message as binary
Ok((buf.len(), Message::Binary(buf)))
} }

View file

@ -3,6 +3,7 @@ use crate::rpc::format::Format;
use crate::telemetry::metrics::ws::record_rpc; use crate::telemetry::metrics::ws::record_rpc;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use opentelemetry::Context as TelemetryContext; use opentelemetry::Context as TelemetryContext;
use revision::revisioned;
use serde::Serialize; use serde::Serialize;
use serde_json::Value as Json; use serde_json::Value as Json;
use surrealdb::channel::Sender; use surrealdb::channel::Sender;
@ -16,6 +17,7 @@ use tracing::Span;
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data` // The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
// In future, they will possibly be merged to avoid having to keep them in sync. // In future, they will possibly be merged to avoid having to keep them in sync.
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[revisioned(revision = 1)]
pub enum Data { pub enum Data {
/// Generally methods return a `sql::Value` /// Generally methods return a `sql::Value`
Other(Value), Other(Value),
@ -61,6 +63,7 @@ impl From<Data> for Value {
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[revisioned(revision = 1)]
pub struct Response { pub struct Response {
id: Option<Value>, id: Option<Value>,
result: Result<Data, Failure>, result: Result<Data, Failure>,