Improve WebSocket protocol implementation (#3291)

This commit is contained in:
Tobie Morgan Hitchcock 2024-01-09 15:27:03 +00:00 committed by GitHub
parent 8e9bd3a2d6
commit a35fc0d04d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 2300 additions and 2338 deletions

23
Cargo.lock generated
View file

@ -4445,6 +4445,16 @@ dependencies = [
"serde",
]
[[package]]
name = "rmpv"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e0e0214a4a2b444ecce41a4025792fc31f77c7bb89c46d253953ea8c65701ec"
dependencies = [
"num-traits",
"rmp",
]
[[package]]
name = "roaring"
version = "0.10.2"
@ -4817,16 +4827,6 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.193"
@ -5235,6 +5235,7 @@ dependencies = [
"axum-server",
"base64 0.21.5",
"bytes",
"ciborium",
"clap",
"env_logger",
"futures",
@ -5257,10 +5258,10 @@ dependencies = [
"rcgen",
"reqwest",
"rmp-serde",
"rmpv",
"rustyline",
"semver",
"serde",
"serde_cbor",
"serde_json",
"serial_test",
"surrealdb",

View file

@ -40,6 +40,7 @@ axum-extra = { version = "0.7.7", features = ["query", "typed-routing"] }
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
base64 = "0.21.5"
bytes = "1.5.0"
ciborium = "0.2.1"
clap = { version = "4.4.11", features = ["env", "derive", "wrap_help", "unicode"] }
futures = "0.3.29"
futures-util = "0.3.29"
@ -55,9 +56,9 @@ opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
pin-project-lite = "0.2.13"
rand = "0.8.5"
reqwest = { version = "0.11.22", default-features = false, features = ["blocking", "gzip"] }
rmpv = "1.0.1"
rustyline = { version = "12.0.0", features = ["derive"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_cbor = "0.11.2"
serde_json = "1.0.108"
serde_pack = { version = "1.1.2", package = "rmp-serde" }
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }

View file

@ -34,13 +34,13 @@ env = { RUST_LOG={ value = "http_integration=debug", condition = { env_not_set =
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"]
[tasks.ci-ws-integration]
category = "CI - INTEGRATION TESTS"
category = "WS - INTEGRATION TESTS"
command = "cargo"
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
[tasks.ci-ml-integration]
category = "CI - INTEGRATION TESTS"
category = "ML - INTEGRATION TESTS"
command = "cargo"
env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]

View file

@ -130,6 +130,7 @@ pub fn thing((arg1, arg2): (Value, Option<Value>)) -> Result<Value, Error> {
pub mod is {
use crate::err::Error;
use crate::sql::table::Table;
use crate::sql::value::Value;
use crate::sql::Geometry;
@ -215,7 +216,7 @@ pub mod is {
pub fn record((arg, table): (Value, Option<String>)) -> Result<Value, Error> {
Ok(match table {
Some(tb) => arg.is_record_of_table(tb).into(),
Some(tb) => arg.is_record_type(&[Table(tb)]).into(),
None => arg.is_record().into(),
})
}

View file

@ -7,6 +7,7 @@ use crate::sql::{escape::escape_rid, Array, Number, Object, Strand, Thing, Uuid,
use nanoid::nanoid;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt::{self, Display, Formatter};
use ulid::Ulid;
@ -106,6 +107,12 @@ impl From<Vec<Value>> for Id {
}
}
impl From<BTreeMap<String, Value>> for Id {
fn from(v: BTreeMap<String, Value>) -> Self {
Id::Object(v.into())
}
}
impl From<Number> for Id {
fn from(v: Number) -> Self {
match v {

View file

@ -23,6 +23,12 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Object";
#[revisioned(revision = 1)]
pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap<String, Value>);
impl From<BTreeMap<&str, Value>> for Object {
fn from(v: BTreeMap<&str, Value>) -> Self {
Self(v.into_iter().map(|(key, val)| (key.to_string(), val)).collect())
}
}
impl From<BTreeMap<String, Value>> for Object {
fn from(v: BTreeMap<String, Value>) -> Self {
Self(v)

View file

@ -2,6 +2,7 @@ use crate::ctx::Context;
use crate::dbs::{Options, Transaction};
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::sql::Uuid;
use crate::sql::Value;
use derive::Store;
use revision::revisioned;
@ -34,6 +35,14 @@ impl KillStatement {
Value::Uuid(id) => *id,
Value::Param(param) => match param.compute(ctx, opt, txn, None).await? {
Value::Uuid(id) => id,
Value::Strand(id) => match uuid::Uuid::try_parse(&id) {
Ok(id) => Uuid(id),
_ => {
return Err(Error::KillStatement {
value: self.id.to_string(),
})
}
},
_ => {
return Err(Error::KillStatement {
value: self.id.to_string(),

View file

@ -458,6 +458,12 @@ impl From<Vec<bool>> for Value {
}
}
impl From<HashMap<&str, Value>> for Value {
fn from(v: HashMap<&str, Value>) -> Self {
Value::Object(Object::from(v))
}
}
impl From<HashMap<String, Value>> for Value {
fn from(v: HashMap<String, Value>) -> Self {
Value::Object(Object::from(v))
@ -470,6 +476,12 @@ impl From<BTreeMap<String, Value>> for Value {
}
}
impl From<BTreeMap<&str, Value>> for Value {
fn from(v: BTreeMap<&str, Value>) -> Self {
Value::Object(Object::from(v))
}
}
impl From<Option<Value>> for Value {
fn from(v: Option<Value>) -> Self {
match v {
@ -730,6 +742,18 @@ impl TryFrom<Value> for Object {
}
}
impl FromIterator<Value> for Value {
fn from_iter<I: IntoIterator<Item = Value>>(iter: I) -> Self {
Value::Array(Array(iter.into_iter().collect()))
}
}
impl FromIterator<(String, Value)> for Value {
fn from_iter<I: IntoIterator<Item = (String, Value)>>(iter: I) -> Self {
Value::Object(Object(iter.into_iter().collect()))
}
}
impl Value {
// -----------------------------------
// Initial record value
@ -828,6 +852,11 @@ impl Value {
matches!(self, Value::Mock(_))
}
/// Check if this Value is a Param
pub fn is_param(&self) -> bool {
matches!(self, Value::Param(_))
}
/// Check if this Value is a Range
pub fn is_range(&self) -> bool {
matches!(self, Value::Range(_))
@ -952,11 +981,6 @@ impl Value {
}
}
/// Check if this Value is a Param
pub fn is_param(&self) -> bool {
matches!(self, Value::Param(_))
}
/// Check if this Value is a Geometry of a specific type
pub fn is_geometry_type(&self, types: &[String]) -> bool {
match self {
@ -1089,7 +1113,7 @@ impl Value {
/// Treat a string as a table name
pub fn could_be_table(self) -> Value {
match self {
Value::Strand(v) => Table::from(v.0).into(),
Value::Strand(v) => Value::Table(v.0.into()),
_ => self,
}
}

View file

@ -7,9 +7,6 @@ use base64::DecodeError as Base64Error;
use http::{HeaderName, StatusCode};
use reqwest::Error as ReqwestError;
use serde::Serialize;
use serde_cbor::error::Error as CborError;
use serde_json::error::Error as JsonError;
use serde_pack::encode::Error as PackError;
use std::io::Error as IoError;
use std::string::FromUtf8Error as Utf8Error;
use surrealdb::error::Db as SurrealDbError;
@ -37,12 +34,12 @@ pub enum Error {
#[error("There was a problem connecting with the storage engine")]
InvalidStorage,
#[error("There was a problem parsing the header {0}: {1}")]
InvalidHeader(HeaderName, TypedHeaderRejection),
#[error("The operation is unsupported")]
OperationUnsupported,
#[error("There was a problem parsing the header {0}: {1}")]
InvalidHeader(HeaderName, TypedHeaderRejection),
#[error("There was a problem with the database: {0}")]
Db(#[from] SurrealError),
@ -52,14 +49,14 @@ pub enum Error {
#[error("There was an error with the network: {0}")]
Axum(#[from] AxumError),
#[error("There was an error serializing to JSON: {0}")]
Json(#[from] JsonError),
#[error("There was an error with JSON serialization: {0}")]
Json(String),
#[error("There was an error serializing to CBOR: {0}")]
Cbor(#[from] CborError),
#[error("There was an error with CBOR serialization: {0}")]
Cbor(String),
#[error("There was an error serializing to MessagePack: {0}")]
Pack(#[from] PackError),
#[error("There was an error with MessagePack serialization: {0}")]
Pack(String),
#[error("There was an error with the remote request: {0}")]
Remote(#[from] ReqwestError),
@ -93,6 +90,42 @@ impl From<Utf8Error> for Error {
}
}
impl From<serde_json::Error> for Error {
fn from(e: serde_json::Error) -> Error {
Error::Json(e.to_string())
}
}
impl From<serde_pack::encode::Error> for Error {
fn from(e: serde_pack::encode::Error) -> Error {
Error::Pack(e.to_string())
}
}
impl From<serde_pack::decode::Error> for Error {
fn from(e: serde_pack::decode::Error) -> Error {
Error::Pack(e.to_string())
}
}
impl From<ciborium::value::Error> for Error {
fn from(e: ciborium::value::Error) -> Error {
Error::Cbor(format!("{e}"))
}
}
impl<T: std::fmt::Debug> From<ciborium::de::Error<T>> for Error {
fn from(e: ciborium::de::Error<T>) -> Error {
Error::Cbor(format!("{e}"))
}
}
impl<T: std::fmt::Debug> From<ciborium::ser::Error<T>> for Error {
fn from(e: ciborium::ser::Error<T>) -> Error {
Error::Cbor(format!("{e}"))
}
}
impl From<surrealdb::error::Db> for Error {
fn from(error: surrealdb::error::Db) -> Error {
if matches!(error, surrealdb::error::Db::InvalidAuth) {

View file

@ -39,8 +39,9 @@ pub fn cbor<T>(val: &T) -> Output
where
T: Serialize,
{
match serde_cbor::to_vec(val) {
Ok(v) => Output::Cbor(v),
let mut out = Vec::new();
match ciborium::into_writer(&val, &mut out) {
Ok(_) => Output::Cbor(out),
Err(_) => Output::Fail,
}
}

View file

@ -1,18 +1,21 @@
use crate::cnf;
use crate::err::Error;
use crate::rpc::connection::Connection;
use crate::rpc::format::Format;
use crate::rpc::format::PROTOCOLS;
use crate::rpc::WEBSOCKETS;
use axum::routing::get;
use axum::Extension;
use axum::Router;
use axum::{
extract::ws::{WebSocket, WebSocketUpgrade},
response::IntoResponse,
Extension, Router,
};
use http::HeaderValue;
use http_body::Body as HttpBody;
use surrealdb::dbs::Session;
use tower_http::request_id::RequestId;
use uuid::Uuid;
use axum::{
extract::ws::{WebSocket, WebSocketUpgrade},
response::IntoResponse,
};
pub(super) fn router<S, B>() -> Router<S, B>
where
B: HttpBody + Send + 'static,
@ -23,28 +26,53 @@ where
async fn handler(
ws: WebSocketUpgrade,
Extension(id): Extension<RequestId>,
Extension(sess): Extension<Session>,
Extension(req_id): Extension<RequestId>,
) -> impl IntoResponse {
ws
// Set the maximum frame size
) -> Result<impl IntoResponse, impl IntoResponse> {
// Check if there is a request id header specified
let id = match id.header_value().is_empty() {
// No request id was specified so create a new id
true => Uuid::new_v4(),
// A request id was specified to try to parse it
false => match id.header_value().to_str() {
// Attempt to parse the request id as a UUID
Ok(id) => match Uuid::try_parse(id) {
// The specified request id was a valid UUID
Ok(id) => id,
// The specified request id was not a UUID
Err(_) => return Err(Error::Request),
},
// The request id contained invalid characters
Err(_) => return Err(Error::Request),
},
};
// Check if a connection with this id already exists
if WEBSOCKETS.read().await.contains_key(&id) {
return Err(Error::Request);
}
// Now let's upgrade the WebSocket connection
Ok(ws
// Set the potential WebSocket protocols
.protocols(PROTOCOLS)
// Set the maximum WebSocket frame size
.max_frame_size(*cnf::WEBSOCKET_MAX_FRAME_SIZE)
// Set the maximum message size
// Set the maximum WebSocket message size
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
// Set the potential WebSocket protocol formats
.protocols(["surrealql-binary", "json", "cbor", "messagepack"])
// Handle the WebSocket upgrade and process messages
.on_upgrade(move |socket| handle_socket(socket, sess, req_id))
.on_upgrade(move |socket| handle_socket(socket, sess, id)))
}
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
// Check if there is a WebSocket protocol specified
let format = match ws.protocol().map(HeaderValue::to_str) {
// Any selected protocol will always be a valie value
Some(protocol) => protocol.unwrap().into(),
// No protocol format was specified
_ => Format::None,
};
//
// Create a new connection instance
let rpc = Connection::new(sess);
// Update the WebSocket ID with the Request ID
if let Ok(Ok(req_id)) = req_id.header_value().to_str().map(Uuid::parse_str) {
// If the ID couldn't be updated, ignore the error and keep the default ID
let _ = rpc.write().await.update_ws_id(req_id).await;
}
let rpc = Connection::new(id, sess, format);
// Serve the socket connection requests
Connection::serve(rpc, ws).await;
}

View file

@ -11,7 +11,7 @@ pub trait Take {
impl Take for Array {
/// Convert the array to one argument
fn needs_one(self) -> Result<Value, ()> {
if self.is_empty() {
if self.len() != 1 {
return Err(());
}
let mut x = self.into_iter();
@ -22,7 +22,7 @@ impl Take for Array {
}
/// Convert the array to two arguments
fn needs_two(self) -> Result<(Value, Value), ()> {
if self.len() < 2 {
if self.len() != 2 {
return Err(());
}
let mut x = self.into_iter();
@ -34,7 +34,7 @@ impl Take for Array {
}
/// Convert the array to two arguments
fn needs_one_or_two(self) -> Result<(Value, Value), ()> {
if self.is_empty() {
if self.is_empty() && self.len() > 2 {
return Err(());
}
let mut x = self.into_iter();
@ -46,7 +46,7 @@ impl Take for Array {
}
/// Convert the array to three arguments
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), ()> {
if self.is_empty() {
if self.is_empty() && self.len() > 3 {
return Err(());
}
let mut x = self.into_iter();

View file

@ -1,11 +1,12 @@
use super::request::parse_request;
use super::response::{failure, success, Data, Failure, IntoRpcResponse, OutputFormat};
use crate::cnf::PKG_NAME;
use crate::cnf::PKG_VERSION;
use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY};
use crate::dbs::DB;
use crate::err::Error;
use crate::rpc::args::Take;
use crate::rpc::failure::Failure;
use crate::rpc::format::Format;
use crate::rpc::response::{failure, success, Data, IntoRpcResponse};
use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
use crate::telemetry;
use crate::telemetry::metrics::ws::RequestContext;
@ -33,9 +34,9 @@ use tracing::Span;
use uuid::Uuid;
pub struct Connection {
ws_id: Uuid,
id: Uuid,
session: Session,
format: OutputFormat,
format: Format,
vars: BTreeMap<String, Value>,
limiter: Arc<Semaphore>,
canceller: CancellationToken,
@ -43,34 +44,20 @@ pub struct Connection {
impl Connection {
/// Instantiate a new RPC
pub fn new(mut session: Session) -> Arc<RwLock<Connection>> {
// Create a new RPC variables store
let vars = BTreeMap::new();
// Set the default output format
let format = OutputFormat::Json;
pub fn new(id: Uuid, mut session: Session, format: Format) -> Arc<RwLock<Connection>> {
// Enable real-time mode
session.rt = true;
// Create and store the RPC connection
Arc::new(RwLock::new(Connection {
ws_id: Uuid::new_v4(),
id,
session,
format,
vars,
vars: BTreeMap::new(),
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
canceller: CancellationToken::new(),
}))
}
/// Update the WebSocket ID. If the ID already exists, do not update it.
pub async fn update_ws_id(&mut self, ws_id: Uuid) -> Result<(), Box<dyn std::error::Error>> {
if WEBSOCKETS.read().await.contains_key(&ws_id) {
trace!("WebSocket ID '{}' is in use by another connection. Do not update it.", &ws_id);
return Err("websocket ID is in use".into());
}
self.ws_id = ws_id;
Ok(())
}
/// Serve the RPC endpoint
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
// Split the socket into send and recv
@ -79,19 +66,19 @@ impl Connection {
let (internal_sender, internal_receiver) =
channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS);
let ws_id = rpc.read().await.ws_id;
let id = rpc.read().await.id;
trace!("WebSocket {} connected", ws_id);
trace!("WebSocket {} connected", id);
if let Err(err) = telemetry::metrics::ws::on_connect() {
error!("Error running metrics::ws::on_connect hook: {}", err);
}
// Add this WebSocket to the list
WEBSOCKETS.write().await.insert(
ws_id,
WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()),
);
WEBSOCKETS
.write()
.await
.insert(id, WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()));
// Spawn async tasks for the WebSocket
let mut tasks = JoinSet::new();
@ -109,15 +96,15 @@ impl Connection {
internal_sender.close();
trace!("WebSocket {} disconnected", ws_id);
trace!("WebSocket {} disconnected", id);
// Remove this WebSocket from the list
WEBSOCKETS.write().await.remove(&ws_id);
WEBSOCKETS.write().await.remove(&id);
// Remove all live queries
let mut gc = Vec::new();
LIVE_QUERIES.write().await.retain(|key, value| {
if value == &ws_id {
if value == &id {
trace!("Removing live query: {}", key);
gc.push(*key);
return false;
@ -288,9 +275,9 @@ impl Connection {
msg = channel.recv() => {
if let Ok(notification) = msg {
// Find which WebSocket the notification belongs to
if let Some(ws_id) = LIVE_QUERIES.read().await.get(&notification.id) {
if let Some(id) = LIVE_QUERIES.read().await.get(&notification.id) {
// Check to see if the WebSocket exists
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(id) {
// Serialize the message to send
let message = success(None, notification);
// Get the current output format
@ -309,27 +296,41 @@ impl Connection {
/// Handle individual WebSocket messages
async fn handle_message(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) {
// Get the current output format
let mut out_fmt = rpc.read().await.format;
let mut fmt = rpc.read().await.format;
// Prepare Span and Otel context
let span = span_for_request(&rpc.read().await.ws_id);
let span = span_for_request(&rpc.read().await.id);
// Acquire concurrent request rate limiter
let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap();
// Calculate the length of the message
let len = match msg {
Message::Text(ref msg) => {
// If no format was specified, default to JSON
if fmt.is_none() {
fmt = Format::Json;
rpc.write().await.format = fmt;
}
// Retrieve the length of the message
msg.len()
}
Message::Binary(ref msg) => {
// If no format was specified, default to Bincode
if fmt.is_none() {
fmt = Format::Bincode;
rpc.write().await.format = fmt;
}
// Retrieve the length of the message
msg.len()
}
_ => unreachable!(),
};
// Parse the request
async move {
let span = Span::current();
let req_cx = RequestContext::default();
let otel_cx = TelemetryContext::new().with_value(req_cx.clone());
match parse_request(msg).await {
// Parse the RPC request structure
match fmt.req(msg) {
Ok(req) => {
if let Some(fmt) = req.out_fmt {
if out_fmt != fmt {
// Update the default format
rpc.write().await.format = fmt;
out_fmt = fmt;
}
}
// Now that we know the method, we can update the span and create otel context
span.record("rpc.method", &req.method);
span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
@ -338,17 +339,17 @@ impl Connection {
req.id.clone().map(Value::as_string).unwrap_or_default(),
);
let otel_cx = TelemetryContext::current_with_value(
req_cx.with_method(&req.method).with_size(req.size),
req_cx.with_method(&req.method).with_size(len),
);
// Process the message
let res =
Connection::process_message(rpc.clone(), &req.method, req.params).await;
// Process the response
res.into_response(req.id).send(out_fmt, &chn).with_context(otel_cx).await
res.into_response(req.id).send(fmt, &chn).with_context(otel_cx).await
}
Err(err) => {
// Process the response
failure(None, err).send(out_fmt, &chn).with_context(otel_cx).await
failure(None, err).send(fmt, &chn).with_context(otel_cx).await
}
}
}
@ -485,13 +486,6 @@ impl Connection {
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Specify the output format for text requests
"format" => match params.needs_one() {
Ok(Value::Strand(v)) => {
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Get the current server version
"version" => match params.len() {
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
@ -515,16 +509,6 @@ impl Connection {
// Methods for authentication
// ------------------------------
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
match out.as_str() {
"json" | "application/json" => self.format = OutputFormat::Json,
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
"pack" | "application/pack" => self.format = OutputFormat::Pack,
_ => return Err(Error::InvalidType),
};
Ok(Value::None)
}
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
if let Value::Strand(ns) = ns {
self.session.ns = Some(ns.0);
@ -615,7 +599,7 @@ impl Connection {
let sql = "KILL $id";
// Specify the query parameters
let var = map! {
String::from("id") => id, // NOTE: id can be parameter
String::from("id") => id,
=> &self.vars
};
// Execute the query on the database
@ -910,15 +894,14 @@ impl Connection {
QueryType::Live => {
if let Ok(Value::Uuid(lqid)) = &res.result {
// Match on Uuid type
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
LIVE_QUERIES.write().await.insert(lqid.0, self.id);
trace!("Registered live query {} on websocket {}", lqid, self.id);
}
}
QueryType::Kill => {
if let Ok(Value::Uuid(lqid)) = &res.result {
let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0);
if let Some(ws_id) = ws_id {
trace!("Unregistered live query {} on websocket {}", lqid, ws_id);
if let Some(id) = LIVE_QUERIES.write().await.remove(&lqid.0) {
trace!("Unregistered live query {} on websocket {}", lqid, id);
}
}
}

70
src/rpc/failure.rs Normal file
View file

@ -0,0 +1,70 @@
use crate::err::Error;
use serde::Serialize;
use std::borrow::Cow;
use surrealdb::sql::Value;
#[derive(Clone, Debug, Serialize)]
pub struct Failure {
pub(crate) code: i64,
pub(crate) message: Cow<'static, str>,
}
impl From<&str> for Failure {
fn from(err: &str) -> Self {
Failure::custom(err.to_string())
}
}
impl From<Error> for Failure {
fn from(err: Error) -> Self {
Failure::custom(err.to_string())
}
}
impl From<Failure> for Value {
fn from(err: Failure) -> Self {
map! {
String::from("code") => Value::from(err.code),
String::from("message") => Value::from(err.message.to_string()),
}
.into()
}
}
#[allow(dead_code)]
impl Failure {
pub const PARSE_ERROR: Failure = Failure {
code: -32700,
message: Cow::Borrowed("Parse error"),
};
pub const INVALID_REQUEST: Failure = Failure {
code: -32600,
message: Cow::Borrowed("Invalid Request"),
};
pub const METHOD_NOT_FOUND: Failure = Failure {
code: -32601,
message: Cow::Borrowed("Method not found"),
};
pub const INVALID_PARAMS: Failure = Failure {
code: -32602,
message: Cow::Borrowed("Invalid params"),
};
pub const INTERNAL_ERROR: Failure = Failure {
code: -32603,
message: Cow::Borrowed("Internal error"),
};
pub fn custom<S>(message: S) -> Failure
where
Cow<'static, str>: From<S>,
{
Failure {
code: -32000,
message: message.into(),
}
}
}

22
src/rpc/format/bincode.rs Normal file
View file

@ -0,0 +1,22 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use surrealdb::sql::serde::deserialize;
use surrealdb::sql::Value;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Binary(val) => {
deserialize::<Value>(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Serialize the response with full internal type information
let res = surrealdb::sql::serde::serialize(&res).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

View file

@ -0,0 +1,24 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into CBOR binary data
ciborium::into_writer(&val, &mut res).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

22
src/rpc/format/json.rs Normal file
View file

@ -0,0 +1,22 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Serialize the response with simplified type information
let res = serde_json::to_string(&val).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Text(res)))
}

70
src/rpc/format/mod.rs Normal file
View file

@ -0,0 +1,70 @@
mod bincode;
pub mod cbor;
mod json;
pub mod msgpack;
mod revision;
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub const PROTOCOLS: [&str; 5] = [
"json", // For basic JSON serialisation
"cbor", // For basic CBOR serialisation
"msgpack", // For basic Msgpack serialisation
"bincode", // For full internal serialisation
"revision", // For full versioned serialisation
];
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Format {
None, // No format is specified yet
Json, // For basic JSON serialisation
Cbor, // For basic CBOR serialisation
Msgpack, // For basic Msgpack serialisation
Bincode, // For full internal serialisation
Revision, // For full versioned serialisation
}
impl From<&str> for Format {
fn from(v: &str) -> Self {
match v {
s if s == PROTOCOLS[0] => Format::Json,
s if s == PROTOCOLS[1] => Format::Cbor,
s if s == PROTOCOLS[2] => Format::Msgpack,
s if s == PROTOCOLS[3] => Format::Bincode,
s if s == PROTOCOLS[4] => Format::Revision,
_ => Format::None,
}
}
}
impl Format {
/// Check if this format has been set
pub fn is_none(&self) -> bool {
matches!(self, Format::None)
}
/// Process a request using the specified format
pub fn req(&self, msg: Message) -> Result<Request, Failure> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Json => json::req(msg),
Self::Cbor => cbor::req(msg),
Self::Msgpack => msgpack::req(msg),
Self::Bincode => bincode::req(msg),
Self::Revision => revision::req(msg),
}
}
/// Process a response using the specified format
pub fn res(&self, res: Response) -> Result<(usize, Message), Failure> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Json => json::res(res),
Self::Cbor => cbor::res(res),
Self::Msgpack => msgpack::res(res),
Self::Bincode => bincode::res(res),
Self::Revision => revision::res(res),
}
}
}

View file

@ -0,0 +1,22 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Serialize the value into MsgPack binary data
let res = serde_pack::to_vec(&val).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

View file

@ -0,0 +1,14 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(_msg: Message) -> Result<Request, Failure> {
// This format is not yet implemented
Err(Failure::INTERNAL_ERROR)
}
pub fn res(_res: Response) -> Result<(usize, Message), Failure> {
// This format is not yet implemented
Err(Failure::INTERNAL_ERROR)
}

View file

@ -1,12 +1,14 @@
pub mod args;
pub mod connection;
pub mod failure;
pub mod format;
pub mod request;
pub mod response;
use std::{collections::HashMap, time::Duration};
use axum::extract::ws::Message;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::time::Duration;
use surrealdb::channel::Sender;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;

View file

@ -1,10 +1,7 @@
use axum::extract::ws::Message;
use surrealdb::sql::{serde::deserialize, Array, Value};
use crate::rpc::failure::Failure;
use once_cell::sync::Lazy;
use surrealdb::sql::Part;
use super::response::{Failure, OutputFormat};
use surrealdb::sql::{Array, Value};
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
@ -14,71 +11,36 @@ pub struct Request {
pub id: Option<Value>,
pub method: String,
pub params: Array,
pub size: usize,
pub out_fmt: Option<OutputFormat>,
}
/// Parse the RPC request
pub async fn parse_request(msg: Message) -> Result<Request, Failure> {
let mut out_fmt = None;
let (req, size) = match msg {
// This is a binary message
Message::Binary(val) => {
// Use binary output
out_fmt = Some(OutputFormat::Full);
match deserialize(&val) {
Ok(v) => (v, val.len()),
Err(_) => {
debug!("Error when trying to deserialize the request");
return Err(Failure::PARSE_ERROR);
}
}
}
// This is a text message
Message::Text(ref val) => {
// Parse the SurrealQL object
match surrealdb::sql::value(val) {
// The SurrealQL message parsed ok
Ok(v) => (v, val.len()),
// The SurrealQL message failed to parse
_ => return Err(Failure::PARSE_ERROR),
}
}
// Unsupported message type
_ => {
debug!("Unsupported message type: {:?}", msg);
return Err(Failure::custom("Unsupported message type"));
}
};
// Fetch the 'id' argument
let id = match req.pick(&*ID) {
v if v.is_none() => None,
v if v.is_null() => Some(v),
v if v.is_uuid() => Some(v),
v if v.is_number() => Some(v),
v if v.is_strand() => Some(v),
v if v.is_datetime() => Some(v),
_ => return Err(Failure::INVALID_REQUEST),
};
// Fetch the 'method' argument
let method = match req.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(),
_ => return Err(Failure::INVALID_REQUEST),
};
// Fetch the 'params' argument
let params = match req.pick(&*PARAMS) {
Value::Array(v) => v,
_ => Array::new(),
};
Ok(Request {
id,
method,
params,
size,
out_fmt,
})
impl TryFrom<Value> for Request {
type Error = Failure;
fn try_from(val: Value) -> Result<Self, Failure> {
// Fetch the 'id' argument
let id = match val.pick(&*ID) {
v if v.is_none() => None,
v if v.is_null() => Some(v),
v if v.is_uuid() => Some(v),
v if v.is_number() => Some(v),
v if v.is_strand() => Some(v),
v if v.is_datetime() => Some(v),
_ => return Err(Failure::INVALID_REQUEST),
};
// Fetch the 'method' argument
let method = match val.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(),
_ => return Err(Failure::INVALID_REQUEST),
};
// Fetch the 'params' argument
let params = match val.pick(&*PARAMS) {
Value::Array(v) => v,
_ => Array::new(),
};
// Return the parsed request
Ok(Request {
id,
method,
params,
})
}
}

View file

@ -1,10 +1,10 @@
use crate::err;
use crate::rpc::failure::Failure;
use crate::rpc::format::Format;
use crate::telemetry::metrics::ws::record_rpc;
use axum::extract::ws::Message;
use opentelemetry::Context as TelemetryContext;
use serde::Serialize;
use serde_json::{json, Value as Json};
use std::borrow::Cow;
use serde_json::Value as Json;
use surrealdb::channel::Sender;
use surrealdb::dbs;
use surrealdb::dbs::Notification;
@ -12,14 +12,6 @@ use surrealdb::sql;
use surrealdb::sql::Value;
use tracing::Span;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum OutputFormat {
Json, // JSON
Cbor, // CBOR
Pack, // MessagePack
Full, // Full type serialization
}
/// The data returned by the database
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
// In future, they will possibly be merged to avoid having to keep them in sync.
@ -46,15 +38,25 @@ impl From<String> for Data {
}
}
impl From<Notification> for Data {
fn from(n: Notification) -> Self {
Data::Live(n)
}
}
impl From<Vec<dbs::Response>> for Data {
fn from(v: Vec<dbs::Response>) -> Self {
Data::Query(v)
}
}
impl From<Notification> for Data {
fn from(n: Notification) -> Self {
Data::Live(n)
impl From<Data> for Value {
fn from(val: Data) -> Self {
match val {
Data::Query(v) => sql::to_value(v).unwrap(),
Data::Live(v) => sql::to_value(v).unwrap(),
Data::Other(v) => v,
}
}
}
@ -67,32 +69,31 @@ pub struct Response {
impl Response {
/// Convert and simplify the value into JSON
#[inline]
fn simplify(self) -> Json {
pub fn into_json(self) -> Json {
Json::from(self.into_value())
}
#[inline]
pub fn into_value(self) -> Value {
let mut value = match self.result {
Ok(data) => {
let value = match data {
Data::Query(vec) => sql::to_value(vec).unwrap(),
Data::Live(notification) => sql::to_value(notification).unwrap(),
Data::Other(value) => value,
};
json!({
"result": Json::from(value),
})
}
Err(failure) => json!({
"error": failure,
}),
Ok(val) => map! {
"result" => Value::from(val),
},
Err(err) => map! {
"error" => Value::from(err),
},
};
if let Some(id) = self.id {
value["id"] = id.into();
value.insert("id", id);
}
value
value.into()
}
/// Send the response to the WebSocket channel
pub async fn send(self, out: OutputFormat, chn: &Sender<Message>) {
pub async fn send(self, fmt: Format, chn: &Sender<Message>) {
// Create a new tracing span
let span = Span::current();
// Log the rpc response call
debug!("Process RPC response");
let is_error = self.result.is_err();
@ -105,73 +106,12 @@ impl Response {
span.record("rpc.error_code", err.code);
span.record("rpc.error_message", err.message.as_ref());
}
let (res_size, message) = match out {
OutputFormat::Json => {
let res = serde_json::to_string(&self.simplify()).unwrap();
(res.len(), Message::Text(res))
}
OutputFormat::Cbor => {
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
(res.len(), Message::Binary(res))
}
OutputFormat::Pack => {
let res = serde_pack::to_vec(&self.simplify()).unwrap();
(res.len(), Message::Binary(res))
}
OutputFormat::Full => {
let res = surrealdb::sql::serde::serialize(&self).unwrap();
(res.len(), Message::Binary(res))
}
// Process the response for the format
let (len, msg) = fmt.res(self).unwrap();
// Send the message to the write channel
if chn.send(msg).await.is_ok() {
record_rpc(&TelemetryContext::current(), len, is_error);
};
if chn.send(message).await.is_ok() {
record_rpc(&TelemetryContext::current(), res_size, is_error);
};
}
}
#[derive(Clone, Debug, Serialize)]
pub struct Failure {
code: i64,
message: Cow<'static, str>,
}
#[allow(dead_code)]
impl Failure {
pub const PARSE_ERROR: Failure = Failure {
code: -32700,
message: Cow::Borrowed("Parse error"),
};
pub const INVALID_REQUEST: Failure = Failure {
code: -32600,
message: Cow::Borrowed("Invalid Request"),
};
pub const METHOD_NOT_FOUND: Failure = Failure {
code: -32601,
message: Cow::Borrowed("Method not found"),
};
pub const INVALID_PARAMS: Failure = Failure {
code: -32602,
message: Cow::Borrowed("Invalid params"),
};
pub const INTERNAL_ERROR: Failure = Failure {
code: -32603,
message: Cow::Borrowed("Internal error"),
};
pub fn custom<S>(message: S) -> Failure
where
Cow<'static, str>: From<S>,
{
Failure {
code: -32000,
message: message.into(),
}
}
}
@ -191,12 +131,6 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
}
}
impl From<err::Error> for Failure {
fn from(err: err::Error) -> Self {
Failure::custom(err.to_string())
}
}
pub trait IntoRpcResponse {
fn into_response(self, id: Option<Value>) -> Response;
}

View file

@ -3,6 +3,8 @@ mod common;
mod cli_integration {
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
use common::Format;
use common::Socket;
use std::fs;
use std::fs::File;
use std::time;
@ -735,15 +737,13 @@ mod cli_integration {
let (addr, mut server) = common::start_server_without_auth().await.unwrap();
// Create a long-lived WS connection so the server don't shutdown gracefully
let mut socket = common::connect_ws(&addr).await.expect("Failed to connect to server");
let mut socket = Socket::connect(&addr, None).await.expect("Failed to connect to server");
let json = serde_json::json!({
"id": "1",
"method": "query",
"params": ["SLEEP 30s;"],
});
common::ws_send_msg(&mut socket, serde_json::to_string(&json).unwrap())
.await
.expect("Failed to send WS message");
socket.send_message(Format::Json, json).await.expect("Failed to send WS message");
// Make sure the SLEEP query is being executed
tokio::select! {

14
tests/common/format.rs Normal file
View file

@ -0,0 +1,14 @@
use std::string::ToString;
#[derive(Debug, Copy, Clone)]
pub enum Format {
Json,
}
impl ToString for Format {
fn to_string(&self) -> String {
match self {
Self::Json => "json".to_owned(),
}
}
}

View file

@ -1,489 +1,17 @@
#![allow(unused_imports)]
#![allow(dead_code)]
pub mod error;
pub mod format;
pub mod server;
pub mod socket;
use crate::common::error::TestError;
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::error::Error;
use std::fs::File;
use std::path::Path;
use std::process::{Command, Stdio};
use std::time::Duration;
use std::{env, fs};
use tokio::net::TcpStream;
use tokio::time;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tracing::{debug, error, info};
pub const USER: &str = "root";
pub const PASS: &str = "root";
/// Child is a (maybe running) CLI process. It can be killed by dropping it
pub struct Child {
inner: Option<std::process::Child>,
stdout_path: String,
stderr_path: String,
}
impl Child {
/// Send some thing to the child's stdin
pub fn input(mut self, input: &str) -> Self {
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
use std::io::Write;
stdin.write_all(input.as_bytes()).unwrap();
self
}
pub fn kill(mut self) -> Self {
self.inner.as_mut().unwrap().kill().unwrap();
self
}
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
nix::sys::signal::kill(
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
signal,
)
}
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
self.inner.as_mut().unwrap().try_wait()
}
pub fn stdout(&self) -> String {
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
}
pub fn stderr(&self) -> String {
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
}
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
pub fn output(mut self) -> Result<String, String> {
let status = self.inner.take().unwrap().wait().unwrap();
let mut buf = self.stdout();
buf.push_str(&self.stderr());
// Cleanup files after reading them
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
if status.success() {
Ok(buf)
} else {
Err(buf)
}
}
}
impl Drop for Child {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_mut() {
let _ = inner.kill();
}
}
}
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
let mut path = std::env::current_exe().unwrap();
assert!(path.pop());
if path.ends_with("deps") {
assert!(path.pop());
}
// Note: Cargo automatically builds this binary for integration tests.
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
let mut cmd = Command::new(path);
if let Some(dir) = current_dir {
cmd.current_dir(&dir);
}
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
let stdout_path = tmp_file("server-stdout.log");
let stderr_path = tmp_file("server-stderr.log");
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
cmd.env_clear();
cmd.stdin(Stdio::piped());
cmd.stdout(stdout);
cmd.stderr(stderr);
cmd.args(args.split_ascii_whitespace());
Child {
inner: Some(cmd.spawn().unwrap()),
stdout_path,
stderr_path,
}
}
/// Run the CLI with the given args
pub fn run(args: &str) -> Child {
run_internal::<String>(args, None)
}
/// Run the CLI with the given args inside a temporary directory
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir))
}
pub fn tmp_file(name: &str) -> String {
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
path.to_string_lossy().into_owned()
}
pub struct StartServerArguments {
pub auth: bool,
pub tls: bool,
pub wait_is_ready: bool,
pub enable_auth_level: bool,
pub tick_interval: time::Duration,
pub args: String,
}
impl Default for StartServerArguments {
fn default() -> Self {
Self {
auth: true,
tls: false,
wait_is_ready: true,
enable_auth_level: false,
tick_interval: time::Duration::new(1, 0),
args: "--allow-all".to_string(),
}
}
}
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
auth: false,
..Default::default()
})
.await
}
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
enable_auth_level: true,
..Default::default()
})
.await
}
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments::default()).await
}
pub async fn start_server(
StartServerArguments {
auth,
tls,
wait_is_ready,
enable_auth_level,
tick_interval,
args,
}: StartServerArguments,
) -> Result<(String, Child), Box<dyn Error>> {
let mut rng = thread_rng();
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let mut extra_args = args.clone();
if tls {
// Test the crt/key args but the keys are self signed so don't actually connect.
let crt_path = tmp_file("crt.crt");
let key_path = tmp_file("key.pem");
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
}
if auth {
extra_args.push_str(" --auth");
}
if enable_auth_level {
extra_args.push_str(" --auth-level-enabled");
}
if !tick_interval.is_zero() {
let sec = tick_interval.as_secs();
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
}
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
info!("starting server with args: {start_args}");
// Configure where the logs go when running the test
let server = run_internal::<String>(&start_args, None);
if !wait_is_ready {
return Ok((addr, server));
}
// Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(1000));
info!("Waiting for server to start...");
for _i in 0..10 {
interval.tick().await;
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
info!("Server ready!");
return Ok((addr, server));
}
}
let server_out = server.kill().output().err().unwrap();
error!("server output: {server_out}");
Err("server failed to start".into())
}
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
let url = format!("ws://{}/rpc", addr);
let (ws_stream, _) = connect_async(url).await?;
Ok(ws_stream)
}
pub async fn ws_send_msg(socket: &mut WsStream, msg_req: String) -> Result<(), Box<dyn Error>> {
let now = time::Instant::now();
debug!("Sending message: {msg_req}");
tokio::select! {
_ = time::sleep(time::Duration::from_millis(500)) => {
return Err("timeout after 500ms waiting for the request to be sent".into());
}
res = socket.send(Message::Text(msg_req)) => {
debug!("Message sent in {:?}", now.elapsed());
if let Err(err) = res {
return Err(format!("Error sending the message: {}", err).into());
}
}
}
Ok(())
}
pub async fn ws_recv_msg(socket: &mut WsStream) -> Result<serde_json::Value, Box<dyn Error>> {
ws_recv_msg_with_fmt(socket, Format::Json).await
}
/// When testing Live Queries, we may receive multiple messages unordered.
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
pub async fn ws_recv_all_msgs(
socket: &mut WsStream,
expected: usize,
timeout: Duration,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let mut res = Vec::new();
let deadline = time::Instant::now() + timeout;
loop {
tokio::select! {
_ = time::sleep_until(deadline) => {
debug!("Waited for {:?} and received {} messages", timeout, res.len());
if res.len() != expected {
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
}
}
msg = ws_recv_msg(socket) => {
res.push(msg?);
}
}
if res.len() == expected {
return Ok(res);
}
}
}
pub async fn ws_send_msg_and_wait_response(
socket: &mut WsStream,
msg_req: String,
) -> Result<serde_json::Value, Box<dyn Error>> {
ws_send_msg(socket, msg_req).await?;
ws_recv_msg_with_fmt(socket, Format::Json).await
}
pub enum Format {
Json,
Cbor,
Pack,
}
pub async fn ws_recv_msg_with_fmt(
socket: &mut WsStream,
format: Format,
) -> Result<serde_json::Value, Box<dyn Error>> {
let now = time::Instant::now();
debug!("Waiting for response...");
// Parse and return response
let mut f = socket.try_filter(|msg| match format {
Format::Json => futures_util::future::ready(msg.is_text()),
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
});
tokio::select! {
_ = time::sleep(time::Duration::from_millis(5000)) => {
Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
}
res = f.select_next_some() => {
debug!("Response received in {:?}", now.elapsed());
match format {
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
}
}
}
}
#[derive(Serialize, Deserialize)]
struct SigninParams<'a> {
user: &'a str,
pass: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
sc: Option<&'a str>,
}
#[derive(Serialize, Deserialize)]
struct UseParams<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
}
pub async fn ws_signin(
socket: &mut WsStream,
user: &str,
pass: &str,
ns: Option<&str>,
db: Option<&str>,
sc: Option<&str>,
) -> Result<String, Box<dyn Error>> {
let request_id = uuid::Uuid::new_v4().to_string().replace('-', "");
let json = json!({
"id": request_id,
"method": "signin",
"params": [
SigninParams { user, pass, ns, db, sc }
],
});
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
let msg = ws_recv_msg(socket).await?;
debug!("ws_query result json={json:?} msg={msg:?}");
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_str()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_query(
socket: &mut WsStream,
query: &str,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "query",
"params": [query],
});
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
let msg = ws_recv_msg(socket).await?;
debug!("ws_query result json={json:?} msg={msg:?}");
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_array()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_use(
socket: &mut WsStream,
ns: Option<&str>,
db: Option<&str>,
) -> Result<serde_json::Value, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "use",
"params": [
ns, db
],
});
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
let msg = ws_recv_msg(socket).await?;
debug!("ws_query result json={json:?} msg={msg:?}");
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!(
"expected a result from the received object, got this instead: {:?}",
obj
),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub use format::*;
pub use server::*;
pub use socket::*;
/// Check if the given message is a successful notification from LQ.
pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool {
pub fn is_notification(msg: &serde_json::Value) -> bool {
// Example of LQ notification:
//
// Object {"result": Object {"action": String("CREATE"), "id": String("04460f07-b0e1-4339-92db-049a94aeec10"), "result": Object {"id": String("table_FD40A9A361884C56B5908A934164884A:⟨an-id-goes-here⟩"), "name": String("ok")}}}
@ -497,7 +25,7 @@ pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool {
}
/// Check if the given message is a notification from LQ and comes from the given LQ ID.
pub fn ws_msg_is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool {
ws_msg_is_notification(msg)
pub fn is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool {
is_notification(msg)
&& msg["result"].as_object().unwrap().get("id").unwrap().as_str() == Some(id)
}

242
tests/common/server.rs Normal file
View file

@ -0,0 +1,242 @@
use rand::{thread_rng, Rng};
use std::error::Error;
use std::fs::File;
use std::path::Path;
use std::process::{Command, Stdio};
use std::{env, fs};
use tokio::time;
use tracing::{debug, error, info};
pub const USER: &str = "root";
pub const PASS: &str = "root";
pub const NS: &str = "testns";
pub const DB: &str = "testdb";
/// Child is a (maybe running) CLI process. It can be killed by dropping it
pub struct Child {
inner: Option<std::process::Child>,
stdout_path: String,
stderr_path: String,
}
impl Child {
/// Send some thing to the child's stdin
pub fn input(mut self, input: &str) -> Self {
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
use std::io::Write;
stdin.write_all(input.as_bytes()).unwrap();
self
}
pub fn kill(mut self) -> Self {
self.inner.as_mut().unwrap().kill().unwrap();
self
}
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
nix::sys::signal::kill(
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
signal,
)
}
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
self.inner.as_mut().unwrap().try_wait()
}
pub fn stdout(&self) -> String {
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
}
pub fn stderr(&self) -> String {
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
}
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
pub fn output(mut self) -> Result<String, String> {
let status = self.inner.take().unwrap().wait().unwrap();
let mut buf = self.stdout();
buf.push_str(&self.stderr());
// Cleanup files after reading them
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
if status.success() {
Ok(buf)
} else {
Err(buf)
}
}
}
impl Drop for Child {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_mut() {
let _ = inner.kill();
}
}
}
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
let mut path = std::env::current_exe().unwrap();
assert!(path.pop());
if path.ends_with("deps") {
assert!(path.pop());
}
// Note: Cargo automatically builds this binary for integration tests.
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
let mut cmd = Command::new(path);
if let Some(dir) = current_dir {
cmd.current_dir(&dir);
}
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
let stdout_path = tmp_file("server-stdout.log");
let stderr_path = tmp_file("server-stderr.log");
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
cmd.env_clear();
cmd.stdin(Stdio::piped());
cmd.stdout(stdout);
cmd.stderr(stderr);
cmd.args(args.split_ascii_whitespace());
Child {
inner: Some(cmd.spawn().unwrap()),
stdout_path,
stderr_path,
}
}
/// Run the CLI with the given args
pub fn run(args: &str) -> Child {
run_internal::<String>(args, None)
}
/// Run the CLI with the given args inside a temporary directory
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir))
}
pub fn tmp_file(name: &str) -> String {
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
path.to_string_lossy().into_owned()
}
pub struct StartServerArguments {
pub auth: bool,
pub tls: bool,
pub wait_is_ready: bool,
pub enable_auth_level: bool,
pub tick_interval: time::Duration,
pub args: String,
}
impl Default for StartServerArguments {
fn default() -> Self {
Self {
auth: true,
tls: false,
wait_is_ready: true,
enable_auth_level: false,
tick_interval: time::Duration::new(1, 0),
args: "--allow-all".to_string(),
}
}
}
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
auth: false,
..Default::default()
})
.await
}
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
enable_auth_level: true,
..Default::default()
})
.await
}
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments::default()).await
}
pub async fn start_server(
StartServerArguments {
auth,
tls,
wait_is_ready,
enable_auth_level,
tick_interval,
args,
}: StartServerArguments,
) -> Result<(String, Child), Box<dyn Error>> {
let mut rng = thread_rng();
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let mut extra_args = args.clone();
if tls {
// Test the crt/key args but the keys are self signed so don't actually connect.
let crt_path = tmp_file("crt.crt");
let key_path = tmp_file("key.pem");
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
}
if auth {
extra_args.push_str(" --auth");
}
if enable_auth_level {
extra_args.push_str(" --auth-level-enabled");
}
if !tick_interval.is_zero() {
let sec = tick_interval.as_secs();
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
}
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
info!("starting server with args: {start_args}");
// Configure where the logs go when running the test
let server = run_internal::<String>(&start_args, None);
if !wait_is_ready {
return Ok((addr, server));
}
// Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(1000));
info!("Waiting for server to start...");
for _i in 0..10 {
interval.tick().await;
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
info!("Server ready!");
return Ok((addr, server));
}
}
let server_out = server.kill().output().err().unwrap();
error!("server output: {server_out}");
Err("server failed to start".into())
}

288
tests/common/socket.rs Normal file
View file

@ -0,0 +1,288 @@
use super::format::Format;
use crate::common::error::TestError;
use futures_util::{SinkExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::error::Error;
use std::time::Duration;
use surrealdb::sql::Value;
use tokio::net::TcpStream;
use tokio::time;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tracing::{debug, error};
#[derive(Serialize, Deserialize)]
struct UseParams<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
}
#[derive(Serialize, Deserialize)]
struct SigninParams<'a> {
user: &'a str,
pass: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
sc: Option<&'a str>,
}
pub struct Socket {
pub stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
}
// pub struct Socket(pub WebSocketStream<MaybeTlsStream<TcpStream>>);
impl Socket {
/// Close the connection with the WebSocket server
pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
Ok(self.stream.close(None).await?)
}
/// Connect to a WebSocket server using a specific format
pub async fn connect(addr: &str, format: Option<Format>) -> Result<Self, Box<dyn Error>> {
let url = format!("ws://{}/rpc", addr);
let mut req = url.into_client_request().unwrap();
if let Some(v) = format.map(|v| v.to_string()) {
req.headers_mut().insert("Sec-WebSocket-Protocol", v.parse().unwrap());
}
let (stream, _) = connect_async(req).await?;
Ok(Self {
stream,
})
}
/// Send a text or binary message to the WebSocket server
pub async fn send_message(
&mut self,
format: Format,
message: serde_json::Value,
) -> Result<(), Box<dyn Error>> {
let now = time::Instant::now();
debug!("Sending message: {message}");
// Format the message
let msg = match format {
Format::Json => Message::Text(serde_json::to_string(&message)?),
};
// Send the message
tokio::select! {
_ = time::sleep(time::Duration::from_millis(500)) => {
return Err("timeout after 500ms waiting for the request to be sent".into());
}
res = self.stream.send(msg) => {
debug!("Message sent in {:?}", now.elapsed());
if let Err(err) = res {
return Err(format!("Error sending the message: {}", err).into());
}
}
}
Ok(())
}
/// Receive a text or binary message from the WebSocket server
pub async fn receive_message(
&mut self,
format: Format,
) -> Result<serde_json::Value, Box<dyn Error>> {
let now = time::Instant::now();
debug!("Receiving response...");
loop {
tokio::select! {
_ = time::sleep(time::Duration::from_millis(5000)) => {
return Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
}
res = self.stream.try_next() => {
match res {
Ok(res) => match res {
Some(Message::Text(msg)) => {
debug!("Response {msg:?} received in {:?}", now.elapsed());
match format {
Format::Json => {
let msg = serde_json::from_str(&msg)?;
debug!("Received message: {msg}");
return Ok(msg);
},
}
},
Some(_) => {
continue;
}
None => {
return Err("Expected to receive a message".to_string().into());
}
},
Err(err) => {
return Err(format!("Error receiving the message: {}", err).into());
}
}
}
}
}
}
/// Send a text or binary message and receive a reponse from the WebSocket server
pub async fn send_and_receive_message(
&mut self,
format: Format,
message: serde_json::Value,
) -> Result<serde_json::Value, Box<dyn Error>> {
self.send_message(format, message).await?;
self.receive_message(format).await
}
/// When testing Live Queries, we may receive multiple messages unordered.
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
pub async fn receive_all_messages(
&mut self,
format: Format,
expected: usize,
timeout: Duration,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let mut res = Vec::new();
let deadline = time::Instant::now() + timeout;
loop {
tokio::select! {
_ = time::sleep_until(deadline) => {
debug!("Waited for {:?} and received {} messages", timeout, res.len());
if res.len() != expected {
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
}
}
msg = self.receive_message(format) => {
res.push(msg?);
}
}
if res.len() == expected {
return Ok(res);
}
}
}
/// Send a USE message to the server and check the response
pub async fn send_message_use(
&mut self,
format: Format,
ns: Option<&str>,
db: Option<&str>,
) -> Result<serde_json::Value, Box<dyn Error>> {
// Generate an ID
let id = uuid::Uuid::new_v4().to_string();
// Construct message
let msg = json!({
"id": id,
"method": "use",
"params": [
ns, db
],
});
// Send message and receive response
let msg = self.send_and_receive_message(format, msg).await?;
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!(
"expected a result from the received object, got this instead: {:?}",
obj
),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
/// Send a generic query message to the server and check the response
pub async fn send_message_query(
&mut self,
format: Format,
query: &str,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
// Generate an ID
let id = uuid::Uuid::new_v4().to_string();
// Construct message
let msg = json!({
"id": id,
"method": "query",
"params": [query],
});
// Send message and receive response
let msg = self.send_and_receive_message(format, msg).await?;
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_array()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
/// Send a signin authentication query message to the server and check the response
pub async fn send_message_signin(
&mut self,
format: Format,
user: &str,
pass: &str,
ns: Option<&str>,
db: Option<&str>,
sc: Option<&str>,
) -> Result<String, Box<dyn Error>> {
// Generate an ID
let id = uuid::Uuid::new_v4().to_string();
// Construct message
let msg = json!({
"id": id,
"method": "signin",
"params": [
SigninParams { user, pass, ns, db, sc }
],
});
// Send message and receive response
let msg = self.send_and_receive_message(format, msg).await?;
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_str()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
}

1177
tests/common/tests.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -1258,8 +1258,8 @@ mod http_integration {
.send()
.await?;
assert_eq!(res.status(), 200);
let _: serde_cbor::Value = serde_cbor::from_slice(&res.bytes().await?).unwrap();
let res = res.bytes().await?.to_vec();
let _: ciborium::Value = ciborium::from_reader(res.as_slice()).unwrap();
}
// Creating a record with Accept PACK encoding is allowed
@ -1272,8 +1272,8 @@ mod http_integration {
.send()
.await?;
assert_eq!(res.status(), 200);
let _: serde_cbor::Value = serde_pack::from_slice(&res.bytes().await?).unwrap();
let res = res.bytes().await?.to_vec();
let _: rmpv::Value = rmpv::decode::read_value(&mut res.as_slice()).unwrap();
}
// Creating a record with Accept Surrealdb encoding is allowed

File diff suppressed because it is too large Load diff