[rpc] Better tracing for WebSockets (#2325)

This commit is contained in:
Salvador Girones Gil 2023-08-03 16:59:05 +02:00 committed by GitHub
parent ab72923fb5
commit e91011cc78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 2617 additions and 650 deletions

View file

@ -133,7 +133,7 @@ jobs:
args: ci-clippy args: ci-clippy
cli: cli:
name: Test command line name: CLI integration tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -163,7 +163,7 @@ jobs:
args: ci-cli-integration args: ci-cli-integration
http-server: http-server:
name: Test HTTP server name: HTTP integration tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -192,6 +192,28 @@ jobs:
command: make command: make
args: ci-http-integration args: ci-http-integration
ws-server:
name: WebSocket integration tests
runs-on: ubuntu-latest
steps:
- name: Install stable toolchain
uses: dtolnay/rust-toolchain@stable
- name: Checkout sources
uses: actions/checkout@v3
- name: Setup cache
uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: |
sudo apt-get -y update
sudo apt-get -y install protobuf-compiler libprotobuf-dev
- name: Run cargo test
run: cargo test --locked --no-default-features --features storage-mem --workspace --test ws_integration
test: test:
name: Test workspace name: Test workspace
runs-on: ubuntu-latest runs-on: ubuntu-latest

782
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -67,6 +67,7 @@ tokio-util = { version = "0.7.8", features = ["io"] }
tower = "0.4.13" tower = "0.4.13"
tower-http = { version = "0.4.2", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] } tower-http = { version = "0.4.2", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] }
tracing = "0.1" tracing = "0.1"
tracing-futures = { version = "0.2.5", features = ["tokio"], default-features = false }
tracing-opentelemetry = "0.19.0" tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
urlencoding = "2.1.2" urlencoding = "2.1.2"
@ -77,11 +78,14 @@ nix = "0.26.2"
[dev-dependencies] [dev-dependencies]
assert_fs = "1.0.13" assert_fs = "1.0.13"
env_logger = "0.10.0"
opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] } opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] }
rcgen = "0.10.0" rcgen = "0.10.0"
serial_test = "2.0.0" serial_test = "2.0.0"
temp-env = "0.3.4" temp-env = { version = "0.3.4", features = ["async_closure"] }
test-log = { version = "0.2.12", features = ["trace"] }
tokio-stream = { version = "0.1", features = ["net"] } tokio-stream = { version = "0.1", features = ["net"] }
tokio-tungstenite = { version = "0.18.0" }
tonic = "0.8.3" tonic = "0.8.3"
[package.metadata.deb] [package.metadata.deb]

View file

@ -115,7 +115,7 @@ pub async fn init(
listen_addresses, listen_addresses,
dbs, dbs,
web, web,
log: CustomEnvFilter(log), log,
tick_interval, tick_interval,
no_banner, no_banner,
.. ..

View file

@ -1,8 +1,9 @@
use clap::builder::{NonEmptyStringValueParser, PossibleValue, TypedValueParser}; use clap::builder::{NonEmptyStringValueParser, PossibleValue, TypedValueParser};
use clap::error::{ContextKind, ContextValue, ErrorKind}; use clap::error::{ContextKind, ContextValue, ErrorKind};
use tracing::Level;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use crate::telemetry::filter_from_value;
#[derive(Debug)] #[derive(Debug)]
pub struct CustomEnvFilter(pub EnvFilter); pub struct CustomEnvFilter(pub EnvFilter);
@ -37,20 +38,7 @@ impl TypedValueParser for CustomEnvFilterParser {
let inner = NonEmptyStringValueParser::new(); let inner = NonEmptyStringValueParser::new();
let v = inner.parse_ref(cmd, arg, value)?; let v = inner.parse_ref(cmd, arg, value)?;
let filter = (match v.as_str() { let filter = filter_from_value(v.as_str()).map_err(|e| {
// Don't show any logs at all
"none" => Ok(EnvFilter::default()),
// Check if we should show all log levels
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
// Otherwise, let's only show errors
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
// Specify the log level for each code area
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::txn=error")),
// Let's try to parse the custom log level
_ => EnvFilter::builder().parse(v),
})
.map_err(|e| {
let mut err = clap::Error::new(ErrorKind::ValueValidation).with_cmd(cmd); let mut err = clap::Error::new(ErrorKind::ValueValidation).with_cmd(cmd);
err.insert(ContextKind::Custom, ContextValue::String(e.to_string())); err.insert(ContextKind::Custom, ContextValue::String(e.to_string()));
err.insert( err.insert(

View file

@ -131,7 +131,7 @@ pub async fn init() -> Result<(), Error> {
// Setup the graceful shutdown with no timeout // Setup the graceful shutdown with no timeout
let handle = Handle::new(); let handle = Handle::new();
graceful_shutdown(handle.clone(), None); let shutdown_handler = graceful_shutdown(handle.clone());
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) { if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
// configure certificate and private key used by https // configure certificate and private key used by https
@ -156,6 +156,12 @@ pub async fn init() -> Result<(), Error> {
.await?; .await?;
}; };
// Wait for the shutdown to finish
let _ = shutdown_handler.await;
// Flush all telemetry data
opentelemetry::global::shutdown_tracer_provider();
info!(target: LOG, "Web server stopped. Bye!"); info!(target: LOG, "Web server stopped. Bye!");
Ok(()) Ok(())

View file

@ -7,26 +7,36 @@ use crate::err::Error;
use crate::rpc::args::Take; use crate::rpc::args::Take;
use crate::rpc::paths::{ID, METHOD, PARAMS}; use crate::rpc::paths::{ID, METHOD, PARAMS};
use crate::rpc::res; use crate::rpc::res;
use crate::rpc::res::Data;
use crate::rpc::res::Failure; use crate::rpc::res::Failure;
use crate::rpc::res::Output; use crate::rpc::res::IntoRpcResponse;
use crate::rpc::res::OutputFormat;
use crate::rpc::CONN_CLOSED_ERR;
use crate::telemetry::traces::rpc::span_for_request;
use axum::routing::get; use axum::routing::get;
use axum::Extension; use axum::Extension;
use axum::Router; use axum::Router;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use http_body::Body as HttpBody; use http_body::Body as HttpBody;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use surrealdb::channel; use surrealdb::channel;
use surrealdb::channel::Sender; use surrealdb::channel::{Receiver, Sender};
use surrealdb::dbs::{QueryType, Response, Session}; use surrealdb::dbs::{QueryType, Response, Session};
use surrealdb::sql::serde::deserialize;
use surrealdb::sql::Array; use surrealdb::sql::Array;
use surrealdb::sql::Object; use surrealdb::sql::Object;
use surrealdb::sql::Strand; use surrealdb::sql::Strand;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::instrument; use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tower_http::request_id::RequestId;
use tracing::Span;
use uuid::Uuid; use uuid::Uuid;
use axum::{ use axum::{
@ -35,11 +45,12 @@ use axum::{
}; };
// Mapping of WebSocketID to WebSocket // Mapping of WebSocketID to WebSocket
type WebSockets = RwLock<HashMap<Uuid, Sender<Message>>>; pub(crate) struct WebSocketRef(pub(crate) Sender<Message>, pub(crate) CancellationToken);
type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>;
// Mapping of LiveQueryID to WebSocketID // Mapping of LiveQueryID to WebSocketID
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>; type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default); pub(super) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default); static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
pub(super) fn router<S, B>() -> Router<S, B> pub(super) fn router<S, B>() -> Router<S, B>
@ -50,22 +61,36 @@ where
Router::new().route("/rpc", get(handler)) Router::new().route("/rpc", get(handler))
} }
async fn handler(ws: WebSocketUpgrade, Extension(sess): Extension<Session>) -> impl IntoResponse { async fn handler(
ws: WebSocketUpgrade,
Extension(sess): Extension<Session>,
Extension(req_id): Extension<RequestId>,
) -> impl IntoResponse {
// finalize the upgrade process by returning upgrade callback. // finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address. // we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, sess)) ws.on_upgrade(move |socket| handle_socket(socket, sess, req_id))
} }
async fn handle_socket(ws: WebSocket, sess: Session) { async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
let rpc = Rpc::new(sess); let rpc = Rpc::new(sess);
Rpc::serve(rpc, ws).await
// If the request ID is a valid UUID and is not already in use, use it as the WebSocket ID
match req_id.header_value().to_str().map(Uuid::parse_str) {
Ok(Ok(req_id)) if !WEBSOCKETS.read().await.contains_key(&req_id) => {
rpc.write().await.ws_id = req_id
}
_ => (),
}
Rpc::serve(rpc, ws).await;
} }
pub struct Rpc { pub struct Rpc {
session: Session, session: Session,
format: Output, format: OutputFormat,
uuid: Uuid, ws_id: Uuid,
vars: BTreeMap<String, Value>, vars: BTreeMap<String, Value>,
graceful_shutdown: CancellationToken,
} }
impl Rpc { impl Rpc {
@ -74,158 +99,247 @@ impl Rpc {
// Create a new RPC variables store // Create a new RPC variables store
let vars = BTreeMap::new(); let vars = BTreeMap::new();
// Set the default output format // Set the default output format
let format = Output::Json; let format = OutputFormat::Json;
// Create a unique WebSocket id // Enable real-time mode
let uuid = Uuid::new_v4();
// Enable real-time live queries
session.rt = true; session.rt = true;
// Create and store the Rpc connection // Create and store the Rpc connection
Arc::new(RwLock::new(Rpc { Arc::new(RwLock::new(Rpc {
session, session,
format, format,
uuid, ws_id: Uuid::new_v4(),
vars, vars,
graceful_shutdown: CancellationToken::new(),
})) }))
} }
/// Serve the RPC endpoint /// Serve the RPC endpoint
pub async fn serve(rpc: Arc<RwLock<Rpc>>, ws: WebSocket) { pub async fn serve(rpc: Arc<RwLock<Rpc>>, ws: WebSocket) {
// Create a channel for sending messages
let (chn, mut rcv) = channel::new(MAX_CONCURRENT_CALLS);
// Split the socket into send and recv // Split the socket into send and recv
let (mut wtx, mut wrx) = ws.split(); let (sender, receiver) = ws.split();
// Clone the channel for sending pings // Create an internal channel between the receiver and the sender
let png = chn.clone(); let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS);
// The WebSocket has connected
Rpc::connected(rpc.clone(), chn.clone()).await; let ws_id = rpc.read().await.ws_id;
// Send Ping messages to the client
tokio::task::spawn(async move {
// Create the interval ticker
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
// Loop indefinitely
loop {
// Wait for the timer
interval.tick().await;
// Create the ping message
let msg = Message::Ping(vec![]);
// Send the message to the client
if png.send(msg).await.is_err() {
// Exit out of the loop
break;
}
}
});
// Send messages to the client
tokio::task::spawn(async move {
// Wait for the next message to send
while let Some(res) = rcv.next().await {
// Send the message to the client
if let Err(err) = wtx.send(res).await {
// Output the WebSocket error to the logs
trace!("WebSocket error: {:?}", err);
// It's already failed, so ignore error
let _ = wtx.close().await;
// Exit out of the loop
break;
}
}
});
// Send notifications to the client
let moved_rpc = rpc.clone();
tokio::task::spawn(async move {
let rpc = moved_rpc;
if let Some(channel) = DB.get().unwrap().notifications() {
while let Ok(notification) = channel.recv().await {
// Find which WebSocket the notification belongs to
if let Some(ws_id) = LIVE_QUERIES.read().await.get(&notification.id) {
// Check to see if the WebSocket exists
if let Some(websocket) = WEBSOCKETS.read().await.get(ws_id) {
// Serialize the message to send
let message = res::success(None, notification);
// Get the current output format
let format = rpc.read().await.format.clone();
// Send the notification to the client
message.send(format, websocket.clone()).await;
}
}
}
}
});
// Get messages from the client
while let Some(msg) = wrx.next().await {
match msg {
// We've received a message from the client
// Ping is automatically handled by the WebSocket library
Ok(msg) => match msg {
Message::Text(_) => {
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
}
Message::Binary(_) => {
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
}
Message::Close(_) => {
break;
}
Message::Pong(_) => {
continue;
}
_ => {
// Ignore everything else
}
},
// There was an error receiving the message
Err(err) => {
// Output the WebSocket error to the logs
trace!("WebSocket error: {:?}", err);
// Exit out of the loop
break;
}
}
}
// The WebSocket has disconnected
Rpc::disconnected(rpc.clone()).await;
}
async fn connected(rpc: Arc<RwLock<Rpc>>, chn: Sender<Message>) {
// Fetch the unique id of the WebSocket
let id = rpc.read().await.uuid;
// Log that the WebSocket has connected
trace!("WebSocket {} connected", id);
// Store this WebSocket in the list of WebSockets // Store this WebSocket in the list of WebSockets
WEBSOCKETS.write().await.insert(id, chn); WEBSOCKETS.write().await.insert(
} ws_id,
WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()),
);
trace!("WebSocket {} connected", ws_id);
// Wait until all tasks finish
tokio::join!(
Self::ping(rpc.clone(), internal_sender.clone()),
Self::read(rpc.clone(), receiver, internal_sender.clone()),
Self::write(rpc.clone(), sender, internal_receiver.clone()),
Self::lq_notifications(rpc.clone()),
);
async fn disconnected(rpc: Arc<RwLock<Rpc>>) {
// Fetch the unique id of the WebSocket
let id = rpc.read().await.uuid;
// Log that the WebSocket has disconnected
trace!("WebSocket {} disconnected", id);
// Remove this WebSocket from the list of WebSockets
WEBSOCKETS.write().await.remove(&id);
// Remove all live queries // Remove all live queries
LIVE_QUERIES.write().await.retain(|key, value| { LIVE_QUERIES.write().await.retain(|key, value| {
if value == &id { if value == &ws_id {
trace!("Removing live query: {}", key); trace!("Removing live query: {}", key);
return false; return false;
} }
true true
}); });
// Remove this WebSocket from the list of WebSockets
WEBSOCKETS.write().await.remove(&ws_id);
trace!("WebSocket {} disconnected", ws_id);
} }
/// Call RPC methods from the WebSocket /// Send Ping messages to the client
async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) { async fn ping(rpc: Arc<RwLock<Rpc>>, internal_sender: Sender<Message>) {
// Create the interval ticker
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
let is_shutdown = cancel_token.cancelled();
tokio::select! {
_ = interval.tick() => {
let msg = Message::Ping(vec![]);
// Send the message to the client and close the WebSocket connection if it fails
if internal_sender.send(msg).await.is_err() {
rpc.read().await.graceful_shutdown.cancel();
break;
}
},
_ = is_shutdown => break,
}
}
}
/// Read messages sent from the client
async fn read(
rpc: Arc<RwLock<Rpc>>,
mut receiver: SplitStream<WebSocket>,
internal_sender: Sender<Message>,
) {
// Collect all spawned tasks so we can wait for them at the end
let mut tasks = JoinSet::new();
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
let is_shutdown = cancel_token.cancelled();
tokio::select! {
msg = receiver.next() => {
if let Some(msg) = msg {
match msg {
// We've received a message from the client
// Ping/Pong is automatically handled by the WebSocket library
Ok(msg) => match msg {
Message::Text(_) => {
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
}
Message::Binary(_) => {
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
}
Message::Close(_) => {
// Respond with a close message
if let Err(err) = internal_sender.send(Message::Close(None)).await {
trace!("WebSocket error when replying to the Close frame: {:?}", err);
};
// Start the graceful shutdown of the WebSocket and close the channels
rpc.read().await.graceful_shutdown.cancel();
let _ = internal_sender.close();
break;
}
_ => {
// Ignore everything else
}
},
Err(err) => {
trace!("WebSocket error: {:?}", err);
// Start the graceful shutdown of the WebSocket and close the channels
rpc.read().await.graceful_shutdown.cancel();
let _ = internal_sender.close();
// Exit out of the loop
break;
}
}
}
}
_ = is_shutdown => break,
}
}
// Wait for all tasks to finish
while let Some(res) = tasks.join_next().await {
if let Err(err) = res {
error!("Error while handling RPC message: {}", err);
}
}
}
/// Write messages to the client
async fn write(
rpc: Arc<RwLock<Rpc>>,
mut sender: SplitSink<WebSocket, Message>,
mut internal_receiver: Receiver<Message>,
) {
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
let is_shutdown = cancel_token.cancelled();
tokio::select! {
// Wait for the next message to send
msg = internal_receiver.next() => {
if let Some(res) = msg {
// Send the message to the client
if let Err(err) = sender.send(res).await {
if err.to_string() != CONN_CLOSED_ERR {
debug!("WebSocket error: {:?}", err);
}
// Close the WebSocket connection
rpc.read().await.graceful_shutdown.cancel();
// Exit out of the loop
break;
}
}
},
_ = is_shutdown => break,
}
}
}
/// Send live query notifications to the client
async fn lq_notifications(rpc: Arc<RwLock<Rpc>>) {
if let Some(channel) = DB.get().unwrap().notifications() {
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
tokio::select! {
msg = channel.recv() => {
if let Ok(notification) = msg {
// Find which WebSocket the notification belongs to
if let Some(ws_id) = LIVE_QUERIES.read().await.get(&notification.id) {
// Check to see if the WebSocket exists
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
// Serialize the message to send
let message = res::success(None, notification);
// Get the current output format
let format = rpc.read().await.format.clone();
// Send the notification to the client
message.send(format, ws.clone()).await
}
}
}
},
_ = cancel_token.cancelled() => break,
}
}
}
}
/// Handle individual WebSocket messages
async fn handle_msg(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
// Get the current output format // Get the current output format
let mut out = { rpc.read().await.format.clone() }; let mut out_fmt = rpc.read().await.format.clone();
// Clone the RPC let span = span_for_request(&rpc.read().await.ws_id);
let rpc = rpc.clone(); let _enter = span.enter();
// Parse the request // Parse the request
match Self::parse_request(msg).await {
Ok((id, method, params, _out_fmt)) => {
span.record(
"rpc.jsonrpc.request_id",
id.clone().map(|v| v.as_string()).unwrap_or(String::new()),
);
if let Some(_out_fmt) = _out_fmt {
out_fmt = _out_fmt;
}
// Process the request
let res = Self::process_request(rpc.clone(), &method, params).await;
// Process the response
res.into_response(id).send(out_fmt, chn).await
}
Err(err) => {
// Process the response
res::failure(None, err).send(out_fmt, chn).await
}
}
}
async fn parse_request(
msg: Message,
) -> Result<(Option<Value>, String, Array, Option<OutputFormat>), Failure> {
let mut out_fmt = None;
let req = match msg { let req = match msg {
// This is a binary message // This is a binary message
Message::Binary(val) => { Message::Binary(val) => {
// Use binary output // Use binary output
out = Output::Full; out_fmt = Some(OutputFormat::Full);
// Deserialize the input
Value::from(val) match deserialize(&val) {
Ok(v) => v,
Err(_) => {
debug!("Error when trying to deserialize the request");
return Err(Failure::PARSE_ERROR);
}
}
} }
// This is a text message // This is a text message
Message::Text(ref val) => { Message::Text(ref val) => {
@ -234,14 +348,15 @@ impl Rpc {
// The SurrealQL message parsed ok // The SurrealQL message parsed ok
Ok(v) => v, Ok(v) => v,
// The SurrealQL message failed to parse // The SurrealQL message failed to parse
_ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await, _ => return Err(Failure::PARSE_ERROR),
} }
} }
// Unsupported message type // Unsupported message type
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await, _ => {
debug!("Unsupported message type: {:?}", msg);
return Err(res::Failure::custom("Unsupported message type"));
}
}; };
// Log the received request
trace!("RPC Received: {}", req);
// Fetch the 'id' argument // Fetch the 'id' argument
let id = match req.pick(&*ID) { let id = match req.pick(&*ID) {
v if v.is_none() => None, v if v.is_none() => None,
@ -250,149 +365,180 @@ impl Rpc {
v if v.is_number() => Some(v), v if v.is_number() => Some(v),
v if v.is_strand() => Some(v), v if v.is_strand() => Some(v),
v if v.is_datetime() => Some(v), v if v.is_datetime() => Some(v),
_ => return res::failure(None, Failure::INVALID_REQUEST).send(out, chn).await, _ => return Err(Failure::INVALID_REQUEST),
}; };
// Fetch the 'method' argument // Fetch the 'method' argument
let method = match req.pick(&*METHOD) { let method = match req.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(), Value::Strand(v) => v.to_raw(),
_ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await, _ => return Err(Failure::INVALID_REQUEST),
}; };
// Now that we know the method, we can update the span
Span::current().record("rpc.method", &method);
Span::current().record("otel.name", format!("surrealdb.rpc/{}", method));
// Fetch the 'params' argument // Fetch the 'params' argument
let params = match req.pick(&*PARAMS) { let params = match req.pick(&*PARAMS) {
Value::Array(v) => v, Value::Array(v) => v,
_ => Array::new(), _ => Array::new(),
}; };
Ok((id, method, params, out_fmt))
}
async fn process_request(
rpc: Arc<RwLock<Rpc>>,
method: &str,
params: Array,
) -> Result<Data, Failure> {
info!("Process RPC request");
// Match the method to a function // Match the method to a function
let res = match &method[..] { match method {
// Handle a ping message // Handle a surrealdb ping message
"ping" => Ok(Value::None), //
// This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough.
// For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios.
// By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent.
//
"ping" => Ok(Value::None.into()),
// Retrieve the current auth record // Retrieve the current auth record
"info" => match params.len() { "info" => match params.len() {
0 => rpc.read().await.info().await, 0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, _ => Err(Failure::INVALID_PARAMS),
}, },
// Switch to a specific namespace and database // Switch to a specific namespace and database
"use" => match params.needs_two() { "use" => match params.needs_two() {
Ok((ns, db)) => rpc.write().await.yuse(ns, db).await, Ok((ns, db)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Signup to a specific authentication scope // Signup to a specific authentication scope
"signup" => match params.needs_one() { "signup" => match params.needs_one() {
Ok(Value::Object(v)) => rpc.write().await.signup(v).await, Ok(Value::Object(v)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Signin as a root, namespace, database or scope user // Signin as a root, namespace, database or scope user
"signin" => match params.needs_one() { "signin" => match params.needs_one() {
Ok(Value::Object(v)) => rpc.write().await.signin(v).await, Ok(Value::Object(v)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Invalidate the current authentication session // Invalidate the current authentication session
"invalidate" => match params.len() { "invalidate" => match params.len() {
0 => rpc.write().await.invalidate().await, 0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, _ => Err(Failure::INVALID_PARAMS),
}, },
// Authenticate using an authentication token // Authenticate using an authentication token
"authenticate" => match params.needs_one() { "authenticate" => match params.needs_one() {
Ok(Value::Strand(v)) => rpc.write().await.authenticate(v).await, Ok(Value::Strand(v)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Kill a live query using a query id // Kill a live query using a query id
"kill" => match params.needs_one() { "kill" => match params.needs_one() {
Ok(v) if v.is_uuid() => rpc.read().await.kill(v).await, Ok(v) if v.is_uuid() => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Setup a live query on a specific table // Setup a live query on a specific table
"live" => match params.needs_one_or_two() { "live" => match params.needs_one_or_two() {
Ok((v, d)) if v.is_table() => rpc.read().await.live(v, d).await, Ok((v, d)) if v.is_table() => {
Ok((v, d)) if v.is_strand() => rpc.read().await.live(v, d).await, rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, }
Ok((v, d)) if v.is_strand() => {
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Specify a connection-wide parameter // Specify a connection-wide parameter
"let" => match params.needs_one_or_two() { "let" | "set" => match params.needs_one_or_two() {
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await, Ok((Value::Strand(s), v)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into)
}, }
// Specify a connection-wide parameter _ => Err(Failure::INVALID_PARAMS),
"set" => match params.needs_one_or_two() {
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
}, },
// Unset and clear a connection-wide parameter // Unset and clear a connection-wide parameter
"unset" => match params.needs_one() { "unset" => match params.needs_one() {
Ok(Value::Strand(s)) => rpc.write().await.unset(s).await, Ok(Value::Strand(s)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Select a value or values from the database // Select a value or values from the database
"select" => match params.needs_one() { "select" => match params.needs_one() {
Ok(v) => rpc.read().await.select(v).await, Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, _ => Err(Failure::INVALID_PARAMS),
}, },
// Insert a value or values in the database // Insert a value or values in the database
"insert" => match params.needs_one_or_two() { "insert" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.insert(v, o).await, Ok((v, o)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Create a value or values in the database // Create a value or values in the database
"create" => match params.needs_one_or_two() { "create" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.create(v, o).await, Ok((v, o)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Update a value or values in the database using `CONTENT` // Update a value or values in the database using `CONTENT`
"update" => match params.needs_one_or_two() { "update" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.update(v, o).await, Ok((v, o)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Update a value or values in the database using `MERGE` // Update a value or values in the database using `MERGE`
"change" | "merge" => match params.needs_one_or_two() { "change" | "merge" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.change(v, o).await, Ok((v, o)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.read().await.change(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Update a value or values in the database using `PATCH` // Update a value or values in the database using `PATCH`
"modify" | "patch" => match params.needs_one_or_two() { "modify" | "patch" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.modify(v, o).await, Ok((v, o)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.read().await.modify(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Delete a value or values from the database // Delete a value or values from the database
"delete" => match params.needs_one() { "delete" => match params.needs_one() {
Ok(v) => rpc.read().await.delete(v).await, Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, _ => Err(Failure::INVALID_PARAMS),
}, },
// Specify the output format for text requests // Specify the output format for text requests
"format" => match params.needs_one() { "format" => match params.needs_one() {
Ok(Value::Strand(v)) => rpc.write().await.format(v).await, Ok(Value::Strand(v)) => {
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
}, },
// Get the current server version // Get the current server version
"version" => match params.len() { "version" => match params.len() {
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()), 0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, _ => Err(Failure::INVALID_PARAMS),
}, },
// Run a full SurrealQL query against the database // Run a full SurrealQL query against the database
"query" => match params.needs_one_or_two() { "query" => match params.needs_one_or_two() {
Ok((Value::Strand(s), o)) if o.is_none_or_null() => { Ok((Value::Strand(s), o)) if o.is_none_or_null() => {
return match rpc.read().await.query(s).await { rpc.read().await.query(s).await.map(Into::into).map_err(Into::into)
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => {
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
}
};
} }
Ok((Value::Strand(s), Value::Object(o))) => { Ok((Value::Strand(s), Value::Object(o))) => {
return match rpc.read().await.query_with(s, o).await { rpc.read().await.query_with(s, o).await.map(Into::into).map_err(Into::into)
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => {
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
}
};
} }
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await, _ => Err(Failure::INVALID_PARAMS),
}, },
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await, _ => Err(Failure::METHOD_NOT_FOUND),
};
// Return the final response
match res {
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await,
} }
} }
@ -402,15 +548,14 @@ impl Rpc {
async fn format(&mut self, out: Strand) -> Result<Value, Error> { async fn format(&mut self, out: Strand) -> Result<Value, Error> {
match out.as_str() { match out.as_str() {
"json" | "application/json" => self.format = Output::Json, "json" | "application/json" => self.format = OutputFormat::Json,
"cbor" | "application/cbor" => self.format = Output::Cbor, "cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
"pack" | "application/pack" => self.format = Output::Pack, "pack" | "application/pack" => self.format = OutputFormat::Pack,
_ => return Err(Error::InvalidType), _ => return Err(Error::InvalidType),
}; };
Ok(Value::None) Ok(Value::None)
} }
#[instrument(skip_all, name = "rpc use", fields(websocket=self.uuid.to_string()))]
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> { async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
if let Value::Strand(ns) = ns { if let Value::Strand(ns) = ns {
self.session.ns = Some(ns.0); self.session.ns = Some(ns.0);
@ -421,7 +566,6 @@ impl Rpc {
Ok(Value::None) Ok(Value::None)
} }
#[instrument(skip_all, name = "rpc signup", fields(websocket=self.uuid.to_string()))]
async fn signup(&mut self, vars: Object) -> Result<Value, Error> { async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
surrealdb::iam::signup::signup(kvs, &mut self.session, vars) surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
@ -430,7 +574,6 @@ impl Rpc {
.map_err(Into::into) .map_err(Into::into)
} }
#[instrument(skip_all, name = "rpc signin", fields(websocket=self.uuid.to_string()))]
async fn signin(&mut self, vars: Object) -> Result<Value, Error> { async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
surrealdb::iam::signin::signin(kvs, &mut self.session, vars) surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
@ -438,13 +581,11 @@ impl Rpc {
.map(Into::into) .map(Into::into)
.map_err(Into::into) .map_err(Into::into)
} }
#[instrument(skip_all, name = "rpc invalidate", fields(websocket=self.uuid.to_string()))]
async fn invalidate(&mut self) -> Result<Value, Error> { async fn invalidate(&mut self) -> Result<Value, Error> {
surrealdb::iam::clear::clear(&mut self.session)?; surrealdb::iam::clear::clear(&mut self.session)?;
Ok(Value::None) Ok(Value::None)
} }
#[instrument(skip_all, name = "rpc auth", fields(websocket=self.uuid.to_string()))]
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> { async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?; surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
@ -455,7 +596,6 @@ impl Rpc {
// Methods for identification // Methods for identification
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc info", fields(websocket=self.uuid.to_string()))]
async fn info(&self) -> Result<Value, Error> { async fn info(&self) -> Result<Value, Error> {
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
@ -473,7 +613,6 @@ impl Rpc {
// Methods for setting variables // Methods for setting variables
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc set", fields(websocket=self.uuid.to_string()))]
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> { async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
match val { match val {
// Remove the variable if undefined // Remove the variable if undefined
@ -484,7 +623,6 @@ impl Rpc {
Ok(Value::Null) Ok(Value::Null)
} }
#[instrument(skip_all, name = "rpc unset", fields(websocket=self.uuid.to_string()))]
async fn unset(&mut self, key: Strand) -> Result<Value, Error> { async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
self.vars.remove(&key.0); self.vars.remove(&key.0);
Ok(Value::Null) Ok(Value::Null)
@ -494,7 +632,6 @@ impl Rpc {
// Methods for live queries // Methods for live queries
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc kill", fields(websocket=self.uuid.to_string()))]
async fn kill(&self, id: Value) -> Result<Value, Error> { async fn kill(&self, id: Value) -> Result<Value, Error> {
// Specify the SQL query string // Specify the SQL query string
let sql = "KILL $id"; let sql = "KILL $id";
@ -513,7 +650,6 @@ impl Rpc {
} }
} }
#[instrument(skip_all, name = "rpc live", fields(websocket=self.uuid.to_string()))]
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> { async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
// Specify the SQL query string // Specify the SQL query string
let sql = match diff.is_true() { let sql = match diff.is_true() {
@ -539,7 +675,6 @@ impl Rpc {
// Methods for selecting // Methods for selecting
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc select", fields(websocket=self.uuid.to_string()))]
async fn select(&self, what: Value) -> Result<Value, Error> { async fn select(&self, what: Value) -> Result<Value, Error> {
// Return a single result? // Return a single result?
let one = what.is_thing(); let one = what.is_thing();
@ -567,7 +702,6 @@ impl Rpc {
// Methods for inserting // Methods for inserting
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc insert", fields(websocket=self.uuid.to_string()))]
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> { async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result? // Return a single result?
let one = what.is_thing(); let one = what.is_thing();
@ -596,7 +730,6 @@ impl Rpc {
// Methods for creating // Methods for creating
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc create", fields(websocket=self.uuid.to_string()))]
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> { async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result? // Return a single result?
let one = what.is_thing(); let one = what.is_thing();
@ -625,7 +758,6 @@ impl Rpc {
// Methods for updating // Methods for updating
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc update", fields(websocket=self.uuid.to_string()))]
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> { async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result? // Return a single result?
let one = what.is_thing(); let one = what.is_thing();
@ -654,7 +786,6 @@ impl Rpc {
// Methods for changing // Methods for changing
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc change", fields(websocket=self.uuid.to_string()))]
async fn change(&self, what: Value, data: Value) -> Result<Value, Error> { async fn change(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result? // Return a single result?
let one = what.is_thing(); let one = what.is_thing();
@ -683,7 +814,6 @@ impl Rpc {
// Methods for modifying // Methods for modifying
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc modify", fields(websocket=self.uuid.to_string()))]
async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> { async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result? // Return a single result?
let one = what.is_thing(); let one = what.is_thing();
@ -712,7 +842,6 @@ impl Rpc {
// Methods for deleting // Methods for deleting
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc delete", fields(websocket=self.uuid.to_string()))]
async fn delete(&self, what: Value) -> Result<Value, Error> { async fn delete(&self, what: Value) -> Result<Value, Error> {
// Return a single result? // Return a single result?
let one = what.is_thing(); let one = what.is_thing();
@ -740,7 +869,6 @@ impl Rpc {
// Methods for querying // Methods for querying
// ------------------------------ // ------------------------------
#[instrument(skip_all, name = "rpc query", fields(websocket=self.uuid.to_string()))]
async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> { async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> {
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
@ -756,7 +884,6 @@ impl Rpc {
Ok(res) Ok(res)
} }
#[instrument(skip_all, name = "rpc query_with", fields(websocket=self.uuid.to_string()))]
async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> { async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> {
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = DB.get().unwrap();
@ -781,8 +908,8 @@ impl Rpc {
QueryType::Live => { QueryType::Live => {
if let Ok(Value::Uuid(lqid)) = &res.result { if let Ok(Value::Uuid(lqid)) = &res.result {
// Match on Uuid type // Match on Uuid type
LIVE_QUERIES.write().await.insert(lqid.0, self.uuid); LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
trace!("Registered live query {} on websocket {}", lqid, self.uuid); trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
} }
} }
QueryType::Kill => { QueryType::Kill => {

View file

@ -1,17 +1,57 @@
use std::time::Duration; use std::time::Duration;
use axum_server::Handle; use axum_server::Handle;
use tokio::task::JoinHandle;
use crate::err::Error; use crate::{
err::Error,
net::rpc::{WebSocketRef, WEBSOCKETS},
};
/// Start a graceful shutdown on the Axum Handle when a shutdown signal is received. /// Start a graceful shutdown:
pub fn graceful_shutdown(handle: Handle, dur: Option<Duration>) { /// * Signal the Axum Handle when a shutdown signal is received.
/// * Stop all WebSocket connections.
///
/// A second signal will force an immediate shutdown.
pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> {
tokio::spawn(async move { tokio::spawn(async move {
let result = listen().await.expect("Failed to listen to shutdown signal"); let result = listen().await.expect("Failed to listen to shutdown signal");
info!(target: super::LOG, "{} received. Start graceful shutdown...", result); info!(target: super::LOG, "{} received. Waiting for graceful shutdown... A second signal will force an immediate shutdown", result);
handle.graceful_shutdown(dur) tokio::select! {
}); // Start a normal graceful shutdown
_ = async {
// First stop accepting new HTTP requests
http_handle.graceful_shutdown(None);
// Close all WebSocket connections. Queued messages will still be processed.
for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() {
cancel_token.cancel();
};
// Wait for all existing WebSocket connections to gracefully close
while WEBSOCKETS.read().await.len() > 0 {
tokio::time::sleep(Duration::from_millis(100)).await;
};
} => (),
// Force an immediate shutdown if a second signal is received
_ = async {
if let Ok(signal) = listen().await {
warn!(target: super::LOG, "{} received during graceful shutdown. Terminate immediately...", signal);
} else {
error!(target: super::LOG, "Failed to listen to shutdown signal. Terminate immediately...");
}
// Force an immediate shutdown
http_handle.shutdown();
// Close all WebSocket connections immediately
if let Ok(mut writer) = WEBSOCKETS.try_write() {
writer.drain();
}
} => (),
}
})
} }
#[cfg(unix)] #[cfg(unix)]

View file

@ -1,119 +1,15 @@
use std::{fmt, time::Duration}; use std::{fmt, time::Duration};
use axum::{ use axum::extract::MatchedPath;
body::{boxed, Body, BoxBody}, use http::header;
extract::MatchedPath,
headers::{
authorization::{Basic, Bearer},
Authorization, Origin,
},
Extension, RequestPartsExt, TypedHeader,
};
use futures_util::future::BoxFuture;
use http::{header, request::Parts, StatusCode};
use hyper::{Request, Response}; use hyper::{Request, Response};
use surrealdb::{
dbs::Session,
iam::verify::{basic, token},
};
use tower_http::{ use tower_http::{
auth::AsyncAuthorizeRequest,
request_id::RequestId, request_id::RequestId,
trace::{MakeSpan, OnFailure, OnRequest, OnResponse}, trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
}; };
use tracing::{field, Level, Span}; use tracing::{field, Level, Span};
use crate::{dbs::DB, err::Error}; use super::client_ip::ExtractClientIP;
use super::{client_ip::ExtractClientIP, AppState};
///
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
///
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
///
/// ```rust
/// use tower_http::auth::RequireAuthorizationLayer;
/// use surrealdb::net::SurrealAuth;
/// use axum::Router;
///
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
///
/// let app = Router::new()
/// .route("/version", get(|| async { "0.1.0" }))
/// .layer(auth);
/// ```
#[derive(Clone, Copy)]
pub(super) struct SurrealAuth;
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
where
B: Send + Sync + 'static,
{
type RequestBody = B;
type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future {
Box::pin(async {
let (mut parts, body) = request.into_parts();
match check_auth(&mut parts).await {
Ok(sess) => {
parts.extensions.insert(sess);
Ok(Request::from_parts(parts, body))
}
Err(err) => {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(boxed(Body::from(err.to_string())))
.unwrap();
Err(unauthorized_response)
}
}
})
}
}
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
let kvs = DB.get().unwrap();
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
if !or.is_null() {
Some(or.to_string())
} else {
None
}
} else {
None
};
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
tracing::error!("Error extracting the app state: {:?}", err);
Error::InvalidAuth
})?;
let ExtractClientIP(ip) =
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
// Create session
#[rustfmt::skip]
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
// If Basic authentication data was supplied
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
basic(kvs, &mut session, au.username(), au.password()).await.map_err(|e| e.into())
} else if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
token(kvs, &mut session, au.token()).await.map_err(|e| e.into())
} else {
Err(Error::InvalidAuth)
}?;
Ok(session)
}
/// ///
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer. /// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
@ -139,7 +35,6 @@ impl<B> MakeSpan<B> for HttpTraceLayerHooks {
fn make_span(&mut self, req: &Request<B>) -> Span { fn make_span(&mut self, req: &Request<B>) -> Span {
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md // The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
let span = tracing::info_span!( let span = tracing::info_span!(
target: "surreal::http",
"request", "request",
otel.name = field::Empty, otel.name = field::Empty,
otel.kind = "server", otel.kind = "server",
@ -154,10 +49,10 @@ impl<B> MakeSpan<B> for HttpTraceLayerHooks {
network.protocol.name = "http", network.protocol.name = "http",
network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"), network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"),
client.address = field::Empty, client.address = field::Empty,
client.port = field::Empty, client.port = field::Empty,
client.socket.address = field::Empty, client.socket.address = field::Empty,
server.address = field::Empty, server.address = field::Empty,
server.port = field::Empty, server.port = field::Empty,
// set on the response hook // set on the response hook
http.latency.ms = field::Empty, http.latency.ms = field::Empty,
http.response.status_code = field::Empty, http.response.status_code = field::Empty,

View file

@ -1,3 +1,5 @@
pub mod args; pub mod args;
pub mod paths; pub mod paths;
pub mod res; pub mod res;
pub(crate) static CONN_CLOSED_ERR: &str = "Connection closed normally";

View file

@ -7,10 +7,13 @@ use surrealdb::dbs;
use surrealdb::dbs::Notification; use surrealdb::dbs::Notification;
use surrealdb::sql; use surrealdb::sql;
use surrealdb::sql::Value; use surrealdb::sql::Value;
use tracing::instrument; use tracing::Span;
#[derive(Clone)] use crate::err;
pub enum Output { use crate::rpc::CONN_CLOSED_ERR;
#[derive(Debug, Clone)]
pub enum OutputFormat {
Json, // JSON Json, // JSON
Cbor, // CBOR Cbor, // CBOR
Pack, // MessagePack Pack, // MessagePack
@ -37,6 +40,12 @@ impl From<Value> for Data {
} }
} }
impl From<String> for Data {
fn from(v: String) -> Self {
Data::Other(Value::from(v))
}
}
impl From<Vec<dbs::Response>> for Data { impl From<Vec<dbs::Response>> for Data {
fn from(v: Vec<dbs::Response>) -> Self { fn from(v: Vec<dbs::Response>) -> Self {
Data::Query(v) Data::Query(v)
@ -82,28 +91,45 @@ impl Response {
} }
/// Send the response to the WebSocket channel /// Send the response to the WebSocket channel
#[instrument(skip_all, name = "rpc response", fields(response = ?self))] pub async fn send(self, out: OutputFormat, chn: Sender<Message>) {
pub async fn send(self, out: Output, chn: Sender<Message>) { let span = Span::current();
info!("Process RPC response");
if let Err(err) = &self.result {
span.record("otel.status_code", "Error");
span.record(
"otel.status_message",
format!("code: {}, message: {}", err.code, err.message),
);
span.record("rpc.jsonrpc.error_code", err.code);
span.record("rpc.jsonrpc.error_message", err.message.as_ref());
}
let message = match out { let message = match out {
Output::Json => { OutputFormat::Json => {
let res = serde_json::to_string(&self.simplify()).unwrap(); let res = serde_json::to_string(&self.simplify()).unwrap();
Message::Text(res) Message::Text(res)
} }
Output::Cbor => { OutputFormat::Cbor => {
let res = serde_cbor::to_vec(&self.simplify()).unwrap(); let res = serde_cbor::to_vec(&self.simplify()).unwrap();
Message::Binary(res) Message::Binary(res)
} }
Output::Pack => { OutputFormat::Pack => {
let res = serde_pack::to_vec(&self.simplify()).unwrap(); let res = serde_pack::to_vec(&self.simplify()).unwrap();
Message::Binary(res) Message::Binary(res)
} }
Output::Full => { OutputFormat::Full => {
let res = surrealdb::sql::serde::serialize(&self).unwrap(); let res = surrealdb::sql::serde::serialize(&self).unwrap();
Message::Binary(res) Message::Binary(res)
} }
}; };
let _ = chn.send(message).await;
trace!("Response sent"); if let Err(err) = chn.send(message).await {
if err.to_string() != CONN_CLOSED_ERR {
error!("Error sending response: {}", err);
}
};
} }
} }
@ -113,6 +139,7 @@ pub struct Failure {
message: Cow<'static, str>, message: Cow<'static, str>,
} }
#[allow(dead_code)]
impl Failure { impl Failure {
pub const PARSE_ERROR: Failure = Failure { pub const PARSE_ERROR: Failure = Failure {
code: -32700, code: -32700,
@ -165,3 +192,26 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
result: Err(err), result: Err(err),
} }
} }
impl From<err::Error> for Failure {
fn from(err: err::Error) -> Self {
Failure::custom(err.to_string())
}
}
pub trait IntoRpcResponse {
fn into_response(self, id: Option<Value>) -> Response;
}
impl<T, E> IntoRpcResponse for Result<T, E>
where
T: Into<Data>,
E: Into<Failure>,
{
fn into_response(self, id: Option<Value>) -> Response {
match self {
Ok(v) => success(id, v.into()),
Err(err) => failure(id, err.into()),
}
}
}

View file

@ -1,8 +1,10 @@
use tracing::Subscriber; use tracing::Subscriber;
use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::Layer;
pub fn new<S>(level: String) -> Box<dyn Layer<S> + Send + Sync> use crate::cli::validator::parser::env_filter::CustomEnvFilter;
pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
where where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync, S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
{ {
@ -11,6 +13,6 @@ where
.with_ansi(true) .with_ansi(true)
.with_span_events(FmtSpan::NONE) .with_span_events(FmtSpan::NONE)
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_filter(EnvFilter::builder().parse(level).unwrap()) .with_filter(filter.0)
.boxed() .boxed()
} }

View file

@ -1,6 +1,6 @@
mod logs; mod logs;
pub mod metrics; pub mod metrics;
mod traces; pub mod traces;
use std::time::Duration; use std::time::Duration;
@ -11,8 +11,7 @@ use opentelemetry::sdk::resource::{
}; };
use opentelemetry::sdk::Resource; use opentelemetry::sdk::Resource;
use opentelemetry::KeyValue; use opentelemetry::KeyValue;
use tracing::Subscriber; use tracing::{Level, Subscriber};
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
#[cfg(feature = "has-storage")] #[cfg(feature = "has-storage")]
@ -39,53 +38,75 @@ pub static OTEL_DEFAULT_RESOURCE: Lazy<Resource> = Lazy::new(|| {
} }
}); });
#[derive(Default, Debug, Clone)] #[derive(Debug, Clone)]
pub struct Builder { pub struct Builder {
log_level: Option<String>, filter: CustomEnvFilter,
filter: Option<CustomEnvFilter>,
} }
pub fn builder() -> Builder { pub fn builder() -> Builder {
Builder::default() Builder::default()
} }
impl Default for Builder {
fn default() -> Self {
Self {
filter: CustomEnvFilter(EnvFilter::default()),
}
}
}
impl Builder { impl Builder {
/// Set the log level on the builder /// Set the log level on the builder
pub fn with_log_level(mut self, log_level: &str) -> Self { pub fn with_log_level(mut self, log_level: &str) -> Self {
self.log_level = Some(log_level.to_string()); if let Ok(filter) = filter_from_value(log_level) {
self.filter = CustomEnvFilter(filter);
}
self self
} }
/// Set the filter on the builder /// Set the filter on the builder
#[cfg(feature = "has-storage")] #[cfg(feature = "has-storage")]
pub fn with_filter(mut self, filter: EnvFilter) -> Self { pub fn with_filter(mut self, filter: CustomEnvFilter) -> Self {
self.filter = Some(CustomEnvFilter(filter)); self.filter = filter;
self self
} }
/// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber /// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> { pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> {
let registry = tracing_subscriber::registry(); let registry = tracing_subscriber::registry();
let registry = registry.with(self.filter.map(|filter| {
tracing_subscriber::fmt::layer() // Setup logging layer
.compact() let registry = registry.with(logs::new(self.filter.clone()));
.with_ansi(true)
.with_span_events(FmtSpan::NONE) // Setup tracing layer
.with_writer(std::io::stderr) let registry = registry.with(traces::new(self.filter));
.with_filter(filter.0)
.boxed()
}));
let registry = registry.with(self.log_level.map(logs::new));
let registry = registry.with(traces::new());
Box::new(registry) Box::new(registry)
} }
/// tracing pipeline /// Install the tracing dispatcher globally
pub fn init(self) { pub fn init(self) {
self.build().init() self.build().init()
} }
} }
/// Create an EnvFilter from the given value. If the value is not a valid log level, it will be treated as EnvFilter directives.
pub fn filter_from_value(v: &str) -> Result<EnvFilter, tracing_subscriber::filter::ParseError> {
match v {
// Don't show any logs at all
"none" => Ok(EnvFilter::default()),
// Check if we should show all log levels
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
// Otherwise, let's only show errors
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
// Specify the log level for each code area
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::kvs::tx=error")),
// Let's try to parse the custom log level
_ => EnvFilter::builder().parse(v),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use opentelemetry::global::shutdown_tracer_provider; use opentelemetry::global::shutdown_tracer_provider;
@ -107,7 +128,7 @@ mod tests {
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())), ("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
], ],
|| { || {
let _enter = telemetry::builder().build().set_default(); let _enter = telemetry::builder().with_log_level("info").build().set_default();
println!("Sending span..."); println!("Sending span...");
@ -123,7 +144,11 @@ mod tests {
} }
println!("Waiting for request..."); println!("Waiting for request...");
let req = req_rx.recv().await.expect("missing export request"); let req = tokio::select! {
req = req_rx.recv() => req.expect("missing export request"),
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
};
let first_span = let first_span =
req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap(); req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap();
assert_eq!("test-surreal-span", first_span.name); assert_eq!("test-surreal-span", first_span.name);
@ -141,11 +166,10 @@ mod tests {
temp_env::with_vars( temp_env::with_vars(
vec![ vec![
("SURREAL_TRACING_TRACER", Some("otlp")), ("SURREAL_TRACING_TRACER", Some("otlp")),
("SURREAL_TRACING_FILTER", Some("debug")),
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())), ("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
], ],
|| { || {
let _enter = telemetry::builder().build().set_default(); let _enter = telemetry::builder().with_log_level("debug").build().set_default();
println!("Sending spans..."); println!("Sending spans...");
@ -169,7 +193,10 @@ mod tests {
} }
println!("Waiting for request..."); println!("Waiting for request...");
let req = req_rx.recv().await.expect("missing export request"); let req = tokio::select! {
req = req_rx.recv() => req.expect("missing export request"),
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
};
let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans; let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans;
assert_eq!(1, spans.len()); assert_eq!(1, spans.len());

View file

@ -1,12 +1,15 @@
use tracing::Subscriber; use tracing::Subscriber;
use tracing_subscriber::Layer; use tracing_subscriber::Layer;
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
pub mod otlp; pub mod otlp;
pub mod rpc;
const TRACING_TRACER_VAR: &str = "SURREAL_TRACING_TRACER"; const TRACING_TRACER_VAR: &str = "SURREAL_TRACING_TRACER";
// Returns a tracer based on the value of the TRACING_TRACER_VAR env var // Returns a tracer based on the value of the TRACING_TRACER_VAR env var
pub fn new<S>() -> Option<Box<dyn Layer<S> + Send + Sync>> pub fn new<S>(filter: CustomEnvFilter) -> Option<Box<dyn Layer<S> + Send + Sync>>
where where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync, S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
{ {
@ -20,7 +23,7 @@ where
// Init the registry with the OTLP tracer // Init the registry with the OTLP tracer
"otlp" => { "otlp" => {
debug!("Setup the OTLP tracer"); debug!("Setup the OTLP tracer");
Some(otlp::new()) Some(otlp::new(filter))
} }
tracer => { tracer => {
panic!("unsupported tracer {}", tracer); panic!("unsupported tracer {}", tracer);

View file

@ -1,18 +1,18 @@
use opentelemetry::sdk::trace::Tracer; use opentelemetry::sdk::trace::Tracer;
use opentelemetry::trace::TraceError; use opentelemetry::trace::TraceError;
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use tracing::{Level, Subscriber}; use tracing::Subscriber;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::Layer;
use crate::telemetry::OTEL_DEFAULT_RESOURCE; use crate::{
cli::validator::parser::env_filter::CustomEnvFilter, telemetry::OTEL_DEFAULT_RESOURCE,
};
const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER"; pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
pub fn new<S>() -> Box<dyn Layer<S> + Send + Sync>
where where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync, S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
{ {
tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter()).boxed() tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter.0).boxed()
} }
fn tracer() -> Result<Tracer, TraceError> { fn tracer() -> Result<Tracer, TraceError> {
@ -24,16 +24,3 @@ fn tracer() -> Result<Tracer, TraceError> {
) )
.install_batch(opentelemetry::runtime::Tokio) .install_batch(opentelemetry::runtime::Tokio)
} }
/// Create a filter for the OTLP subscriber
///
/// It creates an EnvFilter based on the TRACING_FILTER_VAR's value
///
/// TRACING_FILTER_VAR accepts the same syntax as RUST_LOG
fn filter() -> EnvFilter {
EnvFilter::builder()
.with_env_var(TRACING_FILTER_VAR)
.with_default_directive(Level::INFO.into())
.from_env()
.unwrap()
}

View file

@ -0,0 +1,31 @@
use tracing::{field, Span};
use uuid::Uuid;
pub fn span_for_request(ws_id: &Uuid) -> Span {
let span = tracing::info_span!(
// Dynamic span names need to be 'recorded', can't be used on the macro. Use a static name here and overwrite later on
"rpc/call",
otel.name = field::Empty,
otel.kind = "server",
// To be populated by the request handler when the method is known
rpc.method = field::Empty,
rpc.service = "surrealdb",
rpc.system = "jsonrpc",
// JSON-RPC fields
rpc.jsonrpc.version = "2.0",
rpc.jsonrpc.request_id = field::Empty,
rpc.jsonrpc.error_code = field::Empty,
rpc.jsonrpc.error_message = field::Empty,
// SurrealDB custom fields
ws.id = %ws_id,
// Fields for error reporting
otel.status_code = field::Empty,
otel.status_message = field::Empty,
);
span
}

View file

@ -5,6 +5,8 @@ mod common;
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild}; use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
use serial_test::serial; use serial_test::serial;
use std::fs; use std::fs;
use test_log::test;
use tracing::info;
use common::{PASS, USER}; use common::{PASS, USER};
@ -32,13 +34,14 @@ fn nonexistent_option() {
assert!(common::run("version --turbo").output().is_err()); assert!(common::run("version --turbo").output().is_err());
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn all_commands() { async fn all_commands() {
// Commands without credentials when auth is disabled, should succeed // Commands without credentials when auth is disabled, should succeed
let (addr, _server) = common::start_server(false, false, true).await.unwrap(); let (addr, _server) = common::start_server(false, false, true).await.unwrap();
let creds = ""; // Anonymous user let creds = ""; // Anonymous user
// Create a record
info!("* Create a record");
{ {
let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi"); let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi");
assert_eq!( assert_eq!(
@ -48,7 +51,7 @@ async fn all_commands() {
); );
} }
// Export to stdout info!("* Export to stdout");
{ {
let args = format!("export --conn http://{addr} {creds} --ns N --db D -"); let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
let output = common::run(&args).output().expect("failed to run stdout export: {args}"); let output = common::run(&args).output().expect("failed to run stdout export: {args}");
@ -56,7 +59,7 @@ async fn all_commands() {
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };")); assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
} }
// Export to file info!("* Export to file");
let exported = { let exported = {
let exported = common::tmp_file("exported.surql"); let exported = common::tmp_file("exported.surql");
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}"); let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
@ -64,13 +67,13 @@ async fn all_commands() {
exported exported
}; };
// Import the exported file info!("* Import the exported file");
{ {
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}"); let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
common::run(&args).output().expect("failed to run import: {args}"); common::run(&args).output().expect("failed to run import: {args}");
} }
// Query from the import (pretty-printed this time) info!("* Query from the import (pretty-printed this time)");
{ {
let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty"); let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty");
assert_eq!( assert_eq!(
@ -80,7 +83,7 @@ async fn all_commands() {
); );
} }
// Unfinished backup CLI info!("* Unfinished backup CLI");
{ {
let file = common::tmp_file("backup.db"); let file = common::tmp_file("backup.db");
let args = format!("backup {creds} http://{addr} {file}"); let args = format!("backup {creds} http://{addr} {file}");
@ -90,7 +93,7 @@ async fn all_commands() {
assert_eq!(fs::read_to_string(file).unwrap(), "Save"); assert_eq!(fs::read_to_string(file).unwrap(), "Save");
} }
// Multi-statement (and multi-line) query including error(s) over WS info!("* Multi-statement (and multi-line) query including error(s) over WS");
{ {
let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty"); let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty");
let output = common::run(&args) let output = common::run(&args)
@ -113,7 +116,7 @@ async fn all_commands() {
assert!(output.contains("thing:also_success"), "missing also_success in {output}") assert!(output.contains("thing:also_success"), "missing also_success in {output}")
} }
// Multi-statement (and multi-line) transaction including error(s) over WS info!("* Multi-statement (and multi-line) transaction including error(s) over WS");
{ {
let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty"); let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty");
let output = common::run(&args) let output = common::run(&args)
@ -137,7 +140,7 @@ async fn all_commands() {
assert!(output.contains("rgument"), "missing argument error in {output}"); assert!(output.contains("rgument"), "missing argument error in {output}");
} }
// Pass neither ns nor db info!("* Pass neither ns nor db");
{ {
let args = format!("sql --conn http://{addr} {creds}"); let args = format!("sql --conn http://{addr} {creds}");
let output = common::run(&args) let output = common::run(&args)
@ -147,7 +150,7 @@ async fn all_commands() {
assert!(output.contains("thing:one"), "missing thing:one in {output}"); assert!(output.contains("thing:one"), "missing thing:one in {output}");
} }
// Pass only ns info!("* Pass only ns");
{ {
let args = format!("sql --conn http://{addr} {creds} --ns N5"); let args = format!("sql --conn http://{addr} {creds} --ns N5");
let output = common::run(&args) let output = common::run(&args)
@ -157,26 +160,36 @@ async fn all_commands() {
assert!(output.contains("thing:one"), "missing thing:one in {output}"); assert!(output.contains("thing:one"), "missing thing:one in {output}");
} }
// Pass only db and expect an error info!("* Pass only db and expect an error");
{ {
let args = format!("sql --conn http://{addr} {creds} --db D5"); let args = format!("sql --conn http://{addr} {creds} --db D5");
common::run(&args).output().expect_err("only db"); common::run(&args).output().expect_err("only db");
} }
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn start_tls() { async fn start_tls() {
let (_, server) = common::start_server(false, true, false).await.unwrap(); // Capute the server's stdout/stderr
temp_env::async_with_vars(
[
("SURREAL_TEST_SERVER_STDOUT", Some("piped")),
("SURREAL_TEST_SERVER_STDERR", Some("piped")),
],
async {
let (_, server) = common::start_server(false, true, false).await.unwrap();
std::thread::sleep(std::time::Duration::from_millis(2000)); std::thread::sleep(std::time::Duration::from_millis(2000));
let output = server.kill().output().err().unwrap(); let output = server.kill().output().err().unwrap();
// Test the crt/key args but the keys are self signed so don't actually connect. // Test the crt/key args but the keys are self signed so don't actually connect.
assert!(output.contains("Started web server"), "couldn't start web server: {output}"); assert!(output.contains("Started web server"), "couldn't start web server: {output}");
},
)
.await;
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn with_root_auth() { async fn with_root_auth() {
// Commands with credentials when auth is enabled, should succeed // Commands with credentials when auth is enabled, should succeed
@ -184,7 +197,7 @@ async fn with_root_auth() {
let creds = format!("--user {USER} --pass {PASS}"); let creds = format!("--user {USER} --pass {PASS}");
let sql_args = format!("sql --conn http://{addr} --multi --pretty"); let sql_args = format!("sql --conn http://{addr} --multi --pretty");
// Can query /sql over HTTP info!("* Query over HTTP");
{ {
let args = format!("{sql_args} {creds}"); let args = format!("{sql_args} {creds}");
let input = "INFO FOR ROOT;"; let input = "INFO FOR ROOT;";
@ -192,7 +205,7 @@ async fn with_root_auth() {
assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap()); assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap());
} }
// Can query /sql over WS info!("* Query over WS");
{ {
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}"); let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
let input = "INFO FOR ROOT;"; let input = "INFO FOR ROOT;";
@ -200,7 +213,7 @@ async fn with_root_auth() {
assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap()); assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap());
} }
// KV user can do exports info!("* Root user can do exports");
let exported = { let exported = {
let exported = common::tmp_file("exported.surql"); let exported = common::tmp_file("exported.surql");
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}"); let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
@ -209,13 +222,13 @@ async fn with_root_auth() {
exported exported
}; };
// KV user can do imports info!("* Root user can do imports");
{ {
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}"); let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}")); common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
} }
// KV user can do backups info!("* Root user can do backups");
{ {
let file = common::tmp_file("backup.db"); let file = common::tmp_file("backup.db");
let args = format!("backup {creds} http://{addr} {file}"); let args = format!("backup {creds} http://{addr} {file}");
@ -226,7 +239,7 @@ async fn with_root_auth() {
} }
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn with_anon_auth() { async fn with_anon_auth() {
// Commands without credentials when auth is enabled, should fail // Commands without credentials when auth is enabled, should fail
@ -234,7 +247,7 @@ async fn with_anon_auth() {
let creds = ""; // Anonymous user let creds = ""; // Anonymous user
let sql_args = format!("sql --conn http://{addr} --multi --pretty"); let sql_args = format!("sql --conn http://{addr} --multi --pretty");
// Can query /sql over HTTP info!("* Query over HTTP");
{ {
let args = format!("{sql_args} {creds}"); let args = format!("{sql_args} {creds}");
let input = ""; let input = "";
@ -242,7 +255,7 @@ async fn with_anon_auth() {
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output); assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
} }
// Can query /sql over HTTP info!("* Query over WS");
{ {
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}"); let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
let input = ""; let input = "";
@ -250,7 +263,7 @@ async fn with_anon_auth() {
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output); assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
} }
// Can't do exports info!("* Can't do exports");
{ {
let args = format!("export --conn http://{addr} {creds} --ns N --db D -"); let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
let output = common::run(&args).output(); let output = common::run(&args).output();
@ -261,7 +274,7 @@ async fn with_anon_auth() {
); );
} }
// Can't do imports info!("* Can't do imports");
{ {
let tmp_file = common::tmp_file("exported.surql"); let tmp_file = common::tmp_file("exported.surql");
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {tmp_file}"); let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {tmp_file}");
@ -273,7 +286,7 @@ async fn with_anon_auth() {
); );
} }
// Can't do backups info!("* Can't do backups");
{ {
let args = format!("backup {creds} http://{addr}"); let args = format!("backup {creds} http://{addr}");
let output = common::run(&args).output(); let output = common::run(&args).output();

View file

@ -1,10 +1,18 @@
#![allow(dead_code)] #![allow(dead_code)]
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::error::Error; use std::error::Error;
use std::fs; use std::fs::File;
use std::path::Path; use std::path::Path;
use std::process::{Command, Stdio}; use std::process::{Command, Stdio};
use std::{env, fs};
use tokio::net::TcpStream;
use tokio::time; use tokio::time;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tracing::{error, info};
pub const USER: &str = "root"; pub const USER: &str = "root";
pub const PASS: &str = "root"; pub const PASS: &str = "root";
@ -52,7 +60,12 @@ impl Drop for Child {
} }
} }
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child { pub fn run_internal<P: AsRef<Path>>(
args: &str,
current_dir: Option<P>,
stdout: Stdio,
stderr: Stdio,
) -> Child {
let mut path = std::env::current_exe().unwrap(); let mut path = std::env::current_exe().unwrap();
assert!(path.pop()); assert!(path.pop());
if path.ends_with("deps") { if path.ends_with("deps") {
@ -68,8 +81,8 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
} }
cmd.env_clear(); cmd.env_clear();
cmd.stdin(Stdio::piped()); cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped()); cmd.stdout(stdout);
cmd.stderr(Stdio::piped()); cmd.stderr(stderr);
cmd.args(args.split_ascii_whitespace()); cmd.args(args.split_ascii_whitespace());
Child { Child {
inner: Some(cmd.spawn().unwrap()), inner: Some(cmd.spawn().unwrap()),
@ -78,12 +91,12 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
/// Run the CLI with the given args /// Run the CLI with the given args
pub fn run(args: &str) -> Child { pub fn run(args: &str) -> Child {
run_internal::<String>(args, None) run_internal::<String>(args, None, Stdio::piped(), Stdio::piped())
} }
/// Run the CLI with the given args inside a temporary directory /// Run the CLI with the given args inside a temporary directory
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child { pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir)) run_internal(args, Some(current_dir), Stdio::piped(), Stdio::piped())
} }
pub fn tmp_file(name: &str) -> String { pub fn tmp_file(name: &str) -> String {
@ -91,6 +104,19 @@ pub fn tmp_file(name: &str) -> String {
path.to_string_lossy().into_owned() path.to_string_lossy().into_owned()
} }
fn parse_server_stdio_from_var(var: &str) -> Result<Stdio, Box<dyn Error>> {
match env::var(var).as_deref() {
Ok("inherit") => Ok(Stdio::inherit()),
Ok("null") => Ok(Stdio::null()),
Ok("piped") => Ok(Stdio::piped()),
Ok(val) if val.starts_with("file://") => {
Ok(Stdio::from(File::create(val.trim_start_matches("file://"))?))
}
Ok(val) => Err(format!("Unsupported stdio value: {val:?}").into()),
_ => Ok(Stdio::null()),
}
}
pub async fn start_server( pub async fn start_server(
auth: bool, auth: bool,
tls: bool, tls: bool,
@ -118,11 +144,14 @@ pub async fn start_server(
extra_args.push_str(" --auth"); extra_args.push_str(" --auth");
} }
let start_args = format!("start --bind {addr} memory --no-banner --log info --user {USER} --pass {PASS} {extra_args}"); let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
println!("starting server with args: {start_args}"); info!("starting server with args: {start_args}");
let server = run(&start_args); // Configure where the logs go when running the test
let stdout = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDOUT")?;
let stderr = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDERR")?;
let server = run_internal::<String>(&start_args, None, stdout, stderr);
if !wait_is_ready { if !wait_is_ready {
return Ok((addr, server)); return Ok((addr, server));
@ -130,17 +159,178 @@ pub async fn start_server(
// Wait 5 seconds for the server to start // Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(500)); let mut interval = time::interval(time::Duration::from_millis(500));
println!("Waiting for server to start..."); info!("Waiting for server to start...");
for _i in 0..10 { for _i in 0..10 {
interval.tick().await; interval.tick().await;
if run(&format!("isready --conn http://{addr}")).output().is_ok() { if run(&format!("isready --conn http://{addr}")).output().is_ok() {
println!("Server ready!"); info!("Server ready!");
return Ok((addr, server)); return Ok((addr, server));
} }
} }
let server_out = server.kill().output().err().unwrap(); let server_out = server.kill().output().err().unwrap();
println!("server output: {server_out}"); error!("server output: {server_out}");
Err("server failed to start".into()) Err("server failed to start".into())
} }
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
let url = format!("ws://{}/rpc", addr);
let (ws_stream, _) = connect_async(url).await?;
Ok(ws_stream)
}
pub async fn ws_send_msg(
socket: &mut WsStream,
msg_req: String,
) -> Result<serde_json::Value, Box<dyn Error>> {
// Use JSON format by default
ws_send_msg_with_fmt(socket, msg_req, Format::Json).await
}
pub enum Format {
Json,
Cbor,
Pack,
}
pub async fn ws_send_msg_with_fmt(
socket: &mut WsStream,
msg_req: String,
response_format: Format,
) -> Result<serde_json::Value, Box<dyn Error>> {
tokio::select! {
_ = time::sleep(time::Duration::from_millis(500)) => {
return Err("timeout waiting for the request to be sent".into());
}
res = socket.send(Message::Text(msg_req)) => {
if let Err(err) = res {
return Err(format!("Error sending the message: {}", err).into());
}
}
}
let mut f = socket.try_filter(|msg| match response_format {
Format::Json => futures_util::future::ready(msg.is_text()),
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
});
tokio::select! {
_ = time::sleep(time::Duration::from_millis(2000)) => {
Err("timeout waiting for the response".into())
}
res = f.select_next_some() => {
match response_format {
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
}
}
}
}
#[derive(Serialize, Deserialize)]
struct SigninParams<'a> {
user: &'a str,
pass: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
sc: Option<&'a str>,
}
#[derive(Serialize, Deserialize)]
struct UseParams<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
}
pub async fn ws_signin(
socket: &mut WsStream,
user: &str,
pass: &str,
ns: Option<&str>,
db: Option<&str>,
sc: Option<&str>,
) -> Result<String, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "signin",
"params": [
SigninParams { user, pass, ns, db, sc }
],
});
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
Ok(obj.get("result").unwrap().as_str().unwrap_or_default().to_owned())
}
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_query(
socket: &mut WsStream,
query: &str,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "query",
"params": [query],
});
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
Ok(obj.get("result").unwrap().as_array().unwrap().to_owned())
}
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_use(
socket: &mut WsStream,
ns: Option<&str>,
db: Option<&str>,
) -> Result<serde_json::Value, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "use",
"params": [
ns, db
],
});
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
Ok(obj.get("result").unwrap().to_owned())
}
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}

View file

@ -7,10 +7,11 @@ use http::{header, Method};
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use serial_test::serial; use serial_test::serial;
use test_log::test;
use crate::common::{PASS, USER}; use crate::common::{PASS, USER};
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> { async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -52,7 +53,7 @@ async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> { async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -133,14 +134,14 @@ async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn client_ip_extractor() -> Result<(), Box<dyn std::error::Error>> { async fn client_ip_extractor() -> Result<(), Box<dyn std::error::Error>> {
// TODO: test the client IP extractor // TODO: test the client IP extractor
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -184,7 +185,7 @@ async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -196,7 +197,7 @@ async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -269,7 +270,7 @@ async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -303,7 +304,7 @@ async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -372,7 +373,7 @@ async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -425,13 +426,17 @@ async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res.status(), 200, "body: {}", res.text().await?); assert_eq!(res.status(), 200, "body: {}", res.text().await?);
let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap(); let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap();
assert!(!body["token"].as_str().unwrap().to_string().is_empty(), "body: {}", body); assert!(
body["token"].as_str().unwrap().starts_with("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9"),
"body: {}",
body
);
} }
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -543,7 +548,7 @@ async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -577,7 +582,7 @@ async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn version_endpoint() -> Result<(), Box<dyn std::error::Error>> { async fn version_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -619,7 +624,7 @@ async fn seed_table(
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -696,7 +701,7 @@ async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -759,7 +764,7 @@ async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -829,7 +834,7 @@ async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -899,7 +904,7 @@ async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -953,7 +958,7 @@ async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -994,7 +999,7 @@ async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -1091,7 +1096,7 @@ async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -1164,7 +1169,7 @@ async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -1242,7 +1247,7 @@ async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) Ok(())
} }
#[tokio::test] #[test(tokio::test)]
#[serial] #[serial]
async fn key_endpoint_delete_one() -> Result<(), Box<dyn std::error::Error>> { async fn key_endpoint_delete_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap(); let (addr, _server) = common::start_server(true, false, true).await.unwrap();

1109
tests/ws_integration.rs Normal file

File diff suppressed because it is too large Load diff