Improve WebSocket protocol implementation (#3291)
This commit is contained in:
parent
8e9bd3a2d6
commit
a35fc0d04d
31 changed files with 2300 additions and 2338 deletions
23
Cargo.lock
generated
23
Cargo.lock
generated
|
@ -4445,6 +4445,16 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rmpv"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2e0e0214a4a2b444ecce41a4025792fc31f77c7bb89c46d253953ea8c65701ec"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
"rmp",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "roaring"
|
name = "roaring"
|
||||||
version = "0.10.2"
|
version = "0.10.2"
|
||||||
|
@ -4817,16 +4827,6 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "serde_cbor"
|
|
||||||
version = "0.11.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5"
|
|
||||||
dependencies = [
|
|
||||||
"half",
|
|
||||||
"serde",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.193"
|
version = "1.0.193"
|
||||||
|
@ -5235,6 +5235,7 @@ dependencies = [
|
||||||
"axum-server",
|
"axum-server",
|
||||||
"base64 0.21.5",
|
"base64 0.21.5",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
"ciborium",
|
||||||
"clap",
|
"clap",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"futures",
|
"futures",
|
||||||
|
@ -5257,10 +5258,10 @@ dependencies = [
|
||||||
"rcgen",
|
"rcgen",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rmp-serde",
|
"rmp-serde",
|
||||||
|
"rmpv",
|
||||||
"rustyline",
|
"rustyline",
|
||||||
"semver",
|
"semver",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_cbor",
|
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serial_test",
|
"serial_test",
|
||||||
"surrealdb",
|
"surrealdb",
|
||||||
|
|
|
@ -40,6 +40,7 @@ axum-extra = { version = "0.7.7", features = ["query", "typed-routing"] }
|
||||||
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
|
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
|
||||||
base64 = "0.21.5"
|
base64 = "0.21.5"
|
||||||
bytes = "1.5.0"
|
bytes = "1.5.0"
|
||||||
|
ciborium = "0.2.1"
|
||||||
clap = { version = "4.4.11", features = ["env", "derive", "wrap_help", "unicode"] }
|
clap = { version = "4.4.11", features = ["env", "derive", "wrap_help", "unicode"] }
|
||||||
futures = "0.3.29"
|
futures = "0.3.29"
|
||||||
futures-util = "0.3.29"
|
futures-util = "0.3.29"
|
||||||
|
@ -55,9 +56,9 @@ opentelemetry-otlp = { version = "0.12.0", features = ["metrics"] }
|
||||||
pin-project-lite = "0.2.13"
|
pin-project-lite = "0.2.13"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.11.22", default-features = false, features = ["blocking", "gzip"] }
|
reqwest = { version = "0.11.22", default-features = false, features = ["blocking", "gzip"] }
|
||||||
|
rmpv = "1.0.1"
|
||||||
rustyline = { version = "12.0.0", features = ["derive"] }
|
rustyline = { version = "12.0.0", features = ["derive"] }
|
||||||
serde = { version = "1.0.193", features = ["derive"] }
|
serde = { version = "1.0.193", features = ["derive"] }
|
||||||
serde_cbor = "0.11.2"
|
|
||||||
serde_json = "1.0.108"
|
serde_json = "1.0.108"
|
||||||
serde_pack = { version = "1.1.2", package = "rmp-serde" }
|
serde_pack = { version = "1.1.2", package = "rmp-serde" }
|
||||||
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }
|
surrealdb = { path = "lib", features = ["protocol-http", "protocol-ws", "rustls"] }
|
||||||
|
|
|
@ -34,13 +34,13 @@ env = { RUST_LOG={ value = "http_integration=debug", condition = { env_not_set =
|
||||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"]
|
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression", "--workspace", "--test", "http_integration", "--", "http_integration", "--nocapture"]
|
||||||
|
|
||||||
[tasks.ci-ws-integration]
|
[tasks.ci-ws-integration]
|
||||||
category = "CI - INTEGRATION TESTS"
|
category = "WS - INTEGRATION TESTS"
|
||||||
command = "cargo"
|
command = "cargo"
|
||||||
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
env = { RUST_LOG={ value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||||
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
|
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration", "--nocapture"]
|
||||||
|
|
||||||
[tasks.ci-ml-integration]
|
[tasks.ci-ml-integration]
|
||||||
category = "CI - INTEGRATION TESTS"
|
category = "ML - INTEGRATION TESTS"
|
||||||
command = "cargo"
|
command = "cargo"
|
||||||
env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
env = { RUST_LOG={ value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
|
||||||
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]
|
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]
|
||||||
|
|
|
@ -130,6 +130,7 @@ pub fn thing((arg1, arg2): (Value, Option<Value>)) -> Result<Value, Error> {
|
||||||
|
|
||||||
pub mod is {
|
pub mod is {
|
||||||
use crate::err::Error;
|
use crate::err::Error;
|
||||||
|
use crate::sql::table::Table;
|
||||||
use crate::sql::value::Value;
|
use crate::sql::value::Value;
|
||||||
use crate::sql::Geometry;
|
use crate::sql::Geometry;
|
||||||
|
|
||||||
|
@ -215,7 +216,7 @@ pub mod is {
|
||||||
|
|
||||||
pub fn record((arg, table): (Value, Option<String>)) -> Result<Value, Error> {
|
pub fn record((arg, table): (Value, Option<String>)) -> Result<Value, Error> {
|
||||||
Ok(match table {
|
Ok(match table {
|
||||||
Some(tb) => arg.is_record_of_table(tb).into(),
|
Some(tb) => arg.is_record_type(&[Table(tb)]).into(),
|
||||||
None => arg.is_record().into(),
|
None => arg.is_record().into(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ use crate::sql::{escape::escape_rid, Array, Number, Object, Strand, Thing, Uuid,
|
||||||
use nanoid::nanoid;
|
use nanoid::nanoid;
|
||||||
use revision::revisioned;
|
use revision::revisioned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::BTreeMap;
|
||||||
use std::fmt::{self, Display, Formatter};
|
use std::fmt::{self, Display, Formatter};
|
||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
@ -106,6 +107,12 @@ impl From<Vec<Value>> for Id {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<BTreeMap<String, Value>> for Id {
|
||||||
|
fn from(v: BTreeMap<String, Value>) -> Self {
|
||||||
|
Id::Object(v.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Number> for Id {
|
impl From<Number> for Id {
|
||||||
fn from(v: Number) -> Self {
|
fn from(v: Number) -> Self {
|
||||||
match v {
|
match v {
|
||||||
|
|
|
@ -23,6 +23,12 @@ pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Object";
|
||||||
#[revisioned(revision = 1)]
|
#[revisioned(revision = 1)]
|
||||||
pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap<String, Value>);
|
pub struct Object(#[serde(with = "no_nul_bytes_in_keys")] pub BTreeMap<String, Value>);
|
||||||
|
|
||||||
|
impl From<BTreeMap<&str, Value>> for Object {
|
||||||
|
fn from(v: BTreeMap<&str, Value>) -> Self {
|
||||||
|
Self(v.into_iter().map(|(key, val)| (key.to_string(), val)).collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<BTreeMap<String, Value>> for Object {
|
impl From<BTreeMap<String, Value>> for Object {
|
||||||
fn from(v: BTreeMap<String, Value>) -> Self {
|
fn from(v: BTreeMap<String, Value>) -> Self {
|
||||||
Self(v)
|
Self(v)
|
||||||
|
|
|
@ -2,6 +2,7 @@ use crate::ctx::Context;
|
||||||
use crate::dbs::{Options, Transaction};
|
use crate::dbs::{Options, Transaction};
|
||||||
use crate::doc::CursorDoc;
|
use crate::doc::CursorDoc;
|
||||||
use crate::err::Error;
|
use crate::err::Error;
|
||||||
|
use crate::sql::Uuid;
|
||||||
use crate::sql::Value;
|
use crate::sql::Value;
|
||||||
use derive::Store;
|
use derive::Store;
|
||||||
use revision::revisioned;
|
use revision::revisioned;
|
||||||
|
@ -34,6 +35,14 @@ impl KillStatement {
|
||||||
Value::Uuid(id) => *id,
|
Value::Uuid(id) => *id,
|
||||||
Value::Param(param) => match param.compute(ctx, opt, txn, None).await? {
|
Value::Param(param) => match param.compute(ctx, opt, txn, None).await? {
|
||||||
Value::Uuid(id) => id,
|
Value::Uuid(id) => id,
|
||||||
|
Value::Strand(id) => match uuid::Uuid::try_parse(&id) {
|
||||||
|
Ok(id) => Uuid(id),
|
||||||
|
_ => {
|
||||||
|
return Err(Error::KillStatement {
|
||||||
|
value: self.id.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
_ => {
|
_ => {
|
||||||
return Err(Error::KillStatement {
|
return Err(Error::KillStatement {
|
||||||
value: self.id.to_string(),
|
value: self.id.to_string(),
|
||||||
|
|
|
@ -458,6 +458,12 @@ impl From<Vec<bool>> for Value {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<HashMap<&str, Value>> for Value {
|
||||||
|
fn from(v: HashMap<&str, Value>) -> Self {
|
||||||
|
Value::Object(Object::from(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<HashMap<String, Value>> for Value {
|
impl From<HashMap<String, Value>> for Value {
|
||||||
fn from(v: HashMap<String, Value>) -> Self {
|
fn from(v: HashMap<String, Value>) -> Self {
|
||||||
Value::Object(Object::from(v))
|
Value::Object(Object::from(v))
|
||||||
|
@ -470,6 +476,12 @@ impl From<BTreeMap<String, Value>> for Value {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<BTreeMap<&str, Value>> for Value {
|
||||||
|
fn from(v: BTreeMap<&str, Value>) -> Self {
|
||||||
|
Value::Object(Object::from(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Option<Value>> for Value {
|
impl From<Option<Value>> for Value {
|
||||||
fn from(v: Option<Value>) -> Self {
|
fn from(v: Option<Value>) -> Self {
|
||||||
match v {
|
match v {
|
||||||
|
@ -730,6 +742,18 @@ impl TryFrom<Value> for Object {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FromIterator<Value> for Value {
|
||||||
|
fn from_iter<I: IntoIterator<Item = Value>>(iter: I) -> Self {
|
||||||
|
Value::Array(Array(iter.into_iter().collect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromIterator<(String, Value)> for Value {
|
||||||
|
fn from_iter<I: IntoIterator<Item = (String, Value)>>(iter: I) -> Self {
|
||||||
|
Value::Object(Object(iter.into_iter().collect()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Value {
|
impl Value {
|
||||||
// -----------------------------------
|
// -----------------------------------
|
||||||
// Initial record value
|
// Initial record value
|
||||||
|
@ -828,6 +852,11 @@ impl Value {
|
||||||
matches!(self, Value::Mock(_))
|
matches!(self, Value::Mock(_))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if this Value is a Param
|
||||||
|
pub fn is_param(&self) -> bool {
|
||||||
|
matches!(self, Value::Param(_))
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if this Value is a Range
|
/// Check if this Value is a Range
|
||||||
pub fn is_range(&self) -> bool {
|
pub fn is_range(&self) -> bool {
|
||||||
matches!(self, Value::Range(_))
|
matches!(self, Value::Range(_))
|
||||||
|
@ -952,11 +981,6 @@ impl Value {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if this Value is a Param
|
|
||||||
pub fn is_param(&self) -> bool {
|
|
||||||
matches!(self, Value::Param(_))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if this Value is a Geometry of a specific type
|
/// Check if this Value is a Geometry of a specific type
|
||||||
pub fn is_geometry_type(&self, types: &[String]) -> bool {
|
pub fn is_geometry_type(&self, types: &[String]) -> bool {
|
||||||
match self {
|
match self {
|
||||||
|
@ -1089,7 +1113,7 @@ impl Value {
|
||||||
/// Treat a string as a table name
|
/// Treat a string as a table name
|
||||||
pub fn could_be_table(self) -> Value {
|
pub fn could_be_table(self) -> Value {
|
||||||
match self {
|
match self {
|
||||||
Value::Strand(v) => Table::from(v.0).into(),
|
Value::Strand(v) => Value::Table(v.0.into()),
|
||||||
_ => self,
|
_ => self,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,9 +7,6 @@ use base64::DecodeError as Base64Error;
|
||||||
use http::{HeaderName, StatusCode};
|
use http::{HeaderName, StatusCode};
|
||||||
use reqwest::Error as ReqwestError;
|
use reqwest::Error as ReqwestError;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_cbor::error::Error as CborError;
|
|
||||||
use serde_json::error::Error as JsonError;
|
|
||||||
use serde_pack::encode::Error as PackError;
|
|
||||||
use std::io::Error as IoError;
|
use std::io::Error as IoError;
|
||||||
use std::string::FromUtf8Error as Utf8Error;
|
use std::string::FromUtf8Error as Utf8Error;
|
||||||
use surrealdb::error::Db as SurrealDbError;
|
use surrealdb::error::Db as SurrealDbError;
|
||||||
|
@ -37,12 +34,12 @@ pub enum Error {
|
||||||
#[error("There was a problem connecting with the storage engine")]
|
#[error("There was a problem connecting with the storage engine")]
|
||||||
InvalidStorage,
|
InvalidStorage,
|
||||||
|
|
||||||
#[error("There was a problem parsing the header {0}: {1}")]
|
|
||||||
InvalidHeader(HeaderName, TypedHeaderRejection),
|
|
||||||
|
|
||||||
#[error("The operation is unsupported")]
|
#[error("The operation is unsupported")]
|
||||||
OperationUnsupported,
|
OperationUnsupported,
|
||||||
|
|
||||||
|
#[error("There was a problem parsing the header {0}: {1}")]
|
||||||
|
InvalidHeader(HeaderName, TypedHeaderRejection),
|
||||||
|
|
||||||
#[error("There was a problem with the database: {0}")]
|
#[error("There was a problem with the database: {0}")]
|
||||||
Db(#[from] SurrealError),
|
Db(#[from] SurrealError),
|
||||||
|
|
||||||
|
@ -52,14 +49,14 @@ pub enum Error {
|
||||||
#[error("There was an error with the network: {0}")]
|
#[error("There was an error with the network: {0}")]
|
||||||
Axum(#[from] AxumError),
|
Axum(#[from] AxumError),
|
||||||
|
|
||||||
#[error("There was an error serializing to JSON: {0}")]
|
#[error("There was an error with JSON serialization: {0}")]
|
||||||
Json(#[from] JsonError),
|
Json(String),
|
||||||
|
|
||||||
#[error("There was an error serializing to CBOR: {0}")]
|
#[error("There was an error with CBOR serialization: {0}")]
|
||||||
Cbor(#[from] CborError),
|
Cbor(String),
|
||||||
|
|
||||||
#[error("There was an error serializing to MessagePack: {0}")]
|
#[error("There was an error with MessagePack serialization: {0}")]
|
||||||
Pack(#[from] PackError),
|
Pack(String),
|
||||||
|
|
||||||
#[error("There was an error with the remote request: {0}")]
|
#[error("There was an error with the remote request: {0}")]
|
||||||
Remote(#[from] ReqwestError),
|
Remote(#[from] ReqwestError),
|
||||||
|
@ -93,6 +90,42 @@ impl From<Utf8Error> for Error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<serde_json::Error> for Error {
|
||||||
|
fn from(e: serde_json::Error) -> Error {
|
||||||
|
Error::Json(e.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<serde_pack::encode::Error> for Error {
|
||||||
|
fn from(e: serde_pack::encode::Error) -> Error {
|
||||||
|
Error::Pack(e.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<serde_pack::decode::Error> for Error {
|
||||||
|
fn from(e: serde_pack::decode::Error) -> Error {
|
||||||
|
Error::Pack(e.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ciborium::value::Error> for Error {
|
||||||
|
fn from(e: ciborium::value::Error) -> Error {
|
||||||
|
Error::Cbor(format!("{e}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: std::fmt::Debug> From<ciborium::de::Error<T>> for Error {
|
||||||
|
fn from(e: ciborium::de::Error<T>) -> Error {
|
||||||
|
Error::Cbor(format!("{e}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: std::fmt::Debug> From<ciborium::ser::Error<T>> for Error {
|
||||||
|
fn from(e: ciborium::ser::Error<T>) -> Error {
|
||||||
|
Error::Cbor(format!("{e}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<surrealdb::error::Db> for Error {
|
impl From<surrealdb::error::Db> for Error {
|
||||||
fn from(error: surrealdb::error::Db) -> Error {
|
fn from(error: surrealdb::error::Db) -> Error {
|
||||||
if matches!(error, surrealdb::error::Db::InvalidAuth) {
|
if matches!(error, surrealdb::error::Db::InvalidAuth) {
|
||||||
|
|
|
@ -39,8 +39,9 @@ pub fn cbor<T>(val: &T) -> Output
|
||||||
where
|
where
|
||||||
T: Serialize,
|
T: Serialize,
|
||||||
{
|
{
|
||||||
match serde_cbor::to_vec(val) {
|
let mut out = Vec::new();
|
||||||
Ok(v) => Output::Cbor(v),
|
match ciborium::into_writer(&val, &mut out) {
|
||||||
|
Ok(_) => Output::Cbor(out),
|
||||||
Err(_) => Output::Fail,
|
Err(_) => Output::Fail,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1,21 @@
|
||||||
use crate::cnf;
|
use crate::cnf;
|
||||||
|
use crate::err::Error;
|
||||||
use crate::rpc::connection::Connection;
|
use crate::rpc::connection::Connection;
|
||||||
|
use crate::rpc::format::Format;
|
||||||
|
use crate::rpc::format::PROTOCOLS;
|
||||||
|
use crate::rpc::WEBSOCKETS;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::Extension;
|
use axum::{
|
||||||
use axum::Router;
|
extract::ws::{WebSocket, WebSocketUpgrade},
|
||||||
|
response::IntoResponse,
|
||||||
|
Extension, Router,
|
||||||
|
};
|
||||||
|
use http::HeaderValue;
|
||||||
use http_body::Body as HttpBody;
|
use http_body::Body as HttpBody;
|
||||||
use surrealdb::dbs::Session;
|
use surrealdb::dbs::Session;
|
||||||
use tower_http::request_id::RequestId;
|
use tower_http::request_id::RequestId;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use axum::{
|
|
||||||
extract::ws::{WebSocket, WebSocketUpgrade},
|
|
||||||
response::IntoResponse,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub(super) fn router<S, B>() -> Router<S, B>
|
pub(super) fn router<S, B>() -> Router<S, B>
|
||||||
where
|
where
|
||||||
B: HttpBody + Send + 'static,
|
B: HttpBody + Send + 'static,
|
||||||
|
@ -23,28 +26,53 @@ where
|
||||||
|
|
||||||
async fn handler(
|
async fn handler(
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
|
Extension(id): Extension<RequestId>,
|
||||||
Extension(sess): Extension<Session>,
|
Extension(sess): Extension<Session>,
|
||||||
Extension(req_id): Extension<RequestId>,
|
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||||
) -> impl IntoResponse {
|
// Check if there is a request id header specified
|
||||||
ws
|
let id = match id.header_value().is_empty() {
|
||||||
// Set the maximum frame size
|
// No request id was specified so create a new id
|
||||||
|
true => Uuid::new_v4(),
|
||||||
|
// A request id was specified to try to parse it
|
||||||
|
false => match id.header_value().to_str() {
|
||||||
|
// Attempt to parse the request id as a UUID
|
||||||
|
Ok(id) => match Uuid::try_parse(id) {
|
||||||
|
// The specified request id was a valid UUID
|
||||||
|
Ok(id) => id,
|
||||||
|
// The specified request id was not a UUID
|
||||||
|
Err(_) => return Err(Error::Request),
|
||||||
|
},
|
||||||
|
// The request id contained invalid characters
|
||||||
|
Err(_) => return Err(Error::Request),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
// Check if a connection with this id already exists
|
||||||
|
if WEBSOCKETS.read().await.contains_key(&id) {
|
||||||
|
return Err(Error::Request);
|
||||||
|
}
|
||||||
|
// Now let's upgrade the WebSocket connection
|
||||||
|
Ok(ws
|
||||||
|
// Set the potential WebSocket protocols
|
||||||
|
.protocols(PROTOCOLS)
|
||||||
|
// Set the maximum WebSocket frame size
|
||||||
.max_frame_size(*cnf::WEBSOCKET_MAX_FRAME_SIZE)
|
.max_frame_size(*cnf::WEBSOCKET_MAX_FRAME_SIZE)
|
||||||
// Set the maximum message size
|
// Set the maximum WebSocket message size
|
||||||
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
|
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
|
||||||
// Set the potential WebSocket protocol formats
|
|
||||||
.protocols(["surrealql-binary", "json", "cbor", "messagepack"])
|
|
||||||
// Handle the WebSocket upgrade and process messages
|
// Handle the WebSocket upgrade and process messages
|
||||||
.on_upgrade(move |socket| handle_socket(socket, sess, req_id))
|
.on_upgrade(move |socket| handle_socket(socket, sess, id)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
|
async fn handle_socket(ws: WebSocket, sess: Session, id: Uuid) {
|
||||||
|
// Check if there is a WebSocket protocol specified
|
||||||
|
let format = match ws.protocol().map(HeaderValue::to_str) {
|
||||||
|
// Any selected protocol will always be a valie value
|
||||||
|
Some(protocol) => protocol.unwrap().into(),
|
||||||
|
// No protocol format was specified
|
||||||
|
_ => Format::None,
|
||||||
|
};
|
||||||
|
//
|
||||||
// Create a new connection instance
|
// Create a new connection instance
|
||||||
let rpc = Connection::new(sess);
|
let rpc = Connection::new(id, sess, format);
|
||||||
// Update the WebSocket ID with the Request ID
|
|
||||||
if let Ok(Ok(req_id)) = req_id.header_value().to_str().map(Uuid::parse_str) {
|
|
||||||
// If the ID couldn't be updated, ignore the error and keep the default ID
|
|
||||||
let _ = rpc.write().await.update_ws_id(req_id).await;
|
|
||||||
}
|
|
||||||
// Serve the socket connection requests
|
// Serve the socket connection requests
|
||||||
Connection::serve(rpc, ws).await;
|
Connection::serve(rpc, ws).await;
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ pub trait Take {
|
||||||
impl Take for Array {
|
impl Take for Array {
|
||||||
/// Convert the array to one argument
|
/// Convert the array to one argument
|
||||||
fn needs_one(self) -> Result<Value, ()> {
|
fn needs_one(self) -> Result<Value, ()> {
|
||||||
if self.is_empty() {
|
if self.len() != 1 {
|
||||||
return Err(());
|
return Err(());
|
||||||
}
|
}
|
||||||
let mut x = self.into_iter();
|
let mut x = self.into_iter();
|
||||||
|
@ -22,7 +22,7 @@ impl Take for Array {
|
||||||
}
|
}
|
||||||
/// Convert the array to two arguments
|
/// Convert the array to two arguments
|
||||||
fn needs_two(self) -> Result<(Value, Value), ()> {
|
fn needs_two(self) -> Result<(Value, Value), ()> {
|
||||||
if self.len() < 2 {
|
if self.len() != 2 {
|
||||||
return Err(());
|
return Err(());
|
||||||
}
|
}
|
||||||
let mut x = self.into_iter();
|
let mut x = self.into_iter();
|
||||||
|
@ -34,7 +34,7 @@ impl Take for Array {
|
||||||
}
|
}
|
||||||
/// Convert the array to two arguments
|
/// Convert the array to two arguments
|
||||||
fn needs_one_or_two(self) -> Result<(Value, Value), ()> {
|
fn needs_one_or_two(self) -> Result<(Value, Value), ()> {
|
||||||
if self.is_empty() {
|
if self.is_empty() && self.len() > 2 {
|
||||||
return Err(());
|
return Err(());
|
||||||
}
|
}
|
||||||
let mut x = self.into_iter();
|
let mut x = self.into_iter();
|
||||||
|
@ -46,7 +46,7 @@ impl Take for Array {
|
||||||
}
|
}
|
||||||
/// Convert the array to three arguments
|
/// Convert the array to three arguments
|
||||||
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), ()> {
|
fn needs_one_two_or_three(self) -> Result<(Value, Value, Value), ()> {
|
||||||
if self.is_empty() {
|
if self.is_empty() && self.len() > 3 {
|
||||||
return Err(());
|
return Err(());
|
||||||
}
|
}
|
||||||
let mut x = self.into_iter();
|
let mut x = self.into_iter();
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
use super::request::parse_request;
|
|
||||||
use super::response::{failure, success, Data, Failure, IntoRpcResponse, OutputFormat};
|
|
||||||
use crate::cnf::PKG_NAME;
|
use crate::cnf::PKG_NAME;
|
||||||
use crate::cnf::PKG_VERSION;
|
use crate::cnf::PKG_VERSION;
|
||||||
use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY};
|
use crate::cnf::{WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY};
|
||||||
use crate::dbs::DB;
|
use crate::dbs::DB;
|
||||||
use crate::err::Error;
|
use crate::err::Error;
|
||||||
use crate::rpc::args::Take;
|
use crate::rpc::args::Take;
|
||||||
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::format::Format;
|
||||||
|
use crate::rpc::response::{failure, success, Data, IntoRpcResponse};
|
||||||
use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
|
use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
|
||||||
use crate::telemetry;
|
use crate::telemetry;
|
||||||
use crate::telemetry::metrics::ws::RequestContext;
|
use crate::telemetry::metrics::ws::RequestContext;
|
||||||
|
@ -33,9 +34,9 @@ use tracing::Span;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub struct Connection {
|
pub struct Connection {
|
||||||
ws_id: Uuid,
|
id: Uuid,
|
||||||
session: Session,
|
session: Session,
|
||||||
format: OutputFormat,
|
format: Format,
|
||||||
vars: BTreeMap<String, Value>,
|
vars: BTreeMap<String, Value>,
|
||||||
limiter: Arc<Semaphore>,
|
limiter: Arc<Semaphore>,
|
||||||
canceller: CancellationToken,
|
canceller: CancellationToken,
|
||||||
|
@ -43,34 +44,20 @@ pub struct Connection {
|
||||||
|
|
||||||
impl Connection {
|
impl Connection {
|
||||||
/// Instantiate a new RPC
|
/// Instantiate a new RPC
|
||||||
pub fn new(mut session: Session) -> Arc<RwLock<Connection>> {
|
pub fn new(id: Uuid, mut session: Session, format: Format) -> Arc<RwLock<Connection>> {
|
||||||
// Create a new RPC variables store
|
|
||||||
let vars = BTreeMap::new();
|
|
||||||
// Set the default output format
|
|
||||||
let format = OutputFormat::Json;
|
|
||||||
// Enable real-time mode
|
// Enable real-time mode
|
||||||
session.rt = true;
|
session.rt = true;
|
||||||
// Create and store the RPC connection
|
// Create and store the RPC connection
|
||||||
Arc::new(RwLock::new(Connection {
|
Arc::new(RwLock::new(Connection {
|
||||||
ws_id: Uuid::new_v4(),
|
id,
|
||||||
session,
|
session,
|
||||||
format,
|
format,
|
||||||
vars,
|
vars: BTreeMap::new(),
|
||||||
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
|
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
|
||||||
canceller: CancellationToken::new(),
|
canceller: CancellationToken::new(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update the WebSocket ID. If the ID already exists, do not update it.
|
|
||||||
pub async fn update_ws_id(&mut self, ws_id: Uuid) -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
if WEBSOCKETS.read().await.contains_key(&ws_id) {
|
|
||||||
trace!("WebSocket ID '{}' is in use by another connection. Do not update it.", &ws_id);
|
|
||||||
return Err("websocket ID is in use".into());
|
|
||||||
}
|
|
||||||
self.ws_id = ws_id;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Serve the RPC endpoint
|
/// Serve the RPC endpoint
|
||||||
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
|
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
|
||||||
// Split the socket into send and recv
|
// Split the socket into send and recv
|
||||||
|
@ -79,19 +66,19 @@ impl Connection {
|
||||||
let (internal_sender, internal_receiver) =
|
let (internal_sender, internal_receiver) =
|
||||||
channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS);
|
channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS);
|
||||||
|
|
||||||
let ws_id = rpc.read().await.ws_id;
|
let id = rpc.read().await.id;
|
||||||
|
|
||||||
trace!("WebSocket {} connected", ws_id);
|
trace!("WebSocket {} connected", id);
|
||||||
|
|
||||||
if let Err(err) = telemetry::metrics::ws::on_connect() {
|
if let Err(err) = telemetry::metrics::ws::on_connect() {
|
||||||
error!("Error running metrics::ws::on_connect hook: {}", err);
|
error!("Error running metrics::ws::on_connect hook: {}", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add this WebSocket to the list
|
// Add this WebSocket to the list
|
||||||
WEBSOCKETS.write().await.insert(
|
WEBSOCKETS
|
||||||
ws_id,
|
.write()
|
||||||
WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()),
|
.await
|
||||||
);
|
.insert(id, WebSocketRef(internal_sender.clone(), rpc.read().await.canceller.clone()));
|
||||||
|
|
||||||
// Spawn async tasks for the WebSocket
|
// Spawn async tasks for the WebSocket
|
||||||
let mut tasks = JoinSet::new();
|
let mut tasks = JoinSet::new();
|
||||||
|
@ -109,15 +96,15 @@ impl Connection {
|
||||||
|
|
||||||
internal_sender.close();
|
internal_sender.close();
|
||||||
|
|
||||||
trace!("WebSocket {} disconnected", ws_id);
|
trace!("WebSocket {} disconnected", id);
|
||||||
|
|
||||||
// Remove this WebSocket from the list
|
// Remove this WebSocket from the list
|
||||||
WEBSOCKETS.write().await.remove(&ws_id);
|
WEBSOCKETS.write().await.remove(&id);
|
||||||
|
|
||||||
// Remove all live queries
|
// Remove all live queries
|
||||||
let mut gc = Vec::new();
|
let mut gc = Vec::new();
|
||||||
LIVE_QUERIES.write().await.retain(|key, value| {
|
LIVE_QUERIES.write().await.retain(|key, value| {
|
||||||
if value == &ws_id {
|
if value == &id {
|
||||||
trace!("Removing live query: {}", key);
|
trace!("Removing live query: {}", key);
|
||||||
gc.push(*key);
|
gc.push(*key);
|
||||||
return false;
|
return false;
|
||||||
|
@ -288,9 +275,9 @@ impl Connection {
|
||||||
msg = channel.recv() => {
|
msg = channel.recv() => {
|
||||||
if let Ok(notification) = msg {
|
if let Ok(notification) = msg {
|
||||||
// Find which WebSocket the notification belongs to
|
// Find which WebSocket the notification belongs to
|
||||||
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
if let Some(id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||||
// Check to see if the WebSocket exists
|
// Check to see if the WebSocket exists
|
||||||
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
|
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(id) {
|
||||||
// Serialize the message to send
|
// Serialize the message to send
|
||||||
let message = success(None, notification);
|
let message = success(None, notification);
|
||||||
// Get the current output format
|
// Get the current output format
|
||||||
|
@ -309,27 +296,41 @@ impl Connection {
|
||||||
/// Handle individual WebSocket messages
|
/// Handle individual WebSocket messages
|
||||||
async fn handle_message(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) {
|
async fn handle_message(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) {
|
||||||
// Get the current output format
|
// Get the current output format
|
||||||
let mut out_fmt = rpc.read().await.format;
|
let mut fmt = rpc.read().await.format;
|
||||||
// Prepare Span and Otel context
|
// Prepare Span and Otel context
|
||||||
let span = span_for_request(&rpc.read().await.ws_id);
|
let span = span_for_request(&rpc.read().await.id);
|
||||||
// Acquire concurrent request rate limiter
|
// Acquire concurrent request rate limiter
|
||||||
let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap();
|
let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap();
|
||||||
|
// Calculate the length of the message
|
||||||
|
let len = match msg {
|
||||||
|
Message::Text(ref msg) => {
|
||||||
|
// If no format was specified, default to JSON
|
||||||
|
if fmt.is_none() {
|
||||||
|
fmt = Format::Json;
|
||||||
|
rpc.write().await.format = fmt;
|
||||||
|
}
|
||||||
|
// Retrieve the length of the message
|
||||||
|
msg.len()
|
||||||
|
}
|
||||||
|
Message::Binary(ref msg) => {
|
||||||
|
// If no format was specified, default to Bincode
|
||||||
|
if fmt.is_none() {
|
||||||
|
fmt = Format::Bincode;
|
||||||
|
rpc.write().await.format = fmt;
|
||||||
|
}
|
||||||
|
// Retrieve the length of the message
|
||||||
|
msg.len()
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
// Parse the request
|
// Parse the request
|
||||||
async move {
|
async move {
|
||||||
let span = Span::current();
|
let span = Span::current();
|
||||||
let req_cx = RequestContext::default();
|
let req_cx = RequestContext::default();
|
||||||
let otel_cx = TelemetryContext::new().with_value(req_cx.clone());
|
let otel_cx = TelemetryContext::new().with_value(req_cx.clone());
|
||||||
|
// Parse the RPC request structure
|
||||||
match parse_request(msg).await {
|
match fmt.req(msg) {
|
||||||
Ok(req) => {
|
Ok(req) => {
|
||||||
if let Some(fmt) = req.out_fmt {
|
|
||||||
if out_fmt != fmt {
|
|
||||||
// Update the default format
|
|
||||||
rpc.write().await.format = fmt;
|
|
||||||
out_fmt = fmt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we know the method, we can update the span and create otel context
|
// Now that we know the method, we can update the span and create otel context
|
||||||
span.record("rpc.method", &req.method);
|
span.record("rpc.method", &req.method);
|
||||||
span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
|
span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
|
||||||
|
@ -338,17 +339,17 @@ impl Connection {
|
||||||
req.id.clone().map(Value::as_string).unwrap_or_default(),
|
req.id.clone().map(Value::as_string).unwrap_or_default(),
|
||||||
);
|
);
|
||||||
let otel_cx = TelemetryContext::current_with_value(
|
let otel_cx = TelemetryContext::current_with_value(
|
||||||
req_cx.with_method(&req.method).with_size(req.size),
|
req_cx.with_method(&req.method).with_size(len),
|
||||||
);
|
);
|
||||||
// Process the message
|
// Process the message
|
||||||
let res =
|
let res =
|
||||||
Connection::process_message(rpc.clone(), &req.method, req.params).await;
|
Connection::process_message(rpc.clone(), &req.method, req.params).await;
|
||||||
// Process the response
|
// Process the response
|
||||||
res.into_response(req.id).send(out_fmt, &chn).with_context(otel_cx).await
|
res.into_response(req.id).send(fmt, &chn).with_context(otel_cx).await
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
// Process the response
|
// Process the response
|
||||||
failure(None, err).send(out_fmt, &chn).with_context(otel_cx).await
|
failure(None, err).send(fmt, &chn).with_context(otel_cx).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -485,13 +486,6 @@ impl Connection {
|
||||||
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
|
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Specify the output format for text requests
|
|
||||||
"format" => match params.needs_one() {
|
|
||||||
Ok(Value::Strand(v)) => {
|
|
||||||
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Get the current server version
|
// Get the current server version
|
||||||
"version" => match params.len() {
|
"version" => match params.len() {
|
||||||
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
||||||
|
@ -515,16 +509,6 @@ impl Connection {
|
||||||
// Methods for authentication
|
// Methods for authentication
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
|
||||||
match out.as_str() {
|
|
||||||
"json" | "application/json" => self.format = OutputFormat::Json,
|
|
||||||
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
|
|
||||||
"pack" | "application/pack" => self.format = OutputFormat::Pack,
|
|
||||||
_ => return Err(Error::InvalidType),
|
|
||||||
};
|
|
||||||
Ok(Value::None)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
||||||
if let Value::Strand(ns) = ns {
|
if let Value::Strand(ns) = ns {
|
||||||
self.session.ns = Some(ns.0);
|
self.session.ns = Some(ns.0);
|
||||||
|
@ -615,7 +599,7 @@ impl Connection {
|
||||||
let sql = "KILL $id";
|
let sql = "KILL $id";
|
||||||
// Specify the query parameters
|
// Specify the query parameters
|
||||||
let var = map! {
|
let var = map! {
|
||||||
String::from("id") => id, // NOTE: id can be parameter
|
String::from("id") => id,
|
||||||
=> &self.vars
|
=> &self.vars
|
||||||
};
|
};
|
||||||
// Execute the query on the database
|
// Execute the query on the database
|
||||||
|
@ -910,15 +894,14 @@ impl Connection {
|
||||||
QueryType::Live => {
|
QueryType::Live => {
|
||||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||||
// Match on Uuid type
|
// Match on Uuid type
|
||||||
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
|
LIVE_QUERIES.write().await.insert(lqid.0, self.id);
|
||||||
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
|
trace!("Registered live query {} on websocket {}", lqid, self.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
QueryType::Kill => {
|
QueryType::Kill => {
|
||||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||||
let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0);
|
if let Some(id) = LIVE_QUERIES.write().await.remove(&lqid.0) {
|
||||||
if let Some(ws_id) = ws_id {
|
trace!("Unregistered live query {} on websocket {}", lqid, id);
|
||||||
trace!("Unregistered live query {} on websocket {}", lqid, ws_id);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
70
src/rpc/failure.rs
Normal file
70
src/rpc/failure.rs
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
use crate::err::Error;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use surrealdb::sql::Value;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize)]
|
||||||
|
pub struct Failure {
|
||||||
|
pub(crate) code: i64,
|
||||||
|
pub(crate) message: Cow<'static, str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for Failure {
|
||||||
|
fn from(err: &str) -> Self {
|
||||||
|
Failure::custom(err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Error> for Failure {
|
||||||
|
fn from(err: Error) -> Self {
|
||||||
|
Failure::custom(err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Failure> for Value {
|
||||||
|
fn from(err: Failure) -> Self {
|
||||||
|
map! {
|
||||||
|
String::from("code") => Value::from(err.code),
|
||||||
|
String::from("message") => Value::from(err.message.to_string()),
|
||||||
|
}
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
impl Failure {
|
||||||
|
pub const PARSE_ERROR: Failure = Failure {
|
||||||
|
code: -32700,
|
||||||
|
message: Cow::Borrowed("Parse error"),
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const INVALID_REQUEST: Failure = Failure {
|
||||||
|
code: -32600,
|
||||||
|
message: Cow::Borrowed("Invalid Request"),
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const METHOD_NOT_FOUND: Failure = Failure {
|
||||||
|
code: -32601,
|
||||||
|
message: Cow::Borrowed("Method not found"),
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const INVALID_PARAMS: Failure = Failure {
|
||||||
|
code: -32602,
|
||||||
|
message: Cow::Borrowed("Invalid params"),
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const INTERNAL_ERROR: Failure = Failure {
|
||||||
|
code: -32603,
|
||||||
|
message: Cow::Borrowed("Internal error"),
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn custom<S>(message: S) -> Failure
|
||||||
|
where
|
||||||
|
Cow<'static, str>: From<S>,
|
||||||
|
{
|
||||||
|
Failure {
|
||||||
|
code: -32000,
|
||||||
|
message: message.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
22
src/rpc/format/bincode.rs
Normal file
22
src/rpc/format/bincode.rs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::request::Request;
|
||||||
|
use crate::rpc::response::Response;
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
use surrealdb::sql::serde::deserialize;
|
||||||
|
use surrealdb::sql::Value;
|
||||||
|
|
||||||
|
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||||
|
match msg {
|
||||||
|
Message::Binary(val) => {
|
||||||
|
deserialize::<Value>(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_REQUEST),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||||
|
// Serialize the response with full internal type information
|
||||||
|
let res = surrealdb::sql::serde::serialize(&res).unwrap();
|
||||||
|
// Return the message length, and message as binary
|
||||||
|
Ok((res.len(), Message::Binary(res)))
|
||||||
|
}
|
24
src/rpc/format/cbor/mod.rs
Normal file
24
src/rpc/format/cbor/mod.rs
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::request::Request;
|
||||||
|
use crate::rpc::response::Response;
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
|
||||||
|
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||||
|
match msg {
|
||||||
|
Message::Text(val) => {
|
||||||
|
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_REQUEST),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||||
|
// Convert the response into simplified JSON
|
||||||
|
let val = res.into_json();
|
||||||
|
// Create a new vector for encoding output
|
||||||
|
let mut res = Vec::new();
|
||||||
|
// Serialize the value into CBOR binary data
|
||||||
|
ciborium::into_writer(&val, &mut res).unwrap();
|
||||||
|
// Return the message length, and message as binary
|
||||||
|
Ok((res.len(), Message::Binary(res)))
|
||||||
|
}
|
22
src/rpc/format/json.rs
Normal file
22
src/rpc/format/json.rs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::request::Request;
|
||||||
|
use crate::rpc::response::Response;
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
|
||||||
|
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||||
|
match msg {
|
||||||
|
Message::Text(val) => {
|
||||||
|
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_REQUEST),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||||
|
// Convert the response into simplified JSON
|
||||||
|
let val = res.into_json();
|
||||||
|
// Serialize the response with simplified type information
|
||||||
|
let res = serde_json::to_string(&val).unwrap();
|
||||||
|
// Return the message length, and message as binary
|
||||||
|
Ok((res.len(), Message::Text(res)))
|
||||||
|
}
|
70
src/rpc/format/mod.rs
Normal file
70
src/rpc/format/mod.rs
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
mod bincode;
|
||||||
|
pub mod cbor;
|
||||||
|
mod json;
|
||||||
|
pub mod msgpack;
|
||||||
|
mod revision;
|
||||||
|
|
||||||
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::request::Request;
|
||||||
|
use crate::rpc::response::Response;
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
|
||||||
|
pub const PROTOCOLS: [&str; 5] = [
|
||||||
|
"json", // For basic JSON serialisation
|
||||||
|
"cbor", // For basic CBOR serialisation
|
||||||
|
"msgpack", // For basic Msgpack serialisation
|
||||||
|
"bincode", // For full internal serialisation
|
||||||
|
"revision", // For full versioned serialisation
|
||||||
|
];
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||||
|
pub enum Format {
|
||||||
|
None, // No format is specified yet
|
||||||
|
Json, // For basic JSON serialisation
|
||||||
|
Cbor, // For basic CBOR serialisation
|
||||||
|
Msgpack, // For basic Msgpack serialisation
|
||||||
|
Bincode, // For full internal serialisation
|
||||||
|
Revision, // For full versioned serialisation
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for Format {
|
||||||
|
fn from(v: &str) -> Self {
|
||||||
|
match v {
|
||||||
|
s if s == PROTOCOLS[0] => Format::Json,
|
||||||
|
s if s == PROTOCOLS[1] => Format::Cbor,
|
||||||
|
s if s == PROTOCOLS[2] => Format::Msgpack,
|
||||||
|
s if s == PROTOCOLS[3] => Format::Bincode,
|
||||||
|
s if s == PROTOCOLS[4] => Format::Revision,
|
||||||
|
_ => Format::None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Format {
|
||||||
|
/// Check if this format has been set
|
||||||
|
pub fn is_none(&self) -> bool {
|
||||||
|
matches!(self, Format::None)
|
||||||
|
}
|
||||||
|
/// Process a request using the specified format
|
||||||
|
pub fn req(&self, msg: Message) -> Result<Request, Failure> {
|
||||||
|
match self {
|
||||||
|
Self::None => unreachable!(), // We should never arrive at this code
|
||||||
|
Self::Json => json::req(msg),
|
||||||
|
Self::Cbor => cbor::req(msg),
|
||||||
|
Self::Msgpack => msgpack::req(msg),
|
||||||
|
Self::Bincode => bincode::req(msg),
|
||||||
|
Self::Revision => revision::req(msg),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Process a response using the specified format
|
||||||
|
pub fn res(&self, res: Response) -> Result<(usize, Message), Failure> {
|
||||||
|
match self {
|
||||||
|
Self::None => unreachable!(), // We should never arrive at this code
|
||||||
|
Self::Json => json::res(res),
|
||||||
|
Self::Cbor => cbor::res(res),
|
||||||
|
Self::Msgpack => msgpack::res(res),
|
||||||
|
Self::Bincode => bincode::res(res),
|
||||||
|
Self::Revision => revision::res(res),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
22
src/rpc/format/msgpack/mod.rs
Normal file
22
src/rpc/format/msgpack/mod.rs
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::request::Request;
|
||||||
|
use crate::rpc::response::Response;
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
|
||||||
|
pub fn req(msg: Message) -> Result<Request, Failure> {
|
||||||
|
match msg {
|
||||||
|
Message::Text(val) => {
|
||||||
|
surrealdb::sql::value(&val).map_err(|_| Failure::PARSE_ERROR)?.try_into()
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_REQUEST),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn res(res: Response) -> Result<(usize, Message), Failure> {
|
||||||
|
// Convert the response into simplified JSON
|
||||||
|
let val = res.into_json();
|
||||||
|
// Serialize the value into MsgPack binary data
|
||||||
|
let res = serde_pack::to_vec(&val).unwrap();
|
||||||
|
// Return the message length, and message as binary
|
||||||
|
Ok((res.len(), Message::Binary(res)))
|
||||||
|
}
|
14
src/rpc/format/revision.rs
Normal file
14
src/rpc/format/revision.rs
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::request::Request;
|
||||||
|
use crate::rpc::response::Response;
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
|
||||||
|
pub fn req(_msg: Message) -> Result<Request, Failure> {
|
||||||
|
// This format is not yet implemented
|
||||||
|
Err(Failure::INTERNAL_ERROR)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn res(_res: Response) -> Result<(usize, Message), Failure> {
|
||||||
|
// This format is not yet implemented
|
||||||
|
Err(Failure::INTERNAL_ERROR)
|
||||||
|
}
|
|
@ -1,12 +1,14 @@
|
||||||
pub mod args;
|
pub mod args;
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
|
pub mod failure;
|
||||||
|
pub mod format;
|
||||||
pub mod request;
|
pub mod request;
|
||||||
pub mod response;
|
pub mod response;
|
||||||
|
|
||||||
use std::{collections::HashMap, time::Duration};
|
|
||||||
|
|
||||||
use axum::extract::ws::Message;
|
use axum::extract::ws::Message;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
use surrealdb::channel::Sender;
|
use surrealdb::channel::Sender;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
use axum::extract::ws::Message;
|
use crate::rpc::failure::Failure;
|
||||||
use surrealdb::sql::{serde::deserialize, Array, Value};
|
|
||||||
|
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use surrealdb::sql::Part;
|
use surrealdb::sql::Part;
|
||||||
|
use surrealdb::sql::{Array, Value};
|
||||||
use super::response::{Failure, OutputFormat};
|
|
||||||
|
|
||||||
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
|
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
|
||||||
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
|
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
|
||||||
|
@ -14,71 +11,36 @@ pub struct Request {
|
||||||
pub id: Option<Value>,
|
pub id: Option<Value>,
|
||||||
pub method: String,
|
pub method: String,
|
||||||
pub params: Array,
|
pub params: Array,
|
||||||
pub size: usize,
|
|
||||||
pub out_fmt: Option<OutputFormat>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse the RPC request
|
impl TryFrom<Value> for Request {
|
||||||
pub async fn parse_request(msg: Message) -> Result<Request, Failure> {
|
type Error = Failure;
|
||||||
let mut out_fmt = None;
|
fn try_from(val: Value) -> Result<Self, Failure> {
|
||||||
let (req, size) = match msg {
|
// Fetch the 'id' argument
|
||||||
// This is a binary message
|
let id = match val.pick(&*ID) {
|
||||||
Message::Binary(val) => {
|
v if v.is_none() => None,
|
||||||
// Use binary output
|
v if v.is_null() => Some(v),
|
||||||
out_fmt = Some(OutputFormat::Full);
|
v if v.is_uuid() => Some(v),
|
||||||
|
v if v.is_number() => Some(v),
|
||||||
match deserialize(&val) {
|
v if v.is_strand() => Some(v),
|
||||||
Ok(v) => (v, val.len()),
|
v if v.is_datetime() => Some(v),
|
||||||
Err(_) => {
|
_ => return Err(Failure::INVALID_REQUEST),
|
||||||
debug!("Error when trying to deserialize the request");
|
};
|
||||||
return Err(Failure::PARSE_ERROR);
|
// Fetch the 'method' argument
|
||||||
}
|
let method = match val.pick(&*METHOD) {
|
||||||
}
|
Value::Strand(v) => v.to_raw(),
|
||||||
}
|
_ => return Err(Failure::INVALID_REQUEST),
|
||||||
// This is a text message
|
};
|
||||||
Message::Text(ref val) => {
|
// Fetch the 'params' argument
|
||||||
// Parse the SurrealQL object
|
let params = match val.pick(&*PARAMS) {
|
||||||
match surrealdb::sql::value(val) {
|
Value::Array(v) => v,
|
||||||
// The SurrealQL message parsed ok
|
_ => Array::new(),
|
||||||
Ok(v) => (v, val.len()),
|
};
|
||||||
// The SurrealQL message failed to parse
|
// Return the parsed request
|
||||||
_ => return Err(Failure::PARSE_ERROR),
|
Ok(Request {
|
||||||
}
|
id,
|
||||||
}
|
method,
|
||||||
// Unsupported message type
|
params,
|
||||||
_ => {
|
})
|
||||||
debug!("Unsupported message type: {:?}", msg);
|
}
|
||||||
return Err(Failure::custom("Unsupported message type"));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Fetch the 'id' argument
|
|
||||||
let id = match req.pick(&*ID) {
|
|
||||||
v if v.is_none() => None,
|
|
||||||
v if v.is_null() => Some(v),
|
|
||||||
v if v.is_uuid() => Some(v),
|
|
||||||
v if v.is_number() => Some(v),
|
|
||||||
v if v.is_strand() => Some(v),
|
|
||||||
v if v.is_datetime() => Some(v),
|
|
||||||
_ => return Err(Failure::INVALID_REQUEST),
|
|
||||||
};
|
|
||||||
// Fetch the 'method' argument
|
|
||||||
let method = match req.pick(&*METHOD) {
|
|
||||||
Value::Strand(v) => v.to_raw(),
|
|
||||||
_ => return Err(Failure::INVALID_REQUEST),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Fetch the 'params' argument
|
|
||||||
let params = match req.pick(&*PARAMS) {
|
|
||||||
Value::Array(v) => v,
|
|
||||||
_ => Array::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Request {
|
|
||||||
id,
|
|
||||||
method,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
out_fmt,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use crate::err;
|
use crate::rpc::failure::Failure;
|
||||||
|
use crate::rpc::format::Format;
|
||||||
use crate::telemetry::metrics::ws::record_rpc;
|
use crate::telemetry::metrics::ws::record_rpc;
|
||||||
use axum::extract::ws::Message;
|
use axum::extract::ws::Message;
|
||||||
use opentelemetry::Context as TelemetryContext;
|
use opentelemetry::Context as TelemetryContext;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::{json, Value as Json};
|
use serde_json::Value as Json;
|
||||||
use std::borrow::Cow;
|
|
||||||
use surrealdb::channel::Sender;
|
use surrealdb::channel::Sender;
|
||||||
use surrealdb::dbs;
|
use surrealdb::dbs;
|
||||||
use surrealdb::dbs::Notification;
|
use surrealdb::dbs::Notification;
|
||||||
|
@ -12,14 +12,6 @@ use surrealdb::sql;
|
||||||
use surrealdb::sql::Value;
|
use surrealdb::sql::Value;
|
||||||
use tracing::Span;
|
use tracing::Span;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
|
||||||
pub enum OutputFormat {
|
|
||||||
Json, // JSON
|
|
||||||
Cbor, // CBOR
|
|
||||||
Pack, // MessagePack
|
|
||||||
Full, // Full type serialization
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The data returned by the database
|
/// The data returned by the database
|
||||||
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
|
// The variants here should be in exactly the same order as `surrealdb::engine::remote::ws::Data`
|
||||||
// In future, they will possibly be merged to avoid having to keep them in sync.
|
// In future, they will possibly be merged to avoid having to keep them in sync.
|
||||||
|
@ -46,15 +38,25 @@ impl From<String> for Data {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<Notification> for Data {
|
||||||
|
fn from(n: Notification) -> Self {
|
||||||
|
Data::Live(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Vec<dbs::Response>> for Data {
|
impl From<Vec<dbs::Response>> for Data {
|
||||||
fn from(v: Vec<dbs::Response>) -> Self {
|
fn from(v: Vec<dbs::Response>) -> Self {
|
||||||
Data::Query(v)
|
Data::Query(v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Notification> for Data {
|
impl From<Data> for Value {
|
||||||
fn from(n: Notification) -> Self {
|
fn from(val: Data) -> Self {
|
||||||
Data::Live(n)
|
match val {
|
||||||
|
Data::Query(v) => sql::to_value(v).unwrap(),
|
||||||
|
Data::Live(v) => sql::to_value(v).unwrap(),
|
||||||
|
Data::Other(v) => v,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,32 +69,31 @@ pub struct Response {
|
||||||
impl Response {
|
impl Response {
|
||||||
/// Convert and simplify the value into JSON
|
/// Convert and simplify the value into JSON
|
||||||
#[inline]
|
#[inline]
|
||||||
fn simplify(self) -> Json {
|
pub fn into_json(self) -> Json {
|
||||||
|
Json::from(self.into_value())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn into_value(self) -> Value {
|
||||||
let mut value = match self.result {
|
let mut value = match self.result {
|
||||||
Ok(data) => {
|
Ok(val) => map! {
|
||||||
let value = match data {
|
"result" => Value::from(val),
|
||||||
Data::Query(vec) => sql::to_value(vec).unwrap(),
|
},
|
||||||
Data::Live(notification) => sql::to_value(notification).unwrap(),
|
Err(err) => map! {
|
||||||
Data::Other(value) => value,
|
"error" => Value::from(err),
|
||||||
};
|
},
|
||||||
json!({
|
|
||||||
"result": Json::from(value),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(failure) => json!({
|
|
||||||
"error": failure,
|
|
||||||
}),
|
|
||||||
};
|
};
|
||||||
if let Some(id) = self.id {
|
if let Some(id) = self.id {
|
||||||
value["id"] = id.into();
|
value.insert("id", id);
|
||||||
}
|
}
|
||||||
value
|
value.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send the response to the WebSocket channel
|
/// Send the response to the WebSocket channel
|
||||||
pub async fn send(self, out: OutputFormat, chn: &Sender<Message>) {
|
pub async fn send(self, fmt: Format, chn: &Sender<Message>) {
|
||||||
|
// Create a new tracing span
|
||||||
let span = Span::current();
|
let span = Span::current();
|
||||||
|
// Log the rpc response call
|
||||||
debug!("Process RPC response");
|
debug!("Process RPC response");
|
||||||
|
|
||||||
let is_error = self.result.is_err();
|
let is_error = self.result.is_err();
|
||||||
|
@ -105,73 +106,12 @@ impl Response {
|
||||||
span.record("rpc.error_code", err.code);
|
span.record("rpc.error_code", err.code);
|
||||||
span.record("rpc.error_message", err.message.as_ref());
|
span.record("rpc.error_message", err.message.as_ref());
|
||||||
}
|
}
|
||||||
|
// Process the response for the format
|
||||||
let (res_size, message) = match out {
|
let (len, msg) = fmt.res(self).unwrap();
|
||||||
OutputFormat::Json => {
|
// Send the message to the write channel
|
||||||
let res = serde_json::to_string(&self.simplify()).unwrap();
|
if chn.send(msg).await.is_ok() {
|
||||||
(res.len(), Message::Text(res))
|
record_rpc(&TelemetryContext::current(), len, is_error);
|
||||||
}
|
|
||||||
OutputFormat::Cbor => {
|
|
||||||
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
|
||||||
(res.len(), Message::Binary(res))
|
|
||||||
}
|
|
||||||
OutputFormat::Pack => {
|
|
||||||
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
|
||||||
(res.len(), Message::Binary(res))
|
|
||||||
}
|
|
||||||
OutputFormat::Full => {
|
|
||||||
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
|
||||||
(res.len(), Message::Binary(res))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if chn.send(message).await.is_ok() {
|
|
||||||
record_rpc(&TelemetryContext::current(), res_size, is_error);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize)]
|
|
||||||
pub struct Failure {
|
|
||||||
code: i64,
|
|
||||||
message: Cow<'static, str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
impl Failure {
|
|
||||||
pub const PARSE_ERROR: Failure = Failure {
|
|
||||||
code: -32700,
|
|
||||||
message: Cow::Borrowed("Parse error"),
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const INVALID_REQUEST: Failure = Failure {
|
|
||||||
code: -32600,
|
|
||||||
message: Cow::Borrowed("Invalid Request"),
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const METHOD_NOT_FOUND: Failure = Failure {
|
|
||||||
code: -32601,
|
|
||||||
message: Cow::Borrowed("Method not found"),
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const INVALID_PARAMS: Failure = Failure {
|
|
||||||
code: -32602,
|
|
||||||
message: Cow::Borrowed("Invalid params"),
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const INTERNAL_ERROR: Failure = Failure {
|
|
||||||
code: -32603,
|
|
||||||
message: Cow::Borrowed("Internal error"),
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn custom<S>(message: S) -> Failure
|
|
||||||
where
|
|
||||||
Cow<'static, str>: From<S>,
|
|
||||||
{
|
|
||||||
Failure {
|
|
||||||
code: -32000,
|
|
||||||
message: message.into(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,12 +131,6 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<err::Error> for Failure {
|
|
||||||
fn from(err: err::Error) -> Self {
|
|
||||||
Failure::custom(err.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait IntoRpcResponse {
|
pub trait IntoRpcResponse {
|
||||||
fn into_response(self, id: Option<Value>) -> Response;
|
fn into_response(self, id: Option<Value>) -> Response;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@ mod common;
|
||||||
|
|
||||||
mod cli_integration {
|
mod cli_integration {
|
||||||
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
||||||
|
use common::Format;
|
||||||
|
use common::Socket;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::time;
|
use std::time;
|
||||||
|
@ -735,15 +737,13 @@ mod cli_integration {
|
||||||
let (addr, mut server) = common::start_server_without_auth().await.unwrap();
|
let (addr, mut server) = common::start_server_without_auth().await.unwrap();
|
||||||
|
|
||||||
// Create a long-lived WS connection so the server don't shutdown gracefully
|
// Create a long-lived WS connection so the server don't shutdown gracefully
|
||||||
let mut socket = common::connect_ws(&addr).await.expect("Failed to connect to server");
|
let mut socket = Socket::connect(&addr, None).await.expect("Failed to connect to server");
|
||||||
let json = serde_json::json!({
|
let json = serde_json::json!({
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"method": "query",
|
"method": "query",
|
||||||
"params": ["SLEEP 30s;"],
|
"params": ["SLEEP 30s;"],
|
||||||
});
|
});
|
||||||
common::ws_send_msg(&mut socket, serde_json::to_string(&json).unwrap())
|
socket.send_message(Format::Json, json).await.expect("Failed to send WS message");
|
||||||
.await
|
|
||||||
.expect("Failed to send WS message");
|
|
||||||
|
|
||||||
// Make sure the SLEEP query is being executed
|
// Make sure the SLEEP query is being executed
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
|
|
14
tests/common/format.rs
Normal file
14
tests/common/format.rs
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
use std::string::ToString;
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub enum Format {
|
||||||
|
Json,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToString for Format {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::Json => "json".to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,489 +1,17 @@
|
||||||
|
#![allow(unused_imports)]
|
||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod format;
|
||||||
|
pub mod server;
|
||||||
|
pub mod socket;
|
||||||
|
|
||||||
use crate::common::error::TestError;
|
pub use format::*;
|
||||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
pub use server::*;
|
||||||
use rand::{thread_rng, Rng};
|
pub use socket::*;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::json;
|
|
||||||
use std::error::Error;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::process::{Command, Stdio};
|
|
||||||
use std::time::Duration;
|
|
||||||
use std::{env, fs};
|
|
||||||
use tokio::net::TcpStream;
|
|
||||||
use tokio::time;
|
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
|
||||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
|
||||||
use tracing::{debug, error, info};
|
|
||||||
|
|
||||||
pub const USER: &str = "root";
|
|
||||||
pub const PASS: &str = "root";
|
|
||||||
|
|
||||||
/// Child is a (maybe running) CLI process. It can be killed by dropping it
|
|
||||||
pub struct Child {
|
|
||||||
inner: Option<std::process::Child>,
|
|
||||||
stdout_path: String,
|
|
||||||
stderr_path: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Child {
|
|
||||||
/// Send some thing to the child's stdin
|
|
||||||
pub fn input(mut self, input: &str) -> Self {
|
|
||||||
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
|
|
||||||
use std::io::Write;
|
|
||||||
stdin.write_all(input.as_bytes()).unwrap();
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn kill(mut self) -> Self {
|
|
||||||
self.inner.as_mut().unwrap().kill().unwrap();
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
|
|
||||||
nix::sys::signal::kill(
|
|
||||||
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
|
|
||||||
signal,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
|
|
||||||
self.inner.as_mut().unwrap().try_wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stdout(&self) -> String {
|
|
||||||
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stderr(&self) -> String {
|
|
||||||
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
|
|
||||||
/// returns successfully, Err otherwise.
|
|
||||||
pub fn output(mut self) -> Result<String, String> {
|
|
||||||
let status = self.inner.take().unwrap().wait().unwrap();
|
|
||||||
|
|
||||||
let mut buf = self.stdout();
|
|
||||||
buf.push_str(&self.stderr());
|
|
||||||
|
|
||||||
// Cleanup files after reading them
|
|
||||||
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
|
|
||||||
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
|
|
||||||
|
|
||||||
if status.success() {
|
|
||||||
Ok(buf)
|
|
||||||
} else {
|
|
||||||
Err(buf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Drop for Child {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
if let Some(inner) = self.inner.as_mut() {
|
|
||||||
let _ = inner.kill();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
|
||||||
let mut path = std::env::current_exe().unwrap();
|
|
||||||
assert!(path.pop());
|
|
||||||
if path.ends_with("deps") {
|
|
||||||
assert!(path.pop());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: Cargo automatically builds this binary for integration tests.
|
|
||||||
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
|
|
||||||
|
|
||||||
let mut cmd = Command::new(path);
|
|
||||||
if let Some(dir) = current_dir {
|
|
||||||
cmd.current_dir(&dir);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
|
|
||||||
let stdout_path = tmp_file("server-stdout.log");
|
|
||||||
let stderr_path = tmp_file("server-stderr.log");
|
|
||||||
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
|
|
||||||
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
|
|
||||||
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
|
|
||||||
|
|
||||||
cmd.env_clear();
|
|
||||||
cmd.stdin(Stdio::piped());
|
|
||||||
cmd.stdout(stdout);
|
|
||||||
cmd.stderr(stderr);
|
|
||||||
cmd.args(args.split_ascii_whitespace());
|
|
||||||
|
|
||||||
Child {
|
|
||||||
inner: Some(cmd.spawn().unwrap()),
|
|
||||||
stdout_path,
|
|
||||||
stderr_path,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run the CLI with the given args
|
|
||||||
pub fn run(args: &str) -> Child {
|
|
||||||
run_internal::<String>(args, None)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run the CLI with the given args inside a temporary directory
|
|
||||||
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
|
||||||
run_internal(args, Some(current_dir))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tmp_file(name: &str) -> String {
|
|
||||||
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
|
|
||||||
path.to_string_lossy().into_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct StartServerArguments {
|
|
||||||
pub auth: bool,
|
|
||||||
pub tls: bool,
|
|
||||||
pub wait_is_ready: bool,
|
|
||||||
pub enable_auth_level: bool,
|
|
||||||
pub tick_interval: time::Duration,
|
|
||||||
pub args: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for StartServerArguments {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
auth: true,
|
|
||||||
tls: false,
|
|
||||||
wait_is_ready: true,
|
|
||||||
enable_auth_level: false,
|
|
||||||
tick_interval: time::Duration::new(1, 0),
|
|
||||||
args: "--allow-all".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
|
|
||||||
start_server(StartServerArguments {
|
|
||||||
auth: false,
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
|
|
||||||
start_server(StartServerArguments {
|
|
||||||
enable_auth_level: true,
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
|
|
||||||
start_server(StartServerArguments::default()).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start_server(
|
|
||||||
StartServerArguments {
|
|
||||||
auth,
|
|
||||||
tls,
|
|
||||||
wait_is_ready,
|
|
||||||
enable_auth_level,
|
|
||||||
tick_interval,
|
|
||||||
args,
|
|
||||||
}: StartServerArguments,
|
|
||||||
) -> Result<(String, Child), Box<dyn Error>> {
|
|
||||||
let mut rng = thread_rng();
|
|
||||||
|
|
||||||
let port: u16 = rng.gen_range(13000..14000);
|
|
||||||
let addr = format!("127.0.0.1:{port}");
|
|
||||||
|
|
||||||
let mut extra_args = args.clone();
|
|
||||||
if tls {
|
|
||||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
|
||||||
let crt_path = tmp_file("crt.crt");
|
|
||||||
let key_path = tmp_file("key.pem");
|
|
||||||
|
|
||||||
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
|
|
||||||
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
|
|
||||||
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
|
|
||||||
|
|
||||||
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if auth {
|
|
||||||
extra_args.push_str(" --auth");
|
|
||||||
}
|
|
||||||
|
|
||||||
if enable_auth_level {
|
|
||||||
extra_args.push_str(" --auth-level-enabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
if !tick_interval.is_zero() {
|
|
||||||
let sec = tick_interval.as_secs();
|
|
||||||
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
|
|
||||||
|
|
||||||
info!("starting server with args: {start_args}");
|
|
||||||
|
|
||||||
// Configure where the logs go when running the test
|
|
||||||
let server = run_internal::<String>(&start_args, None);
|
|
||||||
|
|
||||||
if !wait_is_ready {
|
|
||||||
return Ok((addr, server));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait 5 seconds for the server to start
|
|
||||||
let mut interval = time::interval(time::Duration::from_millis(1000));
|
|
||||||
info!("Waiting for server to start...");
|
|
||||||
for _i in 0..10 {
|
|
||||||
interval.tick().await;
|
|
||||||
|
|
||||||
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
|
||||||
info!("Server ready!");
|
|
||||||
return Ok((addr, server));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let server_out = server.kill().output().err().unwrap();
|
|
||||||
error!("server output: {server_out}");
|
|
||||||
Err("server failed to start".into())
|
|
||||||
}
|
|
||||||
|
|
||||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
|
||||||
|
|
||||||
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
|
|
||||||
let url = format!("ws://{}/rpc", addr);
|
|
||||||
let (ws_stream, _) = connect_async(url).await?;
|
|
||||||
Ok(ws_stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_send_msg(socket: &mut WsStream, msg_req: String) -> Result<(), Box<dyn Error>> {
|
|
||||||
let now = time::Instant::now();
|
|
||||||
debug!("Sending message: {msg_req}");
|
|
||||||
tokio::select! {
|
|
||||||
_ = time::sleep(time::Duration::from_millis(500)) => {
|
|
||||||
return Err("timeout after 500ms waiting for the request to be sent".into());
|
|
||||||
}
|
|
||||||
res = socket.send(Message::Text(msg_req)) => {
|
|
||||||
debug!("Message sent in {:?}", now.elapsed());
|
|
||||||
if let Err(err) = res {
|
|
||||||
return Err(format!("Error sending the message: {}", err).into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_recv_msg(socket: &mut WsStream) -> Result<serde_json::Value, Box<dyn Error>> {
|
|
||||||
ws_recv_msg_with_fmt(socket, Format::Json).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// When testing Live Queries, we may receive multiple messages unordered.
|
|
||||||
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
|
|
||||||
pub async fn ws_recv_all_msgs(
|
|
||||||
socket: &mut WsStream,
|
|
||||||
expected: usize,
|
|
||||||
timeout: Duration,
|
|
||||||
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
|
||||||
let mut res = Vec::new();
|
|
||||||
let deadline = time::Instant::now() + timeout;
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
_ = time::sleep_until(deadline) => {
|
|
||||||
debug!("Waited for {:?} and received {} messages", timeout, res.len());
|
|
||||||
if res.len() != expected {
|
|
||||||
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg = ws_recv_msg(socket) => {
|
|
||||||
res.push(msg?);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if res.len() == expected {
|
|
||||||
return Ok(res);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_send_msg_and_wait_response(
|
|
||||||
socket: &mut WsStream,
|
|
||||||
msg_req: String,
|
|
||||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
|
||||||
ws_send_msg(socket, msg_req).await?;
|
|
||||||
ws_recv_msg_with_fmt(socket, Format::Json).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum Format {
|
|
||||||
Json,
|
|
||||||
Cbor,
|
|
||||||
Pack,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_recv_msg_with_fmt(
|
|
||||||
socket: &mut WsStream,
|
|
||||||
format: Format,
|
|
||||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
|
||||||
let now = time::Instant::now();
|
|
||||||
debug!("Waiting for response...");
|
|
||||||
// Parse and return response
|
|
||||||
let mut f = socket.try_filter(|msg| match format {
|
|
||||||
Format::Json => futures_util::future::ready(msg.is_text()),
|
|
||||||
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
|
|
||||||
});
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
_ = time::sleep(time::Duration::from_millis(5000)) => {
|
|
||||||
Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
|
|
||||||
}
|
|
||||||
res = f.select_next_some() => {
|
|
||||||
debug!("Response received in {:?}", now.elapsed());
|
|
||||||
match format {
|
|
||||||
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
|
|
||||||
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
|
|
||||||
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
struct SigninParams<'a> {
|
|
||||||
user: &'a str,
|
|
||||||
pass: &'a str,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
ns: Option<&'a str>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
db: Option<&'a str>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
sc: Option<&'a str>,
|
|
||||||
}
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
struct UseParams<'a> {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
ns: Option<&'a str>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
db: Option<&'a str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_signin(
|
|
||||||
socket: &mut WsStream,
|
|
||||||
user: &str,
|
|
||||||
pass: &str,
|
|
||||||
ns: Option<&str>,
|
|
||||||
db: Option<&str>,
|
|
||||||
sc: Option<&str>,
|
|
||||||
) -> Result<String, Box<dyn Error>> {
|
|
||||||
let request_id = uuid::Uuid::new_v4().to_string().replace('-', "");
|
|
||||||
let json = json!({
|
|
||||||
"id": request_id,
|
|
||||||
"method": "signin",
|
|
||||||
"params": [
|
|
||||||
SigninParams { user, pass, ns, db, sc }
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
|
||||||
let msg = ws_recv_msg(socket).await?;
|
|
||||||
debug!("ws_query result json={json:?} msg={msg:?}");
|
|
||||||
|
|
||||||
match msg.as_object() {
|
|
||||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
|
||||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
|
||||||
}
|
|
||||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
|
||||||
.get("result")
|
|
||||||
.ok_or(TestError::AssertionError {
|
|
||||||
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
|
||||||
})?
|
|
||||||
.as_str()
|
|
||||||
.ok_or(TestError::AssertionError {
|
|
||||||
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
|
||||||
})?
|
|
||||||
.to_owned()),
|
|
||||||
_ => {
|
|
||||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
|
||||||
Err(format!("unexpected response: {:?}", msg).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_query(
|
|
||||||
socket: &mut WsStream,
|
|
||||||
query: &str,
|
|
||||||
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
|
||||||
let json = json!({
|
|
||||||
"id": "1",
|
|
||||||
"method": "query",
|
|
||||||
"params": [query],
|
|
||||||
});
|
|
||||||
|
|
||||||
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
|
||||||
let msg = ws_recv_msg(socket).await?;
|
|
||||||
debug!("ws_query result json={json:?} msg={msg:?}");
|
|
||||||
|
|
||||||
match msg.as_object() {
|
|
||||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
|
||||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
|
||||||
}
|
|
||||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
|
||||||
.get("result")
|
|
||||||
.ok_or(TestError::AssertionError {
|
|
||||||
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
|
||||||
})?
|
|
||||||
.as_array()
|
|
||||||
.ok_or(TestError::AssertionError {
|
|
||||||
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
|
||||||
})?
|
|
||||||
.to_owned()),
|
|
||||||
_ => {
|
|
||||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
|
||||||
Err(format!("unexpected response: {:?}", msg).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ws_use(
|
|
||||||
socket: &mut WsStream,
|
|
||||||
ns: Option<&str>,
|
|
||||||
db: Option<&str>,
|
|
||||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
|
||||||
let json = json!({
|
|
||||||
"id": "1",
|
|
||||||
"method": "use",
|
|
||||||
"params": [
|
|
||||||
ns, db
|
|
||||||
],
|
|
||||||
});
|
|
||||||
|
|
||||||
ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
|
||||||
let msg = ws_recv_msg(socket).await?;
|
|
||||||
debug!("ws_query result json={json:?} msg={msg:?}");
|
|
||||||
|
|
||||||
match msg.as_object() {
|
|
||||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
|
||||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
|
||||||
}
|
|
||||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
|
||||||
.get("result")
|
|
||||||
.ok_or(TestError::AssertionError {
|
|
||||||
message: format!(
|
|
||||||
"expected a result from the received object, got this instead: {:?}",
|
|
||||||
obj
|
|
||||||
),
|
|
||||||
})?
|
|
||||||
.to_owned()),
|
|
||||||
_ => {
|
|
||||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
|
||||||
Err(format!("unexpected response: {:?}", msg).into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the given message is a successful notification from LQ.
|
/// Check if the given message is a successful notification from LQ.
|
||||||
pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool {
|
pub fn is_notification(msg: &serde_json::Value) -> bool {
|
||||||
// Example of LQ notification:
|
// Example of LQ notification:
|
||||||
//
|
//
|
||||||
// Object {"result": Object {"action": String("CREATE"), "id": String("04460f07-b0e1-4339-92db-049a94aeec10"), "result": Object {"id": String("table_FD40A9A361884C56B5908A934164884A:⟨an-id-goes-here⟩"), "name": String("ok")}}}
|
// Object {"result": Object {"action": String("CREATE"), "id": String("04460f07-b0e1-4339-92db-049a94aeec10"), "result": Object {"id": String("table_FD40A9A361884C56B5908A934164884A:⟨an-id-goes-here⟩"), "name": String("ok")}}}
|
||||||
|
@ -497,7 +25,7 @@ pub fn ws_msg_is_notification(msg: &serde_json::Value) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if the given message is a notification from LQ and comes from the given LQ ID.
|
/// Check if the given message is a notification from LQ and comes from the given LQ ID.
|
||||||
pub fn ws_msg_is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool {
|
pub fn is_notification_from_lq(msg: &serde_json::Value, id: &str) -> bool {
|
||||||
ws_msg_is_notification(msg)
|
is_notification(msg)
|
||||||
&& msg["result"].as_object().unwrap().get("id").unwrap().as_str() == Some(id)
|
&& msg["result"].as_object().unwrap().get("id").unwrap().as_str() == Some(id)
|
||||||
}
|
}
|
||||||
|
|
242
tests/common/server.rs
Normal file
242
tests/common/server.rs
Normal file
|
@ -0,0 +1,242 @@
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::process::{Command, Stdio};
|
||||||
|
use std::{env, fs};
|
||||||
|
use tokio::time;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
|
pub const USER: &str = "root";
|
||||||
|
pub const PASS: &str = "root";
|
||||||
|
pub const NS: &str = "testns";
|
||||||
|
pub const DB: &str = "testdb";
|
||||||
|
|
||||||
|
/// Child is a (maybe running) CLI process. It can be killed by dropping it
|
||||||
|
pub struct Child {
|
||||||
|
inner: Option<std::process::Child>,
|
||||||
|
stdout_path: String,
|
||||||
|
stderr_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Child {
|
||||||
|
/// Send some thing to the child's stdin
|
||||||
|
pub fn input(mut self, input: &str) -> Self {
|
||||||
|
let stdin = self.inner.as_mut().unwrap().stdin.as_mut().unwrap();
|
||||||
|
use std::io::Write;
|
||||||
|
stdin.write_all(input.as_bytes()).unwrap();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn kill(mut self) -> Self {
|
||||||
|
self.inner.as_mut().unwrap().kill().unwrap();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
|
||||||
|
nix::sys::signal::kill(
|
||||||
|
nix::unistd::Pid::from_raw(self.inner.as_ref().unwrap().id() as i32),
|
||||||
|
signal,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn status(&mut self) -> std::io::Result<Option<std::process::ExitStatus>> {
|
||||||
|
self.inner.as_mut().unwrap().try_wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stdout(&self) -> String {
|
||||||
|
std::fs::read_to_string(&self.stdout_path).expect("Failed to read the stdout file")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stderr(&self) -> String {
|
||||||
|
std::fs::read_to_string(&self.stderr_path).expect("Failed to read the stderr file")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
|
||||||
|
/// returns successfully, Err otherwise.
|
||||||
|
pub fn output(mut self) -> Result<String, String> {
|
||||||
|
let status = self.inner.take().unwrap().wait().unwrap();
|
||||||
|
|
||||||
|
let mut buf = self.stdout();
|
||||||
|
buf.push_str(&self.stderr());
|
||||||
|
|
||||||
|
// Cleanup files after reading them
|
||||||
|
std::fs::remove_file(self.stdout_path.as_str()).unwrap();
|
||||||
|
std::fs::remove_file(self.stderr_path.as_str()).unwrap();
|
||||||
|
|
||||||
|
if status.success() {
|
||||||
|
Ok(buf)
|
||||||
|
} else {
|
||||||
|
Err(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for Child {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(inner) = self.inner.as_mut() {
|
||||||
|
let _ = inner.kill();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
||||||
|
let mut path = std::env::current_exe().unwrap();
|
||||||
|
assert!(path.pop());
|
||||||
|
if path.ends_with("deps") {
|
||||||
|
assert!(path.pop());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Cargo automatically builds this binary for integration tests.
|
||||||
|
path.push(format!("{}{}", env!("CARGO_PKG_NAME"), std::env::consts::EXE_SUFFIX));
|
||||||
|
|
||||||
|
let mut cmd = Command::new(path);
|
||||||
|
if let Some(dir) = current_dir {
|
||||||
|
cmd.current_dir(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use local files instead of pipes to avoid deadlocks. See https://github.com/rust-lang/rust/issues/45572
|
||||||
|
let stdout_path = tmp_file("server-stdout.log");
|
||||||
|
let stderr_path = tmp_file("server-stderr.log");
|
||||||
|
debug!("Redirecting output. args=`{args}` stdout={stdout_path} stderr={stderr_path})");
|
||||||
|
let stdout = Stdio::from(File::create(&stdout_path).unwrap());
|
||||||
|
let stderr = Stdio::from(File::create(&stderr_path).unwrap());
|
||||||
|
|
||||||
|
cmd.env_clear();
|
||||||
|
cmd.stdin(Stdio::piped());
|
||||||
|
cmd.stdout(stdout);
|
||||||
|
cmd.stderr(stderr);
|
||||||
|
cmd.args(args.split_ascii_whitespace());
|
||||||
|
|
||||||
|
Child {
|
||||||
|
inner: Some(cmd.spawn().unwrap()),
|
||||||
|
stdout_path,
|
||||||
|
stderr_path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the CLI with the given args
|
||||||
|
pub fn run(args: &str) -> Child {
|
||||||
|
run_internal::<String>(args, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the CLI with the given args inside a temporary directory
|
||||||
|
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
||||||
|
run_internal(args, Some(current_dir))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tmp_file(name: &str) -> String {
|
||||||
|
let path = Path::new(env!("OUT_DIR")).join(format!("{}-{}", rand::random::<u32>(), name));
|
||||||
|
path.to_string_lossy().into_owned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct StartServerArguments {
|
||||||
|
pub auth: bool,
|
||||||
|
pub tls: bool,
|
||||||
|
pub wait_is_ready: bool,
|
||||||
|
pub enable_auth_level: bool,
|
||||||
|
pub tick_interval: time::Duration,
|
||||||
|
pub args: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for StartServerArguments {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
auth: true,
|
||||||
|
tls: false,
|
||||||
|
wait_is_ready: true,
|
||||||
|
enable_auth_level: false,
|
||||||
|
tick_interval: time::Duration::new(1, 0),
|
||||||
|
args: "--allow-all".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_server_without_auth() -> Result<(String, Child), Box<dyn Error>> {
|
||||||
|
start_server(StartServerArguments {
|
||||||
|
auth: false,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_server_with_auth_level() -> Result<(String, Child), Box<dyn Error>> {
|
||||||
|
start_server(StartServerArguments {
|
||||||
|
enable_auth_level: true,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_server_with_defaults() -> Result<(String, Child), Box<dyn Error>> {
|
||||||
|
start_server(StartServerArguments::default()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_server(
|
||||||
|
StartServerArguments {
|
||||||
|
auth,
|
||||||
|
tls,
|
||||||
|
wait_is_ready,
|
||||||
|
enable_auth_level,
|
||||||
|
tick_interval,
|
||||||
|
args,
|
||||||
|
}: StartServerArguments,
|
||||||
|
) -> Result<(String, Child), Box<dyn Error>> {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
|
||||||
|
let port: u16 = rng.gen_range(13000..14000);
|
||||||
|
let addr = format!("127.0.0.1:{port}");
|
||||||
|
|
||||||
|
let mut extra_args = args.clone();
|
||||||
|
if tls {
|
||||||
|
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||||
|
let crt_path = tmp_file("crt.crt");
|
||||||
|
let key_path = tmp_file("key.pem");
|
||||||
|
|
||||||
|
let cert = rcgen::generate_simple_self_signed(Vec::new()).unwrap();
|
||||||
|
fs::write(&crt_path, cert.serialize_pem().unwrap()).unwrap();
|
||||||
|
fs::write(&key_path, cert.serialize_private_key_pem().into_bytes()).unwrap();
|
||||||
|
|
||||||
|
extra_args.push_str(format!(" --web-crt {crt_path} --web-key {key_path}").as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth {
|
||||||
|
extra_args.push_str(" --auth");
|
||||||
|
}
|
||||||
|
|
||||||
|
if enable_auth_level {
|
||||||
|
extra_args.push_str(" --auth-level-enabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tick_interval.is_zero() {
|
||||||
|
let sec = tick_interval.as_secs();
|
||||||
|
extra_args.push_str(format!(" --tick-interval {sec}s").as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
|
||||||
|
|
||||||
|
info!("starting server with args: {start_args}");
|
||||||
|
|
||||||
|
// Configure where the logs go when running the test
|
||||||
|
let server = run_internal::<String>(&start_args, None);
|
||||||
|
|
||||||
|
if !wait_is_ready {
|
||||||
|
return Ok((addr, server));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait 5 seconds for the server to start
|
||||||
|
let mut interval = time::interval(time::Duration::from_millis(1000));
|
||||||
|
info!("Waiting for server to start...");
|
||||||
|
for _i in 0..10 {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
||||||
|
info!("Server ready!");
|
||||||
|
return Ok((addr, server));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let server_out = server.kill().output().err().unwrap();
|
||||||
|
error!("server output: {server_out}");
|
||||||
|
Err("server failed to start".into())
|
||||||
|
}
|
288
tests/common/socket.rs
Normal file
288
tests/common/socket.rs
Normal file
|
@ -0,0 +1,288 @@
|
||||||
|
use super::format::Format;
|
||||||
|
use crate::common::error::TestError;
|
||||||
|
use futures_util::{SinkExt, TryStreamExt};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::time::Duration;
|
||||||
|
use surrealdb::sql::Value;
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::time;
|
||||||
|
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||||
|
use tokio_tungstenite::tungstenite::Message;
|
||||||
|
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||||
|
use tracing::{debug, error};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct UseParams<'a> {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
ns: Option<&'a str>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
db: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct SigninParams<'a> {
|
||||||
|
user: &'a str,
|
||||||
|
pass: &'a str,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
ns: Option<&'a str>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
db: Option<&'a str>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
sc: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Socket {
|
||||||
|
pub stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub struct Socket(pub WebSocketStream<MaybeTlsStream<TcpStream>>);
|
||||||
|
|
||||||
|
impl Socket {
|
||||||
|
/// Close the connection with the WebSocket server
|
||||||
|
pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
|
||||||
|
Ok(self.stream.close(None).await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect to a WebSocket server using a specific format
|
||||||
|
pub async fn connect(addr: &str, format: Option<Format>) -> Result<Self, Box<dyn Error>> {
|
||||||
|
let url = format!("ws://{}/rpc", addr);
|
||||||
|
let mut req = url.into_client_request().unwrap();
|
||||||
|
if let Some(v) = format.map(|v| v.to_string()) {
|
||||||
|
req.headers_mut().insert("Sec-WebSocket-Protocol", v.parse().unwrap());
|
||||||
|
}
|
||||||
|
let (stream, _) = connect_async(req).await?;
|
||||||
|
Ok(Self {
|
||||||
|
stream,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a text or binary message to the WebSocket server
|
||||||
|
pub async fn send_message(
|
||||||
|
&mut self,
|
||||||
|
format: Format,
|
||||||
|
message: serde_json::Value,
|
||||||
|
) -> Result<(), Box<dyn Error>> {
|
||||||
|
let now = time::Instant::now();
|
||||||
|
debug!("Sending message: {message}");
|
||||||
|
// Format the message
|
||||||
|
let msg = match format {
|
||||||
|
Format::Json => Message::Text(serde_json::to_string(&message)?),
|
||||||
|
};
|
||||||
|
// Send the message
|
||||||
|
tokio::select! {
|
||||||
|
_ = time::sleep(time::Duration::from_millis(500)) => {
|
||||||
|
return Err("timeout after 500ms waiting for the request to be sent".into());
|
||||||
|
}
|
||||||
|
res = self.stream.send(msg) => {
|
||||||
|
debug!("Message sent in {:?}", now.elapsed());
|
||||||
|
if let Err(err) = res {
|
||||||
|
return Err(format!("Error sending the message: {}", err).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Receive a text or binary message from the WebSocket server
|
||||||
|
pub async fn receive_message(
|
||||||
|
&mut self,
|
||||||
|
format: Format,
|
||||||
|
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||||
|
let now = time::Instant::now();
|
||||||
|
debug!("Receiving response...");
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = time::sleep(time::Duration::from_millis(5000)) => {
|
||||||
|
return Err(Box::new(TestError::NetworkError {message: "timeout after 5s waiting for the response".to_string()}))
|
||||||
|
}
|
||||||
|
res = self.stream.try_next() => {
|
||||||
|
match res {
|
||||||
|
Ok(res) => match res {
|
||||||
|
Some(Message::Text(msg)) => {
|
||||||
|
debug!("Response {msg:?} received in {:?}", now.elapsed());
|
||||||
|
match format {
|
||||||
|
Format::Json => {
|
||||||
|
let msg = serde_json::from_str(&msg)?;
|
||||||
|
debug!("Received message: {msg}");
|
||||||
|
return Ok(msg);
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Some(_) => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
return Err("Expected to receive a message".to_string().into());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
return Err(format!("Error receiving the message: {}", err).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a text or binary message and receive a reponse from the WebSocket server
|
||||||
|
pub async fn send_and_receive_message(
|
||||||
|
&mut self,
|
||||||
|
format: Format,
|
||||||
|
message: serde_json::Value,
|
||||||
|
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||||
|
self.send_message(format, message).await?;
|
||||||
|
self.receive_message(format).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// When testing Live Queries, we may receive multiple messages unordered.
|
||||||
|
/// This method captures all the expected messages before the given timeout. The result can be inspected later on to find the desired message.
|
||||||
|
pub async fn receive_all_messages(
|
||||||
|
&mut self,
|
||||||
|
format: Format,
|
||||||
|
expected: usize,
|
||||||
|
timeout: Duration,
|
||||||
|
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||||
|
let mut res = Vec::new();
|
||||||
|
let deadline = time::Instant::now() + timeout;
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = time::sleep_until(deadline) => {
|
||||||
|
debug!("Waited for {:?} and received {} messages", timeout, res.len());
|
||||||
|
if res.len() != expected {
|
||||||
|
return Err(format!("Expected {} messages but got {} after {:?}: {:?}", expected, res.len(), timeout, res).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg = self.receive_message(format) => {
|
||||||
|
res.push(msg?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if res.len() == expected {
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a USE message to the server and check the response
|
||||||
|
pub async fn send_message_use(
|
||||||
|
&mut self,
|
||||||
|
format: Format,
|
||||||
|
ns: Option<&str>,
|
||||||
|
db: Option<&str>,
|
||||||
|
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||||
|
// Generate an ID
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
// Construct message
|
||||||
|
let msg = json!({
|
||||||
|
"id": id,
|
||||||
|
"method": "use",
|
||||||
|
"params": [
|
||||||
|
ns, db
|
||||||
|
],
|
||||||
|
});
|
||||||
|
// Send message and receive response
|
||||||
|
let msg = self.send_and_receive_message(format, msg).await?;
|
||||||
|
// Check response message structure
|
||||||
|
match msg.as_object() {
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||||
|
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||||
|
}
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||||
|
.get("result")
|
||||||
|
.ok_or(TestError::AssertionError {
|
||||||
|
message: format!(
|
||||||
|
"expected a result from the received object, got this instead: {:?}",
|
||||||
|
obj
|
||||||
|
),
|
||||||
|
})?
|
||||||
|
.to_owned()),
|
||||||
|
_ => {
|
||||||
|
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||||
|
Err(format!("unexpected response: {:?}", msg).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a generic query message to the server and check the response
|
||||||
|
pub async fn send_message_query(
|
||||||
|
&mut self,
|
||||||
|
format: Format,
|
||||||
|
query: &str,
|
||||||
|
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||||
|
// Generate an ID
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
// Construct message
|
||||||
|
let msg = json!({
|
||||||
|
"id": id,
|
||||||
|
"method": "query",
|
||||||
|
"params": [query],
|
||||||
|
});
|
||||||
|
// Send message and receive response
|
||||||
|
let msg = self.send_and_receive_message(format, msg).await?;
|
||||||
|
// Check response message structure
|
||||||
|
match msg.as_object() {
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||||
|
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||||
|
}
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||||
|
.get("result")
|
||||||
|
.ok_or(TestError::AssertionError {
|
||||||
|
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
||||||
|
})?
|
||||||
|
.as_array()
|
||||||
|
.ok_or(TestError::AssertionError {
|
||||||
|
message: format!("expected the result object to be an array for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
||||||
|
})?
|
||||||
|
.to_owned()),
|
||||||
|
_ => {
|
||||||
|
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||||
|
Err(format!("unexpected response: {:?}", msg).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a signin authentication query message to the server and check the response
|
||||||
|
pub async fn send_message_signin(
|
||||||
|
&mut self,
|
||||||
|
format: Format,
|
||||||
|
user: &str,
|
||||||
|
pass: &str,
|
||||||
|
ns: Option<&str>,
|
||||||
|
db: Option<&str>,
|
||||||
|
sc: Option<&str>,
|
||||||
|
) -> Result<String, Box<dyn Error>> {
|
||||||
|
// Generate an ID
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
// Construct message
|
||||||
|
let msg = json!({
|
||||||
|
"id": id,
|
||||||
|
"method": "signin",
|
||||||
|
"params": [
|
||||||
|
SigninParams { user, pass, ns, db, sc }
|
||||||
|
],
|
||||||
|
});
|
||||||
|
// Send message and receive response
|
||||||
|
let msg = self.send_and_receive_message(format, msg).await?;
|
||||||
|
// Check response message structure
|
||||||
|
match msg.as_object() {
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||||
|
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||||
|
}
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => Ok(obj
|
||||||
|
.get("result")
|
||||||
|
.ok_or(TestError::AssertionError {
|
||||||
|
message: format!("expected a result from the received object, got this instead: {:?}", obj),
|
||||||
|
})?
|
||||||
|
.as_str()
|
||||||
|
.ok_or(TestError::AssertionError {
|
||||||
|
message: format!("expected the result object to be a string for the received ws message, got this instead: {:?}", obj.get("result")).to_string(),
|
||||||
|
})?
|
||||||
|
.to_owned()),
|
||||||
|
_ => {
|
||||||
|
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||||
|
Err(format!("unexpected response: {:?}", msg).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
1177
tests/common/tests.rs
Normal file
1177
tests/common/tests.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1258,8 +1258,8 @@ mod http_integration {
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
assert_eq!(res.status(), 200);
|
assert_eq!(res.status(), 200);
|
||||||
|
let res = res.bytes().await?.to_vec();
|
||||||
let _: serde_cbor::Value = serde_cbor::from_slice(&res.bytes().await?).unwrap();
|
let _: ciborium::Value = ciborium::from_reader(res.as_slice()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creating a record with Accept PACK encoding is allowed
|
// Creating a record with Accept PACK encoding is allowed
|
||||||
|
@ -1272,8 +1272,8 @@ mod http_integration {
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
assert_eq!(res.status(), 200);
|
assert_eq!(res.status(), 200);
|
||||||
|
let res = res.bytes().await?.to_vec();
|
||||||
let _: serde_cbor::Value = serde_pack::from_slice(&res.bytes().await?).unwrap();
|
let _: rmpv::Value = rmpv::decode::read_value(&mut res.as_slice()).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creating a record with Accept Surrealdb encoding is allowed
|
// Creating a record with Accept Surrealdb encoding is allowed
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue