[rpc] Better tracing for WebSockets (#2325)
This commit is contained in:
parent
ab72923fb5
commit
e91011cc78
20 changed files with 2617 additions and 650 deletions
26
.github/workflows/ci.yml
vendored
26
.github/workflows/ci.yml
vendored
|
@ -133,7 +133,7 @@ jobs:
|
||||||
args: ci-clippy
|
args: ci-clippy
|
||||||
|
|
||||||
cli:
|
cli:
|
||||||
name: Test command line
|
name: CLI integration tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ jobs:
|
||||||
args: ci-cli-integration
|
args: ci-cli-integration
|
||||||
|
|
||||||
http-server:
|
http-server:
|
||||||
name: Test HTTP server
|
name: HTTP integration tests
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
|
@ -192,6 +192,28 @@ jobs:
|
||||||
command: make
|
command: make
|
||||||
args: ci-http-integration
|
args: ci-http-integration
|
||||||
|
|
||||||
|
ws-server:
|
||||||
|
name: WebSocket integration tests
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
|
||||||
|
- name: Install stable toolchain
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
|
||||||
|
- name: Checkout sources
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup cache
|
||||||
|
uses: Swatinem/rust-cache@v2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get -y update
|
||||||
|
sudo apt-get -y install protobuf-compiler libprotobuf-dev
|
||||||
|
|
||||||
|
- name: Run cargo test
|
||||||
|
run: cargo test --locked --no-default-features --features storage-mem --workspace --test ws_integration
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test workspace
|
name: Test workspace
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
782
Cargo.lock
generated
782
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -67,6 +67,7 @@ tokio-util = { version = "0.7.8", features = ["io"] }
|
||||||
tower = "0.4.13"
|
tower = "0.4.13"
|
||||||
tower-http = { version = "0.4.2", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] }
|
tower-http = { version = "0.4.2", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
|
tracing-futures = { version = "0.2.5", features = ["tokio"], default-features = false }
|
||||||
tracing-opentelemetry = "0.19.0"
|
tracing-opentelemetry = "0.19.0"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
||||||
urlencoding = "2.1.2"
|
urlencoding = "2.1.2"
|
||||||
|
@ -77,11 +78,14 @@ nix = "0.26.2"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
assert_fs = "1.0.13"
|
assert_fs = "1.0.13"
|
||||||
|
env_logger = "0.10.0"
|
||||||
opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] }
|
opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] }
|
||||||
rcgen = "0.10.0"
|
rcgen = "0.10.0"
|
||||||
serial_test = "2.0.0"
|
serial_test = "2.0.0"
|
||||||
temp-env = "0.3.4"
|
temp-env = { version = "0.3.4", features = ["async_closure"] }
|
||||||
|
test-log = { version = "0.2.12", features = ["trace"] }
|
||||||
tokio-stream = { version = "0.1", features = ["net"] }
|
tokio-stream = { version = "0.1", features = ["net"] }
|
||||||
|
tokio-tungstenite = { version = "0.18.0" }
|
||||||
tonic = "0.8.3"
|
tonic = "0.8.3"
|
||||||
|
|
||||||
[package.metadata.deb]
|
[package.metadata.deb]
|
||||||
|
|
|
@ -115,7 +115,7 @@ pub async fn init(
|
||||||
listen_addresses,
|
listen_addresses,
|
||||||
dbs,
|
dbs,
|
||||||
web,
|
web,
|
||||||
log: CustomEnvFilter(log),
|
log,
|
||||||
tick_interval,
|
tick_interval,
|
||||||
no_banner,
|
no_banner,
|
||||||
..
|
..
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
use clap::builder::{NonEmptyStringValueParser, PossibleValue, TypedValueParser};
|
use clap::builder::{NonEmptyStringValueParser, PossibleValue, TypedValueParser};
|
||||||
use clap::error::{ContextKind, ContextValue, ErrorKind};
|
use clap::error::{ContextKind, ContextValue, ErrorKind};
|
||||||
use tracing::Level;
|
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
use crate::telemetry::filter_from_value;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct CustomEnvFilter(pub EnvFilter);
|
pub struct CustomEnvFilter(pub EnvFilter);
|
||||||
|
|
||||||
|
@ -37,20 +38,7 @@ impl TypedValueParser for CustomEnvFilterParser {
|
||||||
|
|
||||||
let inner = NonEmptyStringValueParser::new();
|
let inner = NonEmptyStringValueParser::new();
|
||||||
let v = inner.parse_ref(cmd, arg, value)?;
|
let v = inner.parse_ref(cmd, arg, value)?;
|
||||||
let filter = (match v.as_str() {
|
let filter = filter_from_value(v.as_str()).map_err(|e| {
|
||||||
// Don't show any logs at all
|
|
||||||
"none" => Ok(EnvFilter::default()),
|
|
||||||
// Check if we should show all log levels
|
|
||||||
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
|
|
||||||
// Otherwise, let's only show errors
|
|
||||||
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
|
|
||||||
// Specify the log level for each code area
|
|
||||||
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
|
|
||||||
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::txn=error")),
|
|
||||||
// Let's try to parse the custom log level
|
|
||||||
_ => EnvFilter::builder().parse(v),
|
|
||||||
})
|
|
||||||
.map_err(|e| {
|
|
||||||
let mut err = clap::Error::new(ErrorKind::ValueValidation).with_cmd(cmd);
|
let mut err = clap::Error::new(ErrorKind::ValueValidation).with_cmd(cmd);
|
||||||
err.insert(ContextKind::Custom, ContextValue::String(e.to_string()));
|
err.insert(ContextKind::Custom, ContextValue::String(e.to_string()));
|
||||||
err.insert(
|
err.insert(
|
||||||
|
|
|
@ -131,7 +131,7 @@ pub async fn init() -> Result<(), Error> {
|
||||||
|
|
||||||
// Setup the graceful shutdown with no timeout
|
// Setup the graceful shutdown with no timeout
|
||||||
let handle = Handle::new();
|
let handle = Handle::new();
|
||||||
graceful_shutdown(handle.clone(), None);
|
let shutdown_handler = graceful_shutdown(handle.clone());
|
||||||
|
|
||||||
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
|
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
|
||||||
// configure certificate and private key used by https
|
// configure certificate and private key used by https
|
||||||
|
@ -156,6 +156,12 @@ pub async fn init() -> Result<(), Error> {
|
||||||
.await?;
|
.await?;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Wait for the shutdown to finish
|
||||||
|
let _ = shutdown_handler.await;
|
||||||
|
|
||||||
|
// Flush all telemetry data
|
||||||
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
|
|
||||||
info!(target: LOG, "Web server stopped. Bye!");
|
info!(target: LOG, "Web server stopped. Bye!");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
585
src/net/rpc.rs
585
src/net/rpc.rs
|
@ -7,26 +7,36 @@ use crate::err::Error;
|
||||||
use crate::rpc::args::Take;
|
use crate::rpc::args::Take;
|
||||||
use crate::rpc::paths::{ID, METHOD, PARAMS};
|
use crate::rpc::paths::{ID, METHOD, PARAMS};
|
||||||
use crate::rpc::res;
|
use crate::rpc::res;
|
||||||
|
use crate::rpc::res::Data;
|
||||||
use crate::rpc::res::Failure;
|
use crate::rpc::res::Failure;
|
||||||
use crate::rpc::res::Output;
|
use crate::rpc::res::IntoRpcResponse;
|
||||||
|
use crate::rpc::res::OutputFormat;
|
||||||
|
use crate::rpc::CONN_CLOSED_ERR;
|
||||||
|
use crate::telemetry::traces::rpc::span_for_request;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::Extension;
|
use axum::Extension;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use futures_util::stream::SplitSink;
|
||||||
|
use futures_util::stream::SplitStream;
|
||||||
use http_body::Body as HttpBody;
|
use http_body::Body as HttpBody;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use surrealdb::channel;
|
use surrealdb::channel;
|
||||||
use surrealdb::channel::Sender;
|
use surrealdb::channel::{Receiver, Sender};
|
||||||
use surrealdb::dbs::{QueryType, Response, Session};
|
use surrealdb::dbs::{QueryType, Response, Session};
|
||||||
|
use surrealdb::sql::serde::deserialize;
|
||||||
use surrealdb::sql::Array;
|
use surrealdb::sql::Array;
|
||||||
use surrealdb::sql::Object;
|
use surrealdb::sql::Object;
|
||||||
use surrealdb::sql::Strand;
|
use surrealdb::sql::Strand;
|
||||||
use surrealdb::sql::Value;
|
use surrealdb::sql::Value;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing::instrument;
|
use tokio::task::JoinSet;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tower_http::request_id::RequestId;
|
||||||
|
use tracing::Span;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
|
@ -35,11 +45,12 @@ use axum::{
|
||||||
};
|
};
|
||||||
|
|
||||||
// Mapping of WebSocketID to WebSocket
|
// Mapping of WebSocketID to WebSocket
|
||||||
type WebSockets = RwLock<HashMap<Uuid, Sender<Message>>>;
|
pub(crate) struct WebSocketRef(pub(crate) Sender<Message>, pub(crate) CancellationToken);
|
||||||
|
type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>;
|
||||||
// Mapping of LiveQueryID to WebSocketID
|
// Mapping of LiveQueryID to WebSocketID
|
||||||
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
|
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
|
||||||
|
|
||||||
static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
pub(super) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
||||||
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
|
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
|
||||||
|
|
||||||
pub(super) fn router<S, B>() -> Router<S, B>
|
pub(super) fn router<S, B>() -> Router<S, B>
|
||||||
|
@ -50,22 +61,36 @@ where
|
||||||
Router::new().route("/rpc", get(handler))
|
Router::new().route("/rpc", get(handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handler(ws: WebSocketUpgrade, Extension(sess): Extension<Session>) -> impl IntoResponse {
|
async fn handler(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
Extension(sess): Extension<Session>,
|
||||||
|
Extension(req_id): Extension<RequestId>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
// finalize the upgrade process by returning upgrade callback.
|
// finalize the upgrade process by returning upgrade callback.
|
||||||
// we can customize the callback by sending additional info such as address.
|
// we can customize the callback by sending additional info such as address.
|
||||||
ws.on_upgrade(move |socket| handle_socket(socket, sess))
|
ws.on_upgrade(move |socket| handle_socket(socket, sess, req_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(ws: WebSocket, sess: Session) {
|
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
|
||||||
let rpc = Rpc::new(sess);
|
let rpc = Rpc::new(sess);
|
||||||
Rpc::serve(rpc, ws).await
|
|
||||||
|
// If the request ID is a valid UUID and is not already in use, use it as the WebSocket ID
|
||||||
|
match req_id.header_value().to_str().map(Uuid::parse_str) {
|
||||||
|
Ok(Ok(req_id)) if !WEBSOCKETS.read().await.contains_key(&req_id) => {
|
||||||
|
rpc.write().await.ws_id = req_id
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
Rpc::serve(rpc, ws).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Rpc {
|
pub struct Rpc {
|
||||||
session: Session,
|
session: Session,
|
||||||
format: Output,
|
format: OutputFormat,
|
||||||
uuid: Uuid,
|
ws_id: Uuid,
|
||||||
vars: BTreeMap<String, Value>,
|
vars: BTreeMap<String, Value>,
|
||||||
|
graceful_shutdown: CancellationToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Rpc {
|
impl Rpc {
|
||||||
|
@ -74,158 +99,247 @@ impl Rpc {
|
||||||
// Create a new RPC variables store
|
// Create a new RPC variables store
|
||||||
let vars = BTreeMap::new();
|
let vars = BTreeMap::new();
|
||||||
// Set the default output format
|
// Set the default output format
|
||||||
let format = Output::Json;
|
let format = OutputFormat::Json;
|
||||||
// Create a unique WebSocket id
|
// Enable real-time mode
|
||||||
let uuid = Uuid::new_v4();
|
|
||||||
// Enable real-time live queries
|
|
||||||
session.rt = true;
|
session.rt = true;
|
||||||
// Create and store the Rpc connection
|
// Create and store the Rpc connection
|
||||||
Arc::new(RwLock::new(Rpc {
|
Arc::new(RwLock::new(Rpc {
|
||||||
session,
|
session,
|
||||||
format,
|
format,
|
||||||
uuid,
|
ws_id: Uuid::new_v4(),
|
||||||
vars,
|
vars,
|
||||||
|
graceful_shutdown: CancellationToken::new(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Serve the RPC endpoint
|
/// Serve the RPC endpoint
|
||||||
pub async fn serve(rpc: Arc<RwLock<Rpc>>, ws: WebSocket) {
|
pub async fn serve(rpc: Arc<RwLock<Rpc>>, ws: WebSocket) {
|
||||||
// Create a channel for sending messages
|
|
||||||
let (chn, mut rcv) = channel::new(MAX_CONCURRENT_CALLS);
|
|
||||||
// Split the socket into send and recv
|
// Split the socket into send and recv
|
||||||
let (mut wtx, mut wrx) = ws.split();
|
let (sender, receiver) = ws.split();
|
||||||
// Clone the channel for sending pings
|
// Create an internal channel between the receiver and the sender
|
||||||
let png = chn.clone();
|
let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS);
|
||||||
// The WebSocket has connected
|
|
||||||
Rpc::connected(rpc.clone(), chn.clone()).await;
|
let ws_id = rpc.read().await.ws_id;
|
||||||
// Send Ping messages to the client
|
|
||||||
tokio::task::spawn(async move {
|
|
||||||
// Create the interval ticker
|
|
||||||
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
|
|
||||||
// Loop indefinitely
|
|
||||||
loop {
|
|
||||||
// Wait for the timer
|
|
||||||
interval.tick().await;
|
|
||||||
// Create the ping message
|
|
||||||
let msg = Message::Ping(vec![]);
|
|
||||||
// Send the message to the client
|
|
||||||
if png.send(msg).await.is_err() {
|
|
||||||
// Exit out of the loop
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// Send messages to the client
|
|
||||||
tokio::task::spawn(async move {
|
|
||||||
// Wait for the next message to send
|
|
||||||
while let Some(res) = rcv.next().await {
|
|
||||||
// Send the message to the client
|
|
||||||
if let Err(err) = wtx.send(res).await {
|
|
||||||
// Output the WebSocket error to the logs
|
|
||||||
trace!("WebSocket error: {:?}", err);
|
|
||||||
// It's already failed, so ignore error
|
|
||||||
let _ = wtx.close().await;
|
|
||||||
// Exit out of the loop
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// Send notifications to the client
|
|
||||||
let moved_rpc = rpc.clone();
|
|
||||||
tokio::task::spawn(async move {
|
|
||||||
let rpc = moved_rpc;
|
|
||||||
if let Some(channel) = DB.get().unwrap().notifications() {
|
|
||||||
while let Ok(notification) = channel.recv().await {
|
|
||||||
// Find which WebSocket the notification belongs to
|
|
||||||
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
|
||||||
// Check to see if the WebSocket exists
|
|
||||||
if let Some(websocket) = WEBSOCKETS.read().await.get(ws_id) {
|
|
||||||
// Serialize the message to send
|
|
||||||
let message = res::success(None, notification);
|
|
||||||
// Get the current output format
|
|
||||||
let format = rpc.read().await.format.clone();
|
|
||||||
// Send the notification to the client
|
|
||||||
message.send(format, websocket.clone()).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// Get messages from the client
|
|
||||||
while let Some(msg) = wrx.next().await {
|
|
||||||
match msg {
|
|
||||||
// We've received a message from the client
|
|
||||||
// Ping is automatically handled by the WebSocket library
|
|
||||||
Ok(msg) => match msg {
|
|
||||||
Message::Text(_) => {
|
|
||||||
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
|
|
||||||
}
|
|
||||||
Message::Binary(_) => {
|
|
||||||
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
|
|
||||||
}
|
|
||||||
Message::Close(_) => {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Message::Pong(_) => {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Ignore everything else
|
|
||||||
}
|
|
||||||
},
|
|
||||||
// There was an error receiving the message
|
|
||||||
Err(err) => {
|
|
||||||
// Output the WebSocket error to the logs
|
|
||||||
trace!("WebSocket error: {:?}", err);
|
|
||||||
// Exit out of the loop
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// The WebSocket has disconnected
|
|
||||||
Rpc::disconnected(rpc.clone()).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn connected(rpc: Arc<RwLock<Rpc>>, chn: Sender<Message>) {
|
|
||||||
// Fetch the unique id of the WebSocket
|
|
||||||
let id = rpc.read().await.uuid;
|
|
||||||
// Log that the WebSocket has connected
|
|
||||||
trace!("WebSocket {} connected", id);
|
|
||||||
// Store this WebSocket in the list of WebSockets
|
// Store this WebSocket in the list of WebSockets
|
||||||
WEBSOCKETS.write().await.insert(id, chn);
|
WEBSOCKETS.write().await.insert(
|
||||||
}
|
ws_id,
|
||||||
|
WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
trace!("WebSocket {} connected", ws_id);
|
||||||
|
|
||||||
|
// Wait until all tasks finish
|
||||||
|
tokio::join!(
|
||||||
|
Self::ping(rpc.clone(), internal_sender.clone()),
|
||||||
|
Self::read(rpc.clone(), receiver, internal_sender.clone()),
|
||||||
|
Self::write(rpc.clone(), sender, internal_receiver.clone()),
|
||||||
|
Self::lq_notifications(rpc.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
async fn disconnected(rpc: Arc<RwLock<Rpc>>) {
|
|
||||||
// Fetch the unique id of the WebSocket
|
|
||||||
let id = rpc.read().await.uuid;
|
|
||||||
// Log that the WebSocket has disconnected
|
|
||||||
trace!("WebSocket {} disconnected", id);
|
|
||||||
// Remove this WebSocket from the list of WebSockets
|
|
||||||
WEBSOCKETS.write().await.remove(&id);
|
|
||||||
// Remove all live queries
|
// Remove all live queries
|
||||||
LIVE_QUERIES.write().await.retain(|key, value| {
|
LIVE_QUERIES.write().await.retain(|key, value| {
|
||||||
if value == &id {
|
if value == &ws_id {
|
||||||
trace!("Removing live query: {}", key);
|
trace!("Removing live query: {}", key);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Remove this WebSocket from the list of WebSockets
|
||||||
|
WEBSOCKETS.write().await.remove(&ws_id);
|
||||||
|
|
||||||
|
trace!("WebSocket {} disconnected", ws_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call RPC methods from the WebSocket
|
/// Send Ping messages to the client
|
||||||
async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
|
async fn ping(rpc: Arc<RwLock<Rpc>>, internal_sender: Sender<Message>) {
|
||||||
|
// Create the interval ticker
|
||||||
|
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
let is_shutdown = cancel_token.cancelled();
|
||||||
|
tokio::select! {
|
||||||
|
_ = interval.tick() => {
|
||||||
|
let msg = Message::Ping(vec![]);
|
||||||
|
|
||||||
|
// Send the message to the client and close the WebSocket connection if it fails
|
||||||
|
if internal_sender.send(msg).await.is_err() {
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ = is_shutdown => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read messages sent from the client
|
||||||
|
async fn read(
|
||||||
|
rpc: Arc<RwLock<Rpc>>,
|
||||||
|
mut receiver: SplitStream<WebSocket>,
|
||||||
|
internal_sender: Sender<Message>,
|
||||||
|
) {
|
||||||
|
// Collect all spawned tasks so we can wait for them at the end
|
||||||
|
let mut tasks = JoinSet::new();
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
let is_shutdown = cancel_token.cancelled();
|
||||||
|
tokio::select! {
|
||||||
|
msg = receiver.next() => {
|
||||||
|
if let Some(msg) = msg {
|
||||||
|
match msg {
|
||||||
|
// We've received a message from the client
|
||||||
|
// Ping/Pong is automatically handled by the WebSocket library
|
||||||
|
Ok(msg) => match msg {
|
||||||
|
Message::Text(_) => {
|
||||||
|
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
||||||
|
}
|
||||||
|
Message::Binary(_) => {
|
||||||
|
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
||||||
|
}
|
||||||
|
Message::Close(_) => {
|
||||||
|
// Respond with a close message
|
||||||
|
if let Err(err) = internal_sender.send(Message::Close(None)).await {
|
||||||
|
trace!("WebSocket error when replying to the Close frame: {:?}", err);
|
||||||
|
};
|
||||||
|
// Start the graceful shutdown of the WebSocket and close the channels
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
let _ = internal_sender.close();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Ignore everything else
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
trace!("WebSocket error: {:?}", err);
|
||||||
|
// Start the graceful shutdown of the WebSocket and close the channels
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
let _ = internal_sender.close();
|
||||||
|
// Exit out of the loop
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = is_shutdown => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all tasks to finish
|
||||||
|
while let Some(res) = tasks.join_next().await {
|
||||||
|
if let Err(err) = res {
|
||||||
|
error!("Error while handling RPC message: {}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write messages to the client
|
||||||
|
async fn write(
|
||||||
|
rpc: Arc<RwLock<Rpc>>,
|
||||||
|
mut sender: SplitSink<WebSocket, Message>,
|
||||||
|
mut internal_receiver: Receiver<Message>,
|
||||||
|
) {
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
let is_shutdown = cancel_token.cancelled();
|
||||||
|
tokio::select! {
|
||||||
|
// Wait for the next message to send
|
||||||
|
msg = internal_receiver.next() => {
|
||||||
|
if let Some(res) = msg {
|
||||||
|
// Send the message to the client
|
||||||
|
if let Err(err) = sender.send(res).await {
|
||||||
|
if err.to_string() != CONN_CLOSED_ERR {
|
||||||
|
debug!("WebSocket error: {:?}", err);
|
||||||
|
}
|
||||||
|
// Close the WebSocket connection
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
// Exit out of the loop
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ = is_shutdown => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send live query notifications to the client
|
||||||
|
async fn lq_notifications(rpc: Arc<RwLock<Rpc>>) {
|
||||||
|
if let Some(channel) = DB.get().unwrap().notifications() {
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
msg = channel.recv() => {
|
||||||
|
if let Ok(notification) = msg {
|
||||||
|
// Find which WebSocket the notification belongs to
|
||||||
|
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||||
|
// Check to see if the WebSocket exists
|
||||||
|
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
|
||||||
|
// Serialize the message to send
|
||||||
|
let message = res::success(None, notification);
|
||||||
|
// Get the current output format
|
||||||
|
let format = rpc.read().await.format.clone();
|
||||||
|
// Send the notification to the client
|
||||||
|
message.send(format, ws.clone()).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ = cancel_token.cancelled() => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle individual WebSocket messages
|
||||||
|
async fn handle_msg(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
|
||||||
// Get the current output format
|
// Get the current output format
|
||||||
let mut out = { rpc.read().await.format.clone() };
|
let mut out_fmt = rpc.read().await.format.clone();
|
||||||
// Clone the RPC
|
let span = span_for_request(&rpc.read().await.ws_id);
|
||||||
let rpc = rpc.clone();
|
let _enter = span.enter();
|
||||||
// Parse the request
|
// Parse the request
|
||||||
|
match Self::parse_request(msg).await {
|
||||||
|
Ok((id, method, params, _out_fmt)) => {
|
||||||
|
span.record(
|
||||||
|
"rpc.jsonrpc.request_id",
|
||||||
|
id.clone().map(|v| v.as_string()).unwrap_or(String::new()),
|
||||||
|
);
|
||||||
|
if let Some(_out_fmt) = _out_fmt {
|
||||||
|
out_fmt = _out_fmt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
let res = Self::process_request(rpc.clone(), &method, params).await;
|
||||||
|
|
||||||
|
// Process the response
|
||||||
|
res.into_response(id).send(out_fmt, chn).await
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
// Process the response
|
||||||
|
res::failure(None, err).send(out_fmt, chn).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn parse_request(
|
||||||
|
msg: Message,
|
||||||
|
) -> Result<(Option<Value>, String, Array, Option<OutputFormat>), Failure> {
|
||||||
|
let mut out_fmt = None;
|
||||||
let req = match msg {
|
let req = match msg {
|
||||||
// This is a binary message
|
// This is a binary message
|
||||||
Message::Binary(val) => {
|
Message::Binary(val) => {
|
||||||
// Use binary output
|
// Use binary output
|
||||||
out = Output::Full;
|
out_fmt = Some(OutputFormat::Full);
|
||||||
// Deserialize the input
|
|
||||||
Value::from(val)
|
match deserialize(&val) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {
|
||||||
|
debug!("Error when trying to deserialize the request");
|
||||||
|
return Err(Failure::PARSE_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// This is a text message
|
// This is a text message
|
||||||
Message::Text(ref val) => {
|
Message::Text(ref val) => {
|
||||||
|
@ -234,14 +348,15 @@ impl Rpc {
|
||||||
// The SurrealQL message parsed ok
|
// The SurrealQL message parsed ok
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
// The SurrealQL message failed to parse
|
// The SurrealQL message failed to parse
|
||||||
_ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await,
|
_ => return Err(Failure::PARSE_ERROR),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Unsupported message type
|
// Unsupported message type
|
||||||
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await,
|
_ => {
|
||||||
|
debug!("Unsupported message type: {:?}", msg);
|
||||||
|
return Err(res::Failure::custom("Unsupported message type"));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// Log the received request
|
|
||||||
trace!("RPC Received: {}", req);
|
|
||||||
// Fetch the 'id' argument
|
// Fetch the 'id' argument
|
||||||
let id = match req.pick(&*ID) {
|
let id = match req.pick(&*ID) {
|
||||||
v if v.is_none() => None,
|
v if v.is_none() => None,
|
||||||
|
@ -250,149 +365,180 @@ impl Rpc {
|
||||||
v if v.is_number() => Some(v),
|
v if v.is_number() => Some(v),
|
||||||
v if v.is_strand() => Some(v),
|
v if v.is_strand() => Some(v),
|
||||||
v if v.is_datetime() => Some(v),
|
v if v.is_datetime() => Some(v),
|
||||||
_ => return res::failure(None, Failure::INVALID_REQUEST).send(out, chn).await,
|
_ => return Err(Failure::INVALID_REQUEST),
|
||||||
};
|
};
|
||||||
// Fetch the 'method' argument
|
// Fetch the 'method' argument
|
||||||
let method = match req.pick(&*METHOD) {
|
let method = match req.pick(&*METHOD) {
|
||||||
Value::Strand(v) => v.to_raw(),
|
Value::Strand(v) => v.to_raw(),
|
||||||
_ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await,
|
_ => return Err(Failure::INVALID_REQUEST),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Now that we know the method, we can update the span
|
||||||
|
Span::current().record("rpc.method", &method);
|
||||||
|
Span::current().record("otel.name", format!("surrealdb.rpc/{}", method));
|
||||||
|
|
||||||
// Fetch the 'params' argument
|
// Fetch the 'params' argument
|
||||||
let params = match req.pick(&*PARAMS) {
|
let params = match req.pick(&*PARAMS) {
|
||||||
Value::Array(v) => v,
|
Value::Array(v) => v,
|
||||||
_ => Array::new(),
|
_ => Array::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Ok((id, method, params, out_fmt))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn process_request(
|
||||||
|
rpc: Arc<RwLock<Rpc>>,
|
||||||
|
method: &str,
|
||||||
|
params: Array,
|
||||||
|
) -> Result<Data, Failure> {
|
||||||
|
info!("Process RPC request");
|
||||||
|
|
||||||
// Match the method to a function
|
// Match the method to a function
|
||||||
let res = match &method[..] {
|
match method {
|
||||||
// Handle a ping message
|
// Handle a surrealdb ping message
|
||||||
"ping" => Ok(Value::None),
|
//
|
||||||
|
// This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough.
|
||||||
|
// For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios.
|
||||||
|
// By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent.
|
||||||
|
//
|
||||||
|
"ping" => Ok(Value::None.into()),
|
||||||
// Retrieve the current auth record
|
// Retrieve the current auth record
|
||||||
"info" => match params.len() {
|
"info" => match params.len() {
|
||||||
0 => rpc.read().await.info().await,
|
0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into),
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Switch to a specific namespace and database
|
// Switch to a specific namespace and database
|
||||||
"use" => match params.needs_two() {
|
"use" => match params.needs_two() {
|
||||||
Ok((ns, db)) => rpc.write().await.yuse(ns, db).await,
|
Ok((ns, db)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Signup to a specific authentication scope
|
// Signup to a specific authentication scope
|
||||||
"signup" => match params.needs_one() {
|
"signup" => match params.needs_one() {
|
||||||
Ok(Value::Object(v)) => rpc.write().await.signup(v).await,
|
Ok(Value::Object(v)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Signin as a root, namespace, database or scope user
|
// Signin as a root, namespace, database or scope user
|
||||||
"signin" => match params.needs_one() {
|
"signin" => match params.needs_one() {
|
||||||
Ok(Value::Object(v)) => rpc.write().await.signin(v).await,
|
Ok(Value::Object(v)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Invalidate the current authentication session
|
// Invalidate the current authentication session
|
||||||
"invalidate" => match params.len() {
|
"invalidate" => match params.len() {
|
||||||
0 => rpc.write().await.invalidate().await,
|
0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into),
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Authenticate using an authentication token
|
// Authenticate using an authentication token
|
||||||
"authenticate" => match params.needs_one() {
|
"authenticate" => match params.needs_one() {
|
||||||
Ok(Value::Strand(v)) => rpc.write().await.authenticate(v).await,
|
Ok(Value::Strand(v)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Kill a live query using a query id
|
// Kill a live query using a query id
|
||||||
"kill" => match params.needs_one() {
|
"kill" => match params.needs_one() {
|
||||||
Ok(v) if v.is_uuid() => rpc.read().await.kill(v).await,
|
Ok(v) if v.is_uuid() => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Setup a live query on a specific table
|
// Setup a live query on a specific table
|
||||||
"live" => match params.needs_one_or_two() {
|
"live" => match params.needs_one_or_two() {
|
||||||
Ok((v, d)) if v.is_table() => rpc.read().await.live(v, d).await,
|
Ok((v, d)) if v.is_table() => {
|
||||||
Ok((v, d)) if v.is_strand() => rpc.read().await.live(v, d).await,
|
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
}
|
||||||
|
Ok((v, d)) if v.is_strand() => {
|
||||||
|
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Specify a connection-wide parameter
|
// Specify a connection-wide parameter
|
||||||
"let" => match params.needs_one_or_two() {
|
"let" | "set" => match params.needs_one_or_two() {
|
||||||
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await,
|
Ok((Value::Strand(s), v)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into)
|
||||||
},
|
}
|
||||||
// Specify a connection-wide parameter
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
"set" => match params.needs_one_or_two() {
|
|
||||||
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await,
|
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
|
||||||
},
|
},
|
||||||
// Unset and clear a connection-wide parameter
|
// Unset and clear a connection-wide parameter
|
||||||
"unset" => match params.needs_one() {
|
"unset" => match params.needs_one() {
|
||||||
Ok(Value::Strand(s)) => rpc.write().await.unset(s).await,
|
Ok(Value::Strand(s)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Select a value or values from the database
|
// Select a value or values from the database
|
||||||
"select" => match params.needs_one() {
|
"select" => match params.needs_one() {
|
||||||
Ok(v) => rpc.read().await.select(v).await,
|
Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into),
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Insert a value or values in the database
|
// Insert a value or values in the database
|
||||||
"insert" => match params.needs_one_or_two() {
|
"insert" => match params.needs_one_or_two() {
|
||||||
Ok((v, o)) => rpc.read().await.insert(v, o).await,
|
Ok((v, o)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Create a value or values in the database
|
// Create a value or values in the database
|
||||||
"create" => match params.needs_one_or_two() {
|
"create" => match params.needs_one_or_two() {
|
||||||
Ok((v, o)) => rpc.read().await.create(v, o).await,
|
Ok((v, o)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Update a value or values in the database using `CONTENT`
|
// Update a value or values in the database using `CONTENT`
|
||||||
"update" => match params.needs_one_or_two() {
|
"update" => match params.needs_one_or_two() {
|
||||||
Ok((v, o)) => rpc.read().await.update(v, o).await,
|
Ok((v, o)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Update a value or values in the database using `MERGE`
|
// Update a value or values in the database using `MERGE`
|
||||||
"change" | "merge" => match params.needs_one_or_two() {
|
"change" | "merge" => match params.needs_one_or_two() {
|
||||||
Ok((v, o)) => rpc.read().await.change(v, o).await,
|
Ok((v, o)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.read().await.change(v, o).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Update a value or values in the database using `PATCH`
|
// Update a value or values in the database using `PATCH`
|
||||||
"modify" | "patch" => match params.needs_one_or_two() {
|
"modify" | "patch" => match params.needs_one_or_two() {
|
||||||
Ok((v, o)) => rpc.read().await.modify(v, o).await,
|
Ok((v, o)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.read().await.modify(v, o).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Delete a value or values from the database
|
// Delete a value or values from the database
|
||||||
"delete" => match params.needs_one() {
|
"delete" => match params.needs_one() {
|
||||||
Ok(v) => rpc.read().await.delete(v).await,
|
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Specify the output format for text requests
|
// Specify the output format for text requests
|
||||||
"format" => match params.needs_one() {
|
"format" => match params.needs_one() {
|
||||||
Ok(Value::Strand(v)) => rpc.write().await.format(v).await,
|
Ok(Value::Strand(v)) => {
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Get the current server version
|
// Get the current server version
|
||||||
"version" => match params.len() {
|
"version" => match params.len() {
|
||||||
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
// Run a full SurrealQL query against the database
|
// Run a full SurrealQL query against the database
|
||||||
"query" => match params.needs_one_or_two() {
|
"query" => match params.needs_one_or_two() {
|
||||||
Ok((Value::Strand(s), o)) if o.is_none_or_null() => {
|
Ok((Value::Strand(s), o)) if o.is_none_or_null() => {
|
||||||
return match rpc.read().await.query(s).await {
|
rpc.read().await.query(s).await.map(Into::into).map_err(Into::into)
|
||||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
|
||||||
Err(e) => {
|
|
||||||
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
Ok((Value::Strand(s), Value::Object(o))) => {
|
Ok((Value::Strand(s), Value::Object(o))) => {
|
||||||
return match rpc.read().await.query_with(s, o).await {
|
rpc.read().await.query_with(s, o).await.map(Into::into).map_err(Into::into)
|
||||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
|
||||||
Err(e) => {
|
|
||||||
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
},
|
},
|
||||||
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await,
|
_ => Err(Failure::METHOD_NOT_FOUND),
|
||||||
};
|
|
||||||
// Return the final response
|
|
||||||
match res {
|
|
||||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
|
||||||
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -402,15 +548,14 @@ impl Rpc {
|
||||||
|
|
||||||
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
||||||
match out.as_str() {
|
match out.as_str() {
|
||||||
"json" | "application/json" => self.format = Output::Json,
|
"json" | "application/json" => self.format = OutputFormat::Json,
|
||||||
"cbor" | "application/cbor" => self.format = Output::Cbor,
|
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
|
||||||
"pack" | "application/pack" => self.format = Output::Pack,
|
"pack" | "application/pack" => self.format = OutputFormat::Pack,
|
||||||
_ => return Err(Error::InvalidType),
|
_ => return Err(Error::InvalidType),
|
||||||
};
|
};
|
||||||
Ok(Value::None)
|
Ok(Value::None)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc use", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
||||||
if let Value::Strand(ns) = ns {
|
if let Value::Strand(ns) = ns {
|
||||||
self.session.ns = Some(ns.0);
|
self.session.ns = Some(ns.0);
|
||||||
|
@ -421,7 +566,6 @@ impl Rpc {
|
||||||
Ok(Value::None)
|
Ok(Value::None)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc signup", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
|
async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
|
||||||
let kvs = DB.get().unwrap();
|
let kvs = DB.get().unwrap();
|
||||||
surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
|
surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
|
||||||
|
@ -430,7 +574,6 @@ impl Rpc {
|
||||||
.map_err(Into::into)
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc signin", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
|
async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
|
||||||
let kvs = DB.get().unwrap();
|
let kvs = DB.get().unwrap();
|
||||||
surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
|
surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
|
||||||
|
@ -438,13 +581,11 @@ impl Rpc {
|
||||||
.map(Into::into)
|
.map(Into::into)
|
||||||
.map_err(Into::into)
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
#[instrument(skip_all, name = "rpc invalidate", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn invalidate(&mut self) -> Result<Value, Error> {
|
async fn invalidate(&mut self) -> Result<Value, Error> {
|
||||||
surrealdb::iam::clear::clear(&mut self.session)?;
|
surrealdb::iam::clear::clear(&mut self.session)?;
|
||||||
Ok(Value::None)
|
Ok(Value::None)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc auth", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
|
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
|
||||||
let kvs = DB.get().unwrap();
|
let kvs = DB.get().unwrap();
|
||||||
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
|
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
|
||||||
|
@ -455,7 +596,6 @@ impl Rpc {
|
||||||
// Methods for identification
|
// Methods for identification
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc info", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn info(&self) -> Result<Value, Error> {
|
async fn info(&self) -> Result<Value, Error> {
|
||||||
// Get a database reference
|
// Get a database reference
|
||||||
let kvs = DB.get().unwrap();
|
let kvs = DB.get().unwrap();
|
||||||
|
@ -473,7 +613,6 @@ impl Rpc {
|
||||||
// Methods for setting variables
|
// Methods for setting variables
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc set", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
|
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
|
||||||
match val {
|
match val {
|
||||||
// Remove the variable if undefined
|
// Remove the variable if undefined
|
||||||
|
@ -484,7 +623,6 @@ impl Rpc {
|
||||||
Ok(Value::Null)
|
Ok(Value::Null)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc unset", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
|
async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
|
||||||
self.vars.remove(&key.0);
|
self.vars.remove(&key.0);
|
||||||
Ok(Value::Null)
|
Ok(Value::Null)
|
||||||
|
@ -494,7 +632,6 @@ impl Rpc {
|
||||||
// Methods for live queries
|
// Methods for live queries
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc kill", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn kill(&self, id: Value) -> Result<Value, Error> {
|
async fn kill(&self, id: Value) -> Result<Value, Error> {
|
||||||
// Specify the SQL query string
|
// Specify the SQL query string
|
||||||
let sql = "KILL $id";
|
let sql = "KILL $id";
|
||||||
|
@ -513,7 +650,6 @@ impl Rpc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc live", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
|
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
|
||||||
// Specify the SQL query string
|
// Specify the SQL query string
|
||||||
let sql = match diff.is_true() {
|
let sql = match diff.is_true() {
|
||||||
|
@ -539,7 +675,6 @@ impl Rpc {
|
||||||
// Methods for selecting
|
// Methods for selecting
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc select", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn select(&self, what: Value) -> Result<Value, Error> {
|
async fn select(&self, what: Value) -> Result<Value, Error> {
|
||||||
// Return a single result?
|
// Return a single result?
|
||||||
let one = what.is_thing();
|
let one = what.is_thing();
|
||||||
|
@ -567,7 +702,6 @@ impl Rpc {
|
||||||
// Methods for inserting
|
// Methods for inserting
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc insert", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
|
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
// Return a single result?
|
// Return a single result?
|
||||||
let one = what.is_thing();
|
let one = what.is_thing();
|
||||||
|
@ -596,7 +730,6 @@ impl Rpc {
|
||||||
// Methods for creating
|
// Methods for creating
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc create", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
|
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
// Return a single result?
|
// Return a single result?
|
||||||
let one = what.is_thing();
|
let one = what.is_thing();
|
||||||
|
@ -625,7 +758,6 @@ impl Rpc {
|
||||||
// Methods for updating
|
// Methods for updating
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc update", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
|
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
// Return a single result?
|
// Return a single result?
|
||||||
let one = what.is_thing();
|
let one = what.is_thing();
|
||||||
|
@ -654,7 +786,6 @@ impl Rpc {
|
||||||
// Methods for changing
|
// Methods for changing
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc change", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn change(&self, what: Value, data: Value) -> Result<Value, Error> {
|
async fn change(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
// Return a single result?
|
// Return a single result?
|
||||||
let one = what.is_thing();
|
let one = what.is_thing();
|
||||||
|
@ -683,7 +814,6 @@ impl Rpc {
|
||||||
// Methods for modifying
|
// Methods for modifying
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc modify", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> {
|
async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
// Return a single result?
|
// Return a single result?
|
||||||
let one = what.is_thing();
|
let one = what.is_thing();
|
||||||
|
@ -712,7 +842,6 @@ impl Rpc {
|
||||||
// Methods for deleting
|
// Methods for deleting
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc delete", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn delete(&self, what: Value) -> Result<Value, Error> {
|
async fn delete(&self, what: Value) -> Result<Value, Error> {
|
||||||
// Return a single result?
|
// Return a single result?
|
||||||
let one = what.is_thing();
|
let one = what.is_thing();
|
||||||
|
@ -740,7 +869,6 @@ impl Rpc {
|
||||||
// Methods for querying
|
// Methods for querying
|
||||||
// ------------------------------
|
// ------------------------------
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc query", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> {
|
async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> {
|
||||||
// Get a database reference
|
// Get a database reference
|
||||||
let kvs = DB.get().unwrap();
|
let kvs = DB.get().unwrap();
|
||||||
|
@ -756,7 +884,6 @@ impl Rpc {
|
||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip_all, name = "rpc query_with", fields(websocket=self.uuid.to_string()))]
|
|
||||||
async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> {
|
async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> {
|
||||||
// Get a database reference
|
// Get a database reference
|
||||||
let kvs = DB.get().unwrap();
|
let kvs = DB.get().unwrap();
|
||||||
|
@ -781,8 +908,8 @@ impl Rpc {
|
||||||
QueryType::Live => {
|
QueryType::Live => {
|
||||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||||
// Match on Uuid type
|
// Match on Uuid type
|
||||||
LIVE_QUERIES.write().await.insert(lqid.0, self.uuid);
|
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
|
||||||
trace!("Registered live query {} on websocket {}", lqid, self.uuid);
|
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
QueryType::Kill => {
|
QueryType::Kill => {
|
||||||
|
|
|
@ -1,17 +1,57 @@
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use axum_server::Handle;
|
use axum_server::Handle;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
use crate::err::Error;
|
use crate::{
|
||||||
|
err::Error,
|
||||||
|
net::rpc::{WebSocketRef, WEBSOCKETS},
|
||||||
|
};
|
||||||
|
|
||||||
/// Start a graceful shutdown on the Axum Handle when a shutdown signal is received.
|
/// Start a graceful shutdown:
|
||||||
pub fn graceful_shutdown(handle: Handle, dur: Option<Duration>) {
|
/// * Signal the Axum Handle when a shutdown signal is received.
|
||||||
|
/// * Stop all WebSocket connections.
|
||||||
|
///
|
||||||
|
/// A second signal will force an immediate shutdown.
|
||||||
|
pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let result = listen().await.expect("Failed to listen to shutdown signal");
|
let result = listen().await.expect("Failed to listen to shutdown signal");
|
||||||
info!(target: super::LOG, "{} received. Start graceful shutdown...", result);
|
info!(target: super::LOG, "{} received. Waiting for graceful shutdown... A second signal will force an immediate shutdown", result);
|
||||||
|
|
||||||
handle.graceful_shutdown(dur)
|
tokio::select! {
|
||||||
});
|
// Start a normal graceful shutdown
|
||||||
|
_ = async {
|
||||||
|
// First stop accepting new HTTP requests
|
||||||
|
http_handle.graceful_shutdown(None);
|
||||||
|
|
||||||
|
// Close all WebSocket connections. Queued messages will still be processed.
|
||||||
|
for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() {
|
||||||
|
cancel_token.cancel();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wait for all existing WebSocket connections to gracefully close
|
||||||
|
while WEBSOCKETS.read().await.len() > 0 {
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
};
|
||||||
|
} => (),
|
||||||
|
// Force an immediate shutdown if a second signal is received
|
||||||
|
_ = async {
|
||||||
|
if let Ok(signal) = listen().await {
|
||||||
|
warn!(target: super::LOG, "{} received during graceful shutdown. Terminate immediately...", signal);
|
||||||
|
} else {
|
||||||
|
error!(target: super::LOG, "Failed to listen to shutdown signal. Terminate immediately...");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force an immediate shutdown
|
||||||
|
http_handle.shutdown();
|
||||||
|
|
||||||
|
// Close all WebSocket connections immediately
|
||||||
|
if let Ok(mut writer) = WEBSOCKETS.try_write() {
|
||||||
|
writer.drain();
|
||||||
|
}
|
||||||
|
} => (),
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
|
|
|
@ -1,119 +1,15 @@
|
||||||
use std::{fmt, time::Duration};
|
use std::{fmt, time::Duration};
|
||||||
|
|
||||||
use axum::{
|
use axum::extract::MatchedPath;
|
||||||
body::{boxed, Body, BoxBody},
|
use http::header;
|
||||||
extract::MatchedPath,
|
|
||||||
headers::{
|
|
||||||
authorization::{Basic, Bearer},
|
|
||||||
Authorization, Origin,
|
|
||||||
},
|
|
||||||
Extension, RequestPartsExt, TypedHeader,
|
|
||||||
};
|
|
||||||
use futures_util::future::BoxFuture;
|
|
||||||
use http::{header, request::Parts, StatusCode};
|
|
||||||
use hyper::{Request, Response};
|
use hyper::{Request, Response};
|
||||||
use surrealdb::{
|
|
||||||
dbs::Session,
|
|
||||||
iam::verify::{basic, token},
|
|
||||||
};
|
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
auth::AsyncAuthorizeRequest,
|
|
||||||
request_id::RequestId,
|
request_id::RequestId,
|
||||||
trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
|
trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
|
||||||
};
|
};
|
||||||
use tracing::{field, Level, Span};
|
use tracing::{field, Level, Span};
|
||||||
|
|
||||||
use crate::{dbs::DB, err::Error};
|
use super::client_ip::ExtractClientIP;
|
||||||
|
|
||||||
use super::{client_ip::ExtractClientIP, AppState};
|
|
||||||
|
|
||||||
///
|
|
||||||
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
|
|
||||||
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
|
|
||||||
///
|
|
||||||
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use tower_http::auth::RequireAuthorizationLayer;
|
|
||||||
/// use surrealdb::net::SurrealAuth;
|
|
||||||
/// use axum::Router;
|
|
||||||
///
|
|
||||||
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
|
|
||||||
///
|
|
||||||
/// let app = Router::new()
|
|
||||||
/// .route("/version", get(|| async { "0.1.0" }))
|
|
||||||
/// .layer(auth);
|
|
||||||
/// ```
|
|
||||||
#[derive(Clone, Copy)]
|
|
||||||
pub(super) struct SurrealAuth;
|
|
||||||
|
|
||||||
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
|
|
||||||
where
|
|
||||||
B: Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
type RequestBody = B;
|
|
||||||
type ResponseBody = BoxBody;
|
|
||||||
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
|
||||||
|
|
||||||
fn authorize(&mut self, request: Request<B>) -> Self::Future {
|
|
||||||
Box::pin(async {
|
|
||||||
let (mut parts, body) = request.into_parts();
|
|
||||||
match check_auth(&mut parts).await {
|
|
||||||
Ok(sess) => {
|
|
||||||
parts.extensions.insert(sess);
|
|
||||||
Ok(Request::from_parts(parts, body))
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
let unauthorized_response = Response::builder()
|
|
||||||
.status(StatusCode::UNAUTHORIZED)
|
|
||||||
.body(boxed(Body::from(err.to_string())))
|
|
||||||
.unwrap();
|
|
||||||
Err(unauthorized_response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
|
|
||||||
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
|
|
||||||
if !or.is_null() {
|
|
||||||
Some(or.to_string())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
|
||||||
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
|
||||||
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
|
||||||
|
|
||||||
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
|
|
||||||
tracing::error!("Error extracting the app state: {:?}", err);
|
|
||||||
Error::InvalidAuth
|
|
||||||
})?;
|
|
||||||
let ExtractClientIP(ip) =
|
|
||||||
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
|
|
||||||
|
|
||||||
// Create session
|
|
||||||
#[rustfmt::skip]
|
|
||||||
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
|
|
||||||
|
|
||||||
// If Basic authentication data was supplied
|
|
||||||
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
|
|
||||||
basic(kvs, &mut session, au.username(), au.password()).await.map_err(|e| e.into())
|
|
||||||
} else if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
|
|
||||||
token(kvs, &mut session, au.token()).await.map_err(|e| e.into())
|
|
||||||
} else {
|
|
||||||
Err(Error::InvalidAuth)
|
|
||||||
}?;
|
|
||||||
|
|
||||||
Ok(session)
|
|
||||||
}
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
|
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
|
||||||
|
@ -139,7 +35,6 @@ impl<B> MakeSpan<B> for HttpTraceLayerHooks {
|
||||||
fn make_span(&mut self, req: &Request<B>) -> Span {
|
fn make_span(&mut self, req: &Request<B>) -> Span {
|
||||||
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
|
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
|
||||||
let span = tracing::info_span!(
|
let span = tracing::info_span!(
|
||||||
target: "surreal::http",
|
|
||||||
"request",
|
"request",
|
||||||
otel.name = field::Empty,
|
otel.name = field::Empty,
|
||||||
otel.kind = "server",
|
otel.kind = "server",
|
||||||
|
@ -154,10 +49,10 @@ impl<B> MakeSpan<B> for HttpTraceLayerHooks {
|
||||||
network.protocol.name = "http",
|
network.protocol.name = "http",
|
||||||
network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"),
|
network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"),
|
||||||
client.address = field::Empty,
|
client.address = field::Empty,
|
||||||
client.port = field::Empty,
|
client.port = field::Empty,
|
||||||
client.socket.address = field::Empty,
|
client.socket.address = field::Empty,
|
||||||
server.address = field::Empty,
|
server.address = field::Empty,
|
||||||
server.port = field::Empty,
|
server.port = field::Empty,
|
||||||
// set on the response hook
|
// set on the response hook
|
||||||
http.latency.ms = field::Empty,
|
http.latency.ms = field::Empty,
|
||||||
http.response.status_code = field::Empty,
|
http.response.status_code = field::Empty,
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
pub mod args;
|
pub mod args;
|
||||||
pub mod paths;
|
pub mod paths;
|
||||||
pub mod res;
|
pub mod res;
|
||||||
|
|
||||||
|
pub(crate) static CONN_CLOSED_ERR: &str = "Connection closed normally";
|
||||||
|
|
|
@ -7,10 +7,13 @@ use surrealdb::dbs;
|
||||||
use surrealdb::dbs::Notification;
|
use surrealdb::dbs::Notification;
|
||||||
use surrealdb::sql;
|
use surrealdb::sql;
|
||||||
use surrealdb::sql::Value;
|
use surrealdb::sql::Value;
|
||||||
use tracing::instrument;
|
use tracing::Span;
|
||||||
|
|
||||||
#[derive(Clone)]
|
use crate::err;
|
||||||
pub enum Output {
|
use crate::rpc::CONN_CLOSED_ERR;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum OutputFormat {
|
||||||
Json, // JSON
|
Json, // JSON
|
||||||
Cbor, // CBOR
|
Cbor, // CBOR
|
||||||
Pack, // MessagePack
|
Pack, // MessagePack
|
||||||
|
@ -37,6 +40,12 @@ impl From<Value> for Data {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<String> for Data {
|
||||||
|
fn from(v: String) -> Self {
|
||||||
|
Data::Other(Value::from(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Vec<dbs::Response>> for Data {
|
impl From<Vec<dbs::Response>> for Data {
|
||||||
fn from(v: Vec<dbs::Response>) -> Self {
|
fn from(v: Vec<dbs::Response>) -> Self {
|
||||||
Data::Query(v)
|
Data::Query(v)
|
||||||
|
@ -82,28 +91,45 @@ impl Response {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send the response to the WebSocket channel
|
/// Send the response to the WebSocket channel
|
||||||
#[instrument(skip_all, name = "rpc response", fields(response = ?self))]
|
pub async fn send(self, out: OutputFormat, chn: Sender<Message>) {
|
||||||
pub async fn send(self, out: Output, chn: Sender<Message>) {
|
let span = Span::current();
|
||||||
|
|
||||||
|
info!("Process RPC response");
|
||||||
|
|
||||||
|
if let Err(err) = &self.result {
|
||||||
|
span.record("otel.status_code", "Error");
|
||||||
|
span.record(
|
||||||
|
"otel.status_message",
|
||||||
|
format!("code: {}, message: {}", err.code, err.message),
|
||||||
|
);
|
||||||
|
span.record("rpc.jsonrpc.error_code", err.code);
|
||||||
|
span.record("rpc.jsonrpc.error_message", err.message.as_ref());
|
||||||
|
}
|
||||||
|
|
||||||
let message = match out {
|
let message = match out {
|
||||||
Output::Json => {
|
OutputFormat::Json => {
|
||||||
let res = serde_json::to_string(&self.simplify()).unwrap();
|
let res = serde_json::to_string(&self.simplify()).unwrap();
|
||||||
Message::Text(res)
|
Message::Text(res)
|
||||||
}
|
}
|
||||||
Output::Cbor => {
|
OutputFormat::Cbor => {
|
||||||
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
||||||
Message::Binary(res)
|
Message::Binary(res)
|
||||||
}
|
}
|
||||||
Output::Pack => {
|
OutputFormat::Pack => {
|
||||||
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
||||||
Message::Binary(res)
|
Message::Binary(res)
|
||||||
}
|
}
|
||||||
Output::Full => {
|
OutputFormat::Full => {
|
||||||
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
||||||
Message::Binary(res)
|
Message::Binary(res)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let _ = chn.send(message).await;
|
|
||||||
trace!("Response sent");
|
if let Err(err) = chn.send(message).await {
|
||||||
|
if err.to_string() != CONN_CLOSED_ERR {
|
||||||
|
error!("Error sending response: {}", err);
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +139,7 @@ pub struct Failure {
|
||||||
message: Cow<'static, str>,
|
message: Cow<'static, str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
impl Failure {
|
impl Failure {
|
||||||
pub const PARSE_ERROR: Failure = Failure {
|
pub const PARSE_ERROR: Failure = Failure {
|
||||||
code: -32700,
|
code: -32700,
|
||||||
|
@ -165,3 +192,26 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
|
||||||
result: Err(err),
|
result: Err(err),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<err::Error> for Failure {
|
||||||
|
fn from(err: err::Error) -> Self {
|
||||||
|
Failure::custom(err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait IntoRpcResponse {
|
||||||
|
fn into_response(self, id: Option<Value>) -> Response;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E> IntoRpcResponse for Result<T, E>
|
||||||
|
where
|
||||||
|
T: Into<Data>,
|
||||||
|
E: Into<Failure>,
|
||||||
|
{
|
||||||
|
fn into_response(self, id: Option<Value>) -> Response {
|
||||||
|
match self {
|
||||||
|
Ok(v) => success(id, v.into()),
|
||||||
|
Err(err) => failure(id, err.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
use tracing::Subscriber;
|
use tracing::Subscriber;
|
||||||
use tracing_subscriber::fmt::format::FmtSpan;
|
use tracing_subscriber::fmt::format::FmtSpan;
|
||||||
use tracing_subscriber::{EnvFilter, Layer};
|
use tracing_subscriber::Layer;
|
||||||
|
|
||||||
pub fn new<S>(level: String) -> Box<dyn Layer<S> + Send + Sync>
|
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
||||||
|
|
||||||
|
pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
|
||||||
where
|
where
|
||||||
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
||||||
{
|
{
|
||||||
|
@ -11,6 +13,6 @@ where
|
||||||
.with_ansi(true)
|
.with_ansi(true)
|
||||||
.with_span_events(FmtSpan::NONE)
|
.with_span_events(FmtSpan::NONE)
|
||||||
.with_writer(std::io::stderr)
|
.with_writer(std::io::stderr)
|
||||||
.with_filter(EnvFilter::builder().parse(level).unwrap())
|
.with_filter(filter.0)
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
mod logs;
|
mod logs;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
mod traces;
|
pub mod traces;
|
||||||
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
@ -11,8 +11,7 @@ use opentelemetry::sdk::resource::{
|
||||||
};
|
};
|
||||||
use opentelemetry::sdk::Resource;
|
use opentelemetry::sdk::Resource;
|
||||||
use opentelemetry::KeyValue;
|
use opentelemetry::KeyValue;
|
||||||
use tracing::Subscriber;
|
use tracing::{Level, Subscriber};
|
||||||
use tracing_subscriber::fmt::format::FmtSpan;
|
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
|
@ -39,53 +38,75 @@ pub static OTEL_DEFAULT_RESOURCE: Lazy<Resource> = Lazy::new(|| {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Builder {
|
pub struct Builder {
|
||||||
log_level: Option<String>,
|
filter: CustomEnvFilter,
|
||||||
filter: Option<CustomEnvFilter>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn builder() -> Builder {
|
pub fn builder() -> Builder {
|
||||||
Builder::default()
|
Builder::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for Builder {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
filter: CustomEnvFilter(EnvFilter::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Builder {
|
impl Builder {
|
||||||
/// Set the log level on the builder
|
/// Set the log level on the builder
|
||||||
pub fn with_log_level(mut self, log_level: &str) -> Self {
|
pub fn with_log_level(mut self, log_level: &str) -> Self {
|
||||||
self.log_level = Some(log_level.to_string());
|
if let Ok(filter) = filter_from_value(log_level) {
|
||||||
|
self.filter = CustomEnvFilter(filter);
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the filter on the builder
|
/// Set the filter on the builder
|
||||||
#[cfg(feature = "has-storage")]
|
#[cfg(feature = "has-storage")]
|
||||||
pub fn with_filter(mut self, filter: EnvFilter) -> Self {
|
pub fn with_filter(mut self, filter: CustomEnvFilter) -> Self {
|
||||||
self.filter = Some(CustomEnvFilter(filter));
|
self.filter = filter;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
|
/// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
|
||||||
pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> {
|
pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> {
|
||||||
let registry = tracing_subscriber::registry();
|
let registry = tracing_subscriber::registry();
|
||||||
let registry = registry.with(self.filter.map(|filter| {
|
|
||||||
tracing_subscriber::fmt::layer()
|
// Setup logging layer
|
||||||
.compact()
|
let registry = registry.with(logs::new(self.filter.clone()));
|
||||||
.with_ansi(true)
|
|
||||||
.with_span_events(FmtSpan::NONE)
|
// Setup tracing layer
|
||||||
.with_writer(std::io::stderr)
|
let registry = registry.with(traces::new(self.filter));
|
||||||
.with_filter(filter.0)
|
|
||||||
.boxed()
|
|
||||||
}));
|
|
||||||
let registry = registry.with(self.log_level.map(logs::new));
|
|
||||||
let registry = registry.with(traces::new());
|
|
||||||
Box::new(registry)
|
Box::new(registry)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// tracing pipeline
|
/// Install the tracing dispatcher globally
|
||||||
pub fn init(self) {
|
pub fn init(self) {
|
||||||
self.build().init()
|
self.build().init()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an EnvFilter from the given value. If the value is not a valid log level, it will be treated as EnvFilter directives.
|
||||||
|
pub fn filter_from_value(v: &str) -> Result<EnvFilter, tracing_subscriber::filter::ParseError> {
|
||||||
|
match v {
|
||||||
|
// Don't show any logs at all
|
||||||
|
"none" => Ok(EnvFilter::default()),
|
||||||
|
// Check if we should show all log levels
|
||||||
|
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
|
||||||
|
// Otherwise, let's only show errors
|
||||||
|
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
|
||||||
|
// Specify the log level for each code area
|
||||||
|
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
|
||||||
|
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::kvs::tx=error")),
|
||||||
|
// Let's try to parse the custom log level
|
||||||
|
_ => EnvFilter::builder().parse(v),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use opentelemetry::global::shutdown_tracer_provider;
|
use opentelemetry::global::shutdown_tracer_provider;
|
||||||
|
@ -107,7 +128,7 @@ mod tests {
|
||||||
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
||||||
],
|
],
|
||||||
|| {
|
|| {
|
||||||
let _enter = telemetry::builder().build().set_default();
|
let _enter = telemetry::builder().with_log_level("info").build().set_default();
|
||||||
|
|
||||||
println!("Sending span...");
|
println!("Sending span...");
|
||||||
|
|
||||||
|
@ -123,7 +144,11 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Waiting for request...");
|
println!("Waiting for request...");
|
||||||
let req = req_rx.recv().await.expect("missing export request");
|
let req = tokio::select! {
|
||||||
|
req = req_rx.recv() => req.expect("missing export request"),
|
||||||
|
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
|
||||||
|
};
|
||||||
|
|
||||||
let first_span =
|
let first_span =
|
||||||
req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap();
|
req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap();
|
||||||
assert_eq!("test-surreal-span", first_span.name);
|
assert_eq!("test-surreal-span", first_span.name);
|
||||||
|
@ -141,11 +166,10 @@ mod tests {
|
||||||
temp_env::with_vars(
|
temp_env::with_vars(
|
||||||
vec![
|
vec![
|
||||||
("SURREAL_TRACING_TRACER", Some("otlp")),
|
("SURREAL_TRACING_TRACER", Some("otlp")),
|
||||||
("SURREAL_TRACING_FILTER", Some("debug")),
|
|
||||||
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
||||||
],
|
],
|
||||||
|| {
|
|| {
|
||||||
let _enter = telemetry::builder().build().set_default();
|
let _enter = telemetry::builder().with_log_level("debug").build().set_default();
|
||||||
|
|
||||||
println!("Sending spans...");
|
println!("Sending spans...");
|
||||||
|
|
||||||
|
@ -169,7 +193,10 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Waiting for request...");
|
println!("Waiting for request...");
|
||||||
let req = req_rx.recv().await.expect("missing export request");
|
let req = tokio::select! {
|
||||||
|
req = req_rx.recv() => req.expect("missing export request"),
|
||||||
|
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
|
||||||
|
};
|
||||||
let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans;
|
let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans;
|
||||||
|
|
||||||
assert_eq!(1, spans.len());
|
assert_eq!(1, spans.len());
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
use tracing::Subscriber;
|
use tracing::Subscriber;
|
||||||
use tracing_subscriber::Layer;
|
use tracing_subscriber::Layer;
|
||||||
|
|
||||||
|
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
||||||
|
|
||||||
pub mod otlp;
|
pub mod otlp;
|
||||||
|
pub mod rpc;
|
||||||
|
|
||||||
const TRACING_TRACER_VAR: &str = "SURREAL_TRACING_TRACER";
|
const TRACING_TRACER_VAR: &str = "SURREAL_TRACING_TRACER";
|
||||||
|
|
||||||
// Returns a tracer based on the value of the TRACING_TRACER_VAR env var
|
// Returns a tracer based on the value of the TRACING_TRACER_VAR env var
|
||||||
pub fn new<S>() -> Option<Box<dyn Layer<S> + Send + Sync>>
|
pub fn new<S>(filter: CustomEnvFilter) -> Option<Box<dyn Layer<S> + Send + Sync>>
|
||||||
where
|
where
|
||||||
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
||||||
{
|
{
|
||||||
|
@ -20,7 +23,7 @@ where
|
||||||
// Init the registry with the OTLP tracer
|
// Init the registry with the OTLP tracer
|
||||||
"otlp" => {
|
"otlp" => {
|
||||||
debug!("Setup the OTLP tracer");
|
debug!("Setup the OTLP tracer");
|
||||||
Some(otlp::new())
|
Some(otlp::new(filter))
|
||||||
}
|
}
|
||||||
tracer => {
|
tracer => {
|
||||||
panic!("unsupported tracer {}", tracer);
|
panic!("unsupported tracer {}", tracer);
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
use opentelemetry::sdk::trace::Tracer;
|
use opentelemetry::sdk::trace::Tracer;
|
||||||
use opentelemetry::trace::TraceError;
|
use opentelemetry::trace::TraceError;
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use tracing::{Level, Subscriber};
|
use tracing::Subscriber;
|
||||||
use tracing_subscriber::{EnvFilter, Layer};
|
use tracing_subscriber::Layer;
|
||||||
|
|
||||||
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
|
use crate::{
|
||||||
|
cli::validator::parser::env_filter::CustomEnvFilter, telemetry::OTEL_DEFAULT_RESOURCE,
|
||||||
|
};
|
||||||
|
|
||||||
const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER";
|
pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
|
||||||
|
|
||||||
pub fn new<S>() -> Box<dyn Layer<S> + Send + Sync>
|
|
||||||
where
|
where
|
||||||
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
||||||
{
|
{
|
||||||
tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter()).boxed()
|
tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter.0).boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tracer() -> Result<Tracer, TraceError> {
|
fn tracer() -> Result<Tracer, TraceError> {
|
||||||
|
@ -24,16 +24,3 @@ fn tracer() -> Result<Tracer, TraceError> {
|
||||||
)
|
)
|
||||||
.install_batch(opentelemetry::runtime::Tokio)
|
.install_batch(opentelemetry::runtime::Tokio)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a filter for the OTLP subscriber
|
|
||||||
///
|
|
||||||
/// It creates an EnvFilter based on the TRACING_FILTER_VAR's value
|
|
||||||
///
|
|
||||||
/// TRACING_FILTER_VAR accepts the same syntax as RUST_LOG
|
|
||||||
fn filter() -> EnvFilter {
|
|
||||||
EnvFilter::builder()
|
|
||||||
.with_env_var(TRACING_FILTER_VAR)
|
|
||||||
.with_default_directive(Level::INFO.into())
|
|
||||||
.from_env()
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
31
src/telemetry/traces/rpc.rs
Normal file
31
src/telemetry/traces/rpc.rs
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
use tracing::{field, Span};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
pub fn span_for_request(ws_id: &Uuid) -> Span {
|
||||||
|
let span = tracing::info_span!(
|
||||||
|
// Dynamic span names need to be 'recorded', can't be used on the macro. Use a static name here and overwrite later on
|
||||||
|
"rpc/call",
|
||||||
|
otel.name = field::Empty,
|
||||||
|
otel.kind = "server",
|
||||||
|
|
||||||
|
// To be populated by the request handler when the method is known
|
||||||
|
rpc.method = field::Empty,
|
||||||
|
rpc.service = "surrealdb",
|
||||||
|
rpc.system = "jsonrpc",
|
||||||
|
|
||||||
|
// JSON-RPC fields
|
||||||
|
rpc.jsonrpc.version = "2.0",
|
||||||
|
rpc.jsonrpc.request_id = field::Empty,
|
||||||
|
rpc.jsonrpc.error_code = field::Empty,
|
||||||
|
rpc.jsonrpc.error_message = field::Empty,
|
||||||
|
|
||||||
|
// SurrealDB custom fields
|
||||||
|
ws.id = %ws_id,
|
||||||
|
|
||||||
|
// Fields for error reporting
|
||||||
|
otel.status_code = field::Empty,
|
||||||
|
otel.status_message = field::Empty,
|
||||||
|
);
|
||||||
|
|
||||||
|
span
|
||||||
|
}
|
|
@ -5,6 +5,8 @@ mod common;
|
||||||
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
||||||
use serial_test::serial;
|
use serial_test::serial;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
use test_log::test;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
use common::{PASS, USER};
|
use common::{PASS, USER};
|
||||||
|
|
||||||
|
@ -32,13 +34,14 @@ fn nonexistent_option() {
|
||||||
assert!(common::run("version --turbo").output().is_err());
|
assert!(common::run("version --turbo").output().is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn all_commands() {
|
async fn all_commands() {
|
||||||
// Commands without credentials when auth is disabled, should succeed
|
// Commands without credentials when auth is disabled, should succeed
|
||||||
let (addr, _server) = common::start_server(false, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(false, false, true).await.unwrap();
|
||||||
let creds = ""; // Anonymous user
|
let creds = ""; // Anonymous user
|
||||||
// Create a record
|
|
||||||
|
info!("* Create a record");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi");
|
let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -48,7 +51,7 @@ async fn all_commands() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Export to stdout
|
info!("* Export to stdout");
|
||||||
{
|
{
|
||||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
|
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
|
||||||
let output = common::run(&args).output().expect("failed to run stdout export: {args}");
|
let output = common::run(&args).output().expect("failed to run stdout export: {args}");
|
||||||
|
@ -56,7 +59,7 @@ async fn all_commands() {
|
||||||
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
|
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Export to file
|
info!("* Export to file");
|
||||||
let exported = {
|
let exported = {
|
||||||
let exported = common::tmp_file("exported.surql");
|
let exported = common::tmp_file("exported.surql");
|
||||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
||||||
|
@ -64,13 +67,13 @@ async fn all_commands() {
|
||||||
exported
|
exported
|
||||||
};
|
};
|
||||||
|
|
||||||
// Import the exported file
|
info!("* Import the exported file");
|
||||||
{
|
{
|
||||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
||||||
common::run(&args).output().expect("failed to run import: {args}");
|
common::run(&args).output().expect("failed to run import: {args}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query from the import (pretty-printed this time)
|
info!("* Query from the import (pretty-printed this time)");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty");
|
let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -80,7 +83,7 @@ async fn all_commands() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unfinished backup CLI
|
info!("* Unfinished backup CLI");
|
||||||
{
|
{
|
||||||
let file = common::tmp_file("backup.db");
|
let file = common::tmp_file("backup.db");
|
||||||
let args = format!("backup {creds} http://{addr} {file}");
|
let args = format!("backup {creds} http://{addr} {file}");
|
||||||
|
@ -90,7 +93,7 @@ async fn all_commands() {
|
||||||
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
|
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multi-statement (and multi-line) query including error(s) over WS
|
info!("* Multi-statement (and multi-line) query including error(s) over WS");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty");
|
let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty");
|
||||||
let output = common::run(&args)
|
let output = common::run(&args)
|
||||||
|
@ -113,7 +116,7 @@ async fn all_commands() {
|
||||||
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
|
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multi-statement (and multi-line) transaction including error(s) over WS
|
info!("* Multi-statement (and multi-line) transaction including error(s) over WS");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty");
|
let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty");
|
||||||
let output = common::run(&args)
|
let output = common::run(&args)
|
||||||
|
@ -137,7 +140,7 @@ async fn all_commands() {
|
||||||
assert!(output.contains("rgument"), "missing argument error in {output}");
|
assert!(output.contains("rgument"), "missing argument error in {output}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass neither ns nor db
|
info!("* Pass neither ns nor db");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn http://{addr} {creds}");
|
let args = format!("sql --conn http://{addr} {creds}");
|
||||||
let output = common::run(&args)
|
let output = common::run(&args)
|
||||||
|
@ -147,7 +150,7 @@ async fn all_commands() {
|
||||||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass only ns
|
info!("* Pass only ns");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn http://{addr} {creds} --ns N5");
|
let args = format!("sql --conn http://{addr} {creds} --ns N5");
|
||||||
let output = common::run(&args)
|
let output = common::run(&args)
|
||||||
|
@ -157,26 +160,36 @@ async fn all_commands() {
|
||||||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass only db and expect an error
|
info!("* Pass only db and expect an error");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn http://{addr} {creds} --db D5");
|
let args = format!("sql --conn http://{addr} {creds} --db D5");
|
||||||
common::run(&args).output().expect_err("only db");
|
common::run(&args).output().expect_err("only db");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn start_tls() {
|
async fn start_tls() {
|
||||||
let (_, server) = common::start_server(false, true, false).await.unwrap();
|
// Capute the server's stdout/stderr
|
||||||
|
temp_env::async_with_vars(
|
||||||
|
[
|
||||||
|
("SURREAL_TEST_SERVER_STDOUT", Some("piped")),
|
||||||
|
("SURREAL_TEST_SERVER_STDERR", Some("piped")),
|
||||||
|
],
|
||||||
|
async {
|
||||||
|
let (_, server) = common::start_server(false, true, false).await.unwrap();
|
||||||
|
|
||||||
std::thread::sleep(std::time::Duration::from_millis(2000));
|
std::thread::sleep(std::time::Duration::from_millis(2000));
|
||||||
let output = server.kill().output().err().unwrap();
|
let output = server.kill().output().err().unwrap();
|
||||||
|
|
||||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||||
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
|
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn with_root_auth() {
|
async fn with_root_auth() {
|
||||||
// Commands with credentials when auth is enabled, should succeed
|
// Commands with credentials when auth is enabled, should succeed
|
||||||
|
@ -184,7 +197,7 @@ async fn with_root_auth() {
|
||||||
let creds = format!("--user {USER} --pass {PASS}");
|
let creds = format!("--user {USER} --pass {PASS}");
|
||||||
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
|
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
|
||||||
|
|
||||||
// Can query /sql over HTTP
|
info!("* Query over HTTP");
|
||||||
{
|
{
|
||||||
let args = format!("{sql_args} {creds}");
|
let args = format!("{sql_args} {creds}");
|
||||||
let input = "INFO FOR ROOT;";
|
let input = "INFO FOR ROOT;";
|
||||||
|
@ -192,7 +205,7 @@ async fn with_root_auth() {
|
||||||
assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap());
|
assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can query /sql over WS
|
info!("* Query over WS");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
|
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
|
||||||
let input = "INFO FOR ROOT;";
|
let input = "INFO FOR ROOT;";
|
||||||
|
@ -200,7 +213,7 @@ async fn with_root_auth() {
|
||||||
assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap());
|
assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
// KV user can do exports
|
info!("* Root user can do exports");
|
||||||
let exported = {
|
let exported = {
|
||||||
let exported = common::tmp_file("exported.surql");
|
let exported = common::tmp_file("exported.surql");
|
||||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
||||||
|
@ -209,13 +222,13 @@ async fn with_root_auth() {
|
||||||
exported
|
exported
|
||||||
};
|
};
|
||||||
|
|
||||||
// KV user can do imports
|
info!("* Root user can do imports");
|
||||||
{
|
{
|
||||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
||||||
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
|
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// KV user can do backups
|
info!("* Root user can do backups");
|
||||||
{
|
{
|
||||||
let file = common::tmp_file("backup.db");
|
let file = common::tmp_file("backup.db");
|
||||||
let args = format!("backup {creds} http://{addr} {file}");
|
let args = format!("backup {creds} http://{addr} {file}");
|
||||||
|
@ -226,7 +239,7 @@ async fn with_root_auth() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn with_anon_auth() {
|
async fn with_anon_auth() {
|
||||||
// Commands without credentials when auth is enabled, should fail
|
// Commands without credentials when auth is enabled, should fail
|
||||||
|
@ -234,7 +247,7 @@ async fn with_anon_auth() {
|
||||||
let creds = ""; // Anonymous user
|
let creds = ""; // Anonymous user
|
||||||
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
|
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
|
||||||
|
|
||||||
// Can query /sql over HTTP
|
info!("* Query over HTTP");
|
||||||
{
|
{
|
||||||
let args = format!("{sql_args} {creds}");
|
let args = format!("{sql_args} {creds}");
|
||||||
let input = "";
|
let input = "";
|
||||||
|
@ -242,7 +255,7 @@ async fn with_anon_auth() {
|
||||||
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
|
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can query /sql over HTTP
|
info!("* Query over WS");
|
||||||
{
|
{
|
||||||
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
|
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
|
||||||
let input = "";
|
let input = "";
|
||||||
|
@ -250,7 +263,7 @@ async fn with_anon_auth() {
|
||||||
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
|
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can't do exports
|
info!("* Can't do exports");
|
||||||
{
|
{
|
||||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
|
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
|
||||||
let output = common::run(&args).output();
|
let output = common::run(&args).output();
|
||||||
|
@ -261,7 +274,7 @@ async fn with_anon_auth() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can't do imports
|
info!("* Can't do imports");
|
||||||
{
|
{
|
||||||
let tmp_file = common::tmp_file("exported.surql");
|
let tmp_file = common::tmp_file("exported.surql");
|
||||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {tmp_file}");
|
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {tmp_file}");
|
||||||
|
@ -273,7 +286,7 @@ async fn with_anon_auth() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can't do backups
|
info!("* Can't do backups");
|
||||||
{
|
{
|
||||||
let args = format!("backup {creds} http://{addr}");
|
let args = format!("backup {creds} http://{addr}");
|
||||||
let output = common::run(&args).output();
|
let output = common::run(&args).output();
|
||||||
|
|
|
@ -1,10 +1,18 @@
|
||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fs;
|
use std::fs::File;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::process::{Command, Stdio};
|
use std::process::{Command, Stdio};
|
||||||
|
use std::{env, fs};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
|
use tokio_tungstenite::tungstenite::Message;
|
||||||
|
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
pub const USER: &str = "root";
|
pub const USER: &str = "root";
|
||||||
pub const PASS: &str = "root";
|
pub const PASS: &str = "root";
|
||||||
|
@ -52,7 +60,12 @@ impl Drop for Child {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
pub fn run_internal<P: AsRef<Path>>(
|
||||||
|
args: &str,
|
||||||
|
current_dir: Option<P>,
|
||||||
|
stdout: Stdio,
|
||||||
|
stderr: Stdio,
|
||||||
|
) -> Child {
|
||||||
let mut path = std::env::current_exe().unwrap();
|
let mut path = std::env::current_exe().unwrap();
|
||||||
assert!(path.pop());
|
assert!(path.pop());
|
||||||
if path.ends_with("deps") {
|
if path.ends_with("deps") {
|
||||||
|
@ -68,8 +81,8 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
|
||||||
}
|
}
|
||||||
cmd.env_clear();
|
cmd.env_clear();
|
||||||
cmd.stdin(Stdio::piped());
|
cmd.stdin(Stdio::piped());
|
||||||
cmd.stdout(Stdio::piped());
|
cmd.stdout(stdout);
|
||||||
cmd.stderr(Stdio::piped());
|
cmd.stderr(stderr);
|
||||||
cmd.args(args.split_ascii_whitespace());
|
cmd.args(args.split_ascii_whitespace());
|
||||||
Child {
|
Child {
|
||||||
inner: Some(cmd.spawn().unwrap()),
|
inner: Some(cmd.spawn().unwrap()),
|
||||||
|
@ -78,12 +91,12 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
|
||||||
|
|
||||||
/// Run the CLI with the given args
|
/// Run the CLI with the given args
|
||||||
pub fn run(args: &str) -> Child {
|
pub fn run(args: &str) -> Child {
|
||||||
run_internal::<String>(args, None)
|
run_internal::<String>(args, None, Stdio::piped(), Stdio::piped())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run the CLI with the given args inside a temporary directory
|
/// Run the CLI with the given args inside a temporary directory
|
||||||
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
||||||
run_internal(args, Some(current_dir))
|
run_internal(args, Some(current_dir), Stdio::piped(), Stdio::piped())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tmp_file(name: &str) -> String {
|
pub fn tmp_file(name: &str) -> String {
|
||||||
|
@ -91,6 +104,19 @@ pub fn tmp_file(name: &str) -> String {
|
||||||
path.to_string_lossy().into_owned()
|
path.to_string_lossy().into_owned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_server_stdio_from_var(var: &str) -> Result<Stdio, Box<dyn Error>> {
|
||||||
|
match env::var(var).as_deref() {
|
||||||
|
Ok("inherit") => Ok(Stdio::inherit()),
|
||||||
|
Ok("null") => Ok(Stdio::null()),
|
||||||
|
Ok("piped") => Ok(Stdio::piped()),
|
||||||
|
Ok(val) if val.starts_with("file://") => {
|
||||||
|
Ok(Stdio::from(File::create(val.trim_start_matches("file://"))?))
|
||||||
|
}
|
||||||
|
Ok(val) => Err(format!("Unsupported stdio value: {val:?}").into()),
|
||||||
|
_ => Ok(Stdio::null()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn start_server(
|
pub async fn start_server(
|
||||||
auth: bool,
|
auth: bool,
|
||||||
tls: bool,
|
tls: bool,
|
||||||
|
@ -118,11 +144,14 @@ pub async fn start_server(
|
||||||
extra_args.push_str(" --auth");
|
extra_args.push_str(" --auth");
|
||||||
}
|
}
|
||||||
|
|
||||||
let start_args = format!("start --bind {addr} memory --no-banner --log info --user {USER} --pass {PASS} {extra_args}");
|
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
|
||||||
|
|
||||||
println!("starting server with args: {start_args}");
|
info!("starting server with args: {start_args}");
|
||||||
|
|
||||||
let server = run(&start_args);
|
// Configure where the logs go when running the test
|
||||||
|
let stdout = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDOUT")?;
|
||||||
|
let stderr = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDERR")?;
|
||||||
|
let server = run_internal::<String>(&start_args, None, stdout, stderr);
|
||||||
|
|
||||||
if !wait_is_ready {
|
if !wait_is_ready {
|
||||||
return Ok((addr, server));
|
return Ok((addr, server));
|
||||||
|
@ -130,17 +159,178 @@ pub async fn start_server(
|
||||||
|
|
||||||
// Wait 5 seconds for the server to start
|
// Wait 5 seconds for the server to start
|
||||||
let mut interval = time::interval(time::Duration::from_millis(500));
|
let mut interval = time::interval(time::Duration::from_millis(500));
|
||||||
println!("Waiting for server to start...");
|
info!("Waiting for server to start...");
|
||||||
for _i in 0..10 {
|
for _i in 0..10 {
|
||||||
interval.tick().await;
|
interval.tick().await;
|
||||||
|
|
||||||
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
||||||
println!("Server ready!");
|
info!("Server ready!");
|
||||||
return Ok((addr, server));
|
return Ok((addr, server));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let server_out = server.kill().output().err().unwrap();
|
let server_out = server.kill().output().err().unwrap();
|
||||||
println!("server output: {server_out}");
|
error!("server output: {server_out}");
|
||||||
Err("server failed to start".into())
|
Err("server failed to start".into())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||||
|
|
||||||
|
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
|
||||||
|
let url = format!("ws://{}/rpc", addr);
|
||||||
|
let (ws_stream, _) = connect_async(url).await?;
|
||||||
|
Ok(ws_stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_send_msg(
|
||||||
|
socket: &mut WsStream,
|
||||||
|
msg_req: String,
|
||||||
|
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||||
|
// Use JSON format by default
|
||||||
|
ws_send_msg_with_fmt(socket, msg_req, Format::Json).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Format {
|
||||||
|
Json,
|
||||||
|
Cbor,
|
||||||
|
Pack,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_send_msg_with_fmt(
|
||||||
|
socket: &mut WsStream,
|
||||||
|
msg_req: String,
|
||||||
|
response_format: Format,
|
||||||
|
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||||
|
tokio::select! {
|
||||||
|
_ = time::sleep(time::Duration::from_millis(500)) => {
|
||||||
|
return Err("timeout waiting for the request to be sent".into());
|
||||||
|
}
|
||||||
|
res = socket.send(Message::Text(msg_req)) => {
|
||||||
|
if let Err(err) = res {
|
||||||
|
return Err(format!("Error sending the message: {}", err).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut f = socket.try_filter(|msg| match response_format {
|
||||||
|
Format::Json => futures_util::future::ready(msg.is_text()),
|
||||||
|
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = time::sleep(time::Duration::from_millis(2000)) => {
|
||||||
|
Err("timeout waiting for the response".into())
|
||||||
|
}
|
||||||
|
res = f.select_next_some() => {
|
||||||
|
match response_format {
|
||||||
|
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
|
||||||
|
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
|
||||||
|
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct SigninParams<'a> {
|
||||||
|
user: &'a str,
|
||||||
|
pass: &'a str,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
ns: Option<&'a str>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
db: Option<&'a str>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
sc: Option<&'a str>,
|
||||||
|
}
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct UseParams<'a> {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
ns: Option<&'a str>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
db: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_signin(
|
||||||
|
socket: &mut WsStream,
|
||||||
|
user: &str,
|
||||||
|
pass: &str,
|
||||||
|
ns: Option<&str>,
|
||||||
|
db: Option<&str>,
|
||||||
|
sc: Option<&str>,
|
||||||
|
) -> Result<String, Box<dyn Error>> {
|
||||||
|
let json = json!({
|
||||||
|
"id": "1",
|
||||||
|
"method": "signin",
|
||||||
|
"params": [
|
||||||
|
SigninParams { user, pass, ns, db, sc }
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||||
|
match msg.as_object() {
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||||
|
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||||
|
}
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
|
||||||
|
Ok(obj.get("result").unwrap().as_str().unwrap_or_default().to_owned())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||||
|
Err(format!("unexpected response: {:?}", msg).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_query(
|
||||||
|
socket: &mut WsStream,
|
||||||
|
query: &str,
|
||||||
|
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||||
|
let json = json!({
|
||||||
|
"id": "1",
|
||||||
|
"method": "query",
|
||||||
|
"params": [query],
|
||||||
|
});
|
||||||
|
|
||||||
|
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||||
|
|
||||||
|
match msg.as_object() {
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||||
|
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||||
|
}
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
|
||||||
|
Ok(obj.get("result").unwrap().as_array().unwrap().to_owned())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||||
|
Err(format!("unexpected response: {:?}", msg).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_use(
|
||||||
|
socket: &mut WsStream,
|
||||||
|
ns: Option<&str>,
|
||||||
|
db: Option<&str>,
|
||||||
|
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||||
|
let json = json!({
|
||||||
|
"id": "1",
|
||||||
|
"method": "use",
|
||||||
|
"params": [
|
||||||
|
ns, db
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||||
|
match msg.as_object() {
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||||
|
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||||
|
}
|
||||||
|
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
|
||||||
|
Ok(obj.get("result").unwrap().to_owned())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||||
|
Err(format!("unexpected response: {:?}", msg).into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -7,10 +7,11 @@ use http::{header, Method};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use serial_test::serial;
|
use serial_test::serial;
|
||||||
|
use test_log::test;
|
||||||
|
|
||||||
use crate::common::{PASS, USER};
|
use crate::common::{PASS, USER};
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
|
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -52,7 +53,7 @@ async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
|
async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -133,14 +134,14 @@ async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn client_ip_extractor() -> Result<(), Box<dyn std::error::Error>> {
|
async fn client_ip_extractor() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// TODO: test the client IP extractor
|
// TODO: test the client IP extractor
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -184,7 +185,7 @@ async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -196,7 +197,7 @@ async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -269,7 +270,7 @@ async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -303,7 +304,7 @@ async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -372,7 +373,7 @@ async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -425,13 +426,17 @@ async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
assert_eq!(res.status(), 200, "body: {}", res.text().await?);
|
assert_eq!(res.status(), 200, "body: {}", res.text().await?);
|
||||||
|
|
||||||
let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap();
|
let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap();
|
||||||
assert!(!body["token"].as_str().unwrap().to_string().is_empty(), "body: {}", body);
|
assert!(
|
||||||
|
body["token"].as_str().unwrap().starts_with("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9"),
|
||||||
|
"body: {}",
|
||||||
|
body
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -543,7 +548,7 @@ async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -577,7 +582,7 @@ async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn version_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
async fn version_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -619,7 +624,7 @@ async fn seed_table(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -696,7 +701,7 @@ async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -759,7 +764,7 @@ async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -829,7 +834,7 @@ async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -899,7 +904,7 @@ async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -953,7 +958,7 @@ async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -994,7 +999,7 @@ async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -1091,7 +1096,7 @@ async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -1164,7 +1169,7 @@ async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
@ -1242,7 +1247,7 @@ async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[test(tokio::test)]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn key_endpoint_delete_one() -> Result<(), Box<dyn std::error::Error>> {
|
async fn key_endpoint_delete_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||||
|
|
1109
tests/ws_integration.rs
Normal file
1109
tests/ws_integration.rs
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue