[rpc] Better tracing for WebSockets (#2325)

This commit is contained in:
Salvador Girones Gil 2023-08-03 16:59:05 +02:00 committed by GitHub
parent ab72923fb5
commit e91011cc78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 2617 additions and 650 deletions

View file

@ -133,7 +133,7 @@ jobs:
args: ci-clippy
cli:
name: Test command line
name: CLI integration tests
runs-on: ubuntu-latest
steps:
@ -163,7 +163,7 @@ jobs:
args: ci-cli-integration
http-server:
name: Test HTTP server
name: HTTP integration tests
runs-on: ubuntu-latest
steps:
@ -192,6 +192,28 @@ jobs:
command: make
args: ci-http-integration
ws-server:
name: WebSocket integration tests
runs-on: ubuntu-latest
steps:
- name: Install stable toolchain
uses: dtolnay/rust-toolchain@stable
- name: Checkout sources
uses: actions/checkout@v3
- name: Setup cache
uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: |
sudo apt-get -y update
sudo apt-get -y install protobuf-compiler libprotobuf-dev
- name: Run cargo test
run: cargo test --locked --no-default-features --features storage-mem --workspace --test ws_integration
test:
name: Test workspace
runs-on: ubuntu-latest

782
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -67,6 +67,7 @@ tokio-util = { version = "0.7.8", features = ["io"] }
tower = "0.4.13"
tower-http = { version = "0.4.2", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] }
tracing = "0.1"
tracing-futures = { version = "0.2.5", features = ["tokio"], default-features = false }
tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
urlencoding = "2.1.2"
@ -77,11 +78,14 @@ nix = "0.26.2"
[dev-dependencies]
assert_fs = "1.0.13"
env_logger = "0.10.0"
opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] }
rcgen = "0.10.0"
serial_test = "2.0.0"
temp-env = "0.3.4"
temp-env = { version = "0.3.4", features = ["async_closure"] }
test-log = { version = "0.2.12", features = ["trace"] }
tokio-stream = { version = "0.1", features = ["net"] }
tokio-tungstenite = { version = "0.18.0" }
tonic = "0.8.3"
[package.metadata.deb]

View file

@ -115,7 +115,7 @@ pub async fn init(
listen_addresses,
dbs,
web,
log: CustomEnvFilter(log),
log,
tick_interval,
no_banner,
..

View file

@ -1,8 +1,9 @@
use clap::builder::{NonEmptyStringValueParser, PossibleValue, TypedValueParser};
use clap::error::{ContextKind, ContextValue, ErrorKind};
use tracing::Level;
use tracing_subscriber::EnvFilter;
use crate::telemetry::filter_from_value;
#[derive(Debug)]
pub struct CustomEnvFilter(pub EnvFilter);
@ -37,20 +38,7 @@ impl TypedValueParser for CustomEnvFilterParser {
let inner = NonEmptyStringValueParser::new();
let v = inner.parse_ref(cmd, arg, value)?;
let filter = (match v.as_str() {
// Don't show any logs at all
"none" => Ok(EnvFilter::default()),
// Check if we should show all log levels
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
// Otherwise, let's only show errors
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
// Specify the log level for each code area
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::txn=error")),
// Let's try to parse the custom log level
_ => EnvFilter::builder().parse(v),
})
.map_err(|e| {
let filter = filter_from_value(v.as_str()).map_err(|e| {
let mut err = clap::Error::new(ErrorKind::ValueValidation).with_cmd(cmd);
err.insert(ContextKind::Custom, ContextValue::String(e.to_string()));
err.insert(

View file

@ -131,7 +131,7 @@ pub async fn init() -> Result<(), Error> {
// Setup the graceful shutdown with no timeout
let handle = Handle::new();
graceful_shutdown(handle.clone(), None);
let shutdown_handler = graceful_shutdown(handle.clone());
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
// configure certificate and private key used by https
@ -156,6 +156,12 @@ pub async fn init() -> Result<(), Error> {
.await?;
};
// Wait for the shutdown to finish
let _ = shutdown_handler.await;
// Flush all telemetry data
opentelemetry::global::shutdown_tracer_provider();
info!(target: LOG, "Web server stopped. Bye!");
Ok(())

View file

@ -7,26 +7,36 @@ use crate::err::Error;
use crate::rpc::args::Take;
use crate::rpc::paths::{ID, METHOD, PARAMS};
use crate::rpc::res;
use crate::rpc::res::Data;
use crate::rpc::res::Failure;
use crate::rpc::res::Output;
use crate::rpc::res::IntoRpcResponse;
use crate::rpc::res::OutputFormat;
use crate::rpc::CONN_CLOSED_ERR;
use crate::telemetry::traces::rpc::span_for_request;
use axum::routing::get;
use axum::Extension;
use axum::Router;
use futures::{SinkExt, StreamExt};
use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use http_body::Body as HttpBody;
use once_cell::sync::Lazy;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use surrealdb::channel;
use surrealdb::channel::Sender;
use surrealdb::channel::{Receiver, Sender};
use surrealdb::dbs::{QueryType, Response, Session};
use surrealdb::sql::serde::deserialize;
use surrealdb::sql::Array;
use surrealdb::sql::Object;
use surrealdb::sql::Strand;
use surrealdb::sql::Value;
use tokio::sync::RwLock;
use tracing::instrument;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tower_http::request_id::RequestId;
use tracing::Span;
use uuid::Uuid;
use axum::{
@ -35,11 +45,12 @@ use axum::{
};
// Mapping of WebSocketID to WebSocket
type WebSockets = RwLock<HashMap<Uuid, Sender<Message>>>;
pub(crate) struct WebSocketRef(pub(crate) Sender<Message>, pub(crate) CancellationToken);
type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>;
// Mapping of LiveQueryID to WebSocketID
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
pub(super) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
pub(super) fn router<S, B>() -> Router<S, B>
@ -50,22 +61,36 @@ where
Router::new().route("/rpc", get(handler))
}
async fn handler(ws: WebSocketUpgrade, Extension(sess): Extension<Session>) -> impl IntoResponse {
async fn handler(
ws: WebSocketUpgrade,
Extension(sess): Extension<Session>,
Extension(req_id): Extension<RequestId>,
) -> impl IntoResponse {
// finalize the upgrade process by returning upgrade callback.
// we can customize the callback by sending additional info such as address.
ws.on_upgrade(move |socket| handle_socket(socket, sess))
ws.on_upgrade(move |socket| handle_socket(socket, sess, req_id))
}
async fn handle_socket(ws: WebSocket, sess: Session) {
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
let rpc = Rpc::new(sess);
Rpc::serve(rpc, ws).await
// If the request ID is a valid UUID and is not already in use, use it as the WebSocket ID
match req_id.header_value().to_str().map(Uuid::parse_str) {
Ok(Ok(req_id)) if !WEBSOCKETS.read().await.contains_key(&req_id) => {
rpc.write().await.ws_id = req_id
}
_ => (),
}
Rpc::serve(rpc, ws).await;
}
pub struct Rpc {
session: Session,
format: Output,
uuid: Uuid,
format: OutputFormat,
ws_id: Uuid,
vars: BTreeMap<String, Value>,
graceful_shutdown: CancellationToken,
}
impl Rpc {
@ -74,158 +99,247 @@ impl Rpc {
// Create a new RPC variables store
let vars = BTreeMap::new();
// Set the default output format
let format = Output::Json;
// Create a unique WebSocket id
let uuid = Uuid::new_v4();
// Enable real-time live queries
let format = OutputFormat::Json;
// Enable real-time mode
session.rt = true;
// Create and store the Rpc connection
Arc::new(RwLock::new(Rpc {
session,
format,
uuid,
ws_id: Uuid::new_v4(),
vars,
graceful_shutdown: CancellationToken::new(),
}))
}
/// Serve the RPC endpoint
pub async fn serve(rpc: Arc<RwLock<Rpc>>, ws: WebSocket) {
// Create a channel for sending messages
let (chn, mut rcv) = channel::new(MAX_CONCURRENT_CALLS);
// Split the socket into send and recv
let (mut wtx, mut wrx) = ws.split();
// Clone the channel for sending pings
let png = chn.clone();
// The WebSocket has connected
Rpc::connected(rpc.clone(), chn.clone()).await;
// Send Ping messages to the client
tokio::task::spawn(async move {
// Create the interval ticker
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
// Loop indefinitely
loop {
// Wait for the timer
interval.tick().await;
// Create the ping message
let msg = Message::Ping(vec![]);
// Send the message to the client
if png.send(msg).await.is_err() {
// Exit out of the loop
break;
}
}
});
// Send messages to the client
tokio::task::spawn(async move {
// Wait for the next message to send
while let Some(res) = rcv.next().await {
// Send the message to the client
if let Err(err) = wtx.send(res).await {
// Output the WebSocket error to the logs
trace!("WebSocket error: {:?}", err);
// It's already failed, so ignore error
let _ = wtx.close().await;
// Exit out of the loop
break;
}
}
});
// Send notifications to the client
let moved_rpc = rpc.clone();
tokio::task::spawn(async move {
let rpc = moved_rpc;
if let Some(channel) = DB.get().unwrap().notifications() {
while let Ok(notification) = channel.recv().await {
// Find which WebSocket the notification belongs to
if let Some(ws_id) = LIVE_QUERIES.read().await.get(&notification.id) {
// Check to see if the WebSocket exists
if let Some(websocket) = WEBSOCKETS.read().await.get(ws_id) {
// Serialize the message to send
let message = res::success(None, notification);
// Get the current output format
let format = rpc.read().await.format.clone();
// Send the notification to the client
message.send(format, websocket.clone()).await;
}
}
}
}
});
// Get messages from the client
while let Some(msg) = wrx.next().await {
match msg {
// We've received a message from the client
// Ping is automatically handled by the WebSocket library
Ok(msg) => match msg {
Message::Text(_) => {
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
}
Message::Binary(_) => {
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
}
Message::Close(_) => {
break;
}
Message::Pong(_) => {
continue;
}
_ => {
// Ignore everything else
}
},
// There was an error receiving the message
Err(err) => {
// Output the WebSocket error to the logs
trace!("WebSocket error: {:?}", err);
// Exit out of the loop
break;
}
}
}
// The WebSocket has disconnected
Rpc::disconnected(rpc.clone()).await;
}
let (sender, receiver) = ws.split();
// Create an internal channel between the receiver and the sender
let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS);
let ws_id = rpc.read().await.ws_id;
async fn connected(rpc: Arc<RwLock<Rpc>>, chn: Sender<Message>) {
// Fetch the unique id of the WebSocket
let id = rpc.read().await.uuid;
// Log that the WebSocket has connected
trace!("WebSocket {} connected", id);
// Store this WebSocket in the list of WebSockets
WEBSOCKETS.write().await.insert(id, chn);
}
WEBSOCKETS.write().await.insert(
ws_id,
WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()),
);
trace!("WebSocket {} connected", ws_id);
// Wait until all tasks finish
tokio::join!(
Self::ping(rpc.clone(), internal_sender.clone()),
Self::read(rpc.clone(), receiver, internal_sender.clone()),
Self::write(rpc.clone(), sender, internal_receiver.clone()),
Self::lq_notifications(rpc.clone()),
);
async fn disconnected(rpc: Arc<RwLock<Rpc>>) {
// Fetch the unique id of the WebSocket
let id = rpc.read().await.uuid;
// Log that the WebSocket has disconnected
trace!("WebSocket {} disconnected", id);
// Remove this WebSocket from the list of WebSockets
WEBSOCKETS.write().await.remove(&id);
// Remove all live queries
LIVE_QUERIES.write().await.retain(|key, value| {
if value == &id {
if value == &ws_id {
trace!("Removing live query: {}", key);
return false;
}
true
});
// Remove this WebSocket from the list of WebSockets
WEBSOCKETS.write().await.remove(&ws_id);
trace!("WebSocket {} disconnected", ws_id);
}
/// Call RPC methods from the WebSocket
async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
/// Send Ping messages to the client
async fn ping(rpc: Arc<RwLock<Rpc>>, internal_sender: Sender<Message>) {
// Create the interval ticker
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
let is_shutdown = cancel_token.cancelled();
tokio::select! {
_ = interval.tick() => {
let msg = Message::Ping(vec![]);
// Send the message to the client and close the WebSocket connection if it fails
if internal_sender.send(msg).await.is_err() {
rpc.read().await.graceful_shutdown.cancel();
break;
}
},
_ = is_shutdown => break,
}
}
}
/// Read messages sent from the client
async fn read(
rpc: Arc<RwLock<Rpc>>,
mut receiver: SplitStream<WebSocket>,
internal_sender: Sender<Message>,
) {
// Collect all spawned tasks so we can wait for them at the end
let mut tasks = JoinSet::new();
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
let is_shutdown = cancel_token.cancelled();
tokio::select! {
msg = receiver.next() => {
if let Some(msg) = msg {
match msg {
// We've received a message from the client
// Ping/Pong is automatically handled by the WebSocket library
Ok(msg) => match msg {
Message::Text(_) => {
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
}
Message::Binary(_) => {
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
}
Message::Close(_) => {
// Respond with a close message
if let Err(err) = internal_sender.send(Message::Close(None)).await {
trace!("WebSocket error when replying to the Close frame: {:?}", err);
};
// Start the graceful shutdown of the WebSocket and close the channels
rpc.read().await.graceful_shutdown.cancel();
let _ = internal_sender.close();
break;
}
_ => {
// Ignore everything else
}
},
Err(err) => {
trace!("WebSocket error: {:?}", err);
// Start the graceful shutdown of the WebSocket and close the channels
rpc.read().await.graceful_shutdown.cancel();
let _ = internal_sender.close();
// Exit out of the loop
break;
}
}
}
}
_ = is_shutdown => break,
}
}
// Wait for all tasks to finish
while let Some(res) = tasks.join_next().await {
if let Err(err) = res {
error!("Error while handling RPC message: {}", err);
}
}
}
/// Write messages to the client
async fn write(
rpc: Arc<RwLock<Rpc>>,
mut sender: SplitSink<WebSocket, Message>,
mut internal_receiver: Receiver<Message>,
) {
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
let is_shutdown = cancel_token.cancelled();
tokio::select! {
// Wait for the next message to send
msg = internal_receiver.next() => {
if let Some(res) = msg {
// Send the message to the client
if let Err(err) = sender.send(res).await {
if err.to_string() != CONN_CLOSED_ERR {
debug!("WebSocket error: {:?}", err);
}
// Close the WebSocket connection
rpc.read().await.graceful_shutdown.cancel();
// Exit out of the loop
break;
}
}
},
_ = is_shutdown => break,
}
}
}
/// Send live query notifications to the client
async fn lq_notifications(rpc: Arc<RwLock<Rpc>>) {
if let Some(channel) = DB.get().unwrap().notifications() {
let cancel_token = rpc.read().await.graceful_shutdown.clone();
loop {
tokio::select! {
msg = channel.recv() => {
if let Ok(notification) = msg {
// Find which WebSocket the notification belongs to
if let Some(ws_id) = LIVE_QUERIES.read().await.get(&notification.id) {
// Check to see if the WebSocket exists
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
// Serialize the message to send
let message = res::success(None, notification);
// Get the current output format
let mut out = { rpc.read().await.format.clone() };
// Clone the RPC
let rpc = rpc.clone();
let format = rpc.read().await.format.clone();
// Send the notification to the client
message.send(format, ws.clone()).await
}
}
}
},
_ = cancel_token.cancelled() => break,
}
}
}
}
/// Handle individual WebSocket messages
async fn handle_msg(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
// Get the current output format
let mut out_fmt = rpc.read().await.format.clone();
let span = span_for_request(&rpc.read().await.ws_id);
let _enter = span.enter();
// Parse the request
match Self::parse_request(msg).await {
Ok((id, method, params, _out_fmt)) => {
span.record(
"rpc.jsonrpc.request_id",
id.clone().map(|v| v.as_string()).unwrap_or(String::new()),
);
if let Some(_out_fmt) = _out_fmt {
out_fmt = _out_fmt;
}
// Process the request
let res = Self::process_request(rpc.clone(), &method, params).await;
// Process the response
res.into_response(id).send(out_fmt, chn).await
}
Err(err) => {
// Process the response
res::failure(None, err).send(out_fmt, chn).await
}
}
}
async fn parse_request(
msg: Message,
) -> Result<(Option<Value>, String, Array, Option<OutputFormat>), Failure> {
let mut out_fmt = None;
let req = match msg {
// This is a binary message
Message::Binary(val) => {
// Use binary output
out = Output::Full;
// Deserialize the input
Value::from(val)
out_fmt = Some(OutputFormat::Full);
match deserialize(&val) {
Ok(v) => v,
Err(_) => {
debug!("Error when trying to deserialize the request");
return Err(Failure::PARSE_ERROR);
}
}
}
// This is a text message
Message::Text(ref val) => {
@ -234,14 +348,15 @@ impl Rpc {
// The SurrealQL message parsed ok
Ok(v) => v,
// The SurrealQL message failed to parse
_ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await,
_ => return Err(Failure::PARSE_ERROR),
}
}
// Unsupported message type
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await,
_ => {
debug!("Unsupported message type: {:?}", msg);
return Err(res::Failure::custom("Unsupported message type"));
}
};
// Log the received request
trace!("RPC Received: {}", req);
// Fetch the 'id' argument
let id = match req.pick(&*ID) {
v if v.is_none() => None,
@ -250,149 +365,180 @@ impl Rpc {
v if v.is_number() => Some(v),
v if v.is_strand() => Some(v),
v if v.is_datetime() => Some(v),
_ => return res::failure(None, Failure::INVALID_REQUEST).send(out, chn).await,
_ => return Err(Failure::INVALID_REQUEST),
};
// Fetch the 'method' argument
let method = match req.pick(&*METHOD) {
Value::Strand(v) => v.to_raw(),
_ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await,
_ => return Err(Failure::INVALID_REQUEST),
};
// Now that we know the method, we can update the span
Span::current().record("rpc.method", &method);
Span::current().record("otel.name", format!("surrealdb.rpc/{}", method));
// Fetch the 'params' argument
let params = match req.pick(&*PARAMS) {
Value::Array(v) => v,
_ => Array::new(),
};
Ok((id, method, params, out_fmt))
}
async fn process_request(
rpc: Arc<RwLock<Rpc>>,
method: &str,
params: Array,
) -> Result<Data, Failure> {
info!("Process RPC request");
// Match the method to a function
let res = match &method[..] {
// Handle a ping message
"ping" => Ok(Value::None),
match method {
// Handle a surrealdb ping message
//
// This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough.
// For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios.
// By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent.
//
"ping" => Ok(Value::None.into()),
// Retrieve the current auth record
"info" => match params.len() {
0 => rpc.read().await.info().await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Switch to a specific namespace and database
"use" => match params.needs_two() {
Ok((ns, db)) => rpc.write().await.yuse(ns, db).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok((ns, db)) => {
rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Signup to a specific authentication scope
"signup" => match params.needs_one() {
Ok(Value::Object(v)) => rpc.write().await.signup(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(Value::Object(v)) => {
rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Signin as a root, namespace, database or scope user
"signin" => match params.needs_one() {
Ok(Value::Object(v)) => rpc.write().await.signin(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(Value::Object(v)) => {
rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Invalidate the current authentication session
"invalidate" => match params.len() {
0 => rpc.write().await.invalidate().await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Authenticate using an authentication token
"authenticate" => match params.needs_one() {
Ok(Value::Strand(v)) => rpc.write().await.authenticate(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(Value::Strand(v)) => {
rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Kill a live query using a query id
"kill" => match params.needs_one() {
Ok(v) if v.is_uuid() => rpc.read().await.kill(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(v) if v.is_uuid() => {
rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Setup a live query on a specific table
"live" => match params.needs_one_or_two() {
Ok((v, d)) if v.is_table() => rpc.read().await.live(v, d).await,
Ok((v, d)) if v.is_strand() => rpc.read().await.live(v, d).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok((v, d)) if v.is_table() => {
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
}
Ok((v, d)) if v.is_strand() => {
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Specify a connection-wide parameter
"let" => match params.needs_one_or_two() {
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
},
// Specify a connection-wide parameter
"set" => match params.needs_one_or_two() {
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
"let" | "set" => match params.needs_one_or_two() {
Ok((Value::Strand(s), v)) => {
rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Unset and clear a connection-wide parameter
"unset" => match params.needs_one() {
Ok(Value::Strand(s)) => rpc.write().await.unset(s).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(Value::Strand(s)) => {
rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Select a value or values from the database
"select" => match params.needs_one() {
Ok(v) => rpc.read().await.select(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Insert a value or values in the database
"insert" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.insert(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok((v, o)) => {
rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Create a value or values in the database
"create" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.create(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok((v, o)) => {
rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Update a value or values in the database using `CONTENT`
"update" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.update(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok((v, o)) => {
rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Update a value or values in the database using `MERGE`
"change" | "merge" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.change(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok((v, o)) => {
rpc.read().await.change(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Update a value or values in the database using `PATCH`
"modify" | "patch" => match params.needs_one_or_two() {
Ok((v, o)) => rpc.read().await.modify(v, o).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok((v, o)) => {
rpc.read().await.modify(v, o).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Delete a value or values from the database
"delete" => match params.needs_one() {
Ok(v) => rpc.read().await.delete(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
_ => Err(Failure::INVALID_PARAMS),
},
// Specify the output format for text requests
"format" => match params.needs_one() {
Ok(Value::Strand(v)) => rpc.write().await.format(v).await,
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
Ok(Value::Strand(v)) => {
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
}
_ => Err(Failure::INVALID_PARAMS),
},
// Get the current server version
"version" => match params.len() {
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
_ => Err(Failure::INVALID_PARAMS),
},
// Run a full SurrealQL query against the database
"query" => match params.needs_one_or_two() {
Ok((Value::Strand(s), o)) if o.is_none_or_null() => {
return match rpc.read().await.query(s).await {
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => {
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
}
};
rpc.read().await.query(s).await.map(Into::into).map_err(Into::into)
}
Ok((Value::Strand(s), Value::Object(o))) => {
return match rpc.read().await.query_with(s, o).await {
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => {
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
rpc.read().await.query_with(s, o).await.map(Into::into).map_err(Into::into)
}
};
}
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
_ => Err(Failure::INVALID_PARAMS),
},
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await,
};
// Return the final response
match res {
Ok(v) => res::success(id, v).send(out, chn).await,
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await,
_ => Err(Failure::METHOD_NOT_FOUND),
}
}
@ -402,15 +548,14 @@ impl Rpc {
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
match out.as_str() {
"json" | "application/json" => self.format = Output::Json,
"cbor" | "application/cbor" => self.format = Output::Cbor,
"pack" | "application/pack" => self.format = Output::Pack,
"json" | "application/json" => self.format = OutputFormat::Json,
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
"pack" | "application/pack" => self.format = OutputFormat::Pack,
_ => return Err(Error::InvalidType),
};
Ok(Value::None)
}
#[instrument(skip_all, name = "rpc use", fields(websocket=self.uuid.to_string()))]
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
if let Value::Strand(ns) = ns {
self.session.ns = Some(ns.0);
@ -421,7 +566,6 @@ impl Rpc {
Ok(Value::None)
}
#[instrument(skip_all, name = "rpc signup", fields(websocket=self.uuid.to_string()))]
async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
let kvs = DB.get().unwrap();
surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
@ -430,7 +574,6 @@ impl Rpc {
.map_err(Into::into)
}
#[instrument(skip_all, name = "rpc signin", fields(websocket=self.uuid.to_string()))]
async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
let kvs = DB.get().unwrap();
surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
@ -438,13 +581,11 @@ impl Rpc {
.map(Into::into)
.map_err(Into::into)
}
#[instrument(skip_all, name = "rpc invalidate", fields(websocket=self.uuid.to_string()))]
async fn invalidate(&mut self) -> Result<Value, Error> {
surrealdb::iam::clear::clear(&mut self.session)?;
Ok(Value::None)
}
#[instrument(skip_all, name = "rpc auth", fields(websocket=self.uuid.to_string()))]
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
let kvs = DB.get().unwrap();
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
@ -455,7 +596,6 @@ impl Rpc {
// Methods for identification
// ------------------------------
#[instrument(skip_all, name = "rpc info", fields(websocket=self.uuid.to_string()))]
async fn info(&self) -> Result<Value, Error> {
// Get a database reference
let kvs = DB.get().unwrap();
@ -473,7 +613,6 @@ impl Rpc {
// Methods for setting variables
// ------------------------------
#[instrument(skip_all, name = "rpc set", fields(websocket=self.uuid.to_string()))]
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
match val {
// Remove the variable if undefined
@ -484,7 +623,6 @@ impl Rpc {
Ok(Value::Null)
}
#[instrument(skip_all, name = "rpc unset", fields(websocket=self.uuid.to_string()))]
async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
self.vars.remove(&key.0);
Ok(Value::Null)
@ -494,7 +632,6 @@ impl Rpc {
// Methods for live queries
// ------------------------------
#[instrument(skip_all, name = "rpc kill", fields(websocket=self.uuid.to_string()))]
async fn kill(&self, id: Value) -> Result<Value, Error> {
// Specify the SQL query string
let sql = "KILL $id";
@ -513,7 +650,6 @@ impl Rpc {
}
}
#[instrument(skip_all, name = "rpc live", fields(websocket=self.uuid.to_string()))]
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
// Specify the SQL query string
let sql = match diff.is_true() {
@ -539,7 +675,6 @@ impl Rpc {
// Methods for selecting
// ------------------------------
#[instrument(skip_all, name = "rpc select", fields(websocket=self.uuid.to_string()))]
async fn select(&self, what: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
@ -567,7 +702,6 @@ impl Rpc {
// Methods for inserting
// ------------------------------
#[instrument(skip_all, name = "rpc insert", fields(websocket=self.uuid.to_string()))]
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
@ -596,7 +730,6 @@ impl Rpc {
// Methods for creating
// ------------------------------
#[instrument(skip_all, name = "rpc create", fields(websocket=self.uuid.to_string()))]
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
@ -625,7 +758,6 @@ impl Rpc {
// Methods for updating
// ------------------------------
#[instrument(skip_all, name = "rpc update", fields(websocket=self.uuid.to_string()))]
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
@ -654,7 +786,6 @@ impl Rpc {
// Methods for changing
// ------------------------------
#[instrument(skip_all, name = "rpc change", fields(websocket=self.uuid.to_string()))]
async fn change(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
@ -683,7 +814,6 @@ impl Rpc {
// Methods for modifying
// ------------------------------
#[instrument(skip_all, name = "rpc modify", fields(websocket=self.uuid.to_string()))]
async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
@ -712,7 +842,6 @@ impl Rpc {
// Methods for deleting
// ------------------------------
#[instrument(skip_all, name = "rpc delete", fields(websocket=self.uuid.to_string()))]
async fn delete(&self, what: Value) -> Result<Value, Error> {
// Return a single result?
let one = what.is_thing();
@ -740,7 +869,6 @@ impl Rpc {
// Methods for querying
// ------------------------------
#[instrument(skip_all, name = "rpc query", fields(websocket=self.uuid.to_string()))]
async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> {
// Get a database reference
let kvs = DB.get().unwrap();
@ -756,7 +884,6 @@ impl Rpc {
Ok(res)
}
#[instrument(skip_all, name = "rpc query_with", fields(websocket=self.uuid.to_string()))]
async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> {
// Get a database reference
let kvs = DB.get().unwrap();
@ -781,8 +908,8 @@ impl Rpc {
QueryType::Live => {
if let Ok(Value::Uuid(lqid)) = &res.result {
// Match on Uuid type
LIVE_QUERIES.write().await.insert(lqid.0, self.uuid);
trace!("Registered live query {} on websocket {}", lqid, self.uuid);
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
}
}
QueryType::Kill => {

View file

@ -1,17 +1,57 @@
use std::time::Duration;
use axum_server::Handle;
use tokio::task::JoinHandle;
use crate::err::Error;
use crate::{
err::Error,
net::rpc::{WebSocketRef, WEBSOCKETS},
};
/// Start a graceful shutdown on the Axum Handle when a shutdown signal is received.
pub fn graceful_shutdown(handle: Handle, dur: Option<Duration>) {
/// Start a graceful shutdown:
/// * Signal the Axum Handle when a shutdown signal is received.
/// * Stop all WebSocket connections.
///
/// A second signal will force an immediate shutdown.
pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> {
tokio::spawn(async move {
let result = listen().await.expect("Failed to listen to shutdown signal");
info!(target: super::LOG, "{} received. Start graceful shutdown...", result);
info!(target: super::LOG, "{} received. Waiting for graceful shutdown... A second signal will force an immediate shutdown", result);
handle.graceful_shutdown(dur)
});
tokio::select! {
// Start a normal graceful shutdown
_ = async {
// First stop accepting new HTTP requests
http_handle.graceful_shutdown(None);
// Close all WebSocket connections. Queued messages will still be processed.
for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() {
cancel_token.cancel();
};
// Wait for all existing WebSocket connections to gracefully close
while WEBSOCKETS.read().await.len() > 0 {
tokio::time::sleep(Duration::from_millis(100)).await;
};
} => (),
// Force an immediate shutdown if a second signal is received
_ = async {
if let Ok(signal) = listen().await {
warn!(target: super::LOG, "{} received during graceful shutdown. Terminate immediately...", signal);
} else {
error!(target: super::LOG, "Failed to listen to shutdown signal. Terminate immediately...");
}
// Force an immediate shutdown
http_handle.shutdown();
// Close all WebSocket connections immediately
if let Ok(mut writer) = WEBSOCKETS.try_write() {
writer.drain();
}
} => (),
}
})
}
#[cfg(unix)]

View file

@ -1,119 +1,15 @@
use std::{fmt, time::Duration};
use axum::{
body::{boxed, Body, BoxBody},
extract::MatchedPath,
headers::{
authorization::{Basic, Bearer},
Authorization, Origin,
},
Extension, RequestPartsExt, TypedHeader,
};
use futures_util::future::BoxFuture;
use http::{header, request::Parts, StatusCode};
use axum::extract::MatchedPath;
use http::header;
use hyper::{Request, Response};
use surrealdb::{
dbs::Session,
iam::verify::{basic, token},
};
use tower_http::{
auth::AsyncAuthorizeRequest,
request_id::RequestId,
trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
};
use tracing::{field, Level, Span};
use crate::{dbs::DB, err::Error};
use super::{client_ip::ExtractClientIP, AppState};
///
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
///
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
///
/// ```rust
/// use tower_http::auth::RequireAuthorizationLayer;
/// use surrealdb::net::SurrealAuth;
/// use axum::Router;
///
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
///
/// let app = Router::new()
/// .route("/version", get(|| async { "0.1.0" }))
/// .layer(auth);
/// ```
#[derive(Clone, Copy)]
pub(super) struct SurrealAuth;
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
where
B: Send + Sync + 'static,
{
type RequestBody = B;
type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future {
Box::pin(async {
let (mut parts, body) = request.into_parts();
match check_auth(&mut parts).await {
Ok(sess) => {
parts.extensions.insert(sess);
Ok(Request::from_parts(parts, body))
}
Err(err) => {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(boxed(Body::from(err.to_string())))
.unwrap();
Err(unauthorized_response)
}
}
})
}
}
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
let kvs = DB.get().unwrap();
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
if !or.is_null() {
Some(or.to_string())
} else {
None
}
} else {
None
};
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
tracing::error!("Error extracting the app state: {:?}", err);
Error::InvalidAuth
})?;
let ExtractClientIP(ip) =
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
// Create session
#[rustfmt::skip]
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
// If Basic authentication data was supplied
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
basic(kvs, &mut session, au.username(), au.password()).await.map_err(|e| e.into())
} else if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
token(kvs, &mut session, au.token()).await.map_err(|e| e.into())
} else {
Err(Error::InvalidAuth)
}?;
Ok(session)
}
use super::client_ip::ExtractClientIP;
///
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
@ -139,7 +35,6 @@ impl<B> MakeSpan<B> for HttpTraceLayerHooks {
fn make_span(&mut self, req: &Request<B>) -> Span {
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
let span = tracing::info_span!(
target: "surreal::http",
"request",
otel.name = field::Empty,
otel.kind = "server",

View file

@ -1,3 +1,5 @@
pub mod args;
pub mod paths;
pub mod res;
pub(crate) static CONN_CLOSED_ERR: &str = "Connection closed normally";

View file

@ -7,10 +7,13 @@ use surrealdb::dbs;
use surrealdb::dbs::Notification;
use surrealdb::sql;
use surrealdb::sql::Value;
use tracing::instrument;
use tracing::Span;
#[derive(Clone)]
pub enum Output {
use crate::err;
use crate::rpc::CONN_CLOSED_ERR;
#[derive(Debug, Clone)]
pub enum OutputFormat {
Json, // JSON
Cbor, // CBOR
Pack, // MessagePack
@ -37,6 +40,12 @@ impl From<Value> for Data {
}
}
impl From<String> for Data {
fn from(v: String) -> Self {
Data::Other(Value::from(v))
}
}
impl From<Vec<dbs::Response>> for Data {
fn from(v: Vec<dbs::Response>) -> Self {
Data::Query(v)
@ -82,28 +91,45 @@ impl Response {
}
/// Send the response to the WebSocket channel
#[instrument(skip_all, name = "rpc response", fields(response = ?self))]
pub async fn send(self, out: Output, chn: Sender<Message>) {
pub async fn send(self, out: OutputFormat, chn: Sender<Message>) {
let span = Span::current();
info!("Process RPC response");
if let Err(err) = &self.result {
span.record("otel.status_code", "Error");
span.record(
"otel.status_message",
format!("code: {}, message: {}", err.code, err.message),
);
span.record("rpc.jsonrpc.error_code", err.code);
span.record("rpc.jsonrpc.error_message", err.message.as_ref());
}
let message = match out {
Output::Json => {
OutputFormat::Json => {
let res = serde_json::to_string(&self.simplify()).unwrap();
Message::Text(res)
}
Output::Cbor => {
OutputFormat::Cbor => {
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
Message::Binary(res)
}
Output::Pack => {
OutputFormat::Pack => {
let res = serde_pack::to_vec(&self.simplify()).unwrap();
Message::Binary(res)
}
Output::Full => {
OutputFormat::Full => {
let res = surrealdb::sql::serde::serialize(&self).unwrap();
Message::Binary(res)
}
};
let _ = chn.send(message).await;
trace!("Response sent");
if let Err(err) = chn.send(message).await {
if err.to_string() != CONN_CLOSED_ERR {
error!("Error sending response: {}", err);
}
};
}
}
@ -113,6 +139,7 @@ pub struct Failure {
message: Cow<'static, str>,
}
#[allow(dead_code)]
impl Failure {
pub const PARSE_ERROR: Failure = Failure {
code: -32700,
@ -165,3 +192,26 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
result: Err(err),
}
}
impl From<err::Error> for Failure {
fn from(err: err::Error) -> Self {
Failure::custom(err.to_string())
}
}
pub trait IntoRpcResponse {
fn into_response(self, id: Option<Value>) -> Response;
}
impl<T, E> IntoRpcResponse for Result<T, E>
where
T: Into<Data>,
E: Into<Failure>,
{
fn into_response(self, id: Option<Value>) -> Response {
match self {
Ok(v) => success(id, v.into()),
Err(err) => failure(id, err.into()),
}
}
}

View file

@ -1,8 +1,10 @@
use tracing::Subscriber;
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::{EnvFilter, Layer};
use tracing_subscriber::Layer;
pub fn new<S>(level: String) -> Box<dyn Layer<S> + Send + Sync>
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
{
@ -11,6 +13,6 @@ where
.with_ansi(true)
.with_span_events(FmtSpan::NONE)
.with_writer(std::io::stderr)
.with_filter(EnvFilter::builder().parse(level).unwrap())
.with_filter(filter.0)
.boxed()
}

View file

@ -1,6 +1,6 @@
mod logs;
pub mod metrics;
mod traces;
pub mod traces;
use std::time::Duration;
@ -11,8 +11,7 @@ use opentelemetry::sdk::resource::{
};
use opentelemetry::sdk::Resource;
use opentelemetry::KeyValue;
use tracing::Subscriber;
use tracing_subscriber::fmt::format::FmtSpan;
use tracing::{Level, Subscriber};
use tracing_subscriber::prelude::*;
use tracing_subscriber::util::SubscriberInitExt;
#[cfg(feature = "has-storage")]
@ -39,53 +38,75 @@ pub static OTEL_DEFAULT_RESOURCE: Lazy<Resource> = Lazy::new(|| {
}
});
#[derive(Default, Debug, Clone)]
#[derive(Debug, Clone)]
pub struct Builder {
log_level: Option<String>,
filter: Option<CustomEnvFilter>,
filter: CustomEnvFilter,
}
pub fn builder() -> Builder {
Builder::default()
}
impl Default for Builder {
fn default() -> Self {
Self {
filter: CustomEnvFilter(EnvFilter::default()),
}
}
}
impl Builder {
/// Set the log level on the builder
pub fn with_log_level(mut self, log_level: &str) -> Self {
self.log_level = Some(log_level.to_string());
if let Ok(filter) = filter_from_value(log_level) {
self.filter = CustomEnvFilter(filter);
}
self
}
/// Set the filter on the builder
#[cfg(feature = "has-storage")]
pub fn with_filter(mut self, filter: EnvFilter) -> Self {
self.filter = Some(CustomEnvFilter(filter));
pub fn with_filter(mut self, filter: CustomEnvFilter) -> Self {
self.filter = filter;
self
}
/// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> {
let registry = tracing_subscriber::registry();
let registry = registry.with(self.filter.map(|filter| {
tracing_subscriber::fmt::layer()
.compact()
.with_ansi(true)
.with_span_events(FmtSpan::NONE)
.with_writer(std::io::stderr)
.with_filter(filter.0)
.boxed()
}));
let registry = registry.with(self.log_level.map(logs::new));
let registry = registry.with(traces::new());
// Setup logging layer
let registry = registry.with(logs::new(self.filter.clone()));
// Setup tracing layer
let registry = registry.with(traces::new(self.filter));
Box::new(registry)
}
/// tracing pipeline
/// Install the tracing dispatcher globally
pub fn init(self) {
self.build().init()
}
}
/// Create an EnvFilter from the given value. If the value is not a valid log level, it will be treated as EnvFilter directives.
pub fn filter_from_value(v: &str) -> Result<EnvFilter, tracing_subscriber::filter::ParseError> {
match v {
// Don't show any logs at all
"none" => Ok(EnvFilter::default()),
// Check if we should show all log levels
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
// Otherwise, let's only show errors
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
// Specify the log level for each code area
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::kvs::tx=error")),
// Let's try to parse the custom log level
_ => EnvFilter::builder().parse(v),
}
}
#[cfg(test)]
mod tests {
use opentelemetry::global::shutdown_tracer_provider;
@ -107,7 +128,7 @@ mod tests {
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
],
|| {
let _enter = telemetry::builder().build().set_default();
let _enter = telemetry::builder().with_log_level("info").build().set_default();
println!("Sending span...");
@ -123,7 +144,11 @@ mod tests {
}
println!("Waiting for request...");
let req = req_rx.recv().await.expect("missing export request");
let req = tokio::select! {
req = req_rx.recv() => req.expect("missing export request"),
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
};
let first_span =
req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap();
assert_eq!("test-surreal-span", first_span.name);
@ -141,11 +166,10 @@ mod tests {
temp_env::with_vars(
vec![
("SURREAL_TRACING_TRACER", Some("otlp")),
("SURREAL_TRACING_FILTER", Some("debug")),
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
],
|| {
let _enter = telemetry::builder().build().set_default();
let _enter = telemetry::builder().with_log_level("debug").build().set_default();
println!("Sending spans...");
@ -169,7 +193,10 @@ mod tests {
}
println!("Waiting for request...");
let req = req_rx.recv().await.expect("missing export request");
let req = tokio::select! {
req = req_rx.recv() => req.expect("missing export request"),
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
};
let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans;
assert_eq!(1, spans.len());

View file

@ -1,12 +1,15 @@
use tracing::Subscriber;
use tracing_subscriber::Layer;
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
pub mod otlp;
pub mod rpc;
const TRACING_TRACER_VAR: &str = "SURREAL_TRACING_TRACER";
// Returns a tracer based on the value of the TRACING_TRACER_VAR env var
pub fn new<S>() -> Option<Box<dyn Layer<S> + Send + Sync>>
pub fn new<S>(filter: CustomEnvFilter) -> Option<Box<dyn Layer<S> + Send + Sync>>
where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
{
@ -20,7 +23,7 @@ where
// Init the registry with the OTLP tracer
"otlp" => {
debug!("Setup the OTLP tracer");
Some(otlp::new())
Some(otlp::new(filter))
}
tracer => {
panic!("unsupported tracer {}", tracer);

View file

@ -1,18 +1,18 @@
use opentelemetry::sdk::trace::Tracer;
use opentelemetry::trace::TraceError;
use opentelemetry_otlp::WithExportConfig;
use tracing::{Level, Subscriber};
use tracing_subscriber::{EnvFilter, Layer};
use tracing::Subscriber;
use tracing_subscriber::Layer;
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
use crate::{
cli::validator::parser::env_filter::CustomEnvFilter, telemetry::OTEL_DEFAULT_RESOURCE,
};
const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER";
pub fn new<S>() -> Box<dyn Layer<S> + Send + Sync>
pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
where
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
{
tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter()).boxed()
tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter.0).boxed()
}
fn tracer() -> Result<Tracer, TraceError> {
@ -24,16 +24,3 @@ fn tracer() -> Result<Tracer, TraceError> {
)
.install_batch(opentelemetry::runtime::Tokio)
}
/// Create a filter for the OTLP subscriber
///
/// It creates an EnvFilter based on the TRACING_FILTER_VAR's value
///
/// TRACING_FILTER_VAR accepts the same syntax as RUST_LOG
fn filter() -> EnvFilter {
EnvFilter::builder()
.with_env_var(TRACING_FILTER_VAR)
.with_default_directive(Level::INFO.into())
.from_env()
.unwrap()
}

View file

@ -0,0 +1,31 @@
use tracing::{field, Span};
use uuid::Uuid;
pub fn span_for_request(ws_id: &Uuid) -> Span {
let span = tracing::info_span!(
// Dynamic span names need to be 'recorded', can't be used on the macro. Use a static name here and overwrite later on
"rpc/call",
otel.name = field::Empty,
otel.kind = "server",
// To be populated by the request handler when the method is known
rpc.method = field::Empty,
rpc.service = "surrealdb",
rpc.system = "jsonrpc",
// JSON-RPC fields
rpc.jsonrpc.version = "2.0",
rpc.jsonrpc.request_id = field::Empty,
rpc.jsonrpc.error_code = field::Empty,
rpc.jsonrpc.error_message = field::Empty,
// SurrealDB custom fields
ws.id = %ws_id,
// Fields for error reporting
otel.status_code = field::Empty,
otel.status_message = field::Empty,
);
span
}

View file

@ -5,6 +5,8 @@ mod common;
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
use serial_test::serial;
use std::fs;
use test_log::test;
use tracing::info;
use common::{PASS, USER};
@ -32,13 +34,14 @@ fn nonexistent_option() {
assert!(common::run("version --turbo").output().is_err());
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn all_commands() {
// Commands without credentials when auth is disabled, should succeed
let (addr, _server) = common::start_server(false, false, true).await.unwrap();
let creds = ""; // Anonymous user
// Create a record
info!("* Create a record");
{
let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi");
assert_eq!(
@ -48,7 +51,7 @@ async fn all_commands() {
);
}
// Export to stdout
info!("* Export to stdout");
{
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
let output = common::run(&args).output().expect("failed to run stdout export: {args}");
@ -56,7 +59,7 @@ async fn all_commands() {
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
}
// Export to file
info!("* Export to file");
let exported = {
let exported = common::tmp_file("exported.surql");
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
@ -64,13 +67,13 @@ async fn all_commands() {
exported
};
// Import the exported file
info!("* Import the exported file");
{
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
common::run(&args).output().expect("failed to run import: {args}");
}
// Query from the import (pretty-printed this time)
info!("* Query from the import (pretty-printed this time)");
{
let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty");
assert_eq!(
@ -80,7 +83,7 @@ async fn all_commands() {
);
}
// Unfinished backup CLI
info!("* Unfinished backup CLI");
{
let file = common::tmp_file("backup.db");
let args = format!("backup {creds} http://{addr} {file}");
@ -90,7 +93,7 @@ async fn all_commands() {
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
}
// Multi-statement (and multi-line) query including error(s) over WS
info!("* Multi-statement (and multi-line) query including error(s) over WS");
{
let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty");
let output = common::run(&args)
@ -113,7 +116,7 @@ async fn all_commands() {
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
}
// Multi-statement (and multi-line) transaction including error(s) over WS
info!("* Multi-statement (and multi-line) transaction including error(s) over WS");
{
let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty");
let output = common::run(&args)
@ -137,7 +140,7 @@ async fn all_commands() {
assert!(output.contains("rgument"), "missing argument error in {output}");
}
// Pass neither ns nor db
info!("* Pass neither ns nor db");
{
let args = format!("sql --conn http://{addr} {creds}");
let output = common::run(&args)
@ -147,7 +150,7 @@ async fn all_commands() {
assert!(output.contains("thing:one"), "missing thing:one in {output}");
}
// Pass only ns
info!("* Pass only ns");
{
let args = format!("sql --conn http://{addr} {creds} --ns N5");
let output = common::run(&args)
@ -157,16 +160,23 @@ async fn all_commands() {
assert!(output.contains("thing:one"), "missing thing:one in {output}");
}
// Pass only db and expect an error
info!("* Pass only db and expect an error");
{
let args = format!("sql --conn http://{addr} {creds} --db D5");
common::run(&args).output().expect_err("only db");
}
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn start_tls() {
// Capute the server's stdout/stderr
temp_env::async_with_vars(
[
("SURREAL_TEST_SERVER_STDOUT", Some("piped")),
("SURREAL_TEST_SERVER_STDERR", Some("piped")),
],
async {
let (_, server) = common::start_server(false, true, false).await.unwrap();
std::thread::sleep(std::time::Duration::from_millis(2000));
@ -174,9 +184,12 @@ async fn start_tls() {
// Test the crt/key args but the keys are self signed so don't actually connect.
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
},
)
.await;
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn with_root_auth() {
// Commands with credentials when auth is enabled, should succeed
@ -184,7 +197,7 @@ async fn with_root_auth() {
let creds = format!("--user {USER} --pass {PASS}");
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
// Can query /sql over HTTP
info!("* Query over HTTP");
{
let args = format!("{sql_args} {creds}");
let input = "INFO FOR ROOT;";
@ -192,7 +205,7 @@ async fn with_root_auth() {
assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap());
}
// Can query /sql over WS
info!("* Query over WS");
{
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
let input = "INFO FOR ROOT;";
@ -200,7 +213,7 @@ async fn with_root_auth() {
assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap());
}
// KV user can do exports
info!("* Root user can do exports");
let exported = {
let exported = common::tmp_file("exported.surql");
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
@ -209,13 +222,13 @@ async fn with_root_auth() {
exported
};
// KV user can do imports
info!("* Root user can do imports");
{
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
}
// KV user can do backups
info!("* Root user can do backups");
{
let file = common::tmp_file("backup.db");
let args = format!("backup {creds} http://{addr} {file}");
@ -226,7 +239,7 @@ async fn with_root_auth() {
}
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn with_anon_auth() {
// Commands without credentials when auth is enabled, should fail
@ -234,7 +247,7 @@ async fn with_anon_auth() {
let creds = ""; // Anonymous user
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
// Can query /sql over HTTP
info!("* Query over HTTP");
{
let args = format!("{sql_args} {creds}");
let input = "";
@ -242,7 +255,7 @@ async fn with_anon_auth() {
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
}
// Can query /sql over HTTP
info!("* Query over WS");
{
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
let input = "";
@ -250,7 +263,7 @@ async fn with_anon_auth() {
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
}
// Can't do exports
info!("* Can't do exports");
{
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
let output = common::run(&args).output();
@ -261,7 +274,7 @@ async fn with_anon_auth() {
);
}
// Can't do imports
info!("* Can't do imports");
{
let tmp_file = common::tmp_file("exported.surql");
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {tmp_file}");
@ -273,7 +286,7 @@ async fn with_anon_auth() {
);
}
// Can't do backups
info!("* Can't do backups");
{
let args = format!("backup {creds} http://{addr}");
let output = common::run(&args).output();

View file

@ -1,10 +1,18 @@
#![allow(dead_code)]
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::error::Error;
use std::fs;
use std::fs::File;
use std::path::Path;
use std::process::{Command, Stdio};
use std::{env, fs};
use tokio::net::TcpStream;
use tokio::time;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
use tracing::{error, info};
pub const USER: &str = "root";
pub const PASS: &str = "root";
@ -52,7 +60,12 @@ impl Drop for Child {
}
}
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
pub fn run_internal<P: AsRef<Path>>(
args: &str,
current_dir: Option<P>,
stdout: Stdio,
stderr: Stdio,
) -> Child {
let mut path = std::env::current_exe().unwrap();
assert!(path.pop());
if path.ends_with("deps") {
@ -68,8 +81,8 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
}
cmd.env_clear();
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
cmd.stdout(stdout);
cmd.stderr(stderr);
cmd.args(args.split_ascii_whitespace());
Child {
inner: Some(cmd.spawn().unwrap()),
@ -78,12 +91,12 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
/// Run the CLI with the given args
pub fn run(args: &str) -> Child {
run_internal::<String>(args, None)
run_internal::<String>(args, None, Stdio::piped(), Stdio::piped())
}
/// Run the CLI with the given args inside a temporary directory
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
run_internal(args, Some(current_dir))
run_internal(args, Some(current_dir), Stdio::piped(), Stdio::piped())
}
pub fn tmp_file(name: &str) -> String {
@ -91,6 +104,19 @@ pub fn tmp_file(name: &str) -> String {
path.to_string_lossy().into_owned()
}
fn parse_server_stdio_from_var(var: &str) -> Result<Stdio, Box<dyn Error>> {
match env::var(var).as_deref() {
Ok("inherit") => Ok(Stdio::inherit()),
Ok("null") => Ok(Stdio::null()),
Ok("piped") => Ok(Stdio::piped()),
Ok(val) if val.starts_with("file://") => {
Ok(Stdio::from(File::create(val.trim_start_matches("file://"))?))
}
Ok(val) => Err(format!("Unsupported stdio value: {val:?}").into()),
_ => Ok(Stdio::null()),
}
}
pub async fn start_server(
auth: bool,
tls: bool,
@ -118,11 +144,14 @@ pub async fn start_server(
extra_args.push_str(" --auth");
}
let start_args = format!("start --bind {addr} memory --no-banner --log info --user {USER} --pass {PASS} {extra_args}");
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
println!("starting server with args: {start_args}");
info!("starting server with args: {start_args}");
let server = run(&start_args);
// Configure where the logs go when running the test
let stdout = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDOUT")?;
let stderr = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDERR")?;
let server = run_internal::<String>(&start_args, None, stdout, stderr);
if !wait_is_ready {
return Ok((addr, server));
@ -130,17 +159,178 @@ pub async fn start_server(
// Wait 5 seconds for the server to start
let mut interval = time::interval(time::Duration::from_millis(500));
println!("Waiting for server to start...");
info!("Waiting for server to start...");
for _i in 0..10 {
interval.tick().await;
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
println!("Server ready!");
info!("Server ready!");
return Ok((addr, server));
}
}
let server_out = server.kill().output().err().unwrap();
println!("server output: {server_out}");
error!("server output: {server_out}");
Err("server failed to start".into())
}
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
let url = format!("ws://{}/rpc", addr);
let (ws_stream, _) = connect_async(url).await?;
Ok(ws_stream)
}
pub async fn ws_send_msg(
socket: &mut WsStream,
msg_req: String,
) -> Result<serde_json::Value, Box<dyn Error>> {
// Use JSON format by default
ws_send_msg_with_fmt(socket, msg_req, Format::Json).await
}
pub enum Format {
Json,
Cbor,
Pack,
}
pub async fn ws_send_msg_with_fmt(
socket: &mut WsStream,
msg_req: String,
response_format: Format,
) -> Result<serde_json::Value, Box<dyn Error>> {
tokio::select! {
_ = time::sleep(time::Duration::from_millis(500)) => {
return Err("timeout waiting for the request to be sent".into());
}
res = socket.send(Message::Text(msg_req)) => {
if let Err(err) = res {
return Err(format!("Error sending the message: {}", err).into());
}
}
}
let mut f = socket.try_filter(|msg| match response_format {
Format::Json => futures_util::future::ready(msg.is_text()),
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
});
tokio::select! {
_ = time::sleep(time::Duration::from_millis(2000)) => {
Err("timeout waiting for the response".into())
}
res = f.select_next_some() => {
match response_format {
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
}
}
}
}
#[derive(Serialize, Deserialize)]
struct SigninParams<'a> {
user: &'a str,
pass: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
sc: Option<&'a str>,
}
#[derive(Serialize, Deserialize)]
struct UseParams<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
ns: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
db: Option<&'a str>,
}
pub async fn ws_signin(
socket: &mut WsStream,
user: &str,
pass: &str,
ns: Option<&str>,
db: Option<&str>,
sc: Option<&str>,
) -> Result<String, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "signin",
"params": [
SigninParams { user, pass, ns, db, sc }
],
});
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
Ok(obj.get("result").unwrap().as_str().unwrap_or_default().to_owned())
}
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_query(
socket: &mut WsStream,
query: &str,
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "query",
"params": [query],
});
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
Ok(obj.get("result").unwrap().as_array().unwrap().to_owned())
}
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}
pub async fn ws_use(
socket: &mut WsStream,
ns: Option<&str>,
db: Option<&str>,
) -> Result<serde_json::Value, Box<dyn Error>> {
let json = json!({
"id": "1",
"method": "use",
"params": [
ns, db
],
});
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
match msg.as_object() {
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
}
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
Ok(obj.get("result").unwrap().to_owned())
}
_ => {
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
Err(format!("unexpected response: {:?}", msg).into())
}
}
}

View file

@ -7,10 +7,11 @@ use http::{header, Method};
use reqwest::Client;
use serde_json::json;
use serial_test::serial;
use test_log::test;
use crate::common::{PASS, USER};
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -52,7 +53,7 @@ async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -133,14 +134,14 @@ async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn client_ip_extractor() -> Result<(), Box<dyn std::error::Error>> {
// TODO: test the client IP extractor
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -184,7 +185,7 @@ async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -196,7 +197,7 @@ async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -269,7 +270,7 @@ async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -303,7 +304,7 @@ async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -372,7 +373,7 @@ async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -425,13 +426,17 @@ async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res.status(), 200, "body: {}", res.text().await?);
let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap();
assert!(!body["token"].as_str().unwrap().to_string().is_empty(), "body: {}", body);
assert!(
body["token"].as_str().unwrap().starts_with("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9"),
"body: {}",
body
);
}
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -543,7 +548,7 @@ async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -577,7 +582,7 @@ async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn version_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -619,7 +624,7 @@ async fn seed_table(
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -696,7 +701,7 @@ async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -759,7 +764,7 @@ async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -829,7 +834,7 @@ async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -899,7 +904,7 @@ async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -953,7 +958,7 @@ async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -994,7 +999,7 @@ async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -1091,7 +1096,7 @@ async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -1164,7 +1169,7 @@ async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
@ -1242,7 +1247,7 @@ async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[tokio::test]
#[test(tokio::test)]
#[serial]
async fn key_endpoint_delete_one() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server(true, false, true).await.unwrap();

1109
tests/ws_integration.rs Normal file

File diff suppressed because it is too large Load diff