Implement revision types for client/server communication (#3474)
This commit is contained in:
parent
2d2b3a40bb
commit
d55d1a3b6e
25 changed files with 321 additions and 150 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -5278,6 +5278,7 @@ dependencies = [
|
|||
"rand 0.8.5",
|
||||
"rcgen",
|
||||
"reqwest",
|
||||
"revision",
|
||||
"rmp-serde",
|
||||
"rmpv",
|
||||
"rustyline",
|
||||
|
@ -5330,6 +5331,7 @@ dependencies = [
|
|||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"revision",
|
||||
"ring 0.17.7",
|
||||
"rust_decimal",
|
||||
"rustls",
|
||||
|
|
|
@ -65,6 +65,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [
|
|||
"blocking",
|
||||
"gzip",
|
||||
] }
|
||||
revision = "0.5.0"
|
||||
rmpv = "1.0.1"
|
||||
rustyline = { version = "12.0.0", features = ["derive"] }
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use crate::sql::{Object, Uuid, Value};
|
||||
use revision::revisioned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{self, Debug, Display};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
#[revisioned(revision = 1)]
|
||||
pub enum Action {
|
||||
Create,
|
||||
Update,
|
||||
|
@ -21,6 +23,7 @@ impl Display for Action {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
pub struct Notification {
|
||||
/// The id of the LIVE query to which this notification belongs
|
||||
pub id: Uuid,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use crate::err::Error;
|
||||
use crate::sql::value::Value;
|
||||
use revision::revisioned;
|
||||
use revision::Revisioned;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
@ -40,6 +42,7 @@ impl Response {
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
#[revisioned(revision = 1)]
|
||||
#[doc(hidden)]
|
||||
pub enum Status {
|
||||
Ok,
|
||||
|
@ -66,3 +69,49 @@ impl Serialize for Response {
|
|||
val.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
#[doc(hidden)]
|
||||
pub struct QueryMethodResponse {
|
||||
pub time: String,
|
||||
pub status: Status,
|
||||
pub result: Value,
|
||||
}
|
||||
|
||||
impl From<&Response> for QueryMethodResponse {
|
||||
fn from(res: &Response) -> Self {
|
||||
let time = res.speed();
|
||||
let (status, result) = match &res.result {
|
||||
Ok(value) => (Status::Ok, value.clone()),
|
||||
Err(error) => (Status::Err, Value::from(error.to_string())),
|
||||
};
|
||||
Self {
|
||||
status,
|
||||
result,
|
||||
time,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
impl Revisioned for Response {
|
||||
#[inline]
|
||||
fn serialize_revisioned<W: std::io::Write>(
|
||||
&self,
|
||||
writer: &mut W,
|
||||
) -> std::result::Result<(), revision::Error> {
|
||||
QueryMethodResponse::from(self).serialize_revisioned(writer)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn deserialize_revisioned<R: std::io::Read>(
|
||||
_reader: &mut R,
|
||||
) -> std::result::Result<Self, revision::Error> {
|
||||
unreachable!("deserialising `Response` directly is not supported")
|
||||
}
|
||||
|
||||
fn revision() -> u16 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -92,6 +92,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [
|
|||
"stream",
|
||||
"multipart",
|
||||
], optional = true }
|
||||
revision = "0.5.0"
|
||||
rust_decimal = { version = "1.33.1", features = ["maths", "serde-str"] }
|
||||
rustls = { version = "0.21.10", optional = true }
|
||||
semver = { version = "1.0.20", features = ["serde"] }
|
||||
|
|
|
@ -134,11 +134,9 @@ impl IntoEndpoint for &str {
|
|||
)
|
||||
}
|
||||
};
|
||||
Ok(Endpoint {
|
||||
url,
|
||||
path,
|
||||
config: Default::default(),
|
||||
})
|
||||
let mut endpoint = Endpoint::new(url);
|
||||
endpoint.path = path;
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -175,9 +175,10 @@ impl Connection for Any {
|
|||
#[cfg(feature = "protocol-ws")]
|
||||
{
|
||||
features.insert(ExtraFeatures::LiveQueries);
|
||||
let url = address.url.join(engine::remote::ws::PATH)?;
|
||||
let mut endpoint = address;
|
||||
endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?;
|
||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||
let maybe_connector = address.config.tls_config.map(Connector::from);
|
||||
let maybe_connector = endpoint.config.tls_config.clone().map(Connector::from);
|
||||
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
|
||||
let maybe_connector = None;
|
||||
|
||||
|
@ -188,13 +189,13 @@ impl Connection for Any {
|
|||
..Default::default()
|
||||
};
|
||||
let socket = engine::remote::ws::native::connect(
|
||||
&url,
|
||||
&endpoint,
|
||||
Some(config),
|
||||
maybe_connector.clone(),
|
||||
)
|
||||
.await?;
|
||||
engine::remote::ws::native::router(
|
||||
url,
|
||||
endpoint,
|
||||
maybe_connector,
|
||||
capacity,
|
||||
config,
|
||||
|
|
|
@ -151,9 +151,9 @@ impl Connection for Any {
|
|||
#[cfg(feature = "protocol-ws")]
|
||||
{
|
||||
features.insert(ExtraFeatures::LiveQueries);
|
||||
let mut address = address;
|
||||
address.url = address.url.join(engine::remote::ws::PATH)?;
|
||||
engine::remote::ws::wasm::router(address, capacity, conn_tx, route_rx);
|
||||
let mut endpoint = address;
|
||||
endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?;
|
||||
engine::remote::ws::wasm::router(endpoint, capacity, conn_tx, route_rx);
|
||||
conn_rx.into_recv_async().await??;
|
||||
}
|
||||
|
||||
|
|
|
@ -15,18 +15,24 @@ use crate::api::Connect;
|
|||
use crate::api::Result;
|
||||
use crate::api::Surreal;
|
||||
use crate::dbs::Notification;
|
||||
use crate::dbs::QueryMethodResponse;
|
||||
use crate::dbs::Status;
|
||||
use crate::method::Stats;
|
||||
use crate::opt::IntoEndpoint;
|
||||
use crate::sql::Value;
|
||||
use indexmap::IndexMap;
|
||||
use revision::revisioned;
|
||||
use revision::Revisioned;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Deserialize;
|
||||
use std::io::Read;
|
||||
use std::marker::PhantomData;
|
||||
use std::time::Duration;
|
||||
|
||||
pub(crate) const PATH: &str = "rpc";
|
||||
const PING_INTERVAL: Duration = Duration::from_secs(5);
|
||||
const PING_METHOD: &str = "ping";
|
||||
const REVISION_HEADER: &str = "revision";
|
||||
|
||||
/// The WS scheme used to connect to `ws://` endpoints
|
||||
#[derive(Debug)]
|
||||
|
@ -78,12 +84,14 @@ impl Surreal<Client> {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
pub(crate) struct Failure {
|
||||
pub(crate) code: i64,
|
||||
pub(crate) message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
pub(crate) enum Data {
|
||||
Other(Value),
|
||||
Query(Vec<QueryMethodResponse>),
|
||||
|
@ -104,13 +112,6 @@ impl From<Failure> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct QueryMethodResponse {
|
||||
time: String,
|
||||
status: Status,
|
||||
result: Value,
|
||||
}
|
||||
|
||||
impl DbResponse {
|
||||
fn from(result: ServerResult) -> Result<Self> {
|
||||
match result.map_err(Error::from)? {
|
||||
|
@ -148,7 +149,30 @@ impl DbResponse {
|
|||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
pub(crate) struct Response {
|
||||
id: Option<Value>,
|
||||
pub(crate) result: ServerResult,
|
||||
}
|
||||
|
||||
fn serialize(value: &Value, revisioned: bool) -> Result<Vec<u8>> {
|
||||
if revisioned {
|
||||
let mut buf = Vec::new();
|
||||
value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?;
|
||||
return Ok(buf);
|
||||
}
|
||||
crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into()))
|
||||
}
|
||||
|
||||
fn deserialize<A, T>(bytes: &mut A, revisioned: bool) -> Result<T>
|
||||
where
|
||||
A: Read,
|
||||
T: Revisioned + DeserializeOwned,
|
||||
{
|
||||
if revisioned {
|
||||
return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into()));
|
||||
}
|
||||
let mut buf = Vec::new();
|
||||
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
|
||||
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::PATH;
|
||||
use super::{deserialize, serialize};
|
||||
use crate::api::conn::Connection;
|
||||
use crate::api::conn::DbResponse;
|
||||
use crate::api::conn::Method;
|
||||
|
@ -19,7 +20,6 @@ use crate::api::Result;
|
|||
use crate::api::Surreal;
|
||||
use crate::engine::remote::ws::Data;
|
||||
use crate::engine::IntervalStream;
|
||||
use crate::sql::serde::{deserialize, serialize};
|
||||
use crate::sql::Strand;
|
||||
use crate::sql::Value;
|
||||
use flume::Receiver;
|
||||
|
@ -28,6 +28,7 @@ use futures::SinkExt;
|
|||
use futures::StreamExt;
|
||||
use futures_concurrency::stream::Merge as _;
|
||||
use indexmap::IndexMap;
|
||||
use revision::revisioned;
|
||||
use serde::Deserialize;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::BTreeMap;
|
||||
|
@ -42,14 +43,16 @@ use std::sync::OnceLock;
|
|||
use tokio::net::TcpStream;
|
||||
use tokio::time;
|
||||
use tokio::time::MissedTickBehavior;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_tungstenite::tungstenite::error::Error as WsError;
|
||||
use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
|
||||
use tokio_tungstenite::tungstenite::http::HeaderValue;
|
||||
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::Connector;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use trice::Instant;
|
||||
use url::Url;
|
||||
|
||||
type WsResult<T> = std::result::Result<T, WsError>;
|
||||
|
||||
|
@ -78,17 +81,29 @@ impl From<Tls> for Connector {
|
|||
}
|
||||
|
||||
pub(crate) async fn connect(
|
||||
url: &Url,
|
||||
endpoint: &Endpoint,
|
||||
config: Option<WebSocketConfig>,
|
||||
#[allow(unused_variables)] maybe_connector: Option<Connector>,
|
||||
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
|
||||
let mut request = (&endpoint.url).into_client_request()?;
|
||||
|
||||
if endpoint.supports_revision {
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static(super::REVISION_HEADER));
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||
let (socket, _) =
|
||||
tokio_tungstenite::connect_async_tls_with_config(url, config, NAGLE_ALG, maybe_connector)
|
||||
.await?;
|
||||
let (socket, _) = tokio_tungstenite::connect_async_tls_with_config(
|
||||
request,
|
||||
config,
|
||||
NAGLE_ALG,
|
||||
maybe_connector,
|
||||
)
|
||||
.await?;
|
||||
|
||||
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
|
||||
let (socket, _) = tokio_tungstenite::connect_async_with_config(url, config, NAGLE_ALG).await?;
|
||||
let (socket, _) = tokio_tungstenite::connect_async_with_config(request, config, NAGLE_ALG).await?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
@ -104,13 +119,13 @@ impl Connection for Client {
|
|||
}
|
||||
|
||||
fn connect(
|
||||
address: Endpoint,
|
||||
mut address: Endpoint,
|
||||
capacity: usize,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
|
||||
Box::pin(async move {
|
||||
let url = address.url.join(PATH)?;
|
||||
address.url = address.url.join(PATH)?;
|
||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||
let maybe_connector = address.config.tls_config.map(Connector::from);
|
||||
let maybe_connector = address.config.tls_config.clone().map(Connector::from);
|
||||
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
|
||||
let maybe_connector = None;
|
||||
|
||||
|
@ -121,14 +136,14 @@ impl Connection for Client {
|
|||
..Default::default()
|
||||
};
|
||||
|
||||
let socket = connect(&url, Some(config), maybe_connector.clone()).await?;
|
||||
let socket = connect(&address, Some(config), maybe_connector.clone()).await?;
|
||||
|
||||
let (route_tx, route_rx) = match capacity {
|
||||
0 => flume::unbounded(),
|
||||
capacity => flume::bounded(capacity),
|
||||
};
|
||||
|
||||
router(url, maybe_connector, capacity, config, socket, route_rx);
|
||||
router(address, maybe_connector, capacity, config, socket, route_rx);
|
||||
|
||||
let mut features = HashSet::new();
|
||||
features.insert(ExtraFeatures::LiveQueries);
|
||||
|
@ -164,7 +179,7 @@ impl Connection for Client {
|
|||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) fn router(
|
||||
url: Url,
|
||||
endpoint: Endpoint,
|
||||
maybe_connector: Option<Connector>,
|
||||
capacity: usize,
|
||||
config: WebSocketConfig,
|
||||
|
@ -176,7 +191,7 @@ pub(crate) fn router(
|
|||
let mut request = BTreeMap::new();
|
||||
request.insert("method".to_owned(), PING_METHOD.into());
|
||||
let value = Value::from(request);
|
||||
let value = serialize(&value).unwrap();
|
||||
let value = serialize(&value, endpoint.supports_revision).unwrap();
|
||||
Message::Binary(value)
|
||||
};
|
||||
|
||||
|
@ -270,7 +285,8 @@ pub(crate) fn router(
|
|||
}
|
||||
let payload = Value::from(request);
|
||||
trace!("Request {payload}");
|
||||
let payload = serialize(&payload).unwrap();
|
||||
let payload =
|
||||
serialize(&payload, endpoint.supports_revision).unwrap();
|
||||
Message::Binary(payload)
|
||||
};
|
||||
if let Method::Authenticate
|
||||
|
@ -314,7 +330,7 @@ pub(crate) fn router(
|
|||
last_activity = Instant::now();
|
||||
match result {
|
||||
Ok(message) => {
|
||||
match Response::try_from(&message) {
|
||||
match Response::try_from(&message, endpoint.supports_revision) {
|
||||
Ok(option) => {
|
||||
// We are only interested in responses that are not empty
|
||||
if let Some(response) = option {
|
||||
|
@ -379,9 +395,12 @@ pub(crate) fn router(
|
|||
);
|
||||
let value =
|
||||
Value::from(request);
|
||||
let value =
|
||||
serialize(&value)
|
||||
.unwrap();
|
||||
let value = serialize(
|
||||
&value,
|
||||
endpoint
|
||||
.supports_revision,
|
||||
)
|
||||
.unwrap();
|
||||
Message::Binary(value)
|
||||
};
|
||||
if let Err(error) =
|
||||
|
@ -402,6 +421,7 @@ pub(crate) fn router(
|
|||
}
|
||||
Err(error) => {
|
||||
#[derive(Deserialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
struct Response {
|
||||
id: Option<Value>,
|
||||
}
|
||||
|
@ -410,8 +430,10 @@ pub(crate) fn router(
|
|||
if let Message::Binary(binary) = message {
|
||||
if let Ok(Response {
|
||||
id,
|
||||
}) = deserialize(&binary)
|
||||
{
|
||||
}) = deserialize(
|
||||
&mut &binary[..],
|
||||
endpoint.supports_revision,
|
||||
) {
|
||||
// Return an error if an ID was returned
|
||||
if let Some(Ok(id)) =
|
||||
id.map(Value::coerce_to_i64)
|
||||
|
@ -473,7 +495,7 @@ pub(crate) fn router(
|
|||
|
||||
'reconnect: loop {
|
||||
trace!("Reconnecting...");
|
||||
match connect(&url, Some(config), maybe_connector.clone()).await {
|
||||
match connect(&endpoint, Some(config), maybe_connector.clone()).await {
|
||||
Ok(s) => {
|
||||
socket = s;
|
||||
for (_, message) in &replay {
|
||||
|
@ -512,19 +534,21 @@ pub(crate) fn router(
|
|||
}
|
||||
|
||||
impl Response {
|
||||
fn try_from(message: &Message) -> Result<Option<Self>> {
|
||||
fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
|
||||
match message {
|
||||
Message::Text(text) => {
|
||||
trace!("Received an unexpected text message; {text}");
|
||||
Ok(None)
|
||||
}
|
||||
Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| {
|
||||
Error::ResponseFromBinary {
|
||||
binary: binary.clone(),
|
||||
error,
|
||||
}
|
||||
.into()
|
||||
}),
|
||||
Message::Binary(binary) => {
|
||||
deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
|
||||
Error::ResponseFromBinary {
|
||||
binary: binary.clone(),
|
||||
error: bincode::ErrorKind::Custom(error.to_string()).into(),
|
||||
}
|
||||
.into()
|
||||
})
|
||||
}
|
||||
Message::Ping(..) => {
|
||||
trace!("Received a ping from the server");
|
||||
Ok(None)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::PATH;
|
||||
use super::{deserialize, serialize};
|
||||
use crate::api::conn::Connection;
|
||||
use crate::api::conn::DbResponse;
|
||||
use crate::api::conn::Method;
|
||||
|
@ -17,7 +18,6 @@ use crate::api::Result;
|
|||
use crate::api::Surreal;
|
||||
use crate::engine::remote::ws::Data;
|
||||
use crate::engine::IntervalStream;
|
||||
use crate::sql::serde::{deserialize, serialize};
|
||||
use crate::sql::Strand;
|
||||
use crate::sql::Value;
|
||||
use flume::Receiver;
|
||||
|
@ -29,6 +29,7 @@ use indexmap::IndexMap;
|
|||
use pharos::Channel;
|
||||
use pharos::Observable;
|
||||
use pharos::ObserveConfig;
|
||||
use revision::revisioned;
|
||||
use serde::Deserialize;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::BTreeMap;
|
||||
|
@ -117,13 +118,17 @@ impl Connection for Client {
|
|||
}
|
||||
|
||||
pub(crate) fn router(
|
||||
address: Endpoint,
|
||||
endpoint: Endpoint,
|
||||
capacity: usize,
|
||||
conn_tx: Sender<Result<()>>,
|
||||
route_rx: Receiver<Option<Route>>,
|
||||
) {
|
||||
spawn_local(async move {
|
||||
let (mut ws, mut socket) = match WsMeta::connect(&address.url, None).await {
|
||||
let connect = match endpoint.supports_revision {
|
||||
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
|
||||
false => WsMeta::connect(&endpoint.url, None).await,
|
||||
};
|
||||
let (mut ws, mut socket) = match connect {
|
||||
Ok(pair) => pair,
|
||||
Err(error) => {
|
||||
let _ = conn_tx.into_send_async(Err(error.into())).await;
|
||||
|
@ -151,7 +156,7 @@ pub(crate) fn router(
|
|||
let mut request = BTreeMap::new();
|
||||
request.insert("method".to_owned(), PING_METHOD.into());
|
||||
let value = Value::from(request);
|
||||
let value = serialize(&value).unwrap();
|
||||
let value = serialize(&value, endpoint.supports_revision).unwrap();
|
||||
Message::Binary(value)
|
||||
};
|
||||
|
||||
|
@ -244,7 +249,7 @@ pub(crate) fn router(
|
|||
}
|
||||
let payload = Value::from(request);
|
||||
trace!("Request {payload}");
|
||||
let payload = serialize(&payload).unwrap();
|
||||
let payload = serialize(&payload, endpoint.supports_revision).unwrap();
|
||||
Message::Binary(payload)
|
||||
};
|
||||
if let Method::Authenticate
|
||||
|
@ -285,7 +290,7 @@ pub(crate) fn router(
|
|||
}
|
||||
Either::Response(message) => {
|
||||
last_activity = Instant::now();
|
||||
match Response::try_from(&message) {
|
||||
match Response::try_from(&message, endpoint.supports_revision) {
|
||||
Ok(option) => {
|
||||
// We are only interested in responses that are not empty
|
||||
if let Some(response) = option {
|
||||
|
@ -335,7 +340,11 @@ pub(crate) fn router(
|
|||
.into(),
|
||||
);
|
||||
let value = Value::from(request);
|
||||
let value = serialize(&value).unwrap();
|
||||
let value = serialize(
|
||||
&value,
|
||||
endpoint.supports_revision,
|
||||
)
|
||||
.unwrap();
|
||||
Message::Binary(value)
|
||||
};
|
||||
if let Err(error) =
|
||||
|
@ -355,6 +364,7 @@ pub(crate) fn router(
|
|||
}
|
||||
Err(error) => {
|
||||
#[derive(Deserialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
struct Response {
|
||||
id: Option<Value>,
|
||||
}
|
||||
|
@ -363,7 +373,7 @@ pub(crate) fn router(
|
|||
if let Message::Binary(binary) = message {
|
||||
if let Ok(Response {
|
||||
id,
|
||||
}) = deserialize(&binary)
|
||||
}) = deserialize(&mut &binary[..], endpoint.supports_revision)
|
||||
{
|
||||
// Return an error if an ID was returned
|
||||
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
|
||||
|
@ -418,7 +428,11 @@ pub(crate) fn router(
|
|||
|
||||
'reconnect: loop {
|
||||
trace!("Reconnecting...");
|
||||
match WsMeta::connect(&address.url, None).await {
|
||||
let connect = match endpoint.supports_revision {
|
||||
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
|
||||
false => WsMeta::connect(&endpoint.url, None).await,
|
||||
};
|
||||
match connect {
|
||||
Ok((mut meta, stream)) => {
|
||||
socket = stream;
|
||||
events = {
|
||||
|
@ -471,19 +485,21 @@ pub(crate) fn router(
|
|||
}
|
||||
|
||||
impl Response {
|
||||
fn try_from(message: &Message) -> Result<Option<Self>> {
|
||||
fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
|
||||
match message {
|
||||
Message::Text(text) => {
|
||||
trace!("Received an unexpected text message; {text}");
|
||||
Ok(None)
|
||||
}
|
||||
Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| {
|
||||
Error::ResponseFromBinary {
|
||||
binary: binary.clone(),
|
||||
error,
|
||||
}
|
||||
.into()
|
||||
}),
|
||||
Message::Binary(binary) => {
|
||||
deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
|
||||
Error::ResponseFromBinary {
|
||||
binary: binary.clone(),
|
||||
error: bincode::ErrorKind::Custom(error.to_string()).into(),
|
||||
}
|
||||
.into()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,11 +29,7 @@ impl IntoEndpoint<Test> for () {
|
|||
type Client = Client;
|
||||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
Ok(Endpoint {
|
||||
url: Url::parse("test://")?,
|
||||
path: String::new(),
|
||||
config: Default::default(),
|
||||
})
|
||||
Ok(Endpoint::new(Url::parse("test://")?))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ pub mod opt;
|
|||
mod conn;
|
||||
|
||||
pub use method::query::Response;
|
||||
use semver::Version;
|
||||
|
||||
use crate::api::conn::DbResponse;
|
||||
use crate::api::conn::Router;
|
||||
|
@ -32,6 +33,7 @@ use self::opt::EndpointKind;
|
|||
pub type Result<T> = std::result::Result<T, crate::Error>;
|
||||
|
||||
const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <2.0.0", "20230701.55918b7c");
|
||||
const REVISION_SUPPORTED_SERVER_VERSION: Version = Version::new(1, 2, 0);
|
||||
|
||||
/// Connection trait implemented by supported engines
|
||||
pub trait Connection: conn::Connection {}
|
||||
|
@ -93,11 +95,19 @@ where
|
|||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
Box::pin(async move {
|
||||
let endpoint = self.address?;
|
||||
let mut endpoint = self.address?;
|
||||
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
|
||||
let client = Client::connect(endpoint, self.capacity).await?;
|
||||
let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
|
||||
if endpoint_kind.is_remote() {
|
||||
client.check_server_version().await?;
|
||||
let mut version = client.version().await?;
|
||||
// we would like to be able to connect to pre-releases too
|
||||
version.pre = Default::default();
|
||||
client.check_server_version(&version).await?;
|
||||
if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
|
||||
// Switch to revision based serialisation
|
||||
endpoint.supports_revision = true;
|
||||
client = Client::connect(endpoint, self.capacity).await?;
|
||||
}
|
||||
}
|
||||
Ok(client)
|
||||
})
|
||||
|
@ -117,19 +127,27 @@ where
|
|||
if self.router.get().is_some() {
|
||||
return Err(Error::AlreadyConnected.into());
|
||||
}
|
||||
let endpoint = self.address?;
|
||||
let mut endpoint = self.address?;
|
||||
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
|
||||
let arc = Client::connect(endpoint, self.capacity).await?.router;
|
||||
let cell = Arc::into_inner(arc).expect("new connection to have no references");
|
||||
let router = cell.into_inner().expect("router to be set");
|
||||
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
|
||||
let client = Surreal {
|
||||
router: self.router,
|
||||
let mut client = Surreal {
|
||||
router: Client::connect(endpoint.clone(), self.capacity).await?.router,
|
||||
engine: PhantomData::<Client>,
|
||||
};
|
||||
if endpoint_kind.is_remote() {
|
||||
client.check_server_version().await?;
|
||||
let mut version = client.version().await?;
|
||||
// we would like to be able to connect to pre-releases too
|
||||
version.pre = Default::default();
|
||||
client.check_server_version(&version).await?;
|
||||
if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
|
||||
// Switch to revision based serialisation
|
||||
endpoint.supports_revision = true;
|
||||
client = Client::connect(endpoint, self.capacity).await?;
|
||||
}
|
||||
}
|
||||
let cell =
|
||||
Arc::into_inner(client.router).expect("new connection to have no references");
|
||||
let router = cell.into_inner().expect("router to be set");
|
||||
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
@ -151,18 +169,15 @@ impl<C> Surreal<C>
|
|||
where
|
||||
C: Connection,
|
||||
{
|
||||
async fn check_server_version(&self) -> Result<()> {
|
||||
async fn check_server_version(&self, version: &Version) -> Result<()> {
|
||||
let (versions, build_meta) = SUPPORTED_VERSIONS;
|
||||
// invalid version requirements should be caught during development
|
||||
let req = VersionReq::parse(versions).expect("valid supported versions");
|
||||
let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata");
|
||||
let mut version = self.version().await?;
|
||||
// we would like to be able to connect to pre-releases too
|
||||
version.pre = Default::default();
|
||||
let server_build = &version.build;
|
||||
if !req.matches(&version) {
|
||||
if !req.matches(version) {
|
||||
return Err(Error::VersionMismatch {
|
||||
server_version: version,
|
||||
server_version: version.clone(),
|
||||
supported_versions: versions.to_owned(),
|
||||
}
|
||||
.into());
|
||||
|
|
|
@ -16,11 +16,11 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let protocol = "fdb://";
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(protocol).unwrap(),
|
||||
path: super::path_to_string(protocol, self),
|
||||
config: Default::default(),
|
||||
})
|
||||
let url = Url::parse(protocol)
|
||||
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
|
||||
let mut endpoint = Endpoint::new(url);
|
||||
endpoint.path = super::path_to_string(protocol, self);
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,11 +17,7 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let url = format!("http://{self}");
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
|
||||
path: String::new(),
|
||||
config: Default::default(),
|
||||
})
|
||||
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,11 +36,7 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let url = format!("https://{self}");
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
|
||||
path: String::new(),
|
||||
config: Default::default(),
|
||||
})
|
||||
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,11 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let protocol = "indxdb://";
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(protocol).unwrap(),
|
||||
path: super::path_to_string(protocol, self),
|
||||
config: Default::default(),
|
||||
})
|
||||
let url = Url::parse(protocol)
|
||||
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
|
||||
let mut endpoint = Endpoint::new(url);
|
||||
endpoint.path = super::path_to_string(protocol, self);
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,11 +10,12 @@ impl IntoEndpoint<Mem> for () {
|
|||
type Client = Db;
|
||||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
Ok(Endpoint {
|
||||
url: Url::parse("mem://").unwrap(),
|
||||
path: "memory".to_owned(),
|
||||
config: Default::default(),
|
||||
})
|
||||
let protocol = "mem://";
|
||||
let url = Url::parse(protocol)
|
||||
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
|
||||
let mut endpoint = Endpoint::new(url);
|
||||
endpoint.path = "memory".to_owned();
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ use url::Url;
|
|||
use super::Config;
|
||||
|
||||
/// A server address used to connect to the server
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)] // used by the embedded and remote connections
|
||||
pub struct Endpoint {
|
||||
#[doc(hidden)]
|
||||
|
@ -32,9 +32,20 @@ pub struct Endpoint {
|
|||
#[doc(hidden)]
|
||||
pub path: String,
|
||||
pub(crate) config: Config,
|
||||
// Whether or not the remote server supports revision based serialisation
|
||||
pub(crate) supports_revision: bool,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub(crate) fn new(url: Url) -> Self {
|
||||
Self {
|
||||
url,
|
||||
path: String::new(),
|
||||
config: Default::default(),
|
||||
supports_revision: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn parse_kind(&self) -> Result<EndpointKind> {
|
||||
match EndpointKind::from(self.url.scheme()) {
|
||||
|
@ -151,6 +162,10 @@ impl EndpointKind {
|
|||
)
|
||||
}
|
||||
|
||||
pub(crate) fn is_ws(&self) -> bool {
|
||||
matches!(self, EndpointKind::Ws | EndpointKind::Wss)
|
||||
}
|
||||
|
||||
pub fn is_local(&self) -> bool {
|
||||
!self.is_remote()
|
||||
}
|
||||
|
|
|
@ -17,11 +17,11 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let protocol = "rocksdb://";
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(protocol).unwrap(),
|
||||
path: super::path_to_string(protocol, self),
|
||||
config: Default::default(),
|
||||
})
|
||||
let url = Url::parse(protocol)
|
||||
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
|
||||
let mut endpoint = Endpoint::new(url);
|
||||
endpoint.path = super::path_to_string(protocol, self);
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,11 +40,11 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let protocol = "file://";
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(protocol).unwrap(),
|
||||
path: super::path_to_string(protocol, self),
|
||||
config: Default::default(),
|
||||
})
|
||||
let url = Url::parse(protocol)
|
||||
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
|
||||
let mut endpoint = Endpoint::new(url);
|
||||
endpoint.path = super::path_to_string(protocol, self);
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,11 +16,11 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let protocol = "speedb://";
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(protocol).unwrap(),
|
||||
path: super::path_to_string(protocol, self),
|
||||
config: Default::default(),
|
||||
})
|
||||
let url = Url::parse(protocol)
|
||||
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
|
||||
let mut endpoint = Endpoint::new(url);
|
||||
endpoint.path = super::path_to_string(protocol, self);
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::api::engine::local::Db;
|
||||
use crate::api::engine::local::TiKv;
|
||||
use crate::api::err::Error;
|
||||
use crate::api::opt::Config;
|
||||
use crate::api::opt::Endpoint;
|
||||
use crate::api::opt::IntoEndpoint;
|
||||
|
@ -15,11 +16,9 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let url = format!("tikv://{self}");
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(&url).unwrap(),
|
||||
path: url,
|
||||
config: Default::default(),
|
||||
})
|
||||
let mut endpoint = Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url.clone()))?);
|
||||
endpoint.path = url;
|
||||
Ok(endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,11 +17,7 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let url = format!("ws://{self}");
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
|
||||
path: String::new(),
|
||||
config: Default::default(),
|
||||
})
|
||||
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,11 +36,7 @@ macro_rules! endpoints {
|
|||
|
||||
fn into_endpoint(self) -> Result<Endpoint> {
|
||||
let url = format!("wss://{self}");
|
||||
Ok(Endpoint {
|
||||
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
|
||||
path: String::new(),
|
||||
config: Default::default(),
|
||||
})
|
||||
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use crate::err::Error;
|
||||
use revision::revisioned;
|
||||
use revision::Revisioned;
|
||||
use serde::Serialize;
|
||||
use std::borrow::Cow;
|
||||
use surrealdb::sql::Value;
|
||||
|
@ -9,6 +11,34 @@ pub struct Failure {
|
|||
pub(crate) message: Cow<'static, str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
struct Inner {
|
||||
code: i64,
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl Revisioned for Failure {
|
||||
fn serialize_revisioned<W: std::io::Write>(
|
||||
&self,
|
||||
writer: &mut W,
|
||||
) -> Result<(), revision::Error> {
|
||||
let inner = Inner {
|
||||
code: self.code,
|
||||
message: self.message.as_ref().to_owned(),
|
||||
};
|
||||
inner.serialize_revisioned(writer)
|
||||
}
|
||||
|
||||
fn deserialize_revisioned<R: std::io::Read>(_reader: &mut R) -> Result<Self, revision::Error> {
|
||||
unreachable!("deserialization not supported for this type")
|
||||
}
|
||||
|
||||
fn revision() -> u16 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Failure {
|
||||
fn from(err: &str) -> Self {
|
||||
Failure::custom(err.to_string())
|
||||
|
|
|
@ -2,13 +2,22 @@ use crate::rpc::failure::Failure;
|
|||
use crate::rpc::request::Request;
|
||||
use crate::rpc::response::Response;
|
||||
use axum::extract::ws::Message;
|
||||
use revision::Revisioned;
|
||||
use surrealdb::sql::Value;
|
||||
|
||||
pub fn req(_msg: Message) -> Result<Request, Failure> {
|
||||
// This format is not yet implemented
|
||||
Err(Failure::INTERNAL_ERROR)
|
||||
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||
match msg {
|
||||
Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice())
|
||||
.map_err(|_| Failure::PARSE_ERROR)?
|
||||
.try_into(),
|
||||
_ => Err(Failure::INVALID_REQUEST),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn res(_res: Response) -> Result<(usize, Message), Failure> {
|
||||
// This format is not yet implemented
|
||||
Err(Failure::INTERNAL_ERROR)
|
||||
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||
// Serialize the response with full internal type information
|
||||
let mut buf = Vec::new();
|
||||
res.serialize_revisioned(&mut buf).unwrap();
|
||||
// Return the message length, and message as binary
|
||||
Ok((buf.len(), Message::Binary(buf)))
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate::rpc::format::Format;
|
|||
use crate::telemetry::metrics::ws::record_rpc;
|
||||
use axum::extract::ws::Message;
|
||||
use opentelemetry::Context as TelemetryContext;
|
||||
use revision::revisioned;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as Json;
|
||||
use surrealdb::channel::Sender;
|
||||
|
@ -16,6 +17,7 @@ use tracing::Span;
|
|||
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
|
||||
// In future, they will possibly be merged to avoid having to keep them in sync.
|
||||
#[derive(Debug, Serialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
pub enum Data {
|
||||
/// Generally methods return a `sql::Value`
|
||||
Other(Value),
|
||||
|
@ -61,6 +63,7 @@ impl From<Data> for Value {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[revisioned(revision = 1)]
|
||||
pub struct Response {
|
||||
id: Option<Value>,
|
||||
result: Result<Data, Failure>,
|
||||
|
|
Loading…
Reference in a new issue