Improve WebSocket protocol implementation (#3291)
This commit is contained in:
parent
8e9bd3a2d6
commit
a35fc0d04d
31 changed files with 2300 additions and 2338 deletions
23
Cargo.lock
generated
23
Cargo.lock
generated
|
@ -4445,6 +4445,16 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmpv"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e0e0214a4a2b444ecce41a4025792fc31f77c7bb89c46d253953ea8c65701ec"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rmp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "roaring"
|
||||
version = "0.10.2"
|
||||
|
@ -4817,16 +4827,6 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_cbor"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
|
||||
dependencies = [
|
||||
"half",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.193"
|
||||
|
@ -5235,6 +5235,7 @@ dependencies = [
|
|||
"axum-server",
|
||||
"base64 0.21.5",
|
||||
"bytes",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"env_logger",
|
||||
"futures",
|
||||
|
@ -5257,10 +5258,10 @@ dependencies = [
|
|||
"rcgen",
|
||||
"reqwest",
|
||||
"rmp-serde",
|
||||
"rmpv",
|
||||
"rustyline",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_cbor",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
"surrealdb",
|
||||
|
|
|
@ -40,6 +40,7 @@ axum-extra = { version = "0.7.7", features = ["query", "typed-routing"] }
|
|||
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
|
||||
base64 = "0.21.5"
|
||||
bytes = "1.5.0"
|
||||
ciborium = "0.2.1"
|
||||
clap = { version = "4.4.11", features = ["env", "derive", "wrap_help", "unicode"] }
|
||||
futures = "0.3.29"
|
||||
futures-util = "0.3.29"
|
||||
|
@ -55,9 +56,9 @@ opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
|
|||
pin-project-lite = "0.2.13"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.22", default-features = false, features = ["blocking", "gzip"] }
|
||||
rmpv = "1.0.1"
|
||||
rustyline = { version = "12.0.0", features = ["derive"] }
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
serde_cbor = "0.11.2"
|
||||
serde_json = "1.0.108"
|
||||
serde_pack = { version = "1.1.2", package = "rmp-serde" }
|
||||
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }
|
||||
|
|
|
@ -34,13 +34,13 @@ env = { RUST_LOG={ value = "http_integration=debug", condition = { env_not_set =
|
|||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"]
|
||||
|
||||
[tasks.ci-ws-integration]
|
||||
category = "CI - INTEGRATION TESTS"
|
||||
category = "WS - INTEGRATION TESTS"
|
||||
command = "cargo"
|
||||
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
|
||||
|
||||
[tasks.ci-ml-integration]
|
||||
category = "CI - INTEGRATION TESTS"
|
||||
category = "ML - INTEGRATION TESTS"
|
||||
command = "cargo"
|
||||
env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]
|
||||
|
|
|
@ -130,6 +130,7 @@ pub fn thing((arg1, arg2): (Value, Option<Value>)) -> Result<Value, Error> {
|
|||
|
||||
pub mod is {
|
||||
use crate::err::Error;
|
||||
use crate::sql::table::Table;
|
||||
use crate::sql::value::Value;
|
||||
use crate::sql::Geometry;
|
||||
|
||||
|
@ -215,7 +216,7 @@ pub mod is {
|
|||
|
||||
pub fn record((arg, table): (Value, Option<String>)) -> Result<Value, Error> {
|
||||
Ok(match table {
|
||||
Some(tb) => arg.is_record_of_table(tb).into(),
|
||||
Some(tb) => arg.is_record_type(&[Table(tb)]).into(),
|
||||
None => arg.is_record().into(),
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::sql::{escape::escape_rid, Array, Number, Object, Strand, Thing, Uuid,
|
|||
use nanoid::nanoid;
|
||||
use revision::revisioned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
use ulid::Ulid;
|
||||
|
||||
|
@ -106,6 +107,12 @@ impl From<Vec<Value>> for Id {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<BTreeMap<String, Value>> for Id {
|
||||
fn from(v: BTreeMap<String, Value>) -> Self {
|
||||
Id::Object(v.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Number> for Id {
|
||||
fn from(v: Number) -> Self {
|
||||
match v {
|
||||
|
|
|
@ -23,6 +23,12 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Object";
|
|||
#[revisioned(revision = 1)]
|
||||
pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap<String, Value>);
|
||||
|
||||
impl From<BTreeMap<&str, Value>> for Object {
|
||||
fn from(v: BTreeMap<&str, Value>) -> Self {
|
||||
Self(v.into_iter().map(|(key, val)| (key.to_string(), val)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BTreeMap<String, Value>> for Object {
|
||||
fn from(v: BTreeMap<String, Value>) -> Self {
|
||||
Self(v)
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::ctx::Context;
|
|||
use crate::dbs::{Options, Transaction};
|
||||
use crate::doc::CursorDoc;
|
||||
use crate::err::Error;
|
||||
use crate::sql::Uuid;
|
||||
use crate::sql::Value;
|
||||
use derive::Store;
|
||||
use revision::revisioned;
|
||||
|
@ -34,6 +35,14 @@ impl KillStatement {
|
|||
Value::Uuid(id) => *id,
|
||||
Value::Param(param) => match param.compute(ctx, opt, txn, None).await? {
|
||||
Value::Uuid(id) => id,
|
||||
Value::Strand(id) => match uuid::Uuid::try_parse(&id) {
|
||||
Ok(id) => Uuid(id),
|
||||
_ => {
|
||||
return Err(Error::KillStatement {
|
||||
value: self.id.to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
return Err(Error::KillStatement {
|
||||
value: self.id.to_string(),
|
||||
|
|
|
@ -458,6 +458,12 @@ impl From<Vec<bool>> for Value {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<HashMap<&str, Value>> for Value {
|
||||
fn from(v: HashMap<&str, Value>) -> Self {
|
||||
Value::Object(Object::from(v))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HashMap<String, Value>> for Value {
|
||||
fn from(v: HashMap<String, Value>) -> Self {
|
||||
Value::Object(Object::from(v))
|
||||
|
@ -470,6 +476,12 @@ impl From<BTreeMap<String, Value>> for Value {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<BTreeMap<&str, Value>> for Value {
|
||||
fn from(v: BTreeMap<&str, Value>) -> Self {
|
||||
Value::Object(Object::from(v))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Option<Value>> for Value {
|
||||
fn from(v: Option<Value>) -> Self {
|
||||
match v {
|
||||
|
@ -730,6 +742,18 @@ impl TryFrom<Value> for Object {
|
|||
}
|
||||
}
|
||||
|
||||
impl FromIterator<Value> for Value {
|
||||
fn from_iter<I: IntoIterator<Item = Value>>(iter: I) -> Self {
|
||||
Value::Array(Array(iter.into_iter().collect()))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<(String, Value)> for Value {
|
||||
fn from_iter<I: IntoIterator<Item = (String, Value)>>(iter: I) -> Self {
|
||||
Value::Object(Object(iter.into_iter().collect()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Value {
|
||||
// -----------------------------------
|
||||
// Initial record value
|
||||
|
@ -828,6 +852,11 @@ impl Value {
|
|||
matches!(self, Value::Mock(_))
|
||||
}
|
||||
|
||||
/// Check if this Value is a Param
|
||||
pub fn is_param(&self) -> bool {
|
||||
matches!(self, Value::Param(_))
|
||||
}
|
||||
|
||||
/// Check if this Value is a Range
|
||||
pub fn is_range(&self) -> bool {
|
||||
matches!(self, Value::Range(_))
|
||||
|
@ -952,11 +981,6 @@ impl Value {
|
|||
}
|
||||
}
|
||||
|
||||
/// Check if this Value is a Param
|
||||
pub fn is_param(&self) -> bool {
|
||||
matches!(self, Value::Param(_))
|
||||
}
|
||||
|
||||
/// Check if this Value is a Geometry of a specific type
|
||||
pub fn is_geometry_type(&self, types: &[String]) -> bool {
|
||||
match self {
|
||||
|
@ -1089,7 +1113,7 @@ impl Value {
|
|||
/// Treat a string as a table name
|
||||
pub fn could_be_table(self) -> Value {
|
||||
match self {
|
||||
Value::Strand(v) => Table::from(v.0).into(),
|
||||
Value::Strand(v) => Value::Table(v.0.into()),
|
||||
_ => self,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,9 +7,6 @@ use base64::DecodeError as Base64Error;
|
|||
use http::{HeaderName, StatusCode};
|
||||
use reqwest::Error as ReqwestError;
|
||||
use serde::Serialize;
|
||||
use serde_cbor::error::Error as CborError;
|
||||
use serde_json::error::Error as JsonError;
|
||||
use serde_pack::encode::Error as PackError;
|
||||
use std::io::Error as IoError;
|
||||
use std::string::FromUtf8Error as Utf8Error;
|
||||
use surrealdb::error::Db as SurrealDbError;
|
||||
|
@ -37,12 +34,12 @@ pub enum Error {
|
|||
#[error("There was a problem connecting with the storage engine")]
|
||||
InvalidStorage,
|
||||
|
||||
#[error("There was a problem parsing the header {0}: {1}")]
|
||||
InvalidHeader(HeaderName, TypedHeaderRejection),
|
||||
|
||||
#[error("The operation is unsupported")]
|
||||
OperationUnsupported,
|
||||
|
||||
#[error("There was a problem parsing the header {0}: {1}")]
|
||||
InvalidHeader(HeaderName, TypedHeaderRejection),
|
||||
|
||||
#[error("There was a problem with the database: {0}")]
|
||||
Db(#[from] SurrealError),
|
||||
|
||||
|
@ -52,14 +49,14 @@ pub enum Error {
|
|||
#[error("There was an error with the network: {0}")]
|
||||
Axum(#[from] AxumError),
|
||||
|
||||
#[error("There was an error serializing to JSON: {0}")]
|
||||
Json(#[from] JsonError),
|
||||
#[error("There was an error with JSON serialization: {0}")]
|
||||
Json(String),
|
||||
|
||||
#[error("There was an error serializing to CBOR: {0}")]
|
||||
Cbor(#[from] CborError),
|
||||
#[error("There was an error with CBOR serialization: {0}")]
|
||||
Cbor(String),
|
||||
|
||||
#[error("There was an error serializing to MessagePack: {0}")]
|
||||
Pack(#[from] PackError),
|
||||
#[error("There was an error with MessagePack serialization: {0}")]
|
||||
Pack(String),
|
||||
|
||||
#[error("There was an error with the remote request: {0}")]
|
||||
Remote(#[from] ReqwestError),
|
||||
|
@ -93,6 +90,42 @@ impl From<Utf8Error> for Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(e: serde_json::Error) -> Error {
|
||||
Error::Json(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_pack::encode::Error> for Error {
|
||||
fn from(e: serde_pack::encode::Error) -> Error {
|
||||
Error::Pack(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_pack::decode::Error> for Error {
|
||||
fn from(e: serde_pack::decode::Error) -> Error {
|
||||
Error::Pack(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ciborium::value::Error> for Error {
|
||||
fn from(e: ciborium::value::Error) -> Error {
|
||||
Error::Cbor(format!("{e}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Debug> From<ciborium::de::Error<T>> for Error {
|
||||
fn from(e: ciborium::de::Error<T>) -> Error {
|
||||
Error::Cbor(format!("{e}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: std::fmt::Debug> From<ciborium::ser::Error<T>> for Error {
|
||||
fn from(e: ciborium::ser::Error<T>) -> Error {
|
||||
Error::Cbor(format!("{e}"))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<surrealdb::error::Db> for Error {
|
||||
fn from(error: surrealdb::error::Db) -> Error {
|
||||
if matches!(error, surrealdb::error::Db::InvalidAuth) {
|
||||
|
|
|
@ -39,8 +39,9 @@ pub fn cbor<T>(val: &T) -> Output
|
|||
where
|
||||
T: Serialize,
|
||||
{
|
||||
match serde_cbor::to_vec(val) {
|
||||
Ok(v) => Output::Cbor(v),
|
||||
let mut out = Vec::new();
|
||||
match ciborium::into_writer(&val, &mut out) {
|
||||
Ok(_) => Output::Cbor(out),
|
||||
Err(_) => Output::Fail,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
use crate::cnf;
|
||||
use crate::err::Error;
|
||||
use crate::rpc::connection::Connection;
|
||||
use crate::rpc::format::Format;
|
||||
use crate::rpc::format::PROTOCOLS;
|
||||
use crate::rpc::WEBSOCKETS;
|
||||
use axum::routing::get;
|
||||
use axum::Extension;
|
||||
use axum::Router;
|
||||
use axum::{
|
||||
extract::ws::{WebSocket, WebSocketUpgrade},
|
||||
response::IntoResponse,
|
||||
Extension, Router,
|
||||
};
|
||||
use http::HeaderValue;
|
||||
use http_body::Body as HttpBody;
|
||||
use surrealdb::dbs::Session;
|
||||
use tower_http::request_id::RequestId;
|
||||
use uuid::Uuid;
|
||||
|
||||
use axum::{
|
||||
extract::ws::{WebSocket, WebSocketUpgrade},
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
|
@ -23,28 +26,53 @@ where
|
|||
|
||||
async fn handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(id): Extension<RequestId>,
|
||||
Extension(sess): Extension<Session>,
|
||||
Extension(req_id): Extension<RequestId>,
|
||||
) -> impl IntoResponse {
|
||||
ws
|
||||
// Set the maximum frame size
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Check if there is a request id header specified
|
||||
let id = match id.header_value().is_empty() {
|
||||
// No request id was specified so create a new id
|
||||
true => Uuid::new_v4(),
|
||||
// A request id was specified to try to parse it
|
||||
false => match id.header_value().to_str() {
|
||||
// Attempt to parse the request id as a UUID
|
||||
Ok(id) => match Uuid::try_parse(id) {
|
||||
// The specified request id was a valid UUID
|
||||
Ok(id) => id,
|
||||
// The specified request id was not a UUID
|
||||
Err(_) => return Err(Error::Request),
|
||||
},
|
||||
// The request id contained invalid characters
|
||||
Err(_) => return Err(Error::Request),
|
||||
},
|
||||
};
|
||||
// Check if a connection with this id already exists
|
||||
if WEBSOCKETS.read().await.contains_key(&id) {
|
||||
return Err(Error::Request);
|
||||
}
|
||||
// Now let's upgrade the WebSocket connection
|
||||
Ok(ws
|
||||
// Set the potential WebSocket protocols
|
||||
.protocols(PROTOCOLS)
|
||||
// Set the maximum WebSocket frame size
|
||||
.max_frame_size(*cnf::WEBSOCKET_MAX_FRAME_SIZE)
|
||||
// Set the maximum message size
|
||||
// Set the maximum WebSocket message size
|
||||
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
|
||||
// Set the potential WebSocket protocol formats
|
||||
.protocols(["surrealql-binary", "json", "cbor", "messagepack"])
|
||||
// Handle the WebSocket upgrade and process messages
|
||||
.on_upgrade(move |socket| handle_socket(socket, sess, req_id))
|
||||
.on_upgrade(move |socket| handle_socket(socket, sess, id)))
|
||||
}
|
||||
|
||||
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
|
||||
async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
|
||||
// Check if there is a WebSocket protocol specified
|
||||
let format = match ws.protocol().map(HeaderValue::to_str) {
|
||||
// Any selected protocol will always be a valie value
|
||||
Some(protocol) => protocol.unwrap().into(),
|
||||
// No protocol format was specified
|
||||
_ => Format::None,
|
||||
};
|
||||
//
|
||||
// Create a new connection instance
|
||||
let rpc = Connection::new(sess);
|
||||
// Update the WebSocket ID with the Request ID
|
||||
if let Ok(Ok(req_id)) = req_id.header_value().to_str().map(Uuid::parse_str) {
|
||||
// If the ID couldn't be updated, ignore the error and keep the default ID
|
||||
let _ = rpc.write().await.update_ws_id(req_id).await;
|
||||
}
|
||||
let rpc = Connection::new(id, sess, format);
|
||||
// Serve the socket connection requests
|
||||
Connection::serve(rpc, ws).await;
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ pub trait Take {
|
|||
impl Take for Array {
|
||||
/// Convert the array to one argument
|
||||
fn needs_one(self) -> Result<Value, ()> {
|
||||
if self.is_empty() {
|
||||
if self.len() != 1 {
|
||||
return Err(());
|
||||
}
|
||||
let mut x = self.into_iter();
|
||||
|
@ -22,7 +22,7 @@ impl Take for Array {
|
|||
}
|
||||
/// Convert the array to two arguments
|
||||
fn needs_two(self) -> Result<(Value, Value), ()> {
|
||||
if self.len() < 2 {
|
||||
if self.len() != 2 {
|
||||
return Err(());
|
||||
}
|
||||
let mut x = self.into_iter();
|
||||
|
@ -34,7 +34,7 @@ impl Take for Array {
|
|||
}
|
||||
/// Convert the array to two arguments
|
||||
fn needs_one_or_two(self) -> Result<(Value, Value), ()> {
|
||||
if self.is_empty() {
|
||||
if self.is_empty() && self.len() > 2 {
|
||||
return Err(());
|
||||
}
|
||||
let mut x = self.into_iter();
|
||||
|
@ -46,7 +46,7 @@ impl Take for Array {
|
|||
}
|
||||
/// Convert the array to three arguments
|
||||
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), ()> {
|
||||
if self.is_empty() {
|
||||
if self.is_empty() && self.len() > 3 {
|
||||
return Err(());
|
||||
}
|
||||
let mut x = self.into_iter();
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
use super::request::parse_request;
|
||||
use super::response::{failure, success, Data, Failure, IntoRpcResponse, OutputFormat};
|
||||
use crate::cnf::PKG_NAME;
|
||||
use crate::cnf::PKG_VERSION;
|
||||
use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY};
|
||||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::rpc::args::Take;
|
||||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::format::Format;
|
||||
use crate::rpc::response::{failure, success, Data, IntoRpcResponse};
|
||||
use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
|
||||
use crate::telemetry;
|
||||
use crate::telemetry::metrics::ws::RequestContext;
|
||||
|
@ -33,9 +34,9 @@ use tracing::Span;
|
|||
use uuid::Uuid;
|
||||
|
||||
pub struct Connection {
|
||||
ws_id: Uuid,
|
||||
id: Uuid,
|
||||
session: Session,
|
||||
format: OutputFormat,
|
||||
format: Format,
|
||||
vars: BTreeMap<String, Value>,
|
||||
limiter: Arc<Semaphore>,
|
||||
canceller: CancellationToken,
|
||||
|
@ -43,34 +44,20 @@ pub struct Connection {
|
|||
|
||||
impl Connection {
|
||||
/// Instantiate a new RPC
|
||||
pub fn new(mut session: Session) -> Arc<RwLock<Connection>> {
|
||||
// Create a new RPC variables store
|
||||
let vars = BTreeMap::new();
|
||||
// Set the default output format
|
||||
let format = OutputFormat::Json;
|
||||
pub fn new(id: Uuid, mut session: Session, format: Format) -> Arc<RwLock<Connection>> {
|
||||
// Enable real-time mode
|
||||
session.rt = true;
|
||||
// Create and store the RPC connection
|
||||
Arc::new(RwLock::new(Connection {
|
||||
ws_id: Uuid::new_v4(),
|
||||
id,
|
||||
session,
|
||||
format,
|
||||
vars,
|
||||
vars: BTreeMap::new(),
|
||||
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
|
||||
canceller: CancellationToken::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Update the WebSocket ID. If the ID already exists, do not update it.
|
||||
pub async fn update_ws_id(&mut self, ws_id: Uuid) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if WEBSOCKETS.read().await.contains_key(&ws_id) {
|
||||
trace!("WebSocket ID '{}' is in use by another connection. Do not update it.", &ws_id);
|
||||
return Err("websocket ID is in use".into());
|
||||
}
|
||||
self.ws_id = ws_id;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Serve the RPC endpoint
|
||||
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
|
||||
// Split the socket into send and recv
|
||||
|
@ -79,19 +66,19 @@ impl Connection {
|
|||
let (internal_sender, internal_receiver) =
|
||||
channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS);
|
||||
|
||||
let ws_id = rpc.read().await.ws_id;
|
||||
let id = rpc.read().await.id;
|
||||
|
||||
trace!("WebSocket {} connected", ws_id);
|
||||
trace!("WebSocket {} connected", id);
|
||||
|
||||
if let Err(err) = telemetry::metrics::ws::on_connect() {
|
||||
error!("Error running metrics::ws::on_connect hook: {}", err);
|
||||
}
|
||||
|
||||
// Add this WebSocket to the list
|
||||
WEBSOCKETS.write().await.insert(
|
||||
ws_id,
|
||||
WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()),
|
||||
);
|
||||
WEBSOCKETS
|
||||
.write()
|
||||
.await
|
||||
.insert(id, WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()));
|
||||
|
||||
// Spawn async tasks for the WebSocket
|
||||
let mut tasks = JoinSet::new();
|
||||
|
@ -109,15 +96,15 @@ impl Connection {
|
|||
|
||||
internal_sender.close();
|
||||
|
||||
trace!("WebSocket {} disconnected", ws_id);
|
||||
trace!("WebSocket {} disconnected", id);
|
||||
|
||||
// Remove this WebSocket from the list
|
||||
WEBSOCKETS.write().await.remove(&ws_id);
|
||||
WEBSOCKETS.write().await.remove(&id);
|
||||
|
||||
// Remove all live queries
|
||||
let mut gc = Vec::new();
|
||||
LIVE_QUERIES.write().await.retain(|key, value| {
|
||||
if value == &ws_id {
|
||||
if value == &id {
|
||||
trace!("Removing live query: {}", key);
|
||||
gc.push(*key);
|
||||
return false;
|
||||
|
@ -288,9 +275,9 @@ impl Connection {
|
|||
msg = channel.recv() => {
|
||||
if let Ok(notification) = msg {
|
||||
// Find which WebSocket the notification belongs to
|
||||
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||
if let Some(id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||
// Check to see if the WebSocket exists
|
||||
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
|
||||
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(id) {
|
||||
// Serialize the message to send
|
||||
let message = success(None, notification);
|
||||
// Get the current output format
|
||||
|
@ -309,27 +296,41 @@ impl Connection {
|
|||
/// Handle individual WebSocket messages
|
||||
async fn handle_message(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) {
|
||||
// Get the current output format
|
||||
let mut out_fmt = rpc.read().await.format;
|
||||
let mut fmt = rpc.read().await.format;
|
||||
// Prepare Span and Otel context
|
||||
let span = span_for_request(&rpc.read().await.ws_id);
|
||||
let span = span_for_request(&rpc.read().await.id);
|
||||
// Acquire concurrent request rate limiter
|
||||
let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap();
|
||||
// Calculate the length of the message
|
||||
let len = match msg {
|
||||
Message::Text(ref msg) => {
|
||||
// If no format was specified, default to JSON
|
||||
if fmt.is_none() {
|
||||
fmt = Format::Json;
|
||||
rpc.write().await.format = fmt;
|
||||
}
|
||||
// Retrieve the length of the message
|
||||
msg.len()
|
||||
}
|
||||
Message::Binary(ref msg) => {
|
||||
// If no format was specified, default to Bincode
|
||||
if fmt.is_none() {
|
||||
fmt = Format::Bincode;
|
||||
rpc.write().await.format = fmt;
|
||||
}
|
||||
// Retrieve the length of the message
|
||||
msg.len()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
// Parse the request
|
||||
async move {
|
||||
let span = Span::current();
|
||||
let req_cx = RequestContext::default();
|
||||
let otel_cx = TelemetryContext::new().with_value(req_cx.clone());
|
||||
|
||||
match parse_request(msg).await {
|
||||
// Parse the RPC request structure
|
||||
match fmt.req(msg) {
|
||||
Ok(req) => {
|
||||
if let Some(fmt) = req.out_fmt {
|
||||
if out_fmt != fmt {
|
||||
// Update the default format
|
||||
rpc.write().await.format = fmt;
|
||||
out_fmt = fmt;
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we know the method, we can update the span and create otel context
|
||||
span.record("rpc.method", &req.method);
|
||||
span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
|
||||
|
@ -338,17 +339,17 @@ impl Connection {
|
|||
req.id.clone().map(Value::as_string).unwrap_or_default(),
|
||||
);
|
||||
let otel_cx = TelemetryContext::current_with_value(
|
||||
req_cx.with_method(&req.method).with_size(req.size),
|
||||
req_cx.with_method(&req.method).with_size(len),
|
||||
);
|
||||
// Process the message
|
||||
let res =
|
||||
Connection::process_message(rpc.clone(), &req.method, req.params).await;
|
||||
// Process the response
|
||||
res.into_response(req.id).send(out_fmt, &chn).with_context(otel_cx).await
|
||||
res.into_response(req.id).send(fmt, &chn).with_context(otel_cx).await
|
||||
}
|
||||
Err(err) => {
|
||||
// Process the response
|
||||
failure(None, err).send(out_fmt, &chn).with_context(otel_cx).await
|
||||
failure(None, err).send(fmt, &chn).with_context(otel_cx).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -485,13 +486,6 @@ impl Connection {
|
|||
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Specify the output format for text requests
|
||||
"format" => match params.needs_one() {
|
||||
Ok(Value::Strand(v)) => {
|
||||
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Get the current server version
|
||||
"version" => match params.len() {
|
||||
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
||||
|
@ -515,16 +509,6 @@ impl Connection {
|
|||
// Methods for authentication
|
||||
// ------------------------------
|
||||
|
||||
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
||||
match out.as_str() {
|
||||
"json" | "application/json" => self.format = OutputFormat::Json,
|
||||
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
|
||||
"pack" | "application/pack" => self.format = OutputFormat::Pack,
|
||||
_ => return Err(Error::InvalidType),
|
||||
};
|
||||
Ok(Value::None)
|
||||
}
|
||||
|
||||
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
||||
if let Value::Strand(ns) = ns {
|
||||
self.session.ns = Some(ns.0);
|
||||
|
@ -615,7 +599,7 @@ impl Connection {
|
|||
let sql = "KILL $id";
|
||||
// Specify the query parameters
|
||||
let var = map! {
|
||||
String::from("id") => id, // NOTE: id can be parameter
|
||||
String::from("id") => id,
|
||||
=> &self.vars
|
||||
};
|
||||
// Execute the query on the database
|
||||
|
@ -910,15 +894,14 @@ impl Connection {
|
|||
QueryType::Live => {
|
||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||
// Match on Uuid type
|
||||
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
|
||||
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
|
||||
LIVE_QUERIES.write().await.insert(lqid.0, self.id);
|
||||
trace!("Registered live query {} on websocket {}", lqid, self.id);
|
||||
}
|
||||
}
|
||||
QueryType::Kill => {
|
||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||
let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0);
|
||||
if let Some(ws_id) = ws_id {
|
||||
trace!("Unregistered live query {} on websocket {}", lqid, ws_id);
|
||||
if let Some(id) = LIVE_QUERIES.write().await.remove(&lqid.0) {
|
||||
trace!("Unregistered live query {} on websocket {}", lqid, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
70
src/rpc/failure.rs
Normal file
70
src/rpc/failure.rs
Normal file
|
@ -0,0 +1,70 @@
|
|||
use crate::err::Error;
|
||||
use serde::Serialize;
|
||||
use std::borrow::Cow;
|
||||
use surrealdb::sql::Value;
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct Failure {
|
||||
pub(crate) code: i64,
|
||||
pub(crate) message: Cow<'static, str>,
|
||||
}
|
||||
|
||||
impl From<&str> for Failure {
|
||||
fn from(err: &str) -> Self {
|
||||
Failure::custom(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for Failure {
|
||||
fn from(err: Error) -> Self {
|
||||
Failure::custom(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Failure> for Value {
|
||||
fn from(err: Failure) -> Self {
|
||||
map! {
|
||||
String::from("code") => Value::from(err.code),
|
||||
String::from("message") => Value::from(err.message.to_string()),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl Failure {
|
||||
pub const PARSE_ERROR: Failure = Failure {
|
||||
code: -32700,
|
||||
message: Cow::Borrowed("Parse error"),
|
||||
};
|
||||
|
||||
pub const INVALID_REQUEST: Failure = Failure {
|
||||
code: -32600,
|
||||
message: Cow::Borrowed("Invalid Request"),
|
||||
};
|
||||
|
||||
pub const METHOD_NOT_FOUND: Failure = Failure {
|
||||
code: -32601,
|
||||
message: Cow::Borrowed("Method not found"),
|
||||
};
|
||||
|
||||
pub const INVALID_PARAMS: Failure = Failure {
|
||||
code: -32602,
|
||||
message: Cow::Borrowed("Invalid params"),
|
||||
};
|
||||
|
||||
pub const INTERNAL_ERROR: Failure = Failure {
|
||||
code: -32603,
|
||||
message: Cow::Borrowed("Internal error"),
|
||||
};
|
||||
|
||||
pub fn custom<S>(message: S) -> Failure
|
||||
where
|
||||
Cow<'static, str>: From<S>,
|
||||
{
|
||||
Failure {
|
||||
code: -32000,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
22
src/rpc/format/bincode.rs
Normal file
22
src/rpc/format/bincode.rs
Normal file
|
@ -0,0 +1,22 @@
|
|||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::request::Request;
|
||||
use crate::rpc::response::Response;
|
||||
use axum::extract::ws::Message;
|
||||
use surrealdb::sql::serde::deserialize;
|
||||
use surrealdb::sql::Value;
|
||||
|
||||
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||
match msg {
|
||||
Message::Binary(val) => {
|
||||
deserialize::<Value>(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||
}
|
||||
_ => Err(Failure::INVALID_REQUEST),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||
// Serialize the response with full internal type information
|
||||
let res = surrealdb::sql::serde::serialize(&res).unwrap();
|
||||
// Return the message length, and message as binary
|
||||
Ok((res.len(), Message::Binary(res)))
|
||||
}
|
24
src/rpc/format/cbor/mod.rs
Normal file
24
src/rpc/format/cbor/mod.rs
Normal file
|
@ -0,0 +1,24 @@
|
|||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::request::Request;
|
||||
use crate::rpc::response::Response;
|
||||
use axum::extract::ws::Message;
|
||||
|
||||
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||
match msg {
|
||||
Message::Text(val) => {
|
||||
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||
}
|
||||
_ => Err(Failure::INVALID_REQUEST),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||
// Convert the response into simplified JSON
|
||||
let val = res.into_json();
|
||||
// Create a new vector for encoding output
|
||||
let mut res = Vec::new();
|
||||
// Serialize the value into CBOR binary data
|
||||
ciborium::into_writer(&val, &mut res).unwrap();
|
||||
// Return the message length, and message as binary
|
||||
Ok((res.len(), Message::Binary(res)))
|
||||
}
|
22
src/rpc/format/json.rs
Normal file
22
src/rpc/format/json.rs
Normal file
|
@ -0,0 +1,22 @@
|
|||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::request::Request;
|
||||
use crate::rpc::response::Response;
|
||||
use axum::extract::ws::Message;
|
||||
|
||||
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||
match msg {
|
||||
Message::Text(val) => {
|
||||
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||
}
|
||||
_ => Err(Failure::INVALID_REQUEST),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||
// Convert the response into simplified JSON
|
||||
let val = res.into_json();
|
||||
// Serialize the response with simplified type information
|
||||
let res = serde_json::to_string(&val).unwrap();
|
||||
// Return the message length, and message as binary
|
||||
Ok((res.len(), Message::Text(res)))
|
||||
}
|
70
src/rpc/format/mod.rs
Normal file
70
src/rpc/format/mod.rs
Normal file
|
@ -0,0 +1,70 @@
|
|||
mod bincode;
|
||||
pub mod cbor;
|
||||
mod json;
|
||||
pub mod msgpack;
|
||||
mod revision;
|
||||
|
||||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::request::Request;
|
||||
use crate::rpc::response::Response;
|
||||
use axum::extract::ws::Message;
|
||||
|
||||
pub const PROTOCOLS: [&str; 5] = [
|
||||
"json", // For basic JSON serialisation
|
||||
"cbor", // For basic CBOR serialisation
|
||||
"msgpack", // For basic Msgpack serialisation
|
||||
"bincode", // For full internal serialisation
|
||||
"revision", // For full versioned serialisation
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub enum Format {
|
||||
None, // No format is specified yet
|
||||
Json, // For basic JSON serialisation
|
||||
Cbor, // For basic CBOR serialisation
|
||||
Msgpack, // For basic Msgpack serialisation
|
||||
Bincode, // For full internal serialisation
|
||||
Revision, // For full versioned serialisation
|
||||
}
|
||||
|
||||
impl From<&str> for Format {
|
||||
fn from(v: &str) -> Self {
|
||||
match v {
|
||||
s if s == PROTOCOLS[0] => Format::Json,
|
||||
s if s == PROTOCOLS[1] => Format::Cbor,
|
||||
s if s == PROTOCOLS[2] => Format::Msgpack,
|
||||
s if s == PROTOCOLS[3] => Format::Bincode,
|
||||
s if s == PROTOCOLS[4] => Format::Revision,
|
||||
_ => Format::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Format {
|
||||
/// Check if this format has been set
|
||||
pub fn is_none(&self) -> bool {
|
||||
matches!(self, Format::None)
|
||||
}
|
||||
/// Process a request using the specified format
|
||||
pub fn req(&self, msg: Message) -> Result<Request, Failure> {
|
||||
match self {
|
||||
Self::None => unreachable!(), // We should never arrive at this code
|
||||
Self::Json => json::req(msg),
|
||||
Self::Cbor => cbor::req(msg),
|
||||
Self::Msgpack => msgpack::req(msg),
|
||||
Self::Bincode => bincode::req(msg),
|
||||
Self::Revision => revision::req(msg),
|
||||
}
|
||||
}
|
||||
/// Process a response using the specified format
|
||||
pub fn res(&self, res: Response) -> Result<(usize, Message), Failure> {
|
||||
match self {
|
||||
Self::None => unreachable!(), // We should never arrive at this code
|
||||
Self::Json => json::res(res),
|
||||
Self::Cbor => cbor::res(res),
|
||||
Self::Msgpack => msgpack::res(res),
|
||||
Self::Bincode => bincode::res(res),
|
||||
Self::Revision => revision::res(res),
|
||||
}
|
||||
}
|
||||
}
|
22
src/rpc/format/msgpack/mod.rs
Normal file
22
src/rpc/format/msgpack/mod.rs
Normal file
|
@ -0,0 +1,22 @@
|
|||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::request::Request;
|
||||
use crate::rpc::response::Response;
|
||||
use axum::extract::ws::Message;
|
||||
|
||||
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||
match msg {
|
||||
Message::Text(val) => {
|
||||
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||
}
|
||||
_ => Err(Failure::INVALID_REQUEST),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||
// Convert the response into simplified JSON
|
||||
let val = res.into_json();
|
||||
// Serialize the value into MsgPack binary data
|
||||
let res = serde_pack::to_vec(&val).unwrap();
|
||||
// Return the message length, and message as binary
|
||||
Ok((res.len(), Message::Binary(res)))
|
||||
}
|
14
src/rpc/format/revision.rs
Normal file
14
src/rpc/format/revision.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::request::Request;
|
||||
use crate::rpc::response::Response;
|
||||
use axum::extract::ws::Message;
|
||||
|
||||
pub fn req(_msg: Message) -> Result<Request, Failure> {
|
||||
// This format is not yet implemented
|
||||
Err(Failure::INTERNAL_ERROR)
|
||||
}
|
||||
|
||||
pub fn res(_res: Response) -> Result<(usize, Message), Failure> {
|
||||
// This format is not yet implemented
|
||||
Err(Failure::INTERNAL_ERROR)
|
||||
}
|
|
@ -1,12 +1,14 @@
|
|||
pub mod args;
|
||||
pub mod connection;
|
||||
pub mod failure;
|
||||
pub mod format;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use axum::extract::ws::Message;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use surrealdb::channel::Sender;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
use axum::extract::ws::Message;
|
||||
use surrealdb::sql::{serde::deserialize, Array, Value};
|
||||
|
||||
use crate::rpc::failure::Failure;
|
||||
use once_cell::sync::Lazy;
|
||||
use surrealdb::sql::Part;
|
||||
|
||||
use super::response::{Failure, OutputFormat};
|
||||
use surrealdb::sql::{Array, Value};
|
||||
|
||||
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
|
||||
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
|
||||
|
@ -14,71 +11,36 @@ pub struct Request {
|
|||
pub id: Option<Value>,
|
||||
pub method: String,
|
||||
pub params: Array,
|
||||
pub size: usize,
|
||||
pub out_fmt: Option<OutputFormat>,
|
||||
}
|
||||
|
||||
/// Parse the RPC request
|
||||
pub async fn parse_request(msg: Message) -> Result<Request, Failure> {
|
||||
let mut out_fmt = None;
|
||||
let (req, size) = match msg {
|
||||
// This is a binary message
|
||||
Message::Binary(val) => {
|
||||
// Use binary output
|
||||
out_fmt = Some(OutputFormat::Full);
|
||||
|
||||
match deserialize(&val) {
|
||||
Ok(v) => (v, val.len()),
|
||||
Err(_) => {
|
||||
debug!("Error when trying to deserialize the request");
|
||||
return Err(Failure::PARSE_ERROR);
|
||||
}
|
||||
}
|
||||
}
|
||||
// This is a text message
|
||||
Message::Text(ref val) => {
|
||||
// Parse the SurrealQL object
|
||||
match surrealdb::sql::value(val) {
|
||||
// The SurrealQL message parsed ok
|
||||
Ok(v) => (v, val.len()),
|
||||
// The SurrealQL message failed to parse
|
||||
_ => return Err(Failure::PARSE_ERROR),
|
||||
}
|
||||
}
|
||||
// Unsupported message type
|
||||
_ => {
|
||||
debug!("Unsupported message type: {:?}", msg);
|
||||
return Err(Failure::custom("Unsupported message type"));
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch the 'id' argument
|
||||
let id = match req.pick(&*ID) {
|
||||
v if v.is_none() => None,
|
||||
v if v.is_null() => Some(v),
|
||||
v if v.is_uuid() => Some(v),
|
||||
v if v.is_number() => Some(v),
|
||||
v if v.is_strand() => Some(v),
|
||||
v if v.is_datetime() => Some(v),
|
||||
_ => return Err(Failure::INVALID_REQUEST),
|
||||
};
|
||||
// Fetch the 'method' argument
|
||||
let method = match req.pick(&*METHOD) {
|
||||
Value::Strand(v) => v.to_raw(),
|
||||
_ => return Err(Failure::INVALID_REQUEST),
|
||||
};
|
||||
|
||||
// Fetch the 'params' argument
|
||||
let params = match req.pick(&*PARAMS) {
|
||||
Value::Array(v) => v,
|
||||
_ => Array::new(),
|
||||
};
|
||||
|
||||
Ok(Request {
|
||||
id,
|
||||
method,
|
||||
params,
|
||||
size,
|
||||
out_fmt,
|
||||
})
|
||||
impl TryFrom<Value> for Request {
|
||||
type Error = Failure;
|
||||
fn try_from(val: Value) -> Result<Self, Failure> {
|
||||
// Fetch the 'id' argument
|
||||
let id = match val.pick(&*ID) {
|
||||
v if v.is_none() => None,
|
||||
v if v.is_null() => Some(v),
|
||||
v if v.is_uuid() => Some(v),
|
||||
v if v.is_number() => Some(v),
|
||||
v if v.is_strand() => Some(v),
|
||||
v if v.is_datetime() => Some(v),
|
||||
_ => return Err(Failure::INVALID_REQUEST),
|
||||
};
|
||||
// Fetch the 'method' argument
|
||||
let method = match val.pick(&*METHOD) {
|
||||
Value::Strand(v) => v.to_raw(),
|
||||
_ => return Err(Failure::INVALID_REQUEST),
|
||||
};
|
||||
// Fetch the 'params' argument
|
||||
let params = match val.pick(&*PARAMS) {
|
||||
Value::Array(v) => v,
|
||||
_ => Array::new(),
|
||||
};
|
||||
// Return the parsed request
|
||||
Ok(Request {
|
||||
id,
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use crate::err;
|
||||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::format::Format;
|
||||
use crate::telemetry::metrics::ws::record_rpc;
|
||||
use axum::extract::ws::Message;
|
||||
use opentelemetry::Context as TelemetryContext;
|
||||
use serde::Serialize;
|
||||
use serde_json::{json, Value as Json};
|
||||
use std::borrow::Cow;
|
||||
use serde_json::Value as Json;
|
||||
use surrealdb::channel::Sender;
|
||||
use surrealdb::dbs;
|
||||
use surrealdb::dbs::Notification;
|
||||
|
@ -12,14 +12,6 @@ use surrealdb::sql;
|
|||
use surrealdb::sql::Value;
|
||||
use tracing::Span;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub enum OutputFormat {
|
||||
Json, // JSON
|
||||
Cbor, // CBOR
|
||||
Pack, // MessagePack
|
||||
Full, // Full type serialization
|
||||
}
|
||||
|
||||
/// The data returned by the database
|
||||
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
|
||||
// In future, they will possibly be merged to avoid having to keep them in sync.
|
||||
|
@ -46,15 +38,25 @@ impl From<String> for Data {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Notification> for Data {
|
||||
fn from(n: Notification) -> Self {
|
||||
Data::Live(n)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<dbs::Response>> for Data {
|
||||
fn from(v: Vec<dbs::Response>) -> Self {
|
||||
Data::Query(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Notification> for Data {
|
||||
fn from(n: Notification) -> Self {
|
||||
Data::Live(n)
|
||||
impl From<Data> for Value {
|
||||
fn from(val: Data) -> Self {
|
||||
match val {
|
||||
Data::Query(v) => sql::to_value(v).unwrap(),
|
||||
Data::Live(v) => sql::to_value(v).unwrap(),
|
||||
Data::Other(v) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,32 +69,31 @@ pub struct Response {
|
|||
impl Response {
|
||||
/// Convert and simplify the value into JSON
|
||||
#[inline]
|
||||
fn simplify(self) -> Json {
|
||||
pub fn into_json(self) -> Json {
|
||||
Json::from(self.into_value())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_value(self) -> Value {
|
||||
let mut value = match self.result {
|
||||
Ok(data) => {
|
||||
let value = match data {
|
||||
Data::Query(vec) => sql::to_value(vec).unwrap(),
|
||||
Data::Live(notification) => sql::to_value(notification).unwrap(),
|
||||
Data::Other(value) => value,
|
||||
};
|
||||
json!({
|
||||
"result": Json::from(value),
|
||||
})
|
||||
}
|
||||
Err(failure) => json!({
|
||||
"error": failure,
|
||||
}),
|
||||
Ok(val) => map! {
|
||||
"result" => Value::from(val),
|
||||
},
|
||||
Err(err) => map! {
|
||||
"error" => Value::from(err),
|
||||
},
|
||||
};
|
||||
if let Some(id) = self.id {
|
||||
value["id"] = id.into();
|
||||
value.insert("id", id);
|
||||
}
|
||||
value
|
||||
value.into()
|
||||
}
|
||||
|
||||
/// Send the response to the WebSocket channel
|
||||
pub async fn send(self, out: OutputFormat, chn: &Sender<Message>) {
|
||||
pub async fn send(self, fmt: Format, chn: &Sender<Message>) {
|
||||
// Create a new tracing span
|
||||
let span = Span::current();
|
||||
|
||||
// Log the rpc response call
|
||||
debug!("Process RPC response");
|
||||
|
||||
let is_error = self.result.is_err();
|
||||
|
@ -105,73 +106,12 @@ impl Response {
|
|||
span.record("rpc.error_code", err.code);
|
||||
span.record("rpc.error_message", err.message.as_ref());
|
||||
}
|
||||
|
||||
let (res_size, message) = match out {
|
||||
OutputFormat::Json => {
|
||||
let res = serde_json::to_string(&self.simplify()).unwrap();
|
||||
(res.len(), Message::Text(res))
|
||||
}
|
||||
OutputFormat::Cbor => {
|
||||
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
||||
(res.len(), Message::Binary(res))
|
||||
}
|
||||
OutputFormat::Pack => {
|
||||
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
||||
(res.len(), Message::Binary(res))
|
||||
}
|
||||
OutputFormat::Full => {
|
||||
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
||||
(res.len(), Message::Binary(res))
|
||||
}
|
||||
// Process the response for the format
|
||||
let (len, msg) = fmt.res(self).unwrap();
|
||||
// Send the message to the write channel
|
||||
if chn.send(msg).await.is_ok() {
|
||||
record_rpc(&TelemetryContext::current(), len, is_error);
|
||||
};
|
||||
|
||||
if chn.send(message).await.is_ok() {
|
||||
record_rpc(&TelemetryContext::current(), res_size, is_error);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct Failure {
|
||||
code: i64,
|
||||
message: Cow<'static, str>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl Failure {
|
||||
pub const PARSE_ERROR: Failure = Failure {
|
||||
code: -32700,
|
||||
message: Cow::Borrowed("Parse error"),
|
||||
};
|
||||
|
||||
pub const INVALID_REQUEST: Failure = Failure {
|
||||
code: -32600,
|
||||
message: Cow::Borrowed("Invalid Request"),
|
||||
};
|
||||
|
||||
pub const METHOD_NOT_FOUND: Failure = Failure {
|
||||
code: -32601,
|
||||
message: Cow::Borrowed("Method not found"),
|
||||
};
|
||||
|
||||
pub const INVALID_PARAMS: Failure = Failure {
|
||||
code: -32602,
|
||||
message: Cow::Borrowed("Invalid params"),
|
||||
};
|
||||
|
||||
pub const INTERNAL_ERROR: Failure = Failure {
|
||||
code: -32603,
|
||||
message: Cow::Borrowed("Internal error"),
|
||||
};
|
||||
|
||||
pub fn custom<S>(message: S) -> Failure
|
||||
where
|
||||
Cow<'static, str>: From<S>,
|
||||
{
|
||||
Failure {
|
||||
code: -32000,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,12 +131,6 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<err::Error> for Failure {
|
||||
fn from(err: err::Error) -> Self {
|
||||
Failure::custom(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoRpcResponse {
|
||||
fn into_response(self, id: Option<Value>) -> Response;
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ mod common;
|
|||
|
||||
mod cli_integration {
|
||||
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
||||
use common::Format;
|
||||
use common::Socket;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::time;
|
||||
|
@ -735,15 +737,13 @@ mod cli_integration {
|
|||
let (addr, mut server) = common::start_server_without_auth().await.unwrap();
|
||||
|
||||
// Create a long-lived WS connection so the server don't shutdown gracefully
|
||||
let mut socket = common::connect_ws(&addr).await.expect("Failed to connect to server");
|
||||
let mut socket = Socket::connect(&addr, None).await.expect("Failed to connect to server");
|
||||
let json = serde_json::json!({
|
||||
"id": "1",
|
||||
"method": "query",
|
||||
"params": ["SLEEP 30s;"],
|
||||
});
|
||||
common::ws_send_msg(&mut socket, serde_json::to_string(&json).unwrap())
|
||||
.await
|
||||
.expect("Failed to send WS message");
|
||||
socket.send_message(Format::Json, json).await.expect("Failed to send WS message");
|
||||
|
||||
// Make sure the SLEEP query is being executed
|
||||
tokio::select! {
|
||||
|
|
14
tests/common/format.rs
Normal file
14
tests/common/format.rs
Normal file
|
@ -0,0 +1,14 @@
|
|||
use std::string::ToString;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum Format {
|
||||
Json,
|
||||
}
|
||||
|
||||
impl ToString for Format {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Self::Json => "json".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,489 +1,17 @@
|
|||
#![allow(unused_imports)]
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub mod error;
|
||||
pub mod format;
|
||||
pub mod server;
|
||||
pub mod socket;
|
||||
|
||||
use crate::common::error::TestError;
|
||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::Duration;
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
pub const USER: &str = "root";
|
||||
pub const PASS: &str = "root";
|
||||
|
||||
/// Child is a (maybe running) CLI process. It can be killed by dropping it
|
||||
pub struct Child {
|
||||
inner: Option<std::process::Child>,
|
||||
stdout_path: String,
|
||||
stderr_path: String,
|
||||
}
|
||||
|
||||
impl Child {
|
||||
/// Send some thing to the child's stdin
|
||||
pub fn input(mut self, input: &str) -> Self {
|
||||
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
|
||||
use std::io::Write;
|
||||
stdin.write_all(input.as_bytes()).unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn kill(mut self) -> Self {
|
||||
self.inner.as_mut().unwrap().kill().unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
|
||||
nix::sys::signal::kill(
|
||||
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
|
||||
signal,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
|
||||
self.inner.as_mut().unwrap().try_wait()
|
||||
}
|
||||
|
||||
pub fn stdout(&self) -> String {
|
||||
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
|
||||
}
|
||||
|
||||
pub fn stderr(&self) -> String {
|
||||
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
|
||||
}
|
||||
|
||||
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
|
||||
/// returns successfully, Err otherwise.
|
||||
pub fn output(mut self) -> Result<String, String> {
|
||||
let status = self.inner.take().unwrap().wait().unwrap();
|
||||
|
||||
let mut buf = self.stdout();
|
||||
buf.push_str(&self.stderr());
|
||||
|
||||
// Cleanup files after reading them
|
||||
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
|
||||
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
|
||||
|
||||
if status.success() {
|
||||
Ok(buf)
|
||||
} else {
|
||||
Err(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Child {
|
||||
fn drop(&mut self) {
|
||||
if let Some(inner) = self.inner.as_mut() {
|
||||
let _ = inner.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
||||
let mut path = std::env::current_exe().unwrap();
|
||||
assert!(path.pop());
|
||||
if path.ends_with("deps") {
|
||||
assert!(path.pop());
|
||||
}
|
||||
|
||||
// Note: Cargo automatically builds this binary for integration tests.
|
||||
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
|
||||
|
||||
let mut cmd = Command::new(path);
|
||||
if let Some(dir) = current_dir {
|
||||
cmd.current_dir(&dir);
|
||||
}
|
||||
|
||||
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
|
||||
let stdout_path = tmp_file("server-stdout.log");
|
||||
let stderr_path = tmp_file("server-stderr.log");
|
||||
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
|
||||
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
|
||||
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
|
||||
|
||||
cmd.env_clear();
|
||||
cmd.stdin(Stdio::piped());
|
||||
cmd.stdout(stdout);
|
||||
cmd.stderr(stderr);
|
||||
cmd.args(args.split_ascii_whitespace());
|
||||
|
||||
Child {
|
||||
inner: Some(cmd.spawn().unwrap()),
|
||||
stdout_path,
|
||||
stderr_path,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args
|
||||
pub fn run(args: &str) -> Child {
|
||||
run_internal::<String>(args, None)
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args inside a temporary directory
|
||||
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
||||
run_internal(args, Some(current_dir))
|
||||
}
|
||||
|
||||
pub fn tmp_file(name: &str) -> String {
|
||||
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
|
||||
path.to_string_lossy().into_owned()
|
||||
}
|
||||
|
||||
pub struct StartServerArguments {
|
||||
pub auth: bool,
|
||||
pub tls: bool,
|
||||
pub wait_is_ready: bool,
|
||||
pub enable_auth_level: bool,
|
||||
pub tick_interval: time::Duration,
|
||||
pub args: String,
|
||||
}
|
||||
|
||||
impl Default for StartServerArguments {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auth: true,
|
||||
tls: false,
|
||||
wait_is_ready: true,
|
||||
enable_auth_level: false,
|
||||
tick_interval: time::Duration::new(1, 0),
|
||||
args: "--allow-all".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
|
||||
start_server(StartServerArguments {
|
||||
auth: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
|
||||
start_server(StartServerArguments {
|
||||
enable_auth_level: true,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
|
||||
start_server(StartServerArguments::default()).await
|
||||
}
|
||||
|
||||
pub async fn start_server(
|
||||
StartServerArguments {
|
||||
auth,
|
||||
tls,
|
||||
wait_is_ready,
|
||||
enable_auth_level,
|
||||
tick_interval,
|
||||
args,
|
||||
}: StartServerArguments,
|
||||
) -> Result<(String, Child), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let port: u16 = rng.gen_range(13000..14000);
|
||||
let addr = format!("127.0.0.1:{port}");
|
||||
|
||||
let mut extra_args = args.clone();
|
||||
if tls {
|
||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||
let crt_path = tmp_file("crt.crt");
|
||||
let key_path = tmp_file("key.pem");
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
|
||||
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
|
||||
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
|
||||
|
||||
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
|
||||
}
|
||||
|
||||
if auth {
|
||||
extra_args.push_str(" --auth");
|
||||
}
|
||||
|
||||
if enable_auth_level {
|
||||
extra_args.push_str(" --auth-level-enabled");
|
||||
}
|
||||
|
||||
if !tick_interval.is_zero() {
|
||||
let sec = tick_interval.as_secs();
|
||||
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
|
||||
}
|
||||
|
||||
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
|
||||
|
||||
info!("starting server with args: {start_args}");
|
||||
|
||||
// Configure where the logs go when running the test
|
||||
let server = run_internal::<String>(&start_args, None);
|
||||
|
||||
if !wait_is_ready {
|
||||
return Ok((addr, server));
|
||||
}
|
||||
|
||||
// Wait 5 seconds for the server to start
|
||||
let mut interval = time::interval(time::Duration::from_millis(1000));
|
||||
info!("Waiting for server to start...");
|
||||
for _i in 0..10 {
|
||||
interval.tick().await;
|
||||
|
||||
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
||||
info!("Server ready!");
|
||||
return Ok((addr, server));
|
||||
}
|
||||
}
|
||||
|
||||
let server_out = server.kill().output().err().unwrap();
|
||||
error!("server output: {server_out}");
|
||||
Err("server failed to start".into())
|
||||
}
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
|
||||
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
|
||||
let url = format!("ws://{}/rpc", addr);
|
||||
let (ws_stream, _) = connect_async(url).await?;
|
||||
Ok(ws_stream)
|
||||
}
|
||||
|
||||
pub async fn ws_send_msg(socket: &mut WsStream, msg_req: String) -> Result<(), Box<dyn Error>> {
|
||||
let now = time::Instant::now();
|
||||
debug!("Sending message: {msg_req}");
|
||||
tokio::select! {
|
||||
_ = time::sleep(time::Duration::from_millis(500)) => {
|
||||
return Err("timeout after 500ms waiting for the request to be sent".into());
|
||||
}
|
||||
res = socket.send(Message::Text(msg_req)) => {
|
||||
debug!("Message sent in {:?}", now.elapsed());
|
||||
if let Err(err) = res {
|
||||
return Err(format!("Error sending the message: {}", err).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn ws_recv_msg(socket: &mut WsStream) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
ws_recv_msg_with_fmt(socket, Format::Json).await
|
||||
}
|
||||
|
||||
/// When testing Live Queries, we may receive multiple messages unordered.
|
||||
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
|
||||
pub async fn ws_recv_all_msgs(
|
||||
socket: &mut WsStream,
|
||||
expected: usize,
|
||||
timeout: Duration,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||
let mut res = Vec::new();
|
||||
let deadline = time::Instant::now() + timeout;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = time::sleep_until(deadline) => {
|
||||
debug!("Waited for {:?} and received {} messages", timeout, res.len());
|
||||
if res.len() != expected {
|
||||
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
|
||||
}
|
||||
}
|
||||
msg = ws_recv_msg(socket) => {
|
||||
res.push(msg?);
|
||||
}
|
||||
}
|
||||
if res.len() == expected {
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_send_msg_and_wait_response(
|
||||
socket: &mut WsStream,
|
||||
msg_req: String,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
ws_send_msg(socket, msg_req).await?;
|
||||
ws_recv_msg_with_fmt(socket, Format::Json).await
|
||||
}
|
||||
|
||||
pub enum Format {
|
||||
Json,
|
||||
Cbor,
|
||||
Pack,
|
||||
}
|
||||
|
||||
pub async fn ws_recv_msg_with_fmt(
|
||||
socket: &mut WsStream,
|
||||
format: Format,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
let now = time::Instant::now();
|
||||
debug!("Waiting for response...");
|
||||
// Parse and return response
|
||||
let mut f = socket.try_filter(|msg| match format {
|
||||
Format::Json => futures_util::future::ready(msg.is_text()),
|
||||
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = time::sleep(time::Duration::from_millis(5000)) => {
|
||||
Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
|
||||
}
|
||||
res = f.select_next_some() => {
|
||||
debug!("Response received in {:?}", now.elapsed());
|
||||
match format {
|
||||
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
|
||||
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
|
||||
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SigninParams<'a> {
|
||||
user: &'a str,
|
||||
pass: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
ns: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
db: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
sc: Option<&'a str>,
|
||||
}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct UseParams<'a> {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
ns: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
db: Option<&'a str>,
|
||||
}
|
||||
|
||||
pub async fn ws_signin(
|
||||
socket: &mut WsStream,
|
||||
user: &str,
|
||||
pass: &str,
|
||||
ns: Option<&str>,
|
||||
db: Option<&str>,
|
||||
sc: Option<&str>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string().replace('-', "");
|
||||
let json = json!({
|
||||
"id": request_id,
|
||||
"method": "signin",
|
||||
"params": [
|
||||
SigninParams { user, pass, ns, db, sc }
|
||||
],
|
||||
});
|
||||
|
||||
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||
let msg = ws_recv_msg(socket).await?;
|
||||
debug!("ws_query result json={json:?} msg={msg:?}");
|
||||
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||
.get("result")
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
||||
})?
|
||||
.as_str()
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
||||
})?
|
||||
.to_owned()),
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_query(
|
||||
socket: &mut WsStream,
|
||||
query: &str,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||
let json = json!({
|
||||
"id": "1",
|
||||
"method": "query",
|
||||
"params": [query],
|
||||
});
|
||||
|
||||
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||
let msg = ws_recv_msg(socket).await?;
|
||||
debug!("ws_query result json={json:?} msg={msg:?}");
|
||||
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||
.get("result")
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
||||
})?
|
||||
.as_array()
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
||||
})?
|
||||
.to_owned()),
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_use(
|
||||
socket: &mut WsStream,
|
||||
ns: Option<&str>,
|
||||
db: Option<&str>,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
let json = json!({
|
||||
"id": "1",
|
||||
"method": "use",
|
||||
"params": [
|
||||
ns, db
|
||||
],
|
||||
});
|
||||
|
||||
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||
let msg = ws_recv_msg(socket).await?;
|
||||
debug!("ws_query result json={json:?} msg={msg:?}");
|
||||
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||
.get("result")
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!(
|
||||
"expected a result from the received object, got this instead: {:?}",
|
||||
obj
|
||||
),
|
||||
})?
|
||||
.to_owned()),
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
pub use format::*;
|
||||
pub use server::*;
|
||||
pub use socket::*;
|
||||
|
||||
/// Check if the given message is a successful notification from LQ.
|
||||
pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool {
|
||||
pub fn is_notification(msg: &serde_json::Value) -> bool {
|
||||
// Example of LQ notification:
|
||||
//
|
||||
// Object {"result": Object {"action": String("CREATE"), "id": String("04460f07-b0e1-4339-92db-049a94aeec10"), "result": Object {"id": String("table_FD40A9A361884C56B5908A934164884A:⟨an-id-goes-here⟩"), "name": String("ok")}}}
|
||||
|
@ -497,7 +25,7 @@ pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool {
|
|||
}
|
||||
|
||||
/// Check if the given message is a notification from LQ and comes from the given LQ ID.
|
||||
pub fn ws_msg_is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool {
|
||||
ws_msg_is_notification(msg)
|
||||
pub fn is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool {
|
||||
is_notification(msg)
|
||||
&& msg["result"].as_object().unwrap().get("id").unwrap().as_str() == Some(id)
|
||||
}
|
||||
|
|
242
tests/common/server.rs
Normal file
242
tests/common/server.rs
Normal file
|
@ -0,0 +1,242 @@
|
|||
use rand::{thread_rng, Rng};
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::{env, fs};
|
||||
use tokio::time;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
pub const USER: &str = "root";
|
||||
pub const PASS: &str = "root";
|
||||
pub const NS: &str = "testns";
|
||||
pub const DB: &str = "testdb";
|
||||
|
||||
/// Child is a (maybe running) CLI process. It can be killed by dropping it
|
||||
pub struct Child {
|
||||
inner: Option<std::process::Child>,
|
||||
stdout_path: String,
|
||||
stderr_path: String,
|
||||
}
|
||||
|
||||
impl Child {
|
||||
/// Send some thing to the child's stdin
|
||||
pub fn input(mut self, input: &str) -> Self {
|
||||
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
|
||||
use std::io::Write;
|
||||
stdin.write_all(input.as_bytes()).unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn kill(mut self) -> Self {
|
||||
self.inner.as_mut().unwrap().kill().unwrap();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
|
||||
nix::sys::signal::kill(
|
||||
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
|
||||
signal,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
|
||||
self.inner.as_mut().unwrap().try_wait()
|
||||
}
|
||||
|
||||
pub fn stdout(&self) -> String {
|
||||
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
|
||||
}
|
||||
|
||||
pub fn stderr(&self) -> String {
|
||||
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
|
||||
}
|
||||
|
||||
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
|
||||
/// returns successfully, Err otherwise.
|
||||
pub fn output(mut self) -> Result<String, String> {
|
||||
let status = self.inner.take().unwrap().wait().unwrap();
|
||||
|
||||
let mut buf = self.stdout();
|
||||
buf.push_str(&self.stderr());
|
||||
|
||||
// Cleanup files after reading them
|
||||
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
|
||||
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
|
||||
|
||||
if status.success() {
|
||||
Ok(buf)
|
||||
} else {
|
||||
Err(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Child {
|
||||
fn drop(&mut self) {
|
||||
if let Some(inner) = self.inner.as_mut() {
|
||||
let _ = inner.kill();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
||||
let mut path = std::env::current_exe().unwrap();
|
||||
assert!(path.pop());
|
||||
if path.ends_with("deps") {
|
||||
assert!(path.pop());
|
||||
}
|
||||
|
||||
// Note: Cargo automatically builds this binary for integration tests.
|
||||
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
|
||||
|
||||
let mut cmd = Command::new(path);
|
||||
if let Some(dir) = current_dir {
|
||||
cmd.current_dir(&dir);
|
||||
}
|
||||
|
||||
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
|
||||
let stdout_path = tmp_file("server-stdout.log");
|
||||
let stderr_path = tmp_file("server-stderr.log");
|
||||
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
|
||||
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
|
||||
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
|
||||
|
||||
cmd.env_clear();
|
||||
cmd.stdin(Stdio::piped());
|
||||
cmd.stdout(stdout);
|
||||
cmd.stderr(stderr);
|
||||
cmd.args(args.split_ascii_whitespace());
|
||||
|
||||
Child {
|
||||
inner: Some(cmd.spawn().unwrap()),
|
||||
stdout_path,
|
||||
stderr_path,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args
|
||||
pub fn run(args: &str) -> Child {
|
||||
run_internal::<String>(args, None)
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args inside a temporary directory
|
||||
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
||||
run_internal(args, Some(current_dir))
|
||||
}
|
||||
|
||||
pub fn tmp_file(name: &str) -> String {
|
||||
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
|
||||
path.to_string_lossy().into_owned()
|
||||
}
|
||||
|
||||
pub struct StartServerArguments {
|
||||
pub auth: bool,
|
||||
pub tls: bool,
|
||||
pub wait_is_ready: bool,
|
||||
pub enable_auth_level: bool,
|
||||
pub tick_interval: time::Duration,
|
||||
pub args: String,
|
||||
}
|
||||
|
||||
impl Default for StartServerArguments {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auth: true,
|
||||
tls: false,
|
||||
wait_is_ready: true,
|
||||
enable_auth_level: false,
|
||||
tick_interval: time::Duration::new(1, 0),
|
||||
args: "--allow-all".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
|
||||
start_server(StartServerArguments {
|
||||
auth: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
|
||||
start_server(StartServerArguments {
|
||||
enable_auth_level: true,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
|
||||
start_server(StartServerArguments::default()).await
|
||||
}
|
||||
|
||||
pub async fn start_server(
|
||||
StartServerArguments {
|
||||
auth,
|
||||
tls,
|
||||
wait_is_ready,
|
||||
enable_auth_level,
|
||||
tick_interval,
|
||||
args,
|
||||
}: StartServerArguments,
|
||||
) -> Result<(String, Child), Box<dyn Error>> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let port: u16 = rng.gen_range(13000..14000);
|
||||
let addr = format!("127.0.0.1:{port}");
|
||||
|
||||
let mut extra_args = args.clone();
|
||||
if tls {
|
||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||
let crt_path = tmp_file("crt.crt");
|
||||
let key_path = tmp_file("key.pem");
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
|
||||
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
|
||||
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
|
||||
|
||||
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
|
||||
}
|
||||
|
||||
if auth {
|
||||
extra_args.push_str(" --auth");
|
||||
}
|
||||
|
||||
if enable_auth_level {
|
||||
extra_args.push_str(" --auth-level-enabled");
|
||||
}
|
||||
|
||||
if !tick_interval.is_zero() {
|
||||
let sec = tick_interval.as_secs();
|
||||
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
|
||||
}
|
||||
|
||||
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
|
||||
|
||||
info!("starting server with args: {start_args}");
|
||||
|
||||
// Configure where the logs go when running the test
|
||||
let server = run_internal::<String>(&start_args, None);
|
||||
|
||||
if !wait_is_ready {
|
||||
return Ok((addr, server));
|
||||
}
|
||||
|
||||
// Wait 5 seconds for the server to start
|
||||
let mut interval = time::interval(time::Duration::from_millis(1000));
|
||||
info!("Waiting for server to start...");
|
||||
for _i in 0..10 {
|
||||
interval.tick().await;
|
||||
|
||||
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
||||
info!("Server ready!");
|
||||
return Ok((addr, server));
|
||||
}
|
||||
}
|
||||
|
||||
let server_out = server.kill().output().err().unwrap();
|
||||
error!("server output: {server_out}");
|
||||
Err("server failed to start".into())
|
||||
}
|
288
tests/common/socket.rs
Normal file
288
tests/common/socket.rs
Normal file
|
@ -0,0 +1,288 @@
|
|||
use super::format::Format;
|
||||
use crate::common::error::TestError;
|
||||
use futures_util::{SinkExt, TryStreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use surrealdb::sql::Value;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||
use tracing::{debug, error};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct UseParams<'a> {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
ns: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
db: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SigninParams<'a> {
|
||||
user: &'a str,
|
||||
pass: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
ns: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
db: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
sc: Option<&'a str>,
|
||||
}
|
||||
|
||||
pub struct Socket {
|
||||
pub stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
}
|
||||
|
||||
// pub struct Socket(pub WebSocketStream<MaybeTlsStream<TcpStream>>);
|
||||
|
||||
impl Socket {
|
||||
/// Close the connection with the WebSocket server
|
||||
pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
|
||||
Ok(self.stream.close(None).await?)
|
||||
}
|
||||
|
||||
/// Connect to a WebSocket server using a specific format
|
||||
pub async fn connect(addr: &str, format: Option<Format>) -> Result<Self, Box<dyn Error>> {
|
||||
let url = format!("ws://{}/rpc", addr);
|
||||
let mut req = url.into_client_request().unwrap();
|
||||
if let Some(v) = format.map(|v| v.to_string()) {
|
||||
req.headers_mut().insert("Sec-WebSocket-Protocol", v.parse().unwrap());
|
||||
}
|
||||
let (stream, _) = connect_async(req).await?;
|
||||
Ok(Self {
|
||||
stream,
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a text or binary message to the WebSocket server
|
||||
pub async fn send_message(
|
||||
&mut self,
|
||||
format: Format,
|
||||
message: serde_json::Value,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let now = time::Instant::now();
|
||||
debug!("Sending message: {message}");
|
||||
// Format the message
|
||||
let msg = match format {
|
||||
Format::Json => Message::Text(serde_json::to_string(&message)?),
|
||||
};
|
||||
// Send the message
|
||||
tokio::select! {
|
||||
_ = time::sleep(time::Duration::from_millis(500)) => {
|
||||
return Err("timeout after 500ms waiting for the request to be sent".into());
|
||||
}
|
||||
res = self.stream.send(msg) => {
|
||||
debug!("Message sent in {:?}", now.elapsed());
|
||||
if let Err(err) = res {
|
||||
return Err(format!("Error sending the message: {}", err).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive a text or binary message from the WebSocket server
|
||||
pub async fn receive_message(
|
||||
&mut self,
|
||||
format: Format,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
let now = time::Instant::now();
|
||||
debug!("Receiving response...");
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = time::sleep(time::Duration::from_millis(5000)) => {
|
||||
return Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
|
||||
}
|
||||
res = self.stream.try_next() => {
|
||||
match res {
|
||||
Ok(res) => match res {
|
||||
Some(Message::Text(msg)) => {
|
||||
debug!("Response {msg:?} received in {:?}", now.elapsed());
|
||||
match format {
|
||||
Format::Json => {
|
||||
let msg = serde_json::from_str(&msg)?;
|
||||
debug!("Received message: {msg}");
|
||||
return Ok(msg);
|
||||
},
|
||||
}
|
||||
},
|
||||
Some(_) => {
|
||||
continue;
|
||||
}
|
||||
None => {
|
||||
return Err("Expected to receive a message".to_string().into());
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(format!("Error receiving the message: {}", err).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a text or binary message and receive a reponse from the WebSocket server
|
||||
pub async fn send_and_receive_message(
|
||||
&mut self,
|
||||
format: Format,
|
||||
message: serde_json::Value,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
self.send_message(format, message).await?;
|
||||
self.receive_message(format).await
|
||||
}
|
||||
|
||||
/// When testing Live Queries, we may receive multiple messages unordered.
|
||||
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
|
||||
pub async fn receive_all_messages(
|
||||
&mut self,
|
||||
format: Format,
|
||||
expected: usize,
|
||||
timeout: Duration,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||
let mut res = Vec::new();
|
||||
let deadline = time::Instant::now() + timeout;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = time::sleep_until(deadline) => {
|
||||
debug!("Waited for {:?} and received {} messages", timeout, res.len());
|
||||
if res.len() != expected {
|
||||
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
|
||||
}
|
||||
}
|
||||
msg = self.receive_message(format) => {
|
||||
res.push(msg?);
|
||||
}
|
||||
}
|
||||
if res.len() == expected {
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a USE message to the server and check the response
|
||||
pub async fn send_message_use(
|
||||
&mut self,
|
||||
format: Format,
|
||||
ns: Option<&str>,
|
||||
db: Option<&str>,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
// Generate an ID
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
// Construct message
|
||||
let msg = json!({
|
||||
"id": id,
|
||||
"method": "use",
|
||||
"params": [
|
||||
ns, db
|
||||
],
|
||||
});
|
||||
// Send message and receive response
|
||||
let msg = self.send_and_receive_message(format, msg).await?;
|
||||
// Check response message structure
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||
.get("result")
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!(
|
||||
"expected a result from the received object, got this instead: {:?}",
|
||||
obj
|
||||
),
|
||||
})?
|
||||
.to_owned()),
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a generic query message to the server and check the response
|
||||
pub async fn send_message_query(
|
||||
&mut self,
|
||||
format: Format,
|
||||
query: &str,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||
// Generate an ID
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
// Construct message
|
||||
let msg = json!({
|
||||
"id": id,
|
||||
"method": "query",
|
||||
"params": [query],
|
||||
});
|
||||
// Send message and receive response
|
||||
let msg = self.send_and_receive_message(format, msg).await?;
|
||||
// Check response message structure
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||
.get("result")
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
||||
})?
|
||||
.as_array()
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
||||
})?
|
||||
.to_owned()),
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a signin authentication query message to the server and check the response
|
||||
pub async fn send_message_signin(
|
||||
&mut self,
|
||||
format: Format,
|
||||
user: &str,
|
||||
pass: &str,
|
||||
ns: Option<&str>,
|
||||
db: Option<&str>,
|
||||
sc: Option<&str>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
// Generate an ID
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
// Construct message
|
||||
let msg = json!({
|
||||
"id": id,
|
||||
"method": "signin",
|
||||
"params": [
|
||||
SigninParams { user, pass, ns, db, sc }
|
||||
],
|
||||
});
|
||||
// Send message and receive response
|
||||
let msg = self.send_and_receive_message(format, msg).await?;
|
||||
// Check response message structure
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||
.get("result")
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
||||
})?
|
||||
.as_str()
|
||||
.ok_or(TestError::AssertionError {
|
||||
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
||||
})?
|
||||
.to_owned()),
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
1177
tests/common/tests.rs
Normal file
1177
tests/common/tests.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1258,8 +1258,8 @@ mod http_integration {
|
|||
.send()
|
||||
.await?;
|
||||
assert_eq!(res.status(), 200);
|
||||
|
||||
let _: serde_cbor::Value = serde_cbor::from_slice(&res.bytes().await?).unwrap();
|
||||
let res = res.bytes().await?.to_vec();
|
||||
let _: ciborium::Value = ciborium::from_reader(res.as_slice()).unwrap();
|
||||
}
|
||||
|
||||
// Creating a record with Accept PACK encoding is allowed
|
||||
|
@ -1272,8 +1272,8 @@ mod http_integration {
|
|||
.send()
|
||||
.await?;
|
||||
assert_eq!(res.status(), 200);
|
||||
|
||||
let _: serde_cbor::Value = serde_pack::from_slice(&res.bytes().await?).unwrap();
|
||||
let res = res.bytes().await?.to_vec();
|
||||
let _: rmpv::Value = rmpv::decode::read_value(&mut res.as_slice()).unwrap();
|
||||
}
|
||||
|
||||
// Creating a record with Accept Surrealdb encoding is allowed
|
||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue