Improve WebSocket protocol implementation (#3291)

This commit is contained in:
Tobie Morgan Hitchcock 2024-01-09 15:27:03 +00:00 committed by GitHub
parent 8e9bd3a2d6
commit a35fc0d04d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 2300 additions and 2338 deletions

23
Cargo.lock generated
View file

@ -4445,6 +4445,16 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "rmpv"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e0e0214a4a2b444ecce41a4025792fc31f77c7bb89c46d253953ea8c65701ec"
dependencies = [
"num-traits",
"rmp",
]
[[package]] [[package]]
name = "roaring" name = "roaring"
version = "0.10.2" version = "0.10.2"
@ -4817,16 +4827,6 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_cbor"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
dependencies = [
"half",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.193" version = "1.0.193"
@ -5235,6 +5235,7 @@ dependencies = [
"axum-server", "axum-server",
"base64 0.21.5", "base64 0.21.5",
"bytes", "bytes",
"ciborium",
"clap", "clap",
"env_logger", "env_logger",
"futures", "futures",
@ -5257,10 +5258,10 @@ dependencies = [
"rcgen", "rcgen",
"reqwest", "reqwest",
"rmp-serde", "rmp-serde",
"rmpv",
"rustyline", "rustyline",
"semver", "semver",
"serde", "serde",
"serde_cbor",
"serde_json", "serde_json",
"serial_test", "serial_test",
"surrealdb", "surrealdb",

View file

@ -40,6 +40,7 @@ axum-extra = { version = "0.7.7", features = ["query", "typed-routing"] }
axum-server = { version = "0.5.1", features = ["tls-rustls"] } axum-server = { version = "0.5.1", features = ["tls-rustls"] }
base64 = "0.21.5" base64 = "0.21.5"
bytes = "1.5.0" bytes = "1.5.0"
ciborium = "0.2.1"
clap = { version = "4.4.11", features = ["env", "derive", "wrap_help", "unicode"] } clap = { version = "4.4.11", features = ["env", "derive", "wrap_help", "unicode"] }
futures = "0.3.29" futures = "0.3.29"
futures-util = "0.3.29" futures-util = "0.3.29"
@ -55,9 +56,9 @@ opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
pin-project-lite = "0.2.13" pin-project-lite = "0.2.13"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.11.22", default-features = false, features = ["blocking", "gzip"] } reqwest = { version = "0.11.22", default-features = false, features = ["blocking", "gzip"] }
rmpv = "1.0.1"
rustyline = { version = "12.0.0", features = ["derive"] } rustyline = { version = "12.0.0", features = ["derive"] }
serde = { version = "1.0.193", features = ["derive"] } serde = { version = "1.0.193", features = ["derive"] }
serde_cbor = "0.11.2"
serde_json = "1.0.108" serde_json = "1.0.108"
serde_pack = { version = "1.1.2", package = "rmp-serde" } serde_pack = { version = "1.1.2", package = "rmp-serde" }
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] } surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }

View file

@ -34,13 +34,13 @@ env = { RUST_LOG={ value = "http_integration=debug", condition = { env_not_set =
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"] args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"]
[tasks.ci-ws-integration] [tasks.ci-ws-integration]
category = "CI - INTEGRATION TESTS" category = "WS - INTEGRATION TESTS"
command = "cargo" command = "cargo"
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"] args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
[tasks.ci-ml-integration] [tasks.ci-ml-integration]
category = "CI - INTEGRATION TESTS" category = "ML - INTEGRATION TESTS"
command = "cargo" command = "cargo"
env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } } env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"] args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]

View file

@ -130,6 +130,7 @@ pub fn thing((arg1, arg2): (Value, Option<Value>)) -> Result<Value, Error> {
pub mod is { pub mod is {
use crate::err::Error; use crate::err::Error;
use crate::sql::table::Table;
use crate::sql::value::Value; use crate::sql::value::Value;
use crate::sql::Geometry; use crate::sql::Geometry;
@ -215,7 +216,7 @@ pub mod is {
pub fn record((arg, table): (Value, Option<String>)) -> Result<Value, Error> { pub fn record((arg, table): (Value, Option<String>)) -> Result<Value, Error> {
Ok(match table { Ok(match table {
Some(tb) => arg.is_record_of_table(tb).into(), Some(tb) => arg.is_record_type(&[Table(tb)]).into(),
None => arg.is_record().into(), None => arg.is_record().into(),
}) })
} }

View file

@ -7,6 +7,7 @@ use crate::sql::{escape::escape_rid, Array, Number, Object, Strand, Thing, Uuid,
use nanoid::nanoid; use nanoid::nanoid;
use revision::revisioned; use revision::revisioned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
use ulid::Ulid; use ulid::Ulid;
@ -106,6 +107,12 @@ impl From<Vec<Value>> for Id {
} }
} }
impl From<BTreeMap<String, Value>> for Id {
fn from(v: BTreeMap<String, Value>) -> Self {
Id::Object(v.into())
}
}
impl From<Number> for Id { impl From<Number> for Id {
fn from(v: Number) -> Self { fn from(v: Number) -> Self {
match v { match v {

View file

@ -23,6 +23,12 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Object";
#[revisioned(revision = 1)] #[revisioned(revision = 1)]
pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap<String, Value>); pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap<String, Value>);
impl From<BTreeMap<&str, Value>> for Object {
fn from(v: BTreeMap<&str, Value>) -> Self {
Self(v.into_iter().map(|(key, val)| (key.to_string(), val)).collect())
}
}
impl From<BTreeMap<String, Value>> for Object { impl From<BTreeMap<String, Value>> for Object {
fn from(v: BTreeMap<String, Value>) -> Self { fn from(v: BTreeMap<String, Value>) -> Self {
Self(v) Self(v)

View file

@ -2,6 +2,7 @@ use crate::ctx::Context;
use crate::dbs::{Options, Transaction}; use crate::dbs::{Options, Transaction};
use crate::doc::CursorDoc; use crate::doc::CursorDoc;
use crate::err::Error; use crate::err::Error;
use crate::sql::Uuid;
use crate::sql::Value; use crate::sql::Value;
use derive::Store; use derive::Store;
use revision::revisioned; use revision::revisioned;
@ -34,6 +35,14 @@ impl KillStatement {
Value::Uuid(id) => *id, Value::Uuid(id) => *id,
Value::Param(param) => match param.compute(ctx, opt, txn, None).await? { Value::Param(param) => match param.compute(ctx, opt, txn, None).await? {
Value::Uuid(id) => id, Value::Uuid(id) => id,
Value::Strand(id) => match uuid::Uuid::try_parse(&id) {
Ok(id) => Uuid(id),
_ => {
return Err(Error::KillStatement {
value: self.id.to_string(),
})
}
},
_ => { _ => {
return Err(Error::KillStatement { return Err(Error::KillStatement {
value: self.id.to_string(), value: self.id.to_string(),

View file

@ -458,6 +458,12 @@ impl From<Vec<bool>> for Value {
} }
} }
impl From<HashMap<&str, Value>> for Value {
fn from(v: HashMap<&str, Value>) -> Self {
Value::Object(Object::from(v))
}
}
impl From<HashMap<String, Value>> for Value { impl From<HashMap<String, Value>> for Value {
fn from(v: HashMap<String, Value>) -> Self { fn from(v: HashMap<String, Value>) -> Self {
Value::Object(Object::from(v)) Value::Object(Object::from(v))
@ -470,6 +476,12 @@ impl From<BTreeMap<String, Value>> for Value {
} }
} }
impl From<BTreeMap<&str, Value>> for Value {
fn from(v: BTreeMap<&str, Value>) -> Self {
Value::Object(Object::from(v))
}
}
impl From<Option<Value>> for Value { impl From<Option<Value>> for Value {
fn from(v: Option<Value>) -> Self { fn from(v: Option<Value>) -> Self {
match v { match v {
@ -730,6 +742,18 @@ impl TryFrom<Value> for Object {
} }
} }
impl FromIterator<Value> for Value {
fn from_iter<I: IntoIterator<Item = Value>>(iter: I) -> Self {
Value::Array(Array(iter.into_iter().collect()))
}
}
impl FromIterator<(String, Value)> for Value {
fn from_iter<I: IntoIterator<Item = (String, Value)>>(iter: I) -> Self {
Value::Object(Object(iter.into_iter().collect()))
}
}
impl Value { impl Value {
// ----------------------------------- // -----------------------------------
// Initial record value // Initial record value
@ -828,6 +852,11 @@ impl Value {
matches!(self, Value::Mock(_)) matches!(self, Value::Mock(_))
} }
/// Check if this Value is a Param
pub fn is_param(&self) -> bool {
matches!(self, Value::Param(_))
}
/// Check if this Value is a Range /// Check if this Value is a Range
pub fn is_range(&self) -> bool { pub fn is_range(&self) -> bool {
matches!(self, Value::Range(_)) matches!(self, Value::Range(_))
@ -952,11 +981,6 @@ impl Value {
} }
} }
/// Check if this Value is a Param
pub fn is_param(&self) -> bool {
matches!(self, Value::Param(_))
}
/// Check if this Value is a Geometry of a specific type /// Check if this Value is a Geometry of a specific type
pub fn is_geometry_type(&self, types: &[String]) -> bool { pub fn is_geometry_type(&self, types: &[String]) -> bool {
match self { match self {
@ -1089,7 +1113,7 @@ impl Value {
/// Treat a string as a table name /// Treat a string as a table name
pub fn could_be_table(self) -> Value { pub fn could_be_table(self) -> Value {
match self { match self {
Value::Strand(v) => Table::from(v.0).into(), Value::Strand(v) => Value::Table(v.0.into()),
_ => self, _ => self,
} }
} }

View file

@ -7,9 +7,6 @@ use base64::DecodeError as Base64Error;
use http::{HeaderName, StatusCode}; use http::{HeaderName, StatusCode};
use reqwest::Error as ReqwestError; use reqwest::Error as ReqwestError;
use serde::Serialize; use serde::Serialize;
use serde_cbor::error::Error as CborError;
use serde_json::error::Error as JsonError;
use serde_pack::encode::Error as PackError;
use std::io::Error as IoError; use std::io::Error as IoError;
use std::string::FromUtf8Error as Utf8Error; use std::string::FromUtf8Error as Utf8Error;
use surrealdb::error::Db as SurrealDbError; use surrealdb::error::Db as SurrealDbError;
@ -37,12 +34,12 @@ pub enum Error {
#[error("There was a problem connecting with the storage engine")] #[error("There was a problem connecting with the storage engine")]
InvalidStorage, InvalidStorage,
#[error("There was a problem parsing the header {0}: {1}")]
InvalidHeader(HeaderName, TypedHeaderRejection),
#[error("The operation is unsupported")] #[error("The operation is unsupported")]
OperationUnsupported, OperationUnsupported,
#[error("There was a problem parsing the header {0}: {1}")]
InvalidHeader(HeaderName, TypedHeaderRejection),
#[error("There was a problem with the database: {0}")] #[error("There was a problem with the database: {0}")]
Db(#[from] SurrealError), Db(#[from] SurrealError),
@ -52,14 +49,14 @@ pub enum Error {
#[error("There was an error with the network: {0}")] #[error("There was an error with the network: {0}")]
Axum(#[from] AxumError), Axum(#[from] AxumError),
#[error("There was an error serializing to JSON: {0}")] #[error("There was an error with JSON serialization: {0}")]
Json(#[from] JsonError), Json(String),
#[error("There was an error serializing to CBOR: {0}")] #[error("There was an error with CBOR serialization: {0}")]
Cbor(#[from] CborError), Cbor(String),
#[error("There was an error serializing to MessagePack: {0}")] #[error("There was an error with MessagePack serialization: {0}")]
Pack(#[from] PackError), Pack(String),
#[error("There was an error with the remote request: {0}")] #[error("There was an error with the remote request: {0}")]
Remote(#[from] ReqwestError), Remote(#[from] ReqwestError),
@ -93,6 +90,42 @@ impl From<Utf8Error> for Error {
} }
} }
impl From<serde_json::Error> for Error {
fn from(e: serde_json::Error) -> Error {
Error::Json(e.to_string())
}
}
impl From<serde_pack::encode::Error> for Error {
fn from(e: serde_pack::encode::Error) -> Error {
Error::Pack(e.to_string())
}
}
impl From<serde_pack::decode::Error> for Error {
fn from(e: serde_pack::decode::Error) -> Error {
Error::Pack(e.to_string())
}
}
impl From<ciborium::value::Error> for Error {
fn from(e: ciborium::value::Error) -> Error {
Error::Cbor(format!("{e}"))
}
}
impl<T: std::fmt::Debug> From<ciborium::de::Error<T>> for Error {
fn from(e: ciborium::de::Error<T>) -> Error {
Error::Cbor(format!("{e}"))
}
}
impl<T: std::fmt::Debug> From<ciborium::ser::Error<T>> for Error {
fn from(e: ciborium::ser::Error<T>) -> Error {
Error::Cbor(format!("{e}"))
}
}
impl From<surrealdb::error::Db> for Error { impl From<surrealdb::error::Db> for Error {
fn from(error: surrealdb::error::Db) -> Error { fn from(error: surrealdb::error::Db) -> Error {
if matches!(error, surrealdb::error::Db::InvalidAuth) { if matches!(error, surrealdb::error::Db::InvalidAuth) {

View file

@ -39,8 +39,9 @@ pub fn cbor<T>(val: &T) -> Output
where where
T: Serialize, T: Serialize,
{ {
match serde_cbor::to_vec(val) { let mut out = Vec::new();
Ok(v) => Output::Cbor(v), match ciborium::into_writer(&val, &mut out) {
Ok(_) => Output::Cbor(out),
Err(_) => Output::Fail, Err(_) => Output::Fail,
} }
} }

View file

@ -1,18 +1,21 @@
use crate::cnf; use crate::cnf;
use crate::err::Error;
use crate::rpc::connection::Connection; use crate::rpc::connection::Connection;
use crate::rpc::format::Format;
use crate::rpc::format::PROTOCOLS;
use crate::rpc::WEBSOCKETS;
use axum::routing::get; use axum::routing::get;
use axum::Extension; use axum::{
use axum::Router; extract::ws::{WebSocket, WebSocketUpgrade},
response::IntoResponse,
Extension, Router,
};
use http::HeaderValue;
use http_body::Body as HttpBody; use http_body::Body as HttpBody;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use tower_http::request_id::RequestId; use tower_http::request_id::RequestId;
use uuid::Uuid; use uuid::Uuid;
use axum::{
extract::ws::{WebSocket, WebSocketUpgrade},
response::IntoResponse,
};
pub(super) fn router<S, B>() -> Router<S, B> pub(super) fn router<S, B>() -> Router<S, B>
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
@ -23,28 +26,53 @@ where
async fn handler( async fn handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(id): Extension<RequestId>,
Extension(sess): Extension<Session>, Extension(sess): Extension<Session>,
Extension(req_id): Extension<RequestId>, ) -> Result<impl IntoResponse, impl IntoResponse> {
) -> impl IntoResponse { // Check if there is a request id header specified
ws let id = match id.header_value().is_empty() {
// Set the maximum frame size // No request id was specified so create a new id
true => Uuid::new_v4(),
// A request id was specified to try to parse it
false => match id.header_value().to_str() {
// Attempt to parse the request id as a UUID
Ok(id) => match Uuid::try_parse(id) {
// The specified request id was a valid UUID
Ok(id) => id,
// The specified request id was not a UUID
Err(_) => return Err(Error::Request),
},
// The request id contained invalid characters
Err(_) => return Err(Error::Request),
},
};
// Check if a connection with this id already exists
if WEBSOCKETS.read().await.contains_key(&id) {
return Err(Error::Request);
}
// Now let's upgrade the WebSocket connection
Ok(ws
// Set the potential WebSocket protocols
.protocols(PROTOCOLS)
// Set the maximum WebSocket frame size
.max_frame_size(*cnf::WEBSOCKET_MAX_FRAME_SIZE) .max_frame_size(*cnf::WEBSOCKET_MAX_FRAME_SIZE)
// Set the maximum message size // Set the maximum WebSocket message size
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE) .max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
// Set the potential WebSocket protocol formats
.protocols(["surrealql-binary", "json", "cbor", "messagepack"])
// Handle the WebSocket upgrade and process messages // Handle the WebSocket upgrade and process messages
.on_upgrade(move |socket| handle_socket(socket, sess, req_id)) .on_upgrade(move |socket| handle_socket(socket, sess, id)))
} }
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) { async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
// Check if there is a WebSocket protocol specified
let format = match ws.protocol().map(HeaderValue::to_str) {
// Any selected protocol will always be a valie value
Some(protocol) => protocol.unwrap().into(),
// No protocol format was specified
_ => Format::None,
};
//
// Create a new connection instance // Create a new connection instance
let rpc = Connection::new(sess); let rpc = Connection::new(id, sess, format);
// Update the WebSocket ID with the Request ID
if let Ok(Ok(req_id)) = req_id.header_value().to_str().map(Uuid::parse_str) {
// If the ID couldn't be updated, ignore the error and keep the default ID
let _ = rpc.write().await.update_ws_id(req_id).await;
}
// Serve the socket connection requests // Serve the socket connection requests
Connection::serve(rpc, ws).await; Connection::serve(rpc, ws).await;
} }

View file

@ -11,7 +11,7 @@ pub trait Take {
impl Take for Array { impl Take for Array {
/// Convert the array to one argument /// Convert the array to one argument
fn needs_one(self) -> Result<Value, ()> { fn needs_one(self) -> Result<Value, ()> {
if self.is_empty() { if self.len() != 1 {
return Err(()); return Err(());
} }
let mut x = self.into_iter(); let mut x = self.into_iter();
@ -22,7 +22,7 @@ impl Take for Array {
} }
/// Convert the array to two arguments /// Convert the array to two arguments
fn needs_two(self) -> Result<(Value, Value), ()> { fn needs_two(self) -> Result<(Value, Value), ()> {
if self.len() < 2 { if self.len() != 2 {
return Err(()); return Err(());
} }
let mut x = self.into_iter(); let mut x = self.into_iter();
@ -34,7 +34,7 @@ impl Take for Array {
} }
/// Convert the array to two arguments /// Convert the array to two arguments
fn needs_one_or_two(self) -> Result<(Value, Value), ()> { fn needs_one_or_two(self) -> Result<(Value, Value), ()> {
if self.is_empty() { if self.is_empty() && self.len() > 2 {
return Err(()); return Err(());
} }
let mut x = self.into_iter(); let mut x = self.into_iter();
@ -46,7 +46,7 @@ impl Take for Array {
} }
/// Convert the array to three arguments /// Convert the array to three arguments
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), ()> { fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), ()> {
if self.is_empty() { if self.is_empty() && self.len() > 3 {
return Err(()); return Err(());
} }
let mut x = self.into_iter(); let mut x = self.into_iter();

View file

@ -1,11 +1,12 @@
use super::request::parse_request;
use super::response::{failure, success, Data, Failure, IntoRpcResponse, OutputFormat};
use crate::cnf::PKG_NAME; use crate::cnf::PKG_NAME;
use crate::cnf::PKG_VERSION; use crate::cnf::PKG_VERSION;
use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY}; use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY};
use crate::dbs::DB; use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::rpc::args::Take; use crate::rpc::args::Take;
use crate::rpc::failure::Failure;
use crate::rpc::format::Format;
use crate::rpc::response::{failure, success, Data, IntoRpcResponse};
use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS}; use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
use crate::telemetry; use crate::telemetry;
use crate::telemetry::metrics::ws::RequestContext; use crate::telemetry::metrics::ws::RequestContext;
@ -33,9 +34,9 @@ use tracing::Span;
use uuid::Uuid; use uuid::Uuid;
pub struct Connection { pub struct Connection {
ws_id: Uuid, id: Uuid,
session: Session, session: Session,
format: OutputFormat, format: Format,
vars: BTreeMap<String, Value>, vars: BTreeMap<String, Value>,
limiter: Arc<Semaphore>, limiter: Arc<Semaphore>,
canceller: CancellationToken, canceller: CancellationToken,
@ -43,34 +44,20 @@ pub struct Connection {
impl Connection { impl Connection {
/// Instantiate a new RPC /// Instantiate a new RPC
pub fn new(mut session: Session) -> Arc<RwLock<Connection>> { pub fn new(id: Uuid, mut session: Session, format: Format) -> Arc<RwLock<Connection>> {
// Create a new RPC variables store
let vars = BTreeMap::new();
// Set the default output format
let format = OutputFormat::Json;
// Enable real-time mode // Enable real-time mode
session.rt = true; session.rt = true;
// Create and store the RPC connection // Create and store the RPC connection
Arc::new(RwLock::new(Connection { Arc::new(RwLock::new(Connection {
ws_id: Uuid::new_v4(), id,
session, session,
format, format,
vars, vars: BTreeMap::new(),
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)), limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
canceller: CancellationToken::new(), canceller: CancellationToken::new(),
})) }))
} }
/// Update the WebSocket ID. If the ID already exists, do not update it.
pub async fn update_ws_id(&mut self, ws_id: Uuid) -> Result<(), Box<dyn std::error::Error>> {
if WEBSOCKETS.read().await.contains_key(&ws_id) {
trace!("WebSocket ID '{}' is in use by another connection. Do not update it.", &ws_id);
return Err("websocket ID is in use".into());
}
self.ws_id = ws_id;
Ok(())
}
/// Serve the RPC endpoint /// Serve the RPC endpoint
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) { pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
// Split the socket into send and recv // Split the socket into send and recv
@ -79,19 +66,19 @@ impl Connection {
let (internal_sender, internal_receiver) = let (internal_sender, internal_receiver) =
channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS); channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS);
let ws_id = rpc.read().await.ws_id; let id = rpc.read().await.id;
trace!("WebSocket {} connected", ws_id); trace!("WebSocket {} connected", id);
if let Err(err) = telemetry::metrics::ws::on_connect() { if let Err(err) = telemetry::metrics::ws::on_connect() {
error!("Error running metrics::ws::on_connect hook: {}", err); error!("Error running metrics::ws::on_connect hook: {}", err);
} }
// Add this WebSocket to the list // Add this WebSocket to the list
WEBSOCKETS.write().await.insert( WEBSOCKETS
ws_id, .write()
WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()), .await
); .insert(id, WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()));
// Spawn async tasks for the WebSocket // Spawn async tasks for the WebSocket
let mut tasks = JoinSet::new(); let mut tasks = JoinSet::new();
@ -109,15 +96,15 @@ impl Connection {
internal_sender.close(); internal_sender.close();
trace!("WebSocket {} disconnected", ws_id); trace!("WebSocket {} disconnected", id);
// Remove this WebSocket from the list // Remove this WebSocket from the list
WEBSOCKETS.write().await.remove(&ws_id); WEBSOCKETS.write().await.remove(&id);
// Remove all live queries // Remove all live queries
let mut gc = Vec::new(); let mut gc = Vec::new();
LIVE_QUERIES.write().await.retain(|key, value| { LIVE_QUERIES.write().await.retain(|key, value| {
if value == &ws_id { if value == &id {
trace!("Removing live query: {}", key); trace!("Removing live query: {}", key);
gc.push(*key); gc.push(*key);
return false; return false;
@ -288,9 +275,9 @@ impl Connection {
msg = channel.recv() => { msg = channel.recv() => {
if let Ok(notification) = msg { if let Ok(notification) = msg {
// Find which WebSocket the notification belongs to // Find which WebSocket the notification belongs to
if let Some(ws_id) = LIVE_QUERIES.read().await.get(&notification.id) { if let Some(id) = LIVE_QUERIES.read().await.get(&notification.id) {
// Check to see if the WebSocket exists // Check to see if the WebSocket exists
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) { if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(id) {
// Serialize the message to send // Serialize the message to send
let message = success(None, notification); let message = success(None, notification);
// Get the current output format // Get the current output format
@ -309,27 +296,41 @@ impl Connection {
/// Handle individual WebSocket messages /// Handle individual WebSocket messages
async fn handle_message(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) { async fn handle_message(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) {
// Get the current output format // Get the current output format
let mut out_fmt = rpc.read().await.format; let mut fmt = rpc.read().await.format;
// Prepare Span and Otel context // Prepare Span and Otel context
let span = span_for_request(&rpc.read().await.ws_id); let span = span_for_request(&rpc.read().await.id);
// Acquire concurrent request rate limiter // Acquire concurrent request rate limiter
let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap(); let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap();
// Calculate the length of the message
let len = match msg {
Message::Text(ref msg) => {
// If no format was specified, default to JSON
if fmt.is_none() {
fmt = Format::Json;
rpc.write().await.format = fmt;
}
// Retrieve the length of the message
msg.len()
}
Message::Binary(ref msg) => {
// If no format was specified, default to Bincode
if fmt.is_none() {
fmt = Format::Bincode;
rpc.write().await.format = fmt;
}
// Retrieve the length of the message
msg.len()
}
_ => unreachable!(),
};
// Parse the request // Parse the request
async move { async move {
let span = Span::current(); let span = Span::current();
let req_cx = RequestContext::default(); let req_cx = RequestContext::default();
let otel_cx = TelemetryContext::new().with_value(req_cx.clone()); let otel_cx = TelemetryContext::new().with_value(req_cx.clone());
// Parse the RPC request structure
match parse_request(msg).await { match fmt.req(msg) {
Ok(req) => { Ok(req) => {
if let Some(fmt) = req.out_fmt {
if out_fmt != fmt {
// Update the default format
rpc.write().await.format = fmt;
out_fmt = fmt;
}
}
// Now that we know the method, we can update the span and create otel context // Now that we know the method, we can update the span and create otel context
span.record("rpc.method", &req.method); span.record("rpc.method", &req.method);
span.record("otel.name", format!("surrealdb.rpc/{}", req.method)); span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
@ -338,17 +339,17 @@ impl Connection {
req.id.clone().map(Value::as_string).unwrap_or_default(), req.id.clone().map(Value::as_string).unwrap_or_default(),
); );
let otel_cx = TelemetryContext::current_with_value( let otel_cx = TelemetryContext::current_with_value(
req_cx.with_method(&req.method).with_size(req.size), req_cx.with_method(&req.method).with_size(len),
); );
// Process the message // Process the message
let res = let res =
Connection::process_message(rpc.clone(), &req.method, req.params).await; Connection::process_message(rpc.clone(), &req.method, req.params).await;
// Process the response // Process the response
res.into_response(req.id).send(out_fmt, &chn).with_context(otel_cx).await res.into_response(req.id).send(fmt, &chn).with_context(otel_cx).await
} }
Err(err) => { Err(err) => {
// Process the response // Process the response
failure(None, err).send(out_fmt, &chn).with_context(otel_cx).await failure(None, err).send(fmt, &chn).with_context(otel_cx).await
} }
} }
} }
@ -485,13 +486,6 @@ impl Connection {
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into), Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS), _ => Err(Failure::INVALID_PARAMS),
}, },
// Specify the output format for text requests
"format" => match params.needs_one() {
Ok(Value::Strand(v)) => {
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Get the current server version // Get the current server version
"version" => match params.len() { "version" => match params.len() {
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()), 0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
@ -515,16 +509,6 @@ impl Connection {
// Methods for authentication // Methods for authentication
// ------------------------------ // ------------------------------
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
match out.as_str() {
"json" | "application/json" => self.format = OutputFormat::Json,
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
"pack" | "application/pack" => self.format = OutputFormat::Pack,
_ => return Err(Error::InvalidType),
};
Ok(Value::None)
}
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> { async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
if let Value::Strand(ns) = ns { if let Value::Strand(ns) = ns {
self.session.ns = Some(ns.0); self.session.ns = Some(ns.0);
@ -615,7 +599,7 @@ impl Connection {
let sql = "KILL $id"; let sql = "KILL $id";
// Specify the query parameters // Specify the query parameters
let var = map! { let var = map! {
String::from("id") => id, // NOTE: id can be parameter String::from("id") => id,
=> &self.vars => &self.vars
}; };
// Execute the query on the database // Execute the query on the database
@ -910,15 +894,14 @@ impl Connection {
QueryType::Live => { QueryType::Live => {
if let Ok(Value::Uuid(lqid)) = &res.result { if let Ok(Value::Uuid(lqid)) = &res.result {
// Match on Uuid type // Match on Uuid type
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id); LIVE_QUERIES.write().await.insert(lqid.0, self.id);
trace!("Registered live query {} on websocket {}", lqid, self.ws_id); trace!("Registered live query {} on websocket {}", lqid, self.id);
} }
} }
QueryType::Kill => { QueryType::Kill => {
if let Ok(Value::Uuid(lqid)) = &res.result { if let Ok(Value::Uuid(lqid)) = &res.result {
let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0); if let Some(id) = LIVE_QUERIES.write().await.remove(&lqid.0) {
if let Some(ws_id) = ws_id { trace!("Unregistered live query {} on websocket {}", lqid, id);
trace!("Unregistered live query {} on websocket {}", lqid, ws_id);
} }
} }
} }

70
src/rpc/failure.rs Normal file
View file

@ -0,0 +1,70 @@
use crate::err::Error;
use serde::Serialize;
use std::borrow::Cow;
use surrealdb::sql::Value;
#[derive(Clone, Debug, Serialize)]
pub struct Failure {
pub(crate) code: i64,
pub(crate) message: Cow<'static, str>,
}
impl From<&str> for Failure {
fn from(err: &str) -> Self {
Failure::custom(err.to_string())
}
}
impl From<Error> for Failure {
fn from(err: Error) -> Self {
Failure::custom(err.to_string())
}
}
impl From<Failure> for Value {
fn from(err: Failure) -> Self {
map! {
String::from("code") => Value::from(err.code),
String::from("message") => Value::from(err.message.to_string()),
}
.into()
}
}
#[allow(dead_code)]
impl Failure {
pub const PARSE_ERROR: Failure = Failure {
code: -32700,
message: Cow::Borrowed("Parse error"),
};
pub const INVALID_REQUEST: Failure = Failure {
code: -32600,
message: Cow::Borrowed("Invalid Request"),
};
pub const METHOD_NOT_FOUND: Failure = Failure {
code: -32601,
message: Cow::Borrowed("Method not found"),
};
pub const INVALID_PARAMS: Failure = Failure {
code: -32602,
message: Cow::Borrowed("Invalid params"),
};
pub const INTERNAL_ERROR: Failure = Failure {
code: -32603,
message: Cow::Borrowed("Internal error"),
};
pub fn custom<S>(message: S) -> Failure
where
Cow<'static, str>: From<S>,
{
Failure {
code: -32000,
message: message.into(),
}
}
}

22
src/rpc/format/bincode.rs Normal file
View file

@ -0,0 +1,22 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
use surrealdb::sql::serde::deserialize;
use surrealdb::sql::Value;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Binary(val) => {
deserialize::<Value>(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Serialize the response with full internal type information
let res = surrealdb::sql::serde::serialize(&res).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

View file

@ -0,0 +1,24 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Create a new vector for encoding output
let mut res = Vec::new();
// Serialize the value into CBOR binary data
ciborium::into_writer(&val, &mut res).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

22
src/rpc/format/json.rs Normal file
View file

@ -0,0 +1,22 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Serialize the response with simplified type information
let res = serde_json::to_string(&val).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Text(res)))
}

70
src/rpc/format/mod.rs Normal file
View file

@ -0,0 +1,70 @@
mod bincode;
pub mod cbor;
mod json;
pub mod msgpack;
mod revision;
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub const PROTOCOLS: [&str; 5] = [
"json", // For basic JSON serialisation
"cbor", // For basic CBOR serialisation
"msgpack", // For basic Msgpack serialisation
"bincode", // For full internal serialisation
"revision", // For full versioned serialisation
];
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum Format {
None, // No format is specified yet
Json, // For basic JSON serialisation
Cbor, // For basic CBOR serialisation
Msgpack, // For basic Msgpack serialisation
Bincode, // For full internal serialisation
Revision, // For full versioned serialisation
}
impl From<&str> for Format {
fn from(v: &str) -> Self {
match v {
s if s == PROTOCOLS[0] => Format::Json,
s if s == PROTOCOLS[1] => Format::Cbor,
s if s == PROTOCOLS[2] => Format::Msgpack,
s if s == PROTOCOLS[3] => Format::Bincode,
s if s == PROTOCOLS[4] => Format::Revision,
_ => Format::None,
}
}
}
impl Format {
/// Check if this format has been set
pub fn is_none(&self) -> bool {
matches!(self, Format::None)
}
/// Process a request using the specified format
pub fn req(&self, msg: Message) -> Result<Request, Failure> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Json => json::req(msg),
Self::Cbor => cbor::req(msg),
Self::Msgpack => msgpack::req(msg),
Self::Bincode => bincode::req(msg),
Self::Revision => revision::req(msg),
}
}
/// Process a response using the specified format
pub fn res(&self, res: Response) -> Result<(usize, Message), Failure> {
match self {
Self::None => unreachable!(), // We should never arrive at this code
Self::Json => json::res(res),
Self::Cbor => cbor::res(res),
Self::Msgpack => msgpack::res(res),
Self::Bincode => bincode::res(res),
Self::Revision => revision::res(res),
}
}
}

View file

@ -0,0 +1,22 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(msg: Message) -> Result<Request, Failure> {
match msg {
Message::Text(val) => {
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
}
_ => Err(Failure::INVALID_REQUEST),
}
}
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
// Convert the response into simplified JSON
let val = res.into_json();
// Serialize the value into MsgPack binary data
let res = serde_pack::to_vec(&val).unwrap();
// Return the message length, and message as binary
Ok((res.len(), Message::Binary(res)))
}

View file

@ -0,0 +1,14 @@
use crate::rpc::failure::Failure;
use crate::rpc::request::Request;
use crate::rpc::response::Response;
use axum::extract::ws::Message;
pub fn req(_msg: Message) -> Result<Request, Failure> {
// This format is not yet implemented
Err(Failure::INTERNAL_ERROR)
}
pub fn res(_res: Response) -> Result<(usize, Message), Failure> {
// This format is not yet implemented
Err(Failure::INTERNAL_ERROR)
}

View file

@ -1,12 +1,14 @@
pub mod args; pub mod args;
pub mod connection; pub mod connection;
pub mod failure;
pub mod format;
pub mod request; pub mod request;
pub mod response; pub mod response;
use std::{collections::HashMap, time::Duration};
use axum::extract::ws::Message; use axum::extract::ws::Message;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::time::Duration;
use surrealdb::channel::Sender; use surrealdb::channel::Sender;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;

View file

@ -1,10 +1,7 @@
use axum::extract::ws::Message; use crate::rpc::failure::Failure;
use surrealdb::sql::{serde::deserialize, Array, Value};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use surrealdb::sql::Part; use surrealdb::sql::Part;
use surrealdb::sql::{Array, Value};
use super::response::{Failure, OutputFormat};
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]); pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]); pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
@ -14,46 +11,13 @@ pub struct Request {
pub id: Option<Value>, pub id: Option<Value>,
pub method: String, pub method: String,
pub params: Array, pub params: Array,
pub size: usize,
pub out_fmt: Option<OutputFormat>,
} }
/// Parse the RPC request impl TryFrom<Value> for Request {
pub async fn parse_request(msg: Message) -> Result<Request, Failure> { type Error = Failure;
let mut out_fmt = None; fn try_from(val: Value) -> Result<Self, Failure> {
let (req, size) = match msg {
// This is a binary message
Message::Binary(val) => {
// Use binary output
out_fmt = Some(OutputFormat::Full);
match deserialize(&val) {
Ok(v) => (v, val.len()),
Err(_) => {
debug!("Error when trying to deserialize the request");
return Err(Failure::PARSE_ERROR);
}
}
}
// This is a text message
Message::Text(ref val) => {
// Parse the SurrealQL object
match surrealdb::sql::value(val) {
// The SurrealQL message parsed ok
Ok(v) => (v, val.len()),
// The SurrealQL message failed to parse
_ => return Err(Failure::PARSE_ERROR),
}
}
// Unsupported message type
_ => {
debug!("Unsupported message type: {:?}", msg);
return Err(Failure::custom("Unsupported message type"));
}
};
// Fetch the 'id' argument // Fetch the 'id' argument
let id = match req.pick(&*ID) { let id = match val.pick(&*ID) {
v if v.is_none() => None, v if v.is_none() => None,
v if v.is_null() => Some(v), v if v.is_null() => Some(v),
v if v.is_uuid() => Some(v), v if v.is_uuid() => Some(v),
@ -63,22 +27,20 @@ pub async fn parse_request(msg: Message) -> Result<Request, Failure> {
_ => return Err(Failure::INVALID_REQUEST), _ => return Err(Failure::INVALID_REQUEST),
}; };
// Fetch the 'method' argument // Fetch the 'method' argument
let method = match req.pick(&*METHOD) { let method = match val.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(), Value::Strand(v) => v.to_raw(),
_ => return Err(Failure::INVALID_REQUEST), _ => return Err(Failure::INVALID_REQUEST),
}; };
// Fetch the 'params' argument // Fetch the 'params' argument
let params = match req.pick(&*PARAMS) { let params = match val.pick(&*PARAMS) {
Value::Array(v) => v, Value::Array(v) => v,
_ => Array::new(), _ => Array::new(),
}; };
// Return the parsed request
Ok(Request { Ok(Request {
id, id,
method, method,
params, params,
size,
out_fmt,
}) })
}
} }

View file

@ -1,10 +1,10 @@
use crate::err; use crate::rpc::failure::Failure;
use crate::rpc::format::Format;
use crate::telemetry::metrics::ws::record_rpc; use crate::telemetry::metrics::ws::record_rpc;
use axum::extract::ws::Message; use axum::extract::ws::Message;
use opentelemetry::Context as TelemetryContext; use opentelemetry::Context as TelemetryContext;
use serde::Serialize; use serde::Serialize;
use serde_json::{json, Value as Json}; use serde_json::Value as Json;
use std::borrow::Cow;
use surrealdb::channel::Sender; use surrealdb::channel::Sender;
use surrealdb::dbs; use surrealdb::dbs;
use surrealdb::dbs::Notification; use surrealdb::dbs::Notification;
@ -12,14 +12,6 @@ use surrealdb::sql;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use tracing::Span; use tracing::Span;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum OutputFormat {
Json, // JSON
Cbor, // CBOR
Pack, // MessagePack
Full, // Full type serialization
}
/// The data returned by the database /// The data returned by the database
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data` // The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
// In future, they will possibly be merged to avoid having to keep them in sync. // In future, they will possibly be merged to avoid having to keep them in sync.
@ -46,15 +38,25 @@ impl From<String> for Data {
} }
} }
impl From<Notification> for Data {
fn from(n: Notification) -> Self {
Data::Live(n)
}
}
impl From<Vec<dbs::Response>> for Data { impl From<Vec<dbs::Response>> for Data {
fn from(v: Vec<dbs::Response>) -> Self { fn from(v: Vec<dbs::Response>) -> Self {
Data::Query(v) Data::Query(v)
} }
} }
impl From<Notification> for Data { impl From<Data> for Value {
fn from(n: Notification) -> Self { fn from(val: Data) -> Self {
Data::Live(n) match val {
Data::Query(v) => sql::to_value(v).unwrap(),
Data::Live(v) => sql::to_value(v).unwrap(),
Data::Other(v) => v,
}
} }
} }
@ -67,32 +69,31 @@ pub struct Response {
impl Response { impl Response {
/// Convert and simplify the value into JSON /// Convert and simplify the value into JSON
#[inline] #[inline]
fn simplify(self) -> Json { pub fn into_json(self) -> Json {
let mut value = match self.result { Json::from(self.into_value())
Ok(data) => {
let value = match data {
Data::Query(vec) => sql::to_value(vec).unwrap(),
Data::Live(notification) => sql::to_value(notification).unwrap(),
Data::Other(value) => value,
};
json!({
"result": Json::from(value),
})
} }
Err(failure) => json!({
"error": failure, #[inline]
}), pub fn into_value(self) -> Value {
let mut value = match self.result {
Ok(val) => map! {
"result" => Value::from(val),
},
Err(err) => map! {
"error" => Value::from(err),
},
}; };
if let Some(id) = self.id { if let Some(id) = self.id {
value["id"] = id.into(); value.insert("id", id);
} }
value value.into()
} }
/// Send the response to the WebSocket channel /// Send the response to the WebSocket channel
pub async fn send(self, out: OutputFormat, chn: &Sender<Message>) { pub async fn send(self, fmt: Format, chn: &Sender<Message>) {
// Create a new tracing span
let span = Span::current(); let span = Span::current();
// Log the rpc response call
debug!("Process RPC response"); debug!("Process RPC response");
let is_error = self.result.is_err(); let is_error = self.result.is_err();
@ -105,73 +106,12 @@ impl Response {
span.record("rpc.error_code", err.code); span.record("rpc.error_code", err.code);
span.record("rpc.error_message", err.message.as_ref()); span.record("rpc.error_message", err.message.as_ref());
} }
// Process the response for the format
let (res_size, message) = match out { let (len, msg) = fmt.res(self).unwrap();
OutputFormat::Json => { // Send the message to the write channel
let res = serde_json::to_string(&self.simplify()).unwrap(); if chn.send(msg).await.is_ok() {
(res.len(), Message::Text(res)) record_rpc(&TelemetryContext::current(), len, is_error);
}
OutputFormat::Cbor => {
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
(res.len(), Message::Binary(res))
}
OutputFormat::Pack => {
let res = serde_pack::to_vec(&self.simplify()).unwrap();
(res.len(), Message::Binary(res))
}
OutputFormat::Full => {
let res = surrealdb::sql::serde::serialize(&self).unwrap();
(res.len(), Message::Binary(res))
}
}; };
if chn.send(message).await.is_ok() {
record_rpc(&TelemetryContext::current(), res_size, is_error);
};
}
}
#[derive(Clone, Debug, Serialize)]
pub struct Failure {
code: i64,
message: Cow<'static, str>,
}
#[allow(dead_code)]
impl Failure {
pub const PARSE_ERROR: Failure = Failure {
code: -32700,
message: Cow::Borrowed("Parse error"),
};
pub const INVALID_REQUEST: Failure = Failure {
code: -32600,
message: Cow::Borrowed("Invalid Request"),
};
pub const METHOD_NOT_FOUND: Failure = Failure {
code: -32601,
message: Cow::Borrowed("Method not found"),
};
pub const INVALID_PARAMS: Failure = Failure {
code: -32602,
message: Cow::Borrowed("Invalid params"),
};
pub const INTERNAL_ERROR: Failure = Failure {
code: -32603,
message: Cow::Borrowed("Internal error"),
};
pub fn custom<S>(message: S) -> Failure
where
Cow<'static, str>: From<S>,
{
Failure {
code: -32000,
message: message.into(),
}
} }
} }
@ -191,12 +131,6 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
} }
} }
impl From<err::Error> for Failure {
fn from(err: err::Error) -> Self {
Failure::custom(err.to_string())
}
}
pub trait IntoRpcResponse { pub trait IntoRpcResponse {
fn into_response(self, id: Option<Value>) -> Response; fn into_response(self, id: Option<Value>) -> Response;
} }

View file

@ -3,6 +3,8 @@ mod common;
mod cli_integration { mod cli_integration {
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild}; use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
use common::Format;
use common::Socket;
use std::fs; use std::fs;
use std::fs::File; use std::fs::File;
use std::time; use std::time;
@ -735,15 +737,13 @@ mod cli_integration {
let (addr, mut server) = common::start_server_without_auth().await.unwrap(); let (addr, mut server) = common::start_server_without_auth().await.unwrap();
// Create a long-lived WS connection so the server don't shutdown gracefully // Create a long-lived WS connection so the server don't shutdown gracefully
let mut socket = common::connect_ws(&addr).await.expect("Failed to connect to server"); let mut socket = Socket::connect(&addr, None).await.expect("Failed to connect to server");
let json = serde_json::json!({ let json = serde_json::json!({
"id": "1", "id": "1",
"method": "query", "method": "query",
"params": ["SLEEP 30s;"], "params": ["SLEEP 30s;"],
}); });
common::ws_send_msg(&mut socket, serde_json::to_string(&json).unwrap()) socket.send_message(Format::Json, json).await.expect("Failed to send WS message");
.await
.expect("Failed to send WS message");
// Make sure the SLEEP query is being executed // Make sure the SLEEP query is being executed
tokio::select! { tokio::select! {

14
tests/common/format.rs Normal file
View file

@ -0,0 +1,14 @@
use std::string::ToString;
#[derive(Debug, Copy, Clone)]
pub enum Format {
Json,
}
impl ToString for Format {
fn to_string(&self) -> String {
match self {
Self::Json => "json".to_owned(),
}
}
}

View file

@ -1,489 +1,17 @@
#![allow(unused_imports)]
#![allow(dead_code)] #![allow(dead_code)]
pub mod error; pub mod error;
pub mod format;
pub mod server;
pub mod socket;
use crate::common::error::TestError; pub use format::*;
use futures_util::{SinkExt, StreamExt, TryStreamExt}; pub use server::*;
use rand::{thread_rng, Rng}; pub use socket::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::error::Error;
use std::fs::File;
use std::path::Path;
use std::process::{Command, Stdio};
use std::time::Duration;
use std::{env, fs};
use tokio::net::TcpStream;
use tokio::time;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tracing::{debug, error, info};
pub const USER: &str = "root";
pub const PASS: &str = "root";
/// Child is a (maybe running) CLI process. It can be killed by dropping it
pub struct Child {
inner: Option<std::process::Child>,
stdout_path: String,
stderr_path: String,
}
impl Child {
/// Send some thing to the child's stdin
pub fn input(mut self, input: &str) -> Self {
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
use std::io::Write;
stdin.write_all(input.as_bytes()).unwrap();
self
}
pub fn kill(mut self) -> Self {
self.inner.as_mut().unwrap().kill().unwrap();
self
}
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
nix::sys::signal::kill(
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
signal,
)
}
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
self.inner.as_mut().unwrap().try_wait()
}
pub fn stdout(&self) -> String {
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
}
pub fn stderr(&self) -> String {
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
}
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
pub fn output(mut self) -> Result<String, String> {
let status = self.inner.take().unwrap().wait().unwrap();
let mut buf = self.stdout();
buf.push_str(&self.stderr());
// Cleanup files after reading them
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
if status.success() {
Ok(buf)
} else {
Err(buf)
}
}
}
impl Drop for Child {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_mut() {
let _ = inner.kill();
}
}
}
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
let mut path = std::env::current_exe().unwrap();
assert!(path.pop());
if path.ends_with("deps") {
assert!(path.pop());
}
// Note: Cargo automatically builds this binary for integration tests.
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
let mut cmd = Command::new(path);
if let Some(dir) = current_dir {
cmd.current_dir(&dir);
}
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
let stdout_path = tmp_file("server-stdout.log");
let stderr_path = tmp_file("server-stderr.log");
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
cmd.env_clear();
cmd.stdin(Stdio::piped());
cmd.stdout(stdout);
cmd.stderr(stderr);
cmd.args(args.split_ascii_whitespace());
Child {
inner: Some(cmd.spawn().unwrap()),
stdout_path,
stderr_path,
}
}
/// Run the CLI with the given args
pub fn run(args: &str) -> Child {
run_internal::<String>(args, None)
}
/// Run the CLI with the given args inside a temporary directory
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir))
}
pub fn tmp_file(name: &str) -> String {
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
path.to_string_lossy().into_owned()
}
pub struct StartServerArguments {
pub auth: bool,
pub tls: bool,
pub wait_is_ready: bool,
pub enable_auth_level: bool,
pub tick_interval: time::Duration,
pub args: String,
}
impl Default for StartServerArguments {
fn default() -> Self {
Self {
auth: true,
tls: false,
wait_is_ready: true,
enable_auth_level: false,
tick_interval: time::Duration::new(1, 0),
args: "--allow-all".to_string(),
}
}
}
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
auth: false,
..Default::default()
})
.await
}
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
enable_auth_level: true,
..Default::default()
})
.await
}
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments::default()).await
}
pub async fn start_server(
StartServerArguments {
auth,
tls,
wait_is_ready,
enable_auth_level,
tick_interval,
args,
}: StartServerArguments,
) -> Result<(String, Child), Box<dyn Error>> {
let mut rng = thread_rng();
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let mut extra_args = args.clone();
if tls {
// Test the crt/key args but the keys are self signed so don't actually connect.
let crt_path = tmp_file("crt.crt");
let key_path = tmp_file("key.pem");
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
}
if auth {
extra_args.push_str(" --auth");
}
if enable_auth_level {
extra_args.push_str(" --auth-level-enabled");
}
if !tick_interval.is_zero() {
let sec = tick_interval.as_secs();
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
}
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
info!("starting server with args: {start_args}");
// Configure where the logs go when running the test
let server = run_internal::<String>(&start_args, None);
if !wait_is_ready {
return Ok((addr, server));
}
// Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(1000));
info!("Waiting for server to start...");
for _i in 0..10 {
interval.tick().await;
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
info!("Server ready!");
return Ok((addr, server));
}
}
let server_out = server.kill().output().err().unwrap();
error!("server output: {server_out}");
Err("server failed to start".into())
}
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
let url = format!("ws://{}/rpc", addr);
let (ws_stream, _) = connect_async(url).await?;
Ok(ws_stream)
}
pub async fn ws_send_msg(socket: &mut WsStream, msg_req: String) -> Result<(), Box<dyn Error>> {
let now = time::Instant::now();
debug!("Sending message: {msg_req}");
tokio::select! {
_ = time::sleep(time::Duration::from_millis(500)) => {
return Err("timeout after 500ms waiting for the request to be sent".into());
}
res = socket.send(Message::Text(msg_req)) => {
debug!("Message sent in {:?}", now.elapsed());
if let Err(err) = res {
return Err(format!("Error sending the message: {}", err).into());
}
}
}
Ok(())
}
pub async fn ws_recv_msg(socket: &mut WsStream) -> Result<serde_json::Value, Box<dyn Error>> {
ws_recv_msg_with_fmt(socket, Format::Json).await
}
/// When testing Live Queries, we may receive multiple messages unordered.
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
pub async fn ws_recv_all_msgs(
socket: &mut WsStream,
expected: usize,
timeout: Duration,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let mut res = Vec::new();
let deadline = time::Instant::now() + timeout;
loop {
tokio::select! {
_ = time::sleep_until(deadline) => {
debug!("Waited for {:?} and received {} messages", timeout, res.len());
if res.len() != expected {
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
}
}
msg = ws_recv_msg(socket) => {
res.push(msg?);
}
}
if res.len() == expected {
return Ok(res);
}
}
}
pub async fn ws_send_msg_and_wait_response(
socket: &mut WsStream,
msg_req: String,
) -> Result<serde_json::Value, Box<dyn Error>> {
ws_send_msg(socket, msg_req).await?;
ws_recv_msg_with_fmt(socket, Format::Json).await
}
pub enum Format {
Json,
Cbor,
Pack,
}
pub async fn ws_recv_msg_with_fmt(
socket: &mut WsStream,
format: Format,
) -> Result<serde_json::Value, Box<dyn Error>> {
let now = time::Instant::now();
debug!("Waiting for response...");
// Parse and return response
let mut f = socket.try_filter(|msg| match format {
Format::Json => futures_util::future::ready(msg.is_text()),
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
});
tokio::select! {
_ = time::sleep(time::Duration::from_millis(5000)) => {
Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
}
res = f.select_next_some() => {
debug!("Response received in {:?}", now.elapsed());
match format {
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
}
}
}
}
#[derive(Serialize, Deserialize)]
struct SigninParams<'a> {
user: &'a str,
pass: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
sc: Option<&'a str>,
}
#[derive(Serialize, Deserialize)]
struct UseParams<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
}
pub async fn ws_signin(
socket: &mut WsStream,
user: &str,
pass: &str,
ns: Option<&str>,
db: Option<&str>,
sc: Option<&str>,
) -> Result<String, Box<dyn Error>> {
let request_id = uuid::Uuid::new_v4().to_string().replace('-', "");
let json = json!({
"id": request_id,
"method": "signin",
"params": [
SigninParams { user, pass, ns, db, sc }
],
});
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
let msg = ws_recv_msg(socket).await?;
debug!("ws_query result json={json:?} msg={msg:?}");
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_str()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_query(
socket: &mut WsStream,
query: &str,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "query",
"params": [query],
});
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
let msg = ws_recv_msg(socket).await?;
debug!("ws_query result json={json:?} msg={msg:?}");
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_array()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_use(
socket: &mut WsStream,
ns: Option<&str>,
db: Option<&str>,
) -> Result<serde_json::Value, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "use",
"params": [
ns, db
],
});
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
let msg = ws_recv_msg(socket).await?;
debug!("ws_query result json={json:?} msg={msg:?}");
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!(
"expected a result from the received object, got this instead: {:?}",
obj
),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
/// Check if the given message is a successful notification from LQ. /// Check if the given message is a successful notification from LQ.
pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool { pub fn is_notification(msg: &serde_json::Value) -> bool {
// Example of LQ notification: // Example of LQ notification:
// //
// Object {"result": Object {"action": String("CREATE"), "id": String("04460f07-b0e1-4339-92db-049a94aeec10"), "result": Object {"id": String("table_FD40A9A361884C56B5908A934164884A:⟨an-id-goes-here⟩"), "name": String("ok")}}} // Object {"result": Object {"action": String("CREATE"), "id": String("04460f07-b0e1-4339-92db-049a94aeec10"), "result": Object {"id": String("table_FD40A9A361884C56B5908A934164884A:⟨an-id-goes-here⟩"), "name": String("ok")}}}
@ -497,7 +25,7 @@ pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool {
} }
/// Check if the given message is a notification from LQ and comes from the given LQ ID. /// Check if the given message is a notification from LQ and comes from the given LQ ID.
pub fn ws_msg_is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool { pub fn is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool {
ws_msg_is_notification(msg) is_notification(msg)
&& msg["result"].as_object().unwrap().get("id").unwrap().as_str() == Some(id) && msg["result"].as_object().unwrap().get("id").unwrap().as_str() == Some(id)
} }

242
tests/common/server.rs Normal file
View file

@ -0,0 +1,242 @@
use rand::{thread_rng, Rng};
use std::error::Error;
use std::fs::File;
use std::path::Path;
use std::process::{Command, Stdio};
use std::{env, fs};
use tokio::time;
use tracing::{debug, error, info};
pub const USER: &str = "root";
pub const PASS: &str = "root";
pub const NS: &str = "testns";
pub const DB: &str = "testdb";
/// Child is a (maybe running) CLI process. It can be killed by dropping it
pub struct Child {
inner: Option<std::process::Child>,
stdout_path: String,
stderr_path: String,
}
impl Child {
/// Send some thing to the child's stdin
pub fn input(mut self, input: &str) -> Self {
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
use std::io::Write;
stdin.write_all(input.as_bytes()).unwrap();
self
}
pub fn kill(mut self) -> Self {
self.inner.as_mut().unwrap().kill().unwrap();
self
}
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
nix::sys::signal::kill(
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
signal,
)
}
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
self.inner.as_mut().unwrap().try_wait()
}
pub fn stdout(&self) -> String {
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
}
pub fn stderr(&self) -> String {
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
}
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
pub fn output(mut self) -> Result<String, String> {
let status = self.inner.take().unwrap().wait().unwrap();
let mut buf = self.stdout();
buf.push_str(&self.stderr());
// Cleanup files after reading them
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
if status.success() {
Ok(buf)
} else {
Err(buf)
}
}
}
impl Drop for Child {
fn drop(&mut self) {
if let Some(inner) = self.inner.as_mut() {
let _ = inner.kill();
}
}
}
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
let mut path = std::env::current_exe().unwrap();
assert!(path.pop());
if path.ends_with("deps") {
assert!(path.pop());
}
// Note: Cargo automatically builds this binary for integration tests.
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
let mut cmd = Command::new(path);
if let Some(dir) = current_dir {
cmd.current_dir(&dir);
}
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
let stdout_path = tmp_file("server-stdout.log");
let stderr_path = tmp_file("server-stderr.log");
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
cmd.env_clear();
cmd.stdin(Stdio::piped());
cmd.stdout(stdout);
cmd.stderr(stderr);
cmd.args(args.split_ascii_whitespace());
Child {
inner: Some(cmd.spawn().unwrap()),
stdout_path,
stderr_path,
}
}
/// Run the CLI with the given args
pub fn run(args: &str) -> Child {
run_internal::<String>(args, None)
}
/// Run the CLI with the given args inside a temporary directory
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir))
}
pub fn tmp_file(name: &str) -> String {
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
path.to_string_lossy().into_owned()
}
pub struct StartServerArguments {
pub auth: bool,
pub tls: bool,
pub wait_is_ready: bool,
pub enable_auth_level: bool,
pub tick_interval: time::Duration,
pub args: String,
}
impl Default for StartServerArguments {
fn default() -> Self {
Self {
auth: true,
tls: false,
wait_is_ready: true,
enable_auth_level: false,
tick_interval: time::Duration::new(1, 0),
args: "--allow-all".to_string(),
}
}
}
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
auth: false,
..Default::default()
})
.await
}
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments {
enable_auth_level: true,
..Default::default()
})
.await
}
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
start_server(StartServerArguments::default()).await
}
pub async fn start_server(
StartServerArguments {
auth,
tls,
wait_is_ready,
enable_auth_level,
tick_interval,
args,
}: StartServerArguments,
) -> Result<(String, Child), Box<dyn Error>> {
let mut rng = thread_rng();
let port: u16 = rng.gen_range(13000..14000);
let addr = format!("127.0.0.1:{port}");
let mut extra_args = args.clone();
if tls {
// Test the crt/key args but the keys are self signed so don't actually connect.
let crt_path = tmp_file("crt.crt");
let key_path = tmp_file("key.pem");
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
}
if auth {
extra_args.push_str(" --auth");
}
if enable_auth_level {
extra_args.push_str(" --auth-level-enabled");
}
if !tick_interval.is_zero() {
let sec = tick_interval.as_secs();
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
}
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
info!("starting server with args: {start_args}");
// Configure where the logs go when running the test
let server = run_internal::<String>(&start_args, None);
if !wait_is_ready {
return Ok((addr, server));
}
// Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(1000));
info!("Waiting for server to start...");
for _i in 0..10 {
interval.tick().await;
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
info!("Server ready!");
return Ok((addr, server));
}
}
let server_out = server.kill().output().err().unwrap();
error!("server output: {server_out}");
Err("server failed to start".into())
}

288
tests/common/socket.rs Normal file
View file

@ -0,0 +1,288 @@
use super::format::Format;
use crate::common::error::TestError;
use futures_util::{SinkExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::error::Error;
use std::time::Duration;
use surrealdb::sql::Value;
use tokio::net::TcpStream;
use tokio::time;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tracing::{debug, error};
#[derive(Serialize, Deserialize)]
struct UseParams<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
}
#[derive(Serialize, Deserialize)]
struct SigninParams<'a> {
user: &'a str,
pass: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
sc: Option<&'a str>,
}
pub struct Socket {
pub stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
}
// pub struct Socket(pub WebSocketStream<MaybeTlsStream<TcpStream>>);
impl Socket {
/// Close the connection with the WebSocket server
pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
Ok(self.stream.close(None).await?)
}
/// Connect to a WebSocket server using a specific format
pub async fn connect(addr: &str, format: Option<Format>) -> Result<Self, Box<dyn Error>> {
let url = format!("ws://{}/rpc", addr);
let mut req = url.into_client_request().unwrap();
if let Some(v) = format.map(|v| v.to_string()) {
req.headers_mut().insert("Sec-WebSocket-Protocol", v.parse().unwrap());
}
let (stream, _) = connect_async(req).await?;
Ok(Self {
stream,
})
}
/// Send a text or binary message to the WebSocket server
pub async fn send_message(
&mut self,
format: Format,
message: serde_json::Value,
) -> Result<(), Box<dyn Error>> {
let now = time::Instant::now();
debug!("Sending message: {message}");
// Format the message
let msg = match format {
Format::Json => Message::Text(serde_json::to_string(&message)?),
};
// Send the message
tokio::select! {
_ = time::sleep(time::Duration::from_millis(500)) => {
return Err("timeout after 500ms waiting for the request to be sent".into());
}
res = self.stream.send(msg) => {
debug!("Message sent in {:?}", now.elapsed());
if let Err(err) = res {
return Err(format!("Error sending the message: {}", err).into());
}
}
}
Ok(())
}
/// Receive a text or binary message from the WebSocket server
pub async fn receive_message(
&mut self,
format: Format,
) -> Result<serde_json::Value, Box<dyn Error>> {
let now = time::Instant::now();
debug!("Receiving response...");
loop {
tokio::select! {
_ = time::sleep(time::Duration::from_millis(5000)) => {
return Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
}
res = self.stream.try_next() => {
match res {
Ok(res) => match res {
Some(Message::Text(msg)) => {
debug!("Response {msg:?} received in {:?}", now.elapsed());
match format {
Format::Json => {
let msg = serde_json::from_str(&msg)?;
debug!("Received message: {msg}");
return Ok(msg);
},
}
},
Some(_) => {
continue;
}
None => {
return Err("Expected to receive a message".to_string().into());
}
},
Err(err) => {
return Err(format!("Error receiving the message: {}", err).into());
}
}
}
}
}
}
/// Send a text or binary message and receive a reponse from the WebSocket server
pub async fn send_and_receive_message(
&mut self,
format: Format,
message: serde_json::Value,
) -> Result<serde_json::Value, Box<dyn Error>> {
self.send_message(format, message).await?;
self.receive_message(format).await
}
/// When testing Live Queries, we may receive multiple messages unordered.
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
pub async fn receive_all_messages(
&mut self,
format: Format,
expected: usize,
timeout: Duration,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let mut res = Vec::new();
let deadline = time::Instant::now() + timeout;
loop {
tokio::select! {
_ = time::sleep_until(deadline) => {
debug!("Waited for {:?} and received {} messages", timeout, res.len());
if res.len() != expected {
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
}
}
msg = self.receive_message(format) => {
res.push(msg?);
}
}
if res.len() == expected {
return Ok(res);
}
}
}
/// Send a USE message to the server and check the response
pub async fn send_message_use(
&mut self,
format: Format,
ns: Option<&str>,
db: Option<&str>,
) -> Result<serde_json::Value, Box<dyn Error>> {
// Generate an ID
let id = uuid::Uuid::new_v4().to_string();
// Construct message
let msg = json!({
"id": id,
"method": "use",
"params": [
ns, db
],
});
// Send message and receive response
let msg = self.send_and_receive_message(format, msg).await?;
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!(
"expected a result from the received object, got this instead: {:?}",
obj
),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
/// Send a generic query message to the server and check the response
pub async fn send_message_query(
&mut self,
format: Format,
query: &str,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
// Generate an ID
let id = uuid::Uuid::new_v4().to_string();
// Construct message
let msg = json!({
"id": id,
"method": "query",
"params": [query],
});
// Send message and receive response
let msg = self.send_and_receive_message(format, msg).await?;
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_array()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
/// Send a signin authentication query message to the server and check the response
pub async fn send_message_signin(
&mut self,
format: Format,
user: &str,
pass: &str,
ns: Option<&str>,
db: Option<&str>,
sc: Option<&str>,
) -> Result<String, Box<dyn Error>> {
// Generate an ID
let id = uuid::Uuid::new_v4().to_string();
// Construct message
let msg = json!({
"id": id,
"method": "signin",
"params": [
SigninParams { user, pass, ns, db, sc }
],
});
// Send message and receive response
let msg = self.send_and_receive_message(format, msg).await?;
// Check response message structure
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
.get("result")
.ok_or(TestError::AssertionError {
message: format!("expected a result from the received object, got this instead: {:?}", obj),
})?
.as_str()
.ok_or(TestError::AssertionError {
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
})?
.to_owned()),
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
}

1177
tests/common/tests.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -1258,8 +1258,8 @@ mod http_integration {
.send() .send()
.await?; .await?;
assert_eq!(res.status(), 200); assert_eq!(res.status(), 200);
let res = res.bytes().await?.to_vec();
let _: serde_cbor::Value = serde_cbor::from_slice(&res.bytes().await?).unwrap(); let _: ciborium::Value = ciborium::from_reader(res.as_slice()).unwrap();
} }
// Creating a record with Accept PACK encoding is allowed // Creating a record with Accept PACK encoding is allowed
@ -1272,8 +1272,8 @@ mod http_integration {
.send() .send()
.await?; .await?;
assert_eq!(res.status(), 200); assert_eq!(res.status(), 200);
let res = res.bytes().await?.to_vec();
let _: serde_cbor::Value = serde_pack::from_slice(&res.bytes().await?).unwrap(); let _: rmpv::Value = rmpv::decode::read_value(&mut res.as_slice()).unwrap();
} }
// Creating a record with Accept Surrealdb encoding is allowed // Creating a record with Accept Surrealdb encoding is allowed

File diff suppressed because it is too large Load diff