Implement revision types for client/server communication (#3474)

This commit is contained in:
Rushmore Mushambi 2024-02-12 13:52:36 +02:00 committed by GitHub
parent 2d2b3a40bb
commit d55d1a3b6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 321 additions and 150 deletions

2
Cargo.lock generated
View file

@ -5278,6 +5278,7 @@ dependencies = [
"rand 0.8.5",
"rcgen",
"reqwest",
"revision",
"rmp-serde",
"rmpv",
"rustyline",
@ -5330,6 +5331,7 @@ dependencies = [
"rand 0.8.5",
"regex",
"reqwest",
"revision",
"ring 0.17.7",
"rust_decimal",
"rustls",

View file

@ -65,6 +65,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [
"blocking",
"gzip",
] }
revision = "0.5.0"
rmpv = "1.0.1"
rustyline = { version = "12.0.0", features = ["derive"] }
serde = { version = "1.0.193", features = ["derive"] }

View file

@ -1,9 +1,11 @@
use crate::sql::{Object, Uuid, Value};
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Debug, Display};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
#[revisioned(revision = 1)]
pub enum Action {
Create,
Update,
@ -21,6 +23,7 @@ impl Display for Action {
}
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
#[revisioned(revision = 1)]
pub struct Notification {
/// The id of the LIVE query to which this notification belongs
pub id: Uuid,

View file

@ -1,5 +1,7 @@
use crate::err::Error;
use crate::sql::value::Value;
use revision::revisioned;
use revision::Revisioned;
use serde::ser::SerializeStruct;
use serde::Deserialize;
use serde::Serialize;
@ -40,6 +42,7 @@ impl Response {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
#[revisioned(revision = 1)]
#[doc(hidden)]
pub enum Status {
Ok,
@ -66,3 +69,49 @@ impl Serialize for Response {
val.end()
}
}
#[derive(Debug, Serialize, Deserialize)]
#[revisioned(revision = 1)]
#[doc(hidden)]
pub struct QueryMethodResponse {
pub time: String,
pub status: Status,
pub result: Value,
}
impl From<&Response> for QueryMethodResponse {
fn from(res: &Response) -> Self {
let time = res.speed();
let (status, result) = match &res.result {
Ok(value) => (Status::Ok, value.clone()),
Err(error) => (Status::Err, Value::from(error.to_string())),
};
Self {
status,
result,
time,
}
}
}
#[doc(hidden)]
impl Revisioned for Response {
#[inline]
fn serialize_revisioned<W: std::io::Write>(
&self,
writer: &mut W,
) -> std::result::Result<(), revision::Error> {
QueryMethodResponse::from(self).serialize_revisioned(writer)
}
#[inline]
fn deserialize_revisioned<R: std::io::Read>(
_reader: &mut R,
) -> std::result::Result<Self, revision::Error> {
unreachable!("deserialising `Response` directly is not supported")
}
fn revision() -> u16 {
1
}
}

View file

@ -92,6 +92,7 @@ reqwest = { version = "0.11.22", default-features = false, features = [
"stream",
"multipart",
], optional = true }
revision = "0.5.0"
rust_decimal = { version = "1.33.1", features = ["maths", "serde-str"] }
rustls = { version = "0.21.10", optional = true }
semver = { version = "1.0.20", features = ["serde"] }

View file

@ -134,11 +134,9 @@ impl IntoEndpoint for &str {
)
}
};
Ok(Endpoint {
url,
path,
config: Default::default(),
})
let mut endpoint = Endpoint::new(url);
endpoint.path = path;
Ok(endpoint)
}
}

View file

@ -175,9 +175,10 @@ impl Connection for Any {
#[cfg(feature = "protocol-ws")]
{
features.insert(ExtraFeatures::LiveQueries);
let url = address.url.join(engine::remote::ws::PATH)?;
let mut endpoint = address;
endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
let maybe_connector = address.config.tls_config.map(Connector::from);
let maybe_connector = endpoint.config.tls_config.clone().map(Connector::from);
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
let maybe_connector = None;
@ -188,13 +189,13 @@ impl Connection for Any {
..Default::default()
};
let socket = engine::remote::ws::native::connect(
&url,
&endpoint,
Some(config),
maybe_connector.clone(),
)
.await?;
engine::remote::ws::native::router(
url,
endpoint,
maybe_connector,
capacity,
config,

View file

@ -151,9 +151,9 @@ impl Connection for Any {
#[cfg(feature = "protocol-ws")]
{
features.insert(ExtraFeatures::LiveQueries);
let mut address = address;
address.url = address.url.join(engine::remote::ws::PATH)?;
engine::remote::ws::wasm::router(address, capacity, conn_tx, route_rx);
let mut endpoint = address;
endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?;
engine::remote::ws::wasm::router(endpoint, capacity, conn_tx, route_rx);
conn_rx.into_recv_async().await??;
}

View file

@ -15,18 +15,24 @@ use crate::api::Connect;
use crate::api::Result;
use crate::api::Surreal;
use crate::dbs::Notification;
use crate::dbs::QueryMethodResponse;
use crate::dbs::Status;
use crate::method::Stats;
use crate::opt::IntoEndpoint;
use crate::sql::Value;
use indexmap::IndexMap;
use revision::revisioned;
use revision::Revisioned;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use std::io::Read;
use std::marker::PhantomData;
use std::time::Duration;
pub(crate) const PATH: &str = "rpc";
const PING_INTERVAL: Duration = Duration::from_secs(5);
const PING_METHOD: &str = "ping";
const REVISION_HEADER: &str = "revision";
/// The WS scheme used to connect to `ws://` endpoints
#[derive(Debug)]
@ -78,12 +84,14 @@ impl Surreal<Client> {
}
#[derive(Clone, Debug, Deserialize)]
#[revisioned(revision = 1)]
pub(crate) struct Failure {
pub(crate) code: i64,
pub(crate) message: String,
}
#[derive(Debug, Deserialize)]
#[revisioned(revision = 1)]
pub(crate) enum Data {
Other(Value),
Query(Vec<QueryMethodResponse>),
@ -104,13 +112,6 @@ impl From<Failure> for Error {
}
}
#[derive(Debug, Deserialize)]
pub(crate) struct QueryMethodResponse {
time: String,
status: Status,
result: Value,
}
impl DbResponse {
fn from(result: ServerResult) -> Result<Self> {
match result.map_err(Error::from)? {
@ -148,7 +149,30 @@ impl DbResponse {
}
#[derive(Debug, Deserialize)]
#[revisioned(revision = 1)]
pub(crate) struct Response {
id: Option<Value>,
pub(crate) result: ServerResult,
}
fn serialize(value: &Value, revisioned: bool) -> Result<Vec<u8>> {
if revisioned {
let mut buf = Vec::new();
value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?;
return Ok(buf);
}
crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into()))
}
fn deserialize<A, T>(bytes: &mut A, revisioned: bool) -> Result<T>
where
A: Read,
T: Revisioned + DeserializeOwned,
{
if revisioned {
return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into()));
}
let mut buf = Vec::new();
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
}

View file

@ -1,4 +1,5 @@
use super::PATH;
use super::{deserialize, serialize};
use crate::api::conn::Connection;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
@ -19,7 +20,6 @@ use crate::api::Result;
use crate::api::Surreal;
use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream;
use crate::sql::serde::{deserialize, serialize};
use crate::sql::Strand;
use crate::sql::Value;
use flume::Receiver;
@ -28,6 +28,7 @@ use futures::SinkExt;
use futures::StreamExt;
use futures_concurrency::stream::Merge as _;
use indexmap::IndexMap;
use revision::revisioned;
use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::collections::BTreeMap;
@ -42,14 +43,16 @@ use std::sync::OnceLock;
use tokio::net::TcpStream;
use tokio::time;
use tokio::time::MissedTickBehavior;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::error::Error as WsError;
use tokio_tungstenite::tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::Connector;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use trice::Instant;
use url::Url;
type WsResult<T> = std::result::Result<T, WsError>;
@ -78,17 +81,29 @@ impl From<Tls> for Connector {
}
pub(crate) async fn connect(
url: &Url,
endpoint: &Endpoint,
config: Option<WebSocketConfig>,
#[allow(unused_variables)] maybe_connector: Option<Connector>,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let mut request = (&endpoint.url).into_client_request()?;
if endpoint.supports_revision {
request
.headers_mut()
.insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static(super::REVISION_HEADER));
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
let (socket, _) =
tokio_tungstenite::connect_async_tls_with_config(url, config, NAGLE_ALG, maybe_connector)
.await?;
let (socket, _) = tokio_tungstenite::connect_async_tls_with_config(
request,
config,
NAGLE_ALG,
maybe_connector,
)
.await?;
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
let (socket, _) = tokio_tungstenite::connect_async_with_config(url, config, NAGLE_ALG).await?;
let (socket, _) = tokio_tungstenite::connect_async_with_config(request, config, NAGLE_ALG).await?;
Ok(socket)
}
@ -104,13 +119,13 @@ impl Connection for Client {
}
fn connect(
address: Endpoint,
mut address: Endpoint,
capacity: usize,
) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
Box::pin(async move {
let url = address.url.join(PATH)?;
address.url = address.url.join(PATH)?;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
let maybe_connector = address.config.tls_config.map(Connector::from);
let maybe_connector = address.config.tls_config.clone().map(Connector::from);
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
let maybe_connector = None;
@ -121,14 +136,14 @@ impl Connection for Client {
..Default::default()
};
let socket = connect(&url, Some(config), maybe_connector.clone()).await?;
let socket = connect(&address, Some(config), maybe_connector.clone()).await?;
let (route_tx, route_rx) = match capacity {
0 => flume::unbounded(),
capacity => flume::bounded(capacity),
};
router(url, maybe_connector, capacity, config, socket, route_rx);
router(address, maybe_connector, capacity, config, socket, route_rx);
let mut features = HashSet::new();
features.insert(ExtraFeatures::LiveQueries);
@ -164,7 +179,7 @@ impl Connection for Client {
#[allow(clippy::too_many_lines)]
pub(crate) fn router(
url: Url,
endpoint: Endpoint,
maybe_connector: Option<Connector>,
capacity: usize,
config: WebSocketConfig,
@ -176,7 +191,7 @@ pub(crate) fn router(
let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into());
let value = Value::from(request);
let value = serialize(&value).unwrap();
let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value)
};
@ -270,7 +285,8 @@ pub(crate) fn router(
}
let payload = Value::from(request);
trace!("Request {payload}");
let payload = serialize(&payload).unwrap();
let payload =
serialize(&payload, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
@ -314,7 +330,7 @@ pub(crate) fn router(
last_activity = Instant::now();
match result {
Ok(message) => {
match Response::try_from(&message) {
match Response::try_from(&message, endpoint.supports_revision) {
Ok(option) => {
// We are only interested in responses that are not empty
if let Some(response) = option {
@ -379,9 +395,12 @@ pub(crate) fn router(
);
let value =
Value::from(request);
let value =
serialize(&value)
.unwrap();
let value = serialize(
&value,
endpoint
.supports_revision,
)
.unwrap();
Message::Binary(value)
};
if let Err(error) =
@ -402,6 +421,7 @@ pub(crate) fn router(
}
Err(error) => {
#[derive(Deserialize)]
#[revisioned(revision = 1)]
struct Response {
id: Option<Value>,
}
@ -410,8 +430,10 @@ pub(crate) fn router(
if let Message::Binary(binary) = message {
if let Ok(Response {
id,
}) = deserialize(&binary)
{
}) = deserialize(
&mut &binary[..],
endpoint.supports_revision,
) {
// Return an error if an ID was returned
if let Some(Ok(id)) =
id.map(Value::coerce_to_i64)
@ -473,7 +495,7 @@ pub(crate) fn router(
'reconnect: loop {
trace!("Reconnecting...");
match connect(&url, Some(config), maybe_connector.clone()).await {
match connect(&endpoint, Some(config), maybe_connector.clone()).await {
Ok(s) => {
socket = s;
for (_, message) in &replay {
@ -512,19 +534,21 @@ pub(crate) fn router(
}
impl Response {
fn try_from(message: &Message) -> Result<Option<Self>> {
fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
match message {
Message::Text(text) => {
trace!("Received an unexpected text message; {text}");
Ok(None)
}
Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| {
Error::ResponseFromBinary {
binary: binary.clone(),
error,
}
.into()
}),
Message::Binary(binary) => {
deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
Error::ResponseFromBinary {
binary: binary.clone(),
error: bincode::ErrorKind::Custom(error.to_string()).into(),
}
.into()
})
}
Message::Ping(..) => {
trace!("Received a ping from the server");
Ok(None)

View file

@ -1,4 +1,5 @@
use super::PATH;
use super::{deserialize, serialize};
use crate::api::conn::Connection;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
@ -17,7 +18,6 @@ use crate::api::Result;
use crate::api::Surreal;
use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream;
use crate::sql::serde::{deserialize, serialize};
use crate::sql::Strand;
use crate::sql::Value;
use flume::Receiver;
@ -29,6 +29,7 @@ use indexmap::IndexMap;
use pharos::Channel;
use pharos::Observable;
use pharos::ObserveConfig;
use revision::revisioned;
use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::collections::BTreeMap;
@ -117,13 +118,17 @@ impl Connection for Client {
}
pub(crate) fn router(
address: Endpoint,
endpoint: Endpoint,
capacity: usize,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Option<Route>>,
) {
spawn_local(async move {
let (mut ws, mut socket) = match WsMeta::connect(&address.url, None).await {
let connect = match endpoint.supports_revision {
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
false => WsMeta::connect(&endpoint.url, None).await,
};
let (mut ws, mut socket) = match connect {
Ok(pair) => pair,
Err(error) => {
let _ = conn_tx.into_send_async(Err(error.into())).await;
@ -151,7 +156,7 @@ pub(crate) fn router(
let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into());
let value = Value::from(request);
let value = serialize(&value).unwrap();
let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value)
};
@ -244,7 +249,7 @@ pub(crate) fn router(
}
let payload = Value::from(request);
trace!("Request {payload}");
let payload = serialize(&payload).unwrap();
let payload = serialize(&payload, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
@ -285,7 +290,7 @@ pub(crate) fn router(
}
Either::Response(message) => {
last_activity = Instant::now();
match Response::try_from(&message) {
match Response::try_from(&message, endpoint.supports_revision) {
Ok(option) => {
// We are only interested in responses that are not empty
if let Some(response) = option {
@ -335,7 +340,11 @@ pub(crate) fn router(
.into(),
);
let value = Value::from(request);
let value = serialize(&value).unwrap();
let value = serialize(
&value,
endpoint.supports_revision,
)
.unwrap();
Message::Binary(value)
};
if let Err(error) =
@ -355,6 +364,7 @@ pub(crate) fn router(
}
Err(error) => {
#[derive(Deserialize)]
#[revisioned(revision = 1)]
struct Response {
id: Option<Value>,
}
@ -363,7 +373,7 @@ pub(crate) fn router(
if let Message::Binary(binary) = message {
if let Ok(Response {
id,
}) = deserialize(&binary)
}) = deserialize(&mut &binary[..], endpoint.supports_revision)
{
// Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
@ -418,7 +428,11 @@ pub(crate) fn router(
'reconnect: loop {
trace!("Reconnecting...");
match WsMeta::connect(&address.url, None).await {
let connect = match endpoint.supports_revision {
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
false => WsMeta::connect(&endpoint.url, None).await,
};
match connect {
Ok((mut meta, stream)) => {
socket = stream;
events = {
@ -471,19 +485,21 @@ pub(crate) fn router(
}
impl Response {
fn try_from(message: &Message) -> Result<Option<Self>> {
fn try_from(message: &Message, supports_revision: bool) -> Result<Option<Self>> {
match message {
Message::Text(text) => {
trace!("Received an unexpected text message; {text}");
Ok(None)
}
Message::Binary(binary) => deserialize(binary).map(Some).map_err(|error| {
Error::ResponseFromBinary {
binary: binary.clone(),
error,
}
.into()
}),
Message::Binary(binary) => {
deserialize(&mut &binary[..], supports_revision).map(Some).map_err(|error| {
Error::ResponseFromBinary {
binary: binary.clone(),
error: bincode::ErrorKind::Custom(error.to_string()).into(),
}
.into()
})
}
}
}
}

View file

@ -29,11 +29,7 @@ impl IntoEndpoint<Test> for () {
type Client = Client;
fn into_endpoint(self) -> Result<Endpoint> {
Ok(Endpoint {
url: Url::parse("test://")?,
path: String::new(),
config: Default::default(),
})
Ok(Endpoint::new(Url::parse("test://")?))
}
}

View file

@ -10,6 +10,7 @@ pub mod opt;
mod conn;
pub use method::query::Response;
use semver::Version;
use crate::api::conn::DbResponse;
use crate::api::conn::Router;
@ -32,6 +33,7 @@ use self::opt::EndpointKind;
pub type Result<T> = std::result::Result<T, crate::Error>;
const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <2.0.0", "20230701.55918b7c");
const REVISION_SUPPORTED_SERVER_VERSION: Version = Version::new(1, 2, 0);
/// Connection trait implemented by supported engines
pub trait Connection: conn::Connection {}
@ -93,11 +95,19 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let endpoint = self.address?;
let mut endpoint = self.address?;
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
let client = Client::connect(endpoint, self.capacity).await?;
let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
if endpoint_kind.is_remote() {
client.check_server_version().await?;
let mut version = client.version().await?;
// we would like to be able to connect to pre-releases too
version.pre = Default::default();
client.check_server_version(&version).await?;
if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
// Switch to revision based serialisation
endpoint.supports_revision = true;
client = Client::connect(endpoint, self.capacity).await?;
}
}
Ok(client)
})
@ -117,19 +127,27 @@ where
if self.router.get().is_some() {
return Err(Error::AlreadyConnected.into());
}
let endpoint = self.address?;
let mut endpoint = self.address?;
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
let arc = Client::connect(endpoint, self.capacity).await?.router;
let cell = Arc::into_inner(arc).expect("new connection to have no references");
let router = cell.into_inner().expect("router to be set");
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
let client = Surreal {
router: self.router,
let mut client = Surreal {
router: Client::connect(endpoint.clone(), self.capacity).await?.router,
engine: PhantomData::<Client>,
};
if endpoint_kind.is_remote() {
client.check_server_version().await?;
let mut version = client.version().await?;
// we would like to be able to connect to pre-releases too
version.pre = Default::default();
client.check_server_version(&version).await?;
if version >= REVISION_SUPPORTED_SERVER_VERSION && endpoint_kind.is_ws() {
// Switch to revision based serialisation
endpoint.supports_revision = true;
client = Client::connect(endpoint, self.capacity).await?;
}
}
let cell =
Arc::into_inner(client.router).expect("new connection to have no references");
let router = cell.into_inner().expect("router to be set");
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
Ok(())
})
}
@ -151,18 +169,15 @@ impl<C> Surreal<C>
where
C: Connection,
{
async fn check_server_version(&self) -> Result<()> {
async fn check_server_version(&self, version: &Version) -> Result<()> {
let (versions, build_meta) = SUPPORTED_VERSIONS;
// invalid version requirements should be caught during development
let req = VersionReq::parse(versions).expect("valid supported versions");
let build_meta = BuildMetadata::new(build_meta).expect("valid supported build metadata");
let mut version = self.version().await?;
// we would like to be able to connect to pre-releases too
version.pre = Default::default();
let server_build = &version.build;
if !req.matches(&version) {
if !req.matches(version) {
return Err(Error::VersionMismatch {
server_version: version,
server_version: version.clone(),
supported_versions: versions.to_owned(),
}
.into());

View file

@ -16,11 +16,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "fdb://";
Ok(Endpoint {
url: Url::parse(protocol).unwrap(),
path: super::path_to_string(protocol, self),
config: Default::default(),
})
let url = Url::parse(protocol)
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
let mut endpoint = Endpoint::new(url);
endpoint.path = super::path_to_string(protocol, self);
Ok(endpoint)
}
}

View file

@ -17,11 +17,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("http://{self}");
Ok(Endpoint {
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
}
}
@ -40,11 +36,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("https://{self}");
Ok(Endpoint {
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
}
}

View file

@ -14,11 +14,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "indxdb://";
Ok(Endpoint {
url: Url::parse(protocol).unwrap(),
path: super::path_to_string(protocol, self),
config: Default::default(),
})
let url = Url::parse(protocol)
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
let mut endpoint = Endpoint::new(url);
endpoint.path = super::path_to_string(protocol, self);
Ok(endpoint)
}
}

View file

@ -10,11 +10,12 @@ impl IntoEndpoint<Mem> for () {
type Client = Db;
fn into_endpoint(self) -> Result<Endpoint> {
Ok(Endpoint {
url: Url::parse("mem://").unwrap(),
path: "memory".to_owned(),
config: Default::default(),
})
let protocol = "mem://";
let url = Url::parse(protocol)
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
let mut endpoint = Endpoint::new(url);
endpoint.path = "memory".to_owned();
Ok(endpoint)
}
}

View file

@ -24,7 +24,7 @@ use url::Url;
use super::Config;
/// A server address used to connect to the server
#[derive(Debug)]
#[derive(Debug, Clone)]
#[allow(dead_code)] // used by the embedded and remote connections
pub struct Endpoint {
#[doc(hidden)]
@ -32,9 +32,20 @@ pub struct Endpoint {
#[doc(hidden)]
pub path: String,
pub(crate) config: Config,
// Whether or not the remote server supports revision based serialisation
pub(crate) supports_revision: bool,
}
impl Endpoint {
pub(crate) fn new(url: Url) -> Self {
Self {
url,
path: String::new(),
config: Default::default(),
supports_revision: false,
}
}
#[doc(hidden)]
pub fn parse_kind(&self) -> Result<EndpointKind> {
match EndpointKind::from(self.url.scheme()) {
@ -151,6 +162,10 @@ impl EndpointKind {
)
}
pub(crate) fn is_ws(&self) -> bool {
matches!(self, EndpointKind::Ws | EndpointKind::Wss)
}
pub fn is_local(&self) -> bool {
!self.is_remote()
}

View file

@ -17,11 +17,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "rocksdb://";
Ok(Endpoint {
url: Url::parse(protocol).unwrap(),
path: super::path_to_string(protocol, self),
config: Default::default(),
})
let url = Url::parse(protocol)
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
let mut endpoint = Endpoint::new(url);
endpoint.path = super::path_to_string(protocol, self);
Ok(endpoint)
}
}
@ -40,11 +40,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "file://";
Ok(Endpoint {
url: Url::parse(protocol).unwrap(),
path: super::path_to_string(protocol, self),
config: Default::default(),
})
let url = Url::parse(protocol)
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
let mut endpoint = Endpoint::new(url);
endpoint.path = super::path_to_string(protocol, self);
Ok(endpoint)
}
}

View file

@ -16,11 +16,11 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let protocol = "speedb://";
Ok(Endpoint {
url: Url::parse(protocol).unwrap(),
path: super::path_to_string(protocol, self),
config: Default::default(),
})
let url = Url::parse(protocol)
.unwrap_or_else(|_| unreachable!("`{protocol}` should be static and valid"));
let mut endpoint = Endpoint::new(url);
endpoint.path = super::path_to_string(protocol, self);
Ok(endpoint)
}
}

View file

@ -1,5 +1,6 @@
use crate::api::engine::local::Db;
use crate::api::engine::local::TiKv;
use crate::api::err::Error;
use crate::api::opt::Config;
use crate::api::opt::Endpoint;
use crate::api::opt::IntoEndpoint;
@ -15,11 +16,9 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("tikv://{self}");
Ok(Endpoint {
url: Url::parse(&url).unwrap(),
path: url,
config: Default::default(),
})
let mut endpoint = Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url.clone()))?);
endpoint.path = url;
Ok(endpoint)
}
}

View file

@ -17,11 +17,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("ws://{self}");
Ok(Endpoint {
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
}
}
@ -40,11 +36,7 @@ macro_rules! endpoints {
fn into_endpoint(self) -> Result<Endpoint> {
let url = format!("wss://{self}");
Ok(Endpoint {
url: Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?,
path: String::new(),
config: Default::default(),
})
Ok(Endpoint::new(Url::parse(&url).map_err(|_| Error::InvalidUrl(url))?))
}
}

View file

@ -1,4 +1,6 @@
use crate::err::Error;
use revision::revisioned;
use revision::Revisioned;
use serde::Serialize;
use std::borrow::Cow;
use surrealdb::sql::Value;
@ -9,6 +11,34 @@ pub struct Failure {
pub(crate) message: Cow<'static, str>,
}
#[derive(Clone, Debug, Serialize)]
#[revisioned(revision = 1)]
struct Inner {
code: i64,
message: String,
}
impl Revisioned for Failure {
fn serialize_revisioned<W: std::io::Write>(
&self,
writer: &mut W,
) -> Result<(), revision::Error> {
let inner = Inner {
code: self.code,
message: self.message.as_ref().to_owned(),
};
inner.serialize_revisioned(writer)
}
fn deserialize_revisioned<R: std::io::Read>(_reader: &mut R) -> Result<Self, revision::Error> {
unreachable!("deserialization not supported for this type")
}
fn revision() -> u16 {
1
}
}
impl From<&str> for Failure {
fn from(err: &str) -> Self {
Failure::custom(err.to_string())

View file

@ -2,13 +2,22 @@ use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use revision::Revisioned;
use surrealdb::sql::Value;
pub fn req(_msg: Message) -> Result<Request, Failure> {
// This format is not yet implemented
Err(Failure::INTERNAL_ERROR)
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Binary(val) => Value::deserialize_revisioned(&mut val.as_slice())
.map_err(|_| Failure::PARSE_ERROR)?
.try_into(),
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(_res: Response) -> Result<(usize, Message), Failure> {
// This format is not yet implemented
Err(Failure::INTERNAL_ERROR)
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Serialize the response with full internal type information
let mut buf = Vec::new();
res.serialize_revisioned(&mut buf).unwrap();
// Return the message length, and message as binary
Ok((buf.len(), Message::Binary(buf)))
}

View file

@ -3,6 +3,7 @@ use crate::rpc::format::Format;
use crate::telemetry::metrics::ws::record_rpc;
use axum::extract::ws::Message;
use opentelemetry::Context as TelemetryContext;
use revision::revisioned;
use serde::Serialize;
use serde_json::Value as Json;
use surrealdb::channel::Sender;
@ -16,6 +17,7 @@ use tracing::Span;
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
// In future, they will possibly be merged to avoid having to keep them in sync.
#[derive(Debug, Serialize)]
#[revisioned(revision = 1)]
pub enum Data {
/// Generally methods return a `sql::Value`
Other(Value),
@ -61,6 +63,7 @@ impl From<Data> for Value {
}
#[derive(Debug, Serialize)]
#[revisioned(revision = 1)]
pub struct Response {
id: Option<Value>,
result: Result<Data, Failure>,