[rpc] Better tracing for WebSockets (#2325)
This commit is contained in:
parent
ab72923fb5
commit
e91011cc78
20 changed files with 2617 additions and 650 deletions
26
.github/workflows/ci.yml
vendored
26
.github/workflows/ci.yml
vendored
|
@ -133,7 +133,7 @@ jobs:
|
|||
args: ci-clippy
|
||||
|
||||
cli:
|
||||
name: Test command line
|
||||
name: CLI integration tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
|
@ -163,7 +163,7 @@ jobs:
|
|||
args: ci-cli-integration
|
||||
|
||||
http-server:
|
||||
name: Test HTTP server
|
||||
name: HTTP integration tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
|
@ -192,6 +192,28 @@ jobs:
|
|||
command: make
|
||||
args: ci-http-integration
|
||||
|
||||
ws-server:
|
||||
name: WebSocket integration tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Install stable toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup cache
|
||||
uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get -y update
|
||||
sudo apt-get -y install protobuf-compiler libprotobuf-dev
|
||||
|
||||
- name: Run cargo test
|
||||
run: cargo test --locked --no-default-features --features storage-mem --workspace --test ws_integration
|
||||
|
||||
test:
|
||||
name: Test workspace
|
||||
runs-on: ubuntu-latest
|
||||
|
|
782
Cargo.lock
generated
782
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -67,6 +67,7 @@ tokio-util = { version = "0.7.8", features = ["io"] }
|
|||
tower = "0.4.13"
|
||||
tower-http = { version = "0.4.2", features = ["trace", "sensitive-headers", "auth", "request-id", "util", "catch-panic", "cors", "set-header", "limit", "add-extension"] }
|
||||
tracing = "0.1"
|
||||
tracing-futures = { version = "0.2.5", features = ["tokio"], default-features = false }
|
||||
tracing-opentelemetry = "0.19.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
|
||||
urlencoding = "2.1.2"
|
||||
|
@ -77,11 +78,14 @@ nix = "0.26.2"
|
|||
|
||||
[dev-dependencies]
|
||||
assert_fs = "1.0.13"
|
||||
env_logger = "0.10.0"
|
||||
opentelemetry-proto = { version = "0.2.0", features = ["gen-tonic", "traces", "metrics", "logs"] }
|
||||
rcgen = "0.10.0"
|
||||
serial_test = "2.0.0"
|
||||
temp-env = "0.3.4"
|
||||
temp-env = { version = "0.3.4", features = ["async_closure"] }
|
||||
test-log = { version = "0.2.12", features = ["trace"] }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tokio-tungstenite = { version = "0.18.0" }
|
||||
tonic = "0.8.3"
|
||||
|
||||
[package.metadata.deb]
|
||||
|
|
|
@ -115,7 +115,7 @@ pub async fn init(
|
|||
listen_addresses,
|
||||
dbs,
|
||||
web,
|
||||
log: CustomEnvFilter(log),
|
||||
log,
|
||||
tick_interval,
|
||||
no_banner,
|
||||
..
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use clap::builder::{NonEmptyStringValueParser, PossibleValue, TypedValueParser};
|
||||
use clap::error::{ContextKind, ContextValue, ErrorKind};
|
||||
use tracing::Level;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use crate::telemetry::filter_from_value;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CustomEnvFilter(pub EnvFilter);
|
||||
|
||||
|
@ -37,20 +38,7 @@ impl TypedValueParser for CustomEnvFilterParser {
|
|||
|
||||
let inner = NonEmptyStringValueParser::new();
|
||||
let v = inner.parse_ref(cmd, arg, value)?;
|
||||
let filter = (match v.as_str() {
|
||||
// Don't show any logs at all
|
||||
"none" => Ok(EnvFilter::default()),
|
||||
// Check if we should show all log levels
|
||||
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
|
||||
// Otherwise, let's only show errors
|
||||
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
|
||||
// Specify the log level for each code area
|
||||
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
|
||||
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::txn=error")),
|
||||
// Let's try to parse the custom log level
|
||||
_ => EnvFilter::builder().parse(v),
|
||||
})
|
||||
.map_err(|e| {
|
||||
let filter = filter_from_value(v.as_str()).map_err(|e| {
|
||||
let mut err = clap::Error::new(ErrorKind::ValueValidation).with_cmd(cmd);
|
||||
err.insert(ContextKind::Custom, ContextValue::String(e.to_string()));
|
||||
err.insert(
|
||||
|
|
|
@ -131,7 +131,7 @@ pub async fn init() -> Result<(), Error> {
|
|||
|
||||
// Setup the graceful shutdown with no timeout
|
||||
let handle = Handle::new();
|
||||
graceful_shutdown(handle.clone(), None);
|
||||
let shutdown_handler = graceful_shutdown(handle.clone());
|
||||
|
||||
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
|
||||
// configure certificate and private key used by https
|
||||
|
@ -156,6 +156,12 @@ pub async fn init() -> Result<(), Error> {
|
|||
.await?;
|
||||
};
|
||||
|
||||
// Wait for the shutdown to finish
|
||||
let _ = shutdown_handler.await;
|
||||
|
||||
// Flush all telemetry data
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
||||
info!(target: LOG, "Web server stopped. Bye!");
|
||||
|
||||
Ok(())
|
||||
|
|
585
src/net/rpc.rs
585
src/net/rpc.rs
|
@ -7,26 +7,36 @@ use crate::err::Error;
|
|||
use crate::rpc::args::Take;
|
||||
use crate::rpc::paths::{ID, METHOD, PARAMS};
|
||||
use crate::rpc::res;
|
||||
use crate::rpc::res::Data;
|
||||
use crate::rpc::res::Failure;
|
||||
use crate::rpc::res::Output;
|
||||
use crate::rpc::res::IntoRpcResponse;
|
||||
use crate::rpc::res::OutputFormat;
|
||||
use crate::rpc::CONN_CLOSED_ERR;
|
||||
use crate::telemetry::traces::rpc::span_for_request;
|
||||
use axum::routing::get;
|
||||
use axum::Extension;
|
||||
use axum::Router;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use futures_util::stream::SplitSink;
|
||||
use futures_util::stream::SplitStream;
|
||||
use http_body::Body as HttpBody;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use surrealdb::channel;
|
||||
use surrealdb::channel::Sender;
|
||||
use surrealdb::channel::{Receiver, Sender};
|
||||
use surrealdb::dbs::{QueryType, Response, Session};
|
||||
use surrealdb::sql::serde::deserialize;
|
||||
use surrealdb::sql::Array;
|
||||
use surrealdb::sql::Object;
|
||||
use surrealdb::sql::Strand;
|
||||
use surrealdb::sql::Value;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::instrument;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tower_http::request_id::RequestId;
|
||||
use tracing::Span;
|
||||
use uuid::Uuid;
|
||||
|
||||
use axum::{
|
||||
|
@ -35,11 +45,12 @@ use axum::{
|
|||
};
|
||||
|
||||
// Mapping of WebSocketID to WebSocket
|
||||
type WebSockets = RwLock<HashMap<Uuid, Sender<Message>>>;
|
||||
pub(crate) struct WebSocketRef(pub(crate) Sender<Message>, pub(crate) CancellationToken);
|
||||
type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>;
|
||||
// Mapping of LiveQueryID to WebSocketID
|
||||
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
|
||||
|
||||
static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
||||
pub(super) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
||||
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
|
||||
|
||||
pub(super) fn router<S, B>() -> Router<S, B>
|
||||
|
@ -50,22 +61,36 @@ where
|
|||
Router::new().route("/rpc", get(handler))
|
||||
}
|
||||
|
||||
async fn handler(ws: WebSocketUpgrade, Extension(sess): Extension<Session>) -> impl IntoResponse {
|
||||
async fn handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(sess): Extension<Session>,
|
||||
Extension(req_id): Extension<RequestId>,
|
||||
) -> impl IntoResponse {
|
||||
// finalize the upgrade process by returning upgrade callback.
|
||||
// we can customize the callback by sending additional info such as address.
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, sess))
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, sess, req_id))
|
||||
}
|
||||
|
||||
async fn handle_socket(ws: WebSocket, sess: Session) {
|
||||
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
|
||||
let rpc = Rpc::new(sess);
|
||||
Rpc::serve(rpc, ws).await
|
||||
|
||||
// If the request ID is a valid UUID and is not already in use, use it as the WebSocket ID
|
||||
match req_id.header_value().to_str().map(Uuid::parse_str) {
|
||||
Ok(Ok(req_id)) if !WEBSOCKETS.read().await.contains_key(&req_id) => {
|
||||
rpc.write().await.ws_id = req_id
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
Rpc::serve(rpc, ws).await;
|
||||
}
|
||||
|
||||
pub struct Rpc {
|
||||
session: Session,
|
||||
format: Output,
|
||||
uuid: Uuid,
|
||||
format: OutputFormat,
|
||||
ws_id: Uuid,
|
||||
vars: BTreeMap<String, Value>,
|
||||
graceful_shutdown: CancellationToken,
|
||||
}
|
||||
|
||||
impl Rpc {
|
||||
|
@ -74,158 +99,247 @@ impl Rpc {
|
|||
// Create a new RPC variables store
|
||||
let vars = BTreeMap::new();
|
||||
// Set the default output format
|
||||
let format = Output::Json;
|
||||
// Create a unique WebSocket id
|
||||
let uuid = Uuid::new_v4();
|
||||
// Enable real-time live queries
|
||||
let format = OutputFormat::Json;
|
||||
// Enable real-time mode
|
||||
session.rt = true;
|
||||
// Create and store the Rpc connection
|
||||
Arc::new(RwLock::new(Rpc {
|
||||
session,
|
||||
format,
|
||||
uuid,
|
||||
ws_id: Uuid::new_v4(),
|
||||
vars,
|
||||
graceful_shutdown: CancellationToken::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Serve the RPC endpoint
|
||||
pub async fn serve(rpc: Arc<RwLock<Rpc>>, ws: WebSocket) {
|
||||
// Create a channel for sending messages
|
||||
let (chn, mut rcv) = channel::new(MAX_CONCURRENT_CALLS);
|
||||
// Split the socket into send and recv
|
||||
let (mut wtx, mut wrx) = ws.split();
|
||||
// Clone the channel for sending pings
|
||||
let png = chn.clone();
|
||||
// The WebSocket has connected
|
||||
Rpc::connected(rpc.clone(), chn.clone()).await;
|
||||
// Send Ping messages to the client
|
||||
tokio::task::spawn(async move {
|
||||
// Create the interval ticker
|
||||
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
|
||||
// Loop indefinitely
|
||||
loop {
|
||||
// Wait for the timer
|
||||
interval.tick().await;
|
||||
// Create the ping message
|
||||
let msg = Message::Ping(vec![]);
|
||||
// Send the message to the client
|
||||
if png.send(msg).await.is_err() {
|
||||
// Exit out of the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
// Send messages to the client
|
||||
tokio::task::spawn(async move {
|
||||
// Wait for the next message to send
|
||||
while let Some(res) = rcv.next().await {
|
||||
// Send the message to the client
|
||||
if let Err(err) = wtx.send(res).await {
|
||||
// Output the WebSocket error to the logs
|
||||
trace!("WebSocket error: {:?}", err);
|
||||
// It's already failed, so ignore error
|
||||
let _ = wtx.close().await;
|
||||
// Exit out of the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
// Send notifications to the client
|
||||
let moved_rpc = rpc.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let rpc = moved_rpc;
|
||||
if let Some(channel) = DB.get().unwrap().notifications() {
|
||||
while let Ok(notification) = channel.recv().await {
|
||||
// Find which WebSocket the notification belongs to
|
||||
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||
// Check to see if the WebSocket exists
|
||||
if let Some(websocket) = WEBSOCKETS.read().await.get(ws_id) {
|
||||
// Serialize the message to send
|
||||
let message = res::success(None, notification);
|
||||
// Get the current output format
|
||||
let format = rpc.read().await.format.clone();
|
||||
// Send the notification to the client
|
||||
message.send(format, websocket.clone()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
// Get messages from the client
|
||||
while let Some(msg) = wrx.next().await {
|
||||
match msg {
|
||||
// We've received a message from the client
|
||||
// Ping is automatically handled by the WebSocket library
|
||||
Ok(msg) => match msg {
|
||||
Message::Text(_) => {
|
||||
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
|
||||
}
|
||||
Message::Binary(_) => {
|
||||
tokio::task::spawn(Rpc::call(rpc.clone(), msg, chn.clone()));
|
||||
}
|
||||
Message::Close(_) => {
|
||||
break;
|
||||
}
|
||||
Message::Pong(_) => {
|
||||
continue;
|
||||
}
|
||||
_ => {
|
||||
// Ignore everything else
|
||||
}
|
||||
},
|
||||
// There was an error receiving the message
|
||||
Err(err) => {
|
||||
// Output the WebSocket error to the logs
|
||||
trace!("WebSocket error: {:?}", err);
|
||||
// Exit out of the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// The WebSocket has disconnected
|
||||
Rpc::disconnected(rpc.clone()).await;
|
||||
}
|
||||
let (sender, receiver) = ws.split();
|
||||
// Create an internal channel between the receiver and the sender
|
||||
let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS);
|
||||
|
||||
let ws_id = rpc.read().await.ws_id;
|
||||
|
||||
async fn connected(rpc: Arc<RwLock<Rpc>>, chn: Sender<Message>) {
|
||||
// Fetch the unique id of the WebSocket
|
||||
let id = rpc.read().await.uuid;
|
||||
// Log that the WebSocket has connected
|
||||
trace!("WebSocket {} connected", id);
|
||||
// Store this WebSocket in the list of WebSockets
|
||||
WEBSOCKETS.write().await.insert(id, chn);
|
||||
}
|
||||
WEBSOCKETS.write().await.insert(
|
||||
ws_id,
|
||||
WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()),
|
||||
);
|
||||
|
||||
trace!("WebSocket {} connected", ws_id);
|
||||
|
||||
// Wait until all tasks finish
|
||||
tokio::join!(
|
||||
Self::ping(rpc.clone(), internal_sender.clone()),
|
||||
Self::read(rpc.clone(), receiver, internal_sender.clone()),
|
||||
Self::write(rpc.clone(), sender, internal_receiver.clone()),
|
||||
Self::lq_notifications(rpc.clone()),
|
||||
);
|
||||
|
||||
async fn disconnected(rpc: Arc<RwLock<Rpc>>) {
|
||||
// Fetch the unique id of the WebSocket
|
||||
let id = rpc.read().await.uuid;
|
||||
// Log that the WebSocket has disconnected
|
||||
trace!("WebSocket {} disconnected", id);
|
||||
// Remove this WebSocket from the list of WebSockets
|
||||
WEBSOCKETS.write().await.remove(&id);
|
||||
// Remove all live queries
|
||||
LIVE_QUERIES.write().await.retain(|key, value| {
|
||||
if value == &id {
|
||||
if value == &ws_id {
|
||||
trace!("Removing live query: {}", key);
|
||||
return false;
|
||||
}
|
||||
true
|
||||
});
|
||||
|
||||
// Remove this WebSocket from the list of WebSockets
|
||||
WEBSOCKETS.write().await.remove(&ws_id);
|
||||
|
||||
trace!("WebSocket {} disconnected", ws_id);
|
||||
}
|
||||
|
||||
/// Call RPC methods from the WebSocket
|
||||
async fn call(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
|
||||
/// Send Ping messages to the client
|
||||
async fn ping(rpc: Arc<RwLock<Rpc>>, internal_sender: Sender<Message>) {
|
||||
// Create the interval ticker
|
||||
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
|
||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||
loop {
|
||||
let is_shutdown = cancel_token.cancelled();
|
||||
tokio::select! {
|
||||
_ = interval.tick() => {
|
||||
let msg = Message::Ping(vec![]);
|
||||
|
||||
// Send the message to the client and close the WebSocket connection if it fails
|
||||
if internal_sender.send(msg).await.is_err() {
|
||||
rpc.read().await.graceful_shutdown.cancel();
|
||||
break;
|
||||
}
|
||||
},
|
||||
_ = is_shutdown => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read messages sent from the client
|
||||
async fn read(
|
||||
rpc: Arc<RwLock<Rpc>>,
|
||||
mut receiver: SplitStream<WebSocket>,
|
||||
internal_sender: Sender<Message>,
|
||||
) {
|
||||
// Collect all spawned tasks so we can wait for them at the end
|
||||
let mut tasks = JoinSet::new();
|
||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||
loop {
|
||||
let is_shutdown = cancel_token.cancelled();
|
||||
tokio::select! {
|
||||
msg = receiver.next() => {
|
||||
if let Some(msg) = msg {
|
||||
match msg {
|
||||
// We've received a message from the client
|
||||
// Ping/Pong is automatically handled by the WebSocket library
|
||||
Ok(msg) => match msg {
|
||||
Message::Text(_) => {
|
||||
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
||||
}
|
||||
Message::Binary(_) => {
|
||||
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
||||
}
|
||||
Message::Close(_) => {
|
||||
// Respond with a close message
|
||||
if let Err(err) = internal_sender.send(Message::Close(None)).await {
|
||||
trace!("WebSocket error when replying to the Close frame: {:?}", err);
|
||||
};
|
||||
// Start the graceful shutdown of the WebSocket and close the channels
|
||||
rpc.read().await.graceful_shutdown.cancel();
|
||||
let _ = internal_sender.close();
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
// Ignore everything else
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
trace!("WebSocket error: {:?}", err);
|
||||
// Start the graceful shutdown of the WebSocket and close the channels
|
||||
rpc.read().await.graceful_shutdown.cancel();
|
||||
let _ = internal_sender.close();
|
||||
// Exit out of the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = is_shutdown => break,
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tasks to finish
|
||||
while let Some(res) = tasks.join_next().await {
|
||||
if let Err(err) = res {
|
||||
error!("Error while handling RPC message: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write messages to the client
|
||||
async fn write(
|
||||
rpc: Arc<RwLock<Rpc>>,
|
||||
mut sender: SplitSink<WebSocket, Message>,
|
||||
mut internal_receiver: Receiver<Message>,
|
||||
) {
|
||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||
loop {
|
||||
let is_shutdown = cancel_token.cancelled();
|
||||
tokio::select! {
|
||||
// Wait for the next message to send
|
||||
msg = internal_receiver.next() => {
|
||||
if let Some(res) = msg {
|
||||
// Send the message to the client
|
||||
if let Err(err) = sender.send(res).await {
|
||||
if err.to_string() != CONN_CLOSED_ERR {
|
||||
debug!("WebSocket error: {:?}", err);
|
||||
}
|
||||
// Close the WebSocket connection
|
||||
rpc.read().await.graceful_shutdown.cancel();
|
||||
// Exit out of the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
_ = is_shutdown => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send live query notifications to the client
|
||||
async fn lq_notifications(rpc: Arc<RwLock<Rpc>>) {
|
||||
if let Some(channel) = DB.get().unwrap().notifications() {
|
||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = channel.recv() => {
|
||||
if let Ok(notification) = msg {
|
||||
// Find which WebSocket the notification belongs to
|
||||
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||
// Check to see if the WebSocket exists
|
||||
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
|
||||
// Serialize the message to send
|
||||
let message = res::success(None, notification);
|
||||
// Get the current output format
|
||||
let mut out = { rpc.read().await.format.clone() };
|
||||
// Clone the RPC
|
||||
let rpc = rpc.clone();
|
||||
let format = rpc.read().await.format.clone();
|
||||
// Send the notification to the client
|
||||
message.send(format, ws.clone()).await
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
_ = cancel_token.cancelled() => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle individual WebSocket messages
|
||||
async fn handle_msg(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
|
||||
// Get the current output format
|
||||
let mut out_fmt = rpc.read().await.format.clone();
|
||||
let span = span_for_request(&rpc.read().await.ws_id);
|
||||
let _enter = span.enter();
|
||||
// Parse the request
|
||||
match Self::parse_request(msg).await {
|
||||
Ok((id, method, params, _out_fmt)) => {
|
||||
span.record(
|
||||
"rpc.jsonrpc.request_id",
|
||||
id.clone().map(|v| v.as_string()).unwrap_or(String::new()),
|
||||
);
|
||||
if let Some(_out_fmt) = _out_fmt {
|
||||
out_fmt = _out_fmt;
|
||||
}
|
||||
|
||||
// Process the request
|
||||
let res = Self::process_request(rpc.clone(), &method, params).await;
|
||||
|
||||
// Process the response
|
||||
res.into_response(id).send(out_fmt, chn).await
|
||||
}
|
||||
Err(err) => {
|
||||
// Process the response
|
||||
res::failure(None, err).send(out_fmt, chn).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_request(
|
||||
msg: Message,
|
||||
) -> Result<(Option<Value>, String, Array, Option<OutputFormat>), Failure> {
|
||||
let mut out_fmt = None;
|
||||
let req = match msg {
|
||||
// This is a binary message
|
||||
Message::Binary(val) => {
|
||||
// Use binary output
|
||||
out = Output::Full;
|
||||
// Deserialize the input
|
||||
Value::from(val)
|
||||
out_fmt = Some(OutputFormat::Full);
|
||||
|
||||
match deserialize(&val) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
debug!("Error when trying to deserialize the request");
|
||||
return Err(Failure::PARSE_ERROR);
|
||||
}
|
||||
}
|
||||
}
|
||||
// This is a text message
|
||||
Message::Text(ref val) => {
|
||||
|
@ -234,14 +348,15 @@ impl Rpc {
|
|||
// The SurrealQL message parsed ok
|
||||
Ok(v) => v,
|
||||
// The SurrealQL message failed to parse
|
||||
_ => return res::failure(None, Failure::PARSE_ERROR).send(out, chn).await,
|
||||
_ => return Err(Failure::PARSE_ERROR),
|
||||
}
|
||||
}
|
||||
// Unsupported message type
|
||||
_ => return res::failure(None, Failure::INTERNAL_ERROR).send(out, chn).await,
|
||||
_ => {
|
||||
debug!("Unsupported message type: {:?}", msg);
|
||||
return Err(res::Failure::custom("Unsupported message type"));
|
||||
}
|
||||
};
|
||||
// Log the received request
|
||||
trace!("RPC Received: {}", req);
|
||||
// Fetch the 'id' argument
|
||||
let id = match req.pick(&*ID) {
|
||||
v if v.is_none() => None,
|
||||
|
@ -250,149 +365,180 @@ impl Rpc {
|
|||
v if v.is_number() => Some(v),
|
||||
v if v.is_strand() => Some(v),
|
||||
v if v.is_datetime() => Some(v),
|
||||
_ => return res::failure(None, Failure::INVALID_REQUEST).send(out, chn).await,
|
||||
_ => return Err(Failure::INVALID_REQUEST),
|
||||
};
|
||||
// Fetch the 'method' argument
|
||||
let method = match req.pick(&*METHOD) {
|
||||
Value::Strand(v) => v.to_raw(),
|
||||
_ => return res::failure(id, Failure::INVALID_REQUEST).send(out, chn).await,
|
||||
_ => return Err(Failure::INVALID_REQUEST),
|
||||
};
|
||||
|
||||
// Now that we know the method, we can update the span
|
||||
Span::current().record("rpc.method", &method);
|
||||
Span::current().record("otel.name", format!("surrealdb.rpc/{}", method));
|
||||
|
||||
// Fetch the 'params' argument
|
||||
let params = match req.pick(&*PARAMS) {
|
||||
Value::Array(v) => v,
|
||||
_ => Array::new(),
|
||||
};
|
||||
|
||||
Ok((id, method, params, out_fmt))
|
||||
}
|
||||
|
||||
async fn process_request(
|
||||
rpc: Arc<RwLock<Rpc>>,
|
||||
method: &str,
|
||||
params: Array,
|
||||
) -> Result<Data, Failure> {
|
||||
info!("Process RPC request");
|
||||
|
||||
// Match the method to a function
|
||||
let res = match &method[..] {
|
||||
// Handle a ping message
|
||||
"ping" => Ok(Value::None),
|
||||
match method {
|
||||
// Handle a surrealdb ping message
|
||||
//
|
||||
// This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough.
|
||||
// For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios.
|
||||
// By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent.
|
||||
//
|
||||
"ping" => Ok(Value::None.into()),
|
||||
// Retrieve the current auth record
|
||||
"info" => match params.len() {
|
||||
0 => rpc.read().await.info().await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into),
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Switch to a specific namespace and database
|
||||
"use" => match params.needs_two() {
|
||||
Ok((ns, db)) => rpc.write().await.yuse(ns, db).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok((ns, db)) => {
|
||||
rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Signup to a specific authentication scope
|
||||
"signup" => match params.needs_one() {
|
||||
Ok(Value::Object(v)) => rpc.write().await.signup(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(Value::Object(v)) => {
|
||||
rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Signin as a root, namespace, database or scope user
|
||||
"signin" => match params.needs_one() {
|
||||
Ok(Value::Object(v)) => rpc.write().await.signin(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(Value::Object(v)) => {
|
||||
rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Invalidate the current authentication session
|
||||
"invalidate" => match params.len() {
|
||||
0 => rpc.write().await.invalidate().await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into),
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Authenticate using an authentication token
|
||||
"authenticate" => match params.needs_one() {
|
||||
Ok(Value::Strand(v)) => rpc.write().await.authenticate(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(Value::Strand(v)) => {
|
||||
rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Kill a live query using a query id
|
||||
"kill" => match params.needs_one() {
|
||||
Ok(v) if v.is_uuid() => rpc.read().await.kill(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(v) if v.is_uuid() => {
|
||||
rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Setup a live query on a specific table
|
||||
"live" => match params.needs_one_or_two() {
|
||||
Ok((v, d)) if v.is_table() => rpc.read().await.live(v, d).await,
|
||||
Ok((v, d)) if v.is_strand() => rpc.read().await.live(v, d).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok((v, d)) if v.is_table() => {
|
||||
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
Ok((v, d)) if v.is_strand() => {
|
||||
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Specify a connection-wide parameter
|
||||
"let" => match params.needs_one_or_two() {
|
||||
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
},
|
||||
// Specify a connection-wide parameter
|
||||
"set" => match params.needs_one_or_two() {
|
||||
Ok((Value::Strand(s), v)) => rpc.write().await.set(s, v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
"let" | "set" => match params.needs_one_or_two() {
|
||||
Ok((Value::Strand(s), v)) => {
|
||||
rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Unset and clear a connection-wide parameter
|
||||
"unset" => match params.needs_one() {
|
||||
Ok(Value::Strand(s)) => rpc.write().await.unset(s).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(Value::Strand(s)) => {
|
||||
rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Select a value or values from the database
|
||||
"select" => match params.needs_one() {
|
||||
Ok(v) => rpc.read().await.select(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into),
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Insert a value or values in the database
|
||||
"insert" => match params.needs_one_or_two() {
|
||||
Ok((v, o)) => rpc.read().await.insert(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok((v, o)) => {
|
||||
rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Create a value or values in the database
|
||||
"create" => match params.needs_one_or_two() {
|
||||
Ok((v, o)) => rpc.read().await.create(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok((v, o)) => {
|
||||
rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Update a value or values in the database using `CONTENT`
|
||||
"update" => match params.needs_one_or_two() {
|
||||
Ok((v, o)) => rpc.read().await.update(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok((v, o)) => {
|
||||
rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Update a value or values in the database using `MERGE`
|
||||
"change" | "merge" => match params.needs_one_or_two() {
|
||||
Ok((v, o)) => rpc.read().await.change(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok((v, o)) => {
|
||||
rpc.read().await.change(v, o).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Update a value or values in the database using `PATCH`
|
||||
"modify" | "patch" => match params.needs_one_or_two() {
|
||||
Ok((v, o)) => rpc.read().await.modify(v, o).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok((v, o)) => {
|
||||
rpc.read().await.modify(v, o).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Delete a value or values from the database
|
||||
"delete" => match params.needs_one() {
|
||||
Ok(v) => rpc.read().await.delete(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Specify the output format for text requests
|
||||
"format" => match params.needs_one() {
|
||||
Ok(Value::Strand(v)) => rpc.write().await.format(v).await,
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
Ok(Value::Strand(v)) => {
|
||||
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Get the current server version
|
||||
"version" => match params.len() {
|
||||
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
// Run a full SurrealQL query against the database
|
||||
"query" => match params.needs_one_or_two() {
|
||||
Ok((Value::Strand(s), o)) if o.is_none_or_null() => {
|
||||
return match rpc.read().await.query(s).await {
|
||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
||||
Err(e) => {
|
||||
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
|
||||
}
|
||||
};
|
||||
rpc.read().await.query(s).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
Ok((Value::Strand(s), Value::Object(o))) => {
|
||||
return match rpc.read().await.query_with(s, o).await {
|
||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
||||
Err(e) => {
|
||||
res::failure(id, Failure::custom(e.to_string())).send(out, chn).await
|
||||
rpc.read().await.query_with(s, o).await.map(Into::into).map_err(Into::into)
|
||||
}
|
||||
};
|
||||
}
|
||||
_ => return res::failure(id, Failure::INVALID_PARAMS).send(out, chn).await,
|
||||
_ => Err(Failure::INVALID_PARAMS),
|
||||
},
|
||||
_ => return res::failure(id, Failure::METHOD_NOT_FOUND).send(out, chn).await,
|
||||
};
|
||||
// Return the final response
|
||||
match res {
|
||||
Ok(v) => res::success(id, v).send(out, chn).await,
|
||||
Err(e) => res::failure(id, Failure::custom(e.to_string())).send(out, chn).await,
|
||||
_ => Err(Failure::METHOD_NOT_FOUND),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -402,15 +548,14 @@ impl Rpc {
|
|||
|
||||
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
||||
match out.as_str() {
|
||||
"json" | "application/json" => self.format = Output::Json,
|
||||
"cbor" | "application/cbor" => self.format = Output::Cbor,
|
||||
"pack" | "application/pack" => self.format = Output::Pack,
|
||||
"json" | "application/json" => self.format = OutputFormat::Json,
|
||||
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
|
||||
"pack" | "application/pack" => self.format = OutputFormat::Pack,
|
||||
_ => return Err(Error::InvalidType),
|
||||
};
|
||||
Ok(Value::None)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "rpc use", fields(websocket=self.uuid.to_string()))]
|
||||
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
||||
if let Value::Strand(ns) = ns {
|
||||
self.session.ns = Some(ns.0);
|
||||
|
@ -421,7 +566,6 @@ impl Rpc {
|
|||
Ok(Value::None)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "rpc signup", fields(websocket=self.uuid.to_string()))]
|
||||
async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
|
||||
let kvs = DB.get().unwrap();
|
||||
surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
|
||||
|
@ -430,7 +574,6 @@ impl Rpc {
|
|||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "rpc signin", fields(websocket=self.uuid.to_string()))]
|
||||
async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
|
||||
let kvs = DB.get().unwrap();
|
||||
surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
|
||||
|
@ -438,13 +581,11 @@ impl Rpc {
|
|||
.map(Into::into)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
#[instrument(skip_all, name = "rpc invalidate", fields(websocket=self.uuid.to_string()))]
|
||||
async fn invalidate(&mut self) -> Result<Value, Error> {
|
||||
surrealdb::iam::clear::clear(&mut self.session)?;
|
||||
Ok(Value::None)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "rpc auth", fields(websocket=self.uuid.to_string()))]
|
||||
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
|
||||
let kvs = DB.get().unwrap();
|
||||
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
|
||||
|
@ -455,7 +596,6 @@ impl Rpc {
|
|||
// Methods for identification
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc info", fields(websocket=self.uuid.to_string()))]
|
||||
async fn info(&self) -> Result<Value, Error> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
|
@ -473,7 +613,6 @@ impl Rpc {
|
|||
// Methods for setting variables
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc set", fields(websocket=self.uuid.to_string()))]
|
||||
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
|
||||
match val {
|
||||
// Remove the variable if undefined
|
||||
|
@ -484,7 +623,6 @@ impl Rpc {
|
|||
Ok(Value::Null)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "rpc unset", fields(websocket=self.uuid.to_string()))]
|
||||
async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
|
||||
self.vars.remove(&key.0);
|
||||
Ok(Value::Null)
|
||||
|
@ -494,7 +632,6 @@ impl Rpc {
|
|||
// Methods for live queries
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc kill", fields(websocket=self.uuid.to_string()))]
|
||||
async fn kill(&self, id: Value) -> Result<Value, Error> {
|
||||
// Specify the SQL query string
|
||||
let sql = "KILL $id";
|
||||
|
@ -513,7 +650,6 @@ impl Rpc {
|
|||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "rpc live", fields(websocket=self.uuid.to_string()))]
|
||||
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
|
||||
// Specify the SQL query string
|
||||
let sql = match diff.is_true() {
|
||||
|
@ -539,7 +675,6 @@ impl Rpc {
|
|||
// Methods for selecting
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc select", fields(websocket=self.uuid.to_string()))]
|
||||
async fn select(&self, what: Value) -> Result<Value, Error> {
|
||||
// Return a single result?
|
||||
let one = what.is_thing();
|
||||
|
@ -567,7 +702,6 @@ impl Rpc {
|
|||
// Methods for inserting
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc insert", fields(websocket=self.uuid.to_string()))]
|
||||
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||
// Return a single result?
|
||||
let one = what.is_thing();
|
||||
|
@ -596,7 +730,6 @@ impl Rpc {
|
|||
// Methods for creating
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc create", fields(websocket=self.uuid.to_string()))]
|
||||
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||
// Return a single result?
|
||||
let one = what.is_thing();
|
||||
|
@ -625,7 +758,6 @@ impl Rpc {
|
|||
// Methods for updating
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc update", fields(websocket=self.uuid.to_string()))]
|
||||
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||
// Return a single result?
|
||||
let one = what.is_thing();
|
||||
|
@ -654,7 +786,6 @@ impl Rpc {
|
|||
// Methods for changing
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc change", fields(websocket=self.uuid.to_string()))]
|
||||
async fn change(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||
// Return a single result?
|
||||
let one = what.is_thing();
|
||||
|
@ -683,7 +814,6 @@ impl Rpc {
|
|||
// Methods for modifying
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc modify", fields(websocket=self.uuid.to_string()))]
|
||||
async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||
// Return a single result?
|
||||
let one = what.is_thing();
|
||||
|
@ -712,7 +842,6 @@ impl Rpc {
|
|||
// Methods for deleting
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc delete", fields(websocket=self.uuid.to_string()))]
|
||||
async fn delete(&self, what: Value) -> Result<Value, Error> {
|
||||
// Return a single result?
|
||||
let one = what.is_thing();
|
||||
|
@ -740,7 +869,6 @@ impl Rpc {
|
|||
// Methods for querying
|
||||
// ------------------------------
|
||||
|
||||
#[instrument(skip_all, name = "rpc query", fields(websocket=self.uuid.to_string()))]
|
||||
async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
|
@ -756,7 +884,6 @@ impl Rpc {
|
|||
Ok(res)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "rpc query_with", fields(websocket=self.uuid.to_string()))]
|
||||
async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
|
@ -781,8 +908,8 @@ impl Rpc {
|
|||
QueryType::Live => {
|
||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||
// Match on Uuid type
|
||||
LIVE_QUERIES.write().await.insert(lqid.0, self.uuid);
|
||||
trace!("Registered live query {} on websocket {}", lqid, self.uuid);
|
||||
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
|
||||
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
|
||||
}
|
||||
}
|
||||
QueryType::Kill => {
|
||||
|
|
|
@ -1,17 +1,57 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use axum_server::Handle;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::err::Error;
|
||||
use crate::{
|
||||
err::Error,
|
||||
net::rpc::{WebSocketRef, WEBSOCKETS},
|
||||
};
|
||||
|
||||
/// Start a graceful shutdown on the Axum Handle when a shutdown signal is received.
|
||||
pub fn graceful_shutdown(handle: Handle, dur: Option<Duration>) {
|
||||
/// Start a graceful shutdown:
|
||||
/// * Signal the Axum Handle when a shutdown signal is received.
|
||||
/// * Stop all WebSocket connections.
|
||||
///
|
||||
/// A second signal will force an immediate shutdown.
|
||||
pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let result = listen().await.expect("Failed to listen to shutdown signal");
|
||||
info!(target: super::LOG, "{} received. Start graceful shutdown...", result);
|
||||
info!(target: super::LOG, "{} received. Waiting for graceful shutdown... A second signal will force an immediate shutdown", result);
|
||||
|
||||
handle.graceful_shutdown(dur)
|
||||
});
|
||||
tokio::select! {
|
||||
// Start a normal graceful shutdown
|
||||
_ = async {
|
||||
// First stop accepting new HTTP requests
|
||||
http_handle.graceful_shutdown(None);
|
||||
|
||||
// Close all WebSocket connections. Queued messages will still be processed.
|
||||
for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() {
|
||||
cancel_token.cancel();
|
||||
};
|
||||
|
||||
// Wait for all existing WebSocket connections to gracefully close
|
||||
while WEBSOCKETS.read().await.len() > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
};
|
||||
} => (),
|
||||
// Force an immediate shutdown if a second signal is received
|
||||
_ = async {
|
||||
if let Ok(signal) = listen().await {
|
||||
warn!(target: super::LOG, "{} received during graceful shutdown. Terminate immediately...", signal);
|
||||
} else {
|
||||
error!(target: super::LOG, "Failed to listen to shutdown signal. Terminate immediately...");
|
||||
}
|
||||
|
||||
// Force an immediate shutdown
|
||||
http_handle.shutdown();
|
||||
|
||||
// Close all WebSocket connections immediately
|
||||
if let Ok(mut writer) = WEBSOCKETS.try_write() {
|
||||
writer.drain();
|
||||
}
|
||||
} => (),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
|
|
|
@ -1,119 +1,15 @@
|
|||
use std::{fmt, time::Duration};
|
||||
|
||||
use axum::{
|
||||
body::{boxed, Body, BoxBody},
|
||||
extract::MatchedPath,
|
||||
headers::{
|
||||
authorization::{Basic, Bearer},
|
||||
Authorization, Origin,
|
||||
},
|
||||
Extension, RequestPartsExt, TypedHeader,
|
||||
};
|
||||
use futures_util::future::BoxFuture;
|
||||
use http::{header, request::Parts, StatusCode};
|
||||
use axum::extract::MatchedPath;
|
||||
use http::header;
|
||||
use hyper::{Request, Response};
|
||||
use surrealdb::{
|
||||
dbs::Session,
|
||||
iam::verify::{basic, token},
|
||||
};
|
||||
use tower_http::{
|
||||
auth::AsyncAuthorizeRequest,
|
||||
request_id::RequestId,
|
||||
trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
|
||||
};
|
||||
use tracing::{field, Level, Span};
|
||||
|
||||
use crate::{dbs::DB, err::Error};
|
||||
|
||||
use super::{client_ip::ExtractClientIP, AppState};
|
||||
|
||||
///
|
||||
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
|
||||
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
|
||||
///
|
||||
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
|
||||
///
|
||||
/// ```rust
|
||||
/// use tower_http::auth::RequireAuthorizationLayer;
|
||||
/// use surrealdb::net::SurrealAuth;
|
||||
/// use axum::Router;
|
||||
///
|
||||
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// .route("/version", get(|| async { "0.1.0" }))
|
||||
/// .layer(auth);
|
||||
/// ```
|
||||
#[derive(Clone, Copy)]
|
||||
pub(super) struct SurrealAuth;
|
||||
|
||||
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type RequestBody = B;
|
||||
type ResponseBody = BoxBody;
|
||||
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
|
||||
|
||||
fn authorize(&mut self, request: Request<B>) -> Self::Future {
|
||||
Box::pin(async {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
match check_auth(&mut parts).await {
|
||||
Ok(sess) => {
|
||||
parts.extensions.insert(sess);
|
||||
Ok(Request::from_parts(parts, body))
|
||||
}
|
||||
Err(err) => {
|
||||
let unauthorized_response = Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(boxed(Body::from(err.to_string())))
|
||||
.unwrap();
|
||||
Err(unauthorized_response)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
|
||||
let kvs = DB.get().unwrap();
|
||||
|
||||
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
|
||||
if !or.is_null() {
|
||||
Some(or.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
|
||||
|
||||
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
|
||||
tracing::error!("Error extracting the app state: {:?}", err);
|
||||
Error::InvalidAuth
|
||||
})?;
|
||||
let ExtractClientIP(ip) =
|
||||
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
|
||||
|
||||
// Create session
|
||||
#[rustfmt::skip]
|
||||
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
|
||||
|
||||
// If Basic authentication data was supplied
|
||||
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
|
||||
basic(kvs, &mut session, au.username(), au.password()).await.map_err(|e| e.into())
|
||||
} else if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
|
||||
token(kvs, &mut session, au.token()).await.map_err(|e| e.into())
|
||||
} else {
|
||||
Err(Error::InvalidAuth)
|
||||
}?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
use super::client_ip::ExtractClientIP;
|
||||
|
||||
///
|
||||
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
|
||||
|
@ -139,7 +35,6 @@ impl<B> MakeSpan<B> for HttpTraceLayerHooks {
|
|||
fn make_span(&mut self, req: &Request<B>) -> Span {
|
||||
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
|
||||
let span = tracing::info_span!(
|
||||
target: "surreal::http",
|
||||
"request",
|
||||
otel.name = field::Empty,
|
||||
otel.kind = "server",
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
pub mod args;
|
||||
pub mod paths;
|
||||
pub mod res;
|
||||
|
||||
pub(crate) static CONN_CLOSED_ERR: &str = "Connection closed normally";
|
||||
|
|
|
@ -7,10 +7,13 @@ use surrealdb::dbs;
|
|||
use surrealdb::dbs::Notification;
|
||||
use surrealdb::sql;
|
||||
use surrealdb::sql::Value;
|
||||
use tracing::instrument;
|
||||
use tracing::Span;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Output {
|
||||
use crate::err;
|
||||
use crate::rpc::CONN_CLOSED_ERR;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OutputFormat {
|
||||
Json, // JSON
|
||||
Cbor, // CBOR
|
||||
Pack, // MessagePack
|
||||
|
@ -37,6 +40,12 @@ impl From<Value> for Data {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<String> for Data {
|
||||
fn from(v: String) -> Self {
|
||||
Data::Other(Value::from(v))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<dbs::Response>> for Data {
|
||||
fn from(v: Vec<dbs::Response>) -> Self {
|
||||
Data::Query(v)
|
||||
|
@ -82,28 +91,45 @@ impl Response {
|
|||
}
|
||||
|
||||
/// Send the response to the WebSocket channel
|
||||
#[instrument(skip_all, name = "rpc response", fields(response = ?self))]
|
||||
pub async fn send(self, out: Output, chn: Sender<Message>) {
|
||||
pub async fn send(self, out: OutputFormat, chn: Sender<Message>) {
|
||||
let span = Span::current();
|
||||
|
||||
info!("Process RPC response");
|
||||
|
||||
if let Err(err) = &self.result {
|
||||
span.record("otel.status_code", "Error");
|
||||
span.record(
|
||||
"otel.status_message",
|
||||
format!("code: {}, message: {}", err.code, err.message),
|
||||
);
|
||||
span.record("rpc.jsonrpc.error_code", err.code);
|
||||
span.record("rpc.jsonrpc.error_message", err.message.as_ref());
|
||||
}
|
||||
|
||||
let message = match out {
|
||||
Output::Json => {
|
||||
OutputFormat::Json => {
|
||||
let res = serde_json::to_string(&self.simplify()).unwrap();
|
||||
Message::Text(res)
|
||||
}
|
||||
Output::Cbor => {
|
||||
OutputFormat::Cbor => {
|
||||
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
||||
Message::Binary(res)
|
||||
}
|
||||
Output::Pack => {
|
||||
OutputFormat::Pack => {
|
||||
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
||||
Message::Binary(res)
|
||||
}
|
||||
Output::Full => {
|
||||
OutputFormat::Full => {
|
||||
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
||||
Message::Binary(res)
|
||||
}
|
||||
};
|
||||
let _ = chn.send(message).await;
|
||||
trace!("Response sent");
|
||||
|
||||
if let Err(err) = chn.send(message).await {
|
||||
if err.to_string() != CONN_CLOSED_ERR {
|
||||
error!("Error sending response: {}", err);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -113,6 +139,7 @@ pub struct Failure {
|
|||
message: Cow<'static, str>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl Failure {
|
||||
pub const PARSE_ERROR: Failure = Failure {
|
||||
code: -32700,
|
||||
|
@ -165,3 +192,26 @@ pub fn failure(id: Option<Value>, err: Failure) -> Response {
|
|||
result: Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
impl From<err::Error> for Failure {
|
||||
fn from(err: err::Error) -> Self {
|
||||
Failure::custom(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoRpcResponse {
|
||||
fn into_response(self, id: Option<Value>) -> Response;
|
||||
}
|
||||
|
||||
impl<T, E> IntoRpcResponse for Result<T, E>
|
||||
where
|
||||
T: Into<Data>,
|
||||
E: Into<Failure>,
|
||||
{
|
||||
fn into_response(self, id: Option<Value>) -> Response {
|
||||
match self {
|
||||
Ok(v) => success(id, v.into()),
|
||||
Err(err) => failure(id, err.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use tracing::Subscriber;
|
||||
use tracing_subscriber::fmt::format::FmtSpan;
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
use tracing_subscriber::Layer;
|
||||
|
||||
pub fn new<S>(level: String) -> Box<dyn Layer<S> + Send + Sync>
|
||||
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
||||
|
||||
pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
|
||||
where
|
||||
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
||||
{
|
||||
|
@ -11,6 +13,6 @@ where
|
|||
.with_ansi(true)
|
||||
.with_span_events(FmtSpan::NONE)
|
||||
.with_writer(std::io::stderr)
|
||||
.with_filter(EnvFilter::builder().parse(level).unwrap())
|
||||
.with_filter(filter.0)
|
||||
.boxed()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
mod logs;
|
||||
pub mod metrics;
|
||||
mod traces;
|
||||
pub mod traces;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
|
@ -11,8 +11,7 @@ use opentelemetry::sdk::resource::{
|
|||
};
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::KeyValue;
|
||||
use tracing::Subscriber;
|
||||
use tracing_subscriber::fmt::format::FmtSpan;
|
||||
use tracing::{Level, Subscriber};
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
#[cfg(feature = "has-storage")]
|
||||
|
@ -39,53 +38,75 @@ pub static OTEL_DEFAULT_RESOURCE: Lazy<Resource> = Lazy::new(|| {
|
|||
}
|
||||
});
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Builder {
|
||||
log_level: Option<String>,
|
||||
filter: Option<CustomEnvFilter>,
|
||||
filter: CustomEnvFilter,
|
||||
}
|
||||
|
||||
pub fn builder() -> Builder {
|
||||
Builder::default()
|
||||
}
|
||||
|
||||
impl Default for Builder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
filter: CustomEnvFilter(EnvFilter::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
/// Set the log level on the builder
|
||||
pub fn with_log_level(mut self, log_level: &str) -> Self {
|
||||
self.log_level = Some(log_level.to_string());
|
||||
if let Ok(filter) = filter_from_value(log_level) {
|
||||
self.filter = CustomEnvFilter(filter);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the filter on the builder
|
||||
#[cfg(feature = "has-storage")]
|
||||
pub fn with_filter(mut self, filter: EnvFilter) -> Self {
|
||||
self.filter = Some(CustomEnvFilter(filter));
|
||||
pub fn with_filter(mut self, filter: CustomEnvFilter) -> Self {
|
||||
self.filter = filter;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build a tracing dispatcher with the fmt subscriber (logs) and the chosen tracer subscriber
|
||||
pub fn build(self) -> Box<dyn Subscriber + Send + Sync + 'static> {
|
||||
let registry = tracing_subscriber::registry();
|
||||
let registry = registry.with(self.filter.map(|filter| {
|
||||
tracing_subscriber::fmt::layer()
|
||||
.compact()
|
||||
.with_ansi(true)
|
||||
.with_span_events(FmtSpan::NONE)
|
||||
.with_writer(std::io::stderr)
|
||||
.with_filter(filter.0)
|
||||
.boxed()
|
||||
}));
|
||||
let registry = registry.with(self.log_level.map(logs::new));
|
||||
let registry = registry.with(traces::new());
|
||||
|
||||
// Setup logging layer
|
||||
let registry = registry.with(logs::new(self.filter.clone()));
|
||||
|
||||
// Setup tracing layer
|
||||
let registry = registry.with(traces::new(self.filter));
|
||||
|
||||
Box::new(registry)
|
||||
}
|
||||
|
||||
/// tracing pipeline
|
||||
/// Install the tracing dispatcher globally
|
||||
pub fn init(self) {
|
||||
self.build().init()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an EnvFilter from the given value. If the value is not a valid log level, it will be treated as EnvFilter directives.
|
||||
pub fn filter_from_value(v: &str) -> Result<EnvFilter, tracing_subscriber::filter::ParseError> {
|
||||
match v {
|
||||
// Don't show any logs at all
|
||||
"none" => Ok(EnvFilter::default()),
|
||||
// Check if we should show all log levels
|
||||
"full" => Ok(EnvFilter::default().add_directive(Level::TRACE.into())),
|
||||
// Otherwise, let's only show errors
|
||||
"error" => Ok(EnvFilter::default().add_directive(Level::ERROR.into())),
|
||||
// Specify the log level for each code area
|
||||
"warn" | "info" | "debug" | "trace" => EnvFilter::builder()
|
||||
.parse(format!("error,surreal={v},surrealdb={v},surrealdb::kvs::tx=error")),
|
||||
// Let's try to parse the custom log level
|
||||
_ => EnvFilter::builder().parse(v),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use opentelemetry::global::shutdown_tracer_provider;
|
||||
|
@ -107,7 +128,7 @@ mod tests {
|
|||
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
||||
],
|
||||
|| {
|
||||
let _enter = telemetry::builder().build().set_default();
|
||||
let _enter = telemetry::builder().with_log_level("info").build().set_default();
|
||||
|
||||
println!("Sending span...");
|
||||
|
||||
|
@ -123,7 +144,11 @@ mod tests {
|
|||
}
|
||||
|
||||
println!("Waiting for request...");
|
||||
let req = req_rx.recv().await.expect("missing export request");
|
||||
let req = tokio::select! {
|
||||
req = req_rx.recv() => req.expect("missing export request"),
|
||||
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
|
||||
};
|
||||
|
||||
let first_span =
|
||||
req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans.first().unwrap();
|
||||
assert_eq!("test-surreal-span", first_span.name);
|
||||
|
@ -141,11 +166,10 @@ mod tests {
|
|||
temp_env::with_vars(
|
||||
vec![
|
||||
("SURREAL_TRACING_TRACER", Some("otlp")),
|
||||
("SURREAL_TRACING_FILTER", Some("debug")),
|
||||
("OTEL_EXPORTER_OTLP_ENDPOINT", Some(otlp_endpoint.as_str())),
|
||||
],
|
||||
|| {
|
||||
let _enter = telemetry::builder().build().set_default();
|
||||
let _enter = telemetry::builder().with_log_level("debug").build().set_default();
|
||||
|
||||
println!("Sending spans...");
|
||||
|
||||
|
@ -169,7 +193,10 @@ mod tests {
|
|||
}
|
||||
|
||||
println!("Waiting for request...");
|
||||
let req = req_rx.recv().await.expect("missing export request");
|
||||
let req = tokio::select! {
|
||||
req = req_rx.recv() => req.expect("missing export request"),
|
||||
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => panic!("timeout waiting for request"),
|
||||
};
|
||||
let spans = &req.resource_spans.first().unwrap().scope_spans.first().unwrap().spans;
|
||||
|
||||
assert_eq!(1, spans.len());
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
use tracing::Subscriber;
|
||||
use tracing_subscriber::Layer;
|
||||
|
||||
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
||||
|
||||
pub mod otlp;
|
||||
pub mod rpc;
|
||||
|
||||
const TRACING_TRACER_VAR: &str = "SURREAL_TRACING_TRACER";
|
||||
|
||||
// Returns a tracer based on the value of the TRACING_TRACER_VAR env var
|
||||
pub fn new<S>() -> Option<Box<dyn Layer<S> + Send + Sync>>
|
||||
pub fn new<S>(filter: CustomEnvFilter) -> Option<Box<dyn Layer<S> + Send + Sync>>
|
||||
where
|
||||
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
||||
{
|
||||
|
@ -20,7 +23,7 @@ where
|
|||
// Init the registry with the OTLP tracer
|
||||
"otlp" => {
|
||||
debug!("Setup the OTLP tracer");
|
||||
Some(otlp::new())
|
||||
Some(otlp::new(filter))
|
||||
}
|
||||
tracer => {
|
||||
panic!("unsupported tracer {}", tracer);
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
use opentelemetry::sdk::trace::Tracer;
|
||||
use opentelemetry::trace::TraceError;
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use tracing::{Level, Subscriber};
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
use tracing::Subscriber;
|
||||
use tracing_subscriber::Layer;
|
||||
|
||||
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
|
||||
use crate::{
|
||||
cli::validator::parser::env_filter::CustomEnvFilter, telemetry::OTEL_DEFAULT_RESOURCE,
|
||||
};
|
||||
|
||||
const TRACING_FILTER_VAR: &str = "SURREAL_TRACING_FILTER";
|
||||
|
||||
pub fn new<S>() -> Box<dyn Layer<S> + Send + Sync>
|
||||
pub fn new<S>(filter: CustomEnvFilter) -> Box<dyn Layer<S> + Send + Sync>
|
||||
where
|
||||
S: Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a> + Send + Sync,
|
||||
{
|
||||
tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter()).boxed()
|
||||
tracing_opentelemetry::layer().with_tracer(tracer().unwrap()).with_filter(filter.0).boxed()
|
||||
}
|
||||
|
||||
fn tracer() -> Result<Tracer, TraceError> {
|
||||
|
@ -24,16 +24,3 @@ fn tracer() -> Result<Tracer, TraceError> {
|
|||
)
|
||||
.install_batch(opentelemetry::runtime::Tokio)
|
||||
}
|
||||
|
||||
/// Create a filter for the OTLP subscriber
|
||||
///
|
||||
/// It creates an EnvFilter based on the TRACING_FILTER_VAR's value
|
||||
///
|
||||
/// TRACING_FILTER_VAR accepts the same syntax as RUST_LOG
|
||||
fn filter() -> EnvFilter {
|
||||
EnvFilter::builder()
|
||||
.with_env_var(TRACING_FILTER_VAR)
|
||||
.with_default_directive(Level::INFO.into())
|
||||
.from_env()
|
||||
.unwrap()
|
||||
}
|
||||
|
|
31
src/telemetry/traces/rpc.rs
Normal file
31
src/telemetry/traces/rpc.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use tracing::{field, Span};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn span_for_request(ws_id: &Uuid) -> Span {
|
||||
let span = tracing::info_span!(
|
||||
// Dynamic span names need to be 'recorded', can't be used on the macro. Use a static name here and overwrite later on
|
||||
"rpc/call",
|
||||
otel.name = field::Empty,
|
||||
otel.kind = "server",
|
||||
|
||||
// To be populated by the request handler when the method is known
|
||||
rpc.method = field::Empty,
|
||||
rpc.service = "surrealdb",
|
||||
rpc.system = "jsonrpc",
|
||||
|
||||
// JSON-RPC fields
|
||||
rpc.jsonrpc.version = "2.0",
|
||||
rpc.jsonrpc.request_id = field::Empty,
|
||||
rpc.jsonrpc.error_code = field::Empty,
|
||||
rpc.jsonrpc.error_message = field::Empty,
|
||||
|
||||
// SurrealDB custom fields
|
||||
ws.id = %ws_id,
|
||||
|
||||
// Fields for error reporting
|
||||
otel.status_code = field::Empty,
|
||||
otel.status_message = field::Empty,
|
||||
);
|
||||
|
||||
span
|
||||
}
|
|
@ -5,6 +5,8 @@ mod common;
|
|||
use assert_fs::prelude::{FileTouch, FileWriteStr, PathChild};
|
||||
use serial_test::serial;
|
||||
use std::fs;
|
||||
use test_log::test;
|
||||
use tracing::info;
|
||||
|
||||
use common::{PASS, USER};
|
||||
|
||||
|
@ -32,13 +34,14 @@ fn nonexistent_option() {
|
|||
assert!(common::run("version --turbo").output().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn all_commands() {
|
||||
// Commands without credentials when auth is disabled, should succeed
|
||||
let (addr, _server) = common::start_server(false, false, true).await.unwrap();
|
||||
let creds = ""; // Anonymous user
|
||||
// Create a record
|
||||
|
||||
info!("* Create a record");
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --ns N --db D --multi");
|
||||
assert_eq!(
|
||||
|
@ -48,7 +51,7 @@ async fn all_commands() {
|
|||
);
|
||||
}
|
||||
|
||||
// Export to stdout
|
||||
info!("* Export to stdout");
|
||||
{
|
||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
|
||||
let output = common::run(&args).output().expect("failed to run stdout export: {args}");
|
||||
|
@ -56,7 +59,7 @@ async fn all_commands() {
|
|||
assert!(output.contains("UPDATE thing:one CONTENT { id: thing:one };"));
|
||||
}
|
||||
|
||||
// Export to file
|
||||
info!("* Export to file");
|
||||
let exported = {
|
||||
let exported = common::tmp_file("exported.surql");
|
||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
||||
|
@ -64,13 +67,13 @@ async fn all_commands() {
|
|||
exported
|
||||
};
|
||||
|
||||
// Import the exported file
|
||||
info!("* Import the exported file");
|
||||
{
|
||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
||||
common::run(&args).output().expect("failed to run import: {args}");
|
||||
}
|
||||
|
||||
// Query from the import (pretty-printed this time)
|
||||
info!("* Query from the import (pretty-printed this time)");
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --ns N --db D2 --pretty");
|
||||
assert_eq!(
|
||||
|
@ -80,7 +83,7 @@ async fn all_commands() {
|
|||
);
|
||||
}
|
||||
|
||||
// Unfinished backup CLI
|
||||
info!("* Unfinished backup CLI");
|
||||
{
|
||||
let file = common::tmp_file("backup.db");
|
||||
let args = format!("backup {creds} http://{addr} {file}");
|
||||
|
@ -90,7 +93,7 @@ async fn all_commands() {
|
|||
assert_eq!(fs::read_to_string(file).unwrap(), "Save");
|
||||
}
|
||||
|
||||
// Multi-statement (and multi-line) query including error(s) over WS
|
||||
info!("* Multi-statement (and multi-line) query including error(s) over WS");
|
||||
{
|
||||
let args = format!("sql --conn ws://{addr} {creds} --ns N3 --db D3 --multi --pretty");
|
||||
let output = common::run(&args)
|
||||
|
@ -113,7 +116,7 @@ async fn all_commands() {
|
|||
assert!(output.contains("thing:also_success"), "missing also_success in {output}")
|
||||
}
|
||||
|
||||
// Multi-statement (and multi-line) transaction including error(s) over WS
|
||||
info!("* Multi-statement (and multi-line) transaction including error(s) over WS");
|
||||
{
|
||||
let args = format!("sql --conn ws://{addr} {creds} --ns N4 --db D4 --multi --pretty");
|
||||
let output = common::run(&args)
|
||||
|
@ -137,7 +140,7 @@ async fn all_commands() {
|
|||
assert!(output.contains("rgument"), "missing argument error in {output}");
|
||||
}
|
||||
|
||||
// Pass neither ns nor db
|
||||
info!("* Pass neither ns nor db");
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds}");
|
||||
let output = common::run(&args)
|
||||
|
@ -147,7 +150,7 @@ async fn all_commands() {
|
|||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||
}
|
||||
|
||||
// Pass only ns
|
||||
info!("* Pass only ns");
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --ns N5");
|
||||
let output = common::run(&args)
|
||||
|
@ -157,16 +160,23 @@ async fn all_commands() {
|
|||
assert!(output.contains("thing:one"), "missing thing:one in {output}");
|
||||
}
|
||||
|
||||
// Pass only db and expect an error
|
||||
info!("* Pass only db and expect an error");
|
||||
{
|
||||
let args = format!("sql --conn http://{addr} {creds} --db D5");
|
||||
common::run(&args).output().expect_err("only db");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn start_tls() {
|
||||
// Capute the server's stdout/stderr
|
||||
temp_env::async_with_vars(
|
||||
[
|
||||
("SURREAL_TEST_SERVER_STDOUT", Some("piped")),
|
||||
("SURREAL_TEST_SERVER_STDERR", Some("piped")),
|
||||
],
|
||||
async {
|
||||
let (_, server) = common::start_server(false, true, false).await.unwrap();
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_millis(2000));
|
||||
|
@ -174,9 +184,12 @@ async fn start_tls() {
|
|||
|
||||
// Test the crt/key args but the keys are self signed so don't actually connect.
|
||||
assert!(output.contains("Started web server"), "couldn't start web server: {output}");
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn with_root_auth() {
|
||||
// Commands with credentials when auth is enabled, should succeed
|
||||
|
@ -184,7 +197,7 @@ async fn with_root_auth() {
|
|||
let creds = format!("--user {USER} --pass {PASS}");
|
||||
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
|
||||
|
||||
// Can query /sql over HTTP
|
||||
info!("* Query over HTTP");
|
||||
{
|
||||
let args = format!("{sql_args} {creds}");
|
||||
let input = "INFO FOR ROOT;";
|
||||
|
@ -192,7 +205,7 @@ async fn with_root_auth() {
|
|||
assert!(output.is_ok(), "failed to query over HTTP: {}", output.err().unwrap());
|
||||
}
|
||||
|
||||
// Can query /sql over WS
|
||||
info!("* Query over WS");
|
||||
{
|
||||
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
|
||||
let input = "INFO FOR ROOT;";
|
||||
|
@ -200,7 +213,7 @@ async fn with_root_auth() {
|
|||
assert!(output.is_ok(), "failed to query over WS: {}", output.err().unwrap());
|
||||
}
|
||||
|
||||
// KV user can do exports
|
||||
info!("* Root user can do exports");
|
||||
let exported = {
|
||||
let exported = common::tmp_file("exported.surql");
|
||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D {exported}");
|
||||
|
@ -209,13 +222,13 @@ async fn with_root_auth() {
|
|||
exported
|
||||
};
|
||||
|
||||
// KV user can do imports
|
||||
info!("* Root user can do imports");
|
||||
{
|
||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {exported}");
|
||||
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
|
||||
}
|
||||
|
||||
// KV user can do backups
|
||||
info!("* Root user can do backups");
|
||||
{
|
||||
let file = common::tmp_file("backup.db");
|
||||
let args = format!("backup {creds} http://{addr} {file}");
|
||||
|
@ -226,7 +239,7 @@ async fn with_root_auth() {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn with_anon_auth() {
|
||||
// Commands without credentials when auth is enabled, should fail
|
||||
|
@ -234,7 +247,7 @@ async fn with_anon_auth() {
|
|||
let creds = ""; // Anonymous user
|
||||
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
|
||||
|
||||
// Can query /sql over HTTP
|
||||
info!("* Query over HTTP");
|
||||
{
|
||||
let args = format!("{sql_args} {creds}");
|
||||
let input = "";
|
||||
|
@ -242,7 +255,7 @@ async fn with_anon_auth() {
|
|||
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
|
||||
}
|
||||
|
||||
// Can query /sql over HTTP
|
||||
info!("* Query over WS");
|
||||
{
|
||||
let args = format!("sql --conn ws://{addr} --multi --pretty {creds}");
|
||||
let input = "";
|
||||
|
@ -250,7 +263,7 @@ async fn with_anon_auth() {
|
|||
assert!(output.is_ok(), "anonymous user should be able to query: {:?}", output);
|
||||
}
|
||||
|
||||
// Can't do exports
|
||||
info!("* Can't do exports");
|
||||
{
|
||||
let args = format!("export --conn http://{addr} {creds} --ns N --db D -");
|
||||
let output = common::run(&args).output();
|
||||
|
@ -261,7 +274,7 @@ async fn with_anon_auth() {
|
|||
);
|
||||
}
|
||||
|
||||
// Can't do imports
|
||||
info!("* Can't do imports");
|
||||
{
|
||||
let tmp_file = common::tmp_file("exported.surql");
|
||||
let args = format!("import --conn http://{addr} {creds} --ns N --db D2 {tmp_file}");
|
||||
|
@ -273,7 +286,7 @@ async fn with_anon_auth() {
|
|||
);
|
||||
}
|
||||
|
||||
// Can't do backups
|
||||
info!("* Can't do backups");
|
||||
{
|
||||
let args = format!("backup {creds} http://{addr}");
|
||||
let output = common::run(&args).output();
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
#![allow(dead_code)]
|
||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::error::Error;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
||||
use tracing::{error, info};
|
||||
|
||||
pub const USER: &str = "root";
|
||||
pub const PASS: &str = "root";
|
||||
|
@ -52,7 +60,12 @@ impl Drop for Child {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child {
|
||||
pub fn run_internal<P: AsRef<Path>>(
|
||||
args: &str,
|
||||
current_dir: Option<P>,
|
||||
stdout: Stdio,
|
||||
stderr: Stdio,
|
||||
) -> Child {
|
||||
let mut path = std::env::current_exe().unwrap();
|
||||
assert!(path.pop());
|
||||
if path.ends_with("deps") {
|
||||
|
@ -68,8 +81,8 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
|
|||
}
|
||||
cmd.env_clear();
|
||||
cmd.stdin(Stdio::piped());
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
cmd.stdout(stdout);
|
||||
cmd.stderr(stderr);
|
||||
cmd.args(args.split_ascii_whitespace());
|
||||
Child {
|
||||
inner: Some(cmd.spawn().unwrap()),
|
||||
|
@ -78,12 +91,12 @@ pub fn run_internal<P: AsRef<Path>>(args: &str, current_dir: Option<P>) -> Child
|
|||
|
||||
/// Run the CLI with the given args
|
||||
pub fn run(args: &str) -> Child {
|
||||
run_internal::<String>(args, None)
|
||||
run_internal::<String>(args, None, Stdio::piped(), Stdio::piped())
|
||||
}
|
||||
|
||||
/// Run the CLI with the given args inside a temporary directory
|
||||
pub fn run_in_dir<P: AsRef<Path>>(args: &str, current_dir: P) -> Child {
|
||||
run_internal(args, Some(current_dir))
|
||||
run_internal(args, Some(current_dir), Stdio::piped(), Stdio::piped())
|
||||
}
|
||||
|
||||
pub fn tmp_file(name: &str) -> String {
|
||||
|
@ -91,6 +104,19 @@ pub fn tmp_file(name: &str) -> String {
|
|||
path.to_string_lossy().into_owned()
|
||||
}
|
||||
|
||||
fn parse_server_stdio_from_var(var: &str) -> Result<Stdio, Box<dyn Error>> {
|
||||
match env::var(var).as_deref() {
|
||||
Ok("inherit") => Ok(Stdio::inherit()),
|
||||
Ok("null") => Ok(Stdio::null()),
|
||||
Ok("piped") => Ok(Stdio::piped()),
|
||||
Ok(val) if val.starts_with("file://") => {
|
||||
Ok(Stdio::from(File::create(val.trim_start_matches("file://"))?))
|
||||
}
|
||||
Ok(val) => Err(format!("Unsupported stdio value: {val:?}").into()),
|
||||
_ => Ok(Stdio::null()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_server(
|
||||
auth: bool,
|
||||
tls: bool,
|
||||
|
@ -118,11 +144,14 @@ pub async fn start_server(
|
|||
extra_args.push_str(" --auth");
|
||||
}
|
||||
|
||||
let start_args = format!("start --bind {addr} memory --no-banner --log info --user {USER} --pass {PASS} {extra_args}");
|
||||
let start_args = format!("start --bind {addr} memory --no-banner --log trace --user {USER} --pass {PASS} {extra_args}");
|
||||
|
||||
println!("starting server with args: {start_args}");
|
||||
info!("starting server with args: {start_args}");
|
||||
|
||||
let server = run(&start_args);
|
||||
// Configure where the logs go when running the test
|
||||
let stdout = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDOUT")?;
|
||||
let stderr = parse_server_stdio_from_var("SURREAL_TEST_SERVER_STDERR")?;
|
||||
let server = run_internal::<String>(&start_args, None, stdout, stderr);
|
||||
|
||||
if !wait_is_ready {
|
||||
return Ok((addr, server));
|
||||
|
@ -130,17 +159,178 @@ pub async fn start_server(
|
|||
|
||||
// Wait 5 seconds for the server to start
|
||||
let mut interval = time::interval(time::Duration::from_millis(500));
|
||||
println!("Waiting for server to start...");
|
||||
info!("Waiting for server to start...");
|
||||
for _i in 0..10 {
|
||||
interval.tick().await;
|
||||
|
||||
if run(&format!("isready --conn http://{addr}")).output().is_ok() {
|
||||
println!("Server ready!");
|
||||
info!("Server ready!");
|
||||
return Ok((addr, server));
|
||||
}
|
||||
}
|
||||
|
||||
let server_out = server.kill().output().err().unwrap();
|
||||
println!("server output: {server_out}");
|
||||
error!("server output: {server_out}");
|
||||
Err("server failed to start".into())
|
||||
}
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
|
||||
pub async fn connect_ws(addr: &str) -> Result<WsStream, Box<dyn Error>> {
|
||||
let url = format!("ws://{}/rpc", addr);
|
||||
let (ws_stream, _) = connect_async(url).await?;
|
||||
Ok(ws_stream)
|
||||
}
|
||||
|
||||
pub async fn ws_send_msg(
|
||||
socket: &mut WsStream,
|
||||
msg_req: String,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
// Use JSON format by default
|
||||
ws_send_msg_with_fmt(socket, msg_req, Format::Json).await
|
||||
}
|
||||
|
||||
pub enum Format {
|
||||
Json,
|
||||
Cbor,
|
||||
Pack,
|
||||
}
|
||||
|
||||
pub async fn ws_send_msg_with_fmt(
|
||||
socket: &mut WsStream,
|
||||
msg_req: String,
|
||||
response_format: Format,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
tokio::select! {
|
||||
_ = time::sleep(time::Duration::from_millis(500)) => {
|
||||
return Err("timeout waiting for the request to be sent".into());
|
||||
}
|
||||
res = socket.send(Message::Text(msg_req)) => {
|
||||
if let Err(err) = res {
|
||||
return Err(format!("Error sending the message: {}", err).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut f = socket.try_filter(|msg| match response_format {
|
||||
Format::Json => futures_util::future::ready(msg.is_text()),
|
||||
Format::Pack | Format::Cbor => futures_util::future::ready(msg.is_binary()),
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = time::sleep(time::Duration::from_millis(2000)) => {
|
||||
Err("timeout waiting for the response".into())
|
||||
}
|
||||
res = f.select_next_some() => {
|
||||
match response_format {
|
||||
Format::Json => Ok(serde_json::from_str(&res?.to_string())?),
|
||||
Format::Cbor => Ok(serde_cbor::from_slice(&res?.into_data())?),
|
||||
Format::Pack => Ok(serde_pack::from_slice(&res?.into_data())?),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SigninParams<'a> {
|
||||
user: &'a str,
|
||||
pass: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
ns: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
db: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
sc: Option<&'a str>,
|
||||
}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct UseParams<'a> {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
ns: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
db: Option<&'a str>,
|
||||
}
|
||||
|
||||
pub async fn ws_signin(
|
||||
socket: &mut WsStream,
|
||||
user: &str,
|
||||
pass: &str,
|
||||
ns: Option<&str>,
|
||||
db: Option<&str>,
|
||||
sc: Option<&str>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
let json = json!({
|
||||
"id": "1",
|
||||
"method": "signin",
|
||||
"params": [
|
||||
SigninParams { user, pass, ns, db, sc }
|
||||
],
|
||||
});
|
||||
|
||||
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
|
||||
Ok(obj.get("result").unwrap().as_str().unwrap_or_default().to_owned())
|
||||
}
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_query(
|
||||
socket: &mut WsStream,
|
||||
query: &str,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn Error>> {
|
||||
let json = json!({
|
||||
"id": "1",
|
||||
"method": "query",
|
||||
"params": [query],
|
||||
});
|
||||
|
||||
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
|
||||
Ok(obj.get("result").unwrap().as_array().unwrap().to_owned())
|
||||
}
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_use(
|
||||
socket: &mut WsStream,
|
||||
ns: Option<&str>,
|
||||
db: Option<&str>,
|
||||
) -> Result<serde_json::Value, Box<dyn Error>> {
|
||||
let json = json!({
|
||||
"id": "1",
|
||||
"method": "use",
|
||||
"params": [
|
||||
ns, db
|
||||
],
|
||||
});
|
||||
|
||||
let msg = ws_send_msg(socket, serde_json::to_string(&json).unwrap()).await?;
|
||||
match msg.as_object() {
|
||||
Some(obj) if obj.keys().all(|k| ["id", "error"].contains(&k.as_str())) => {
|
||||
Err(format!("unexpected error from query request: {:?}", obj.get("error")).into())
|
||||
}
|
||||
Some(obj) if obj.keys().all(|k| ["id", "result"].contains(&k.as_str())) => {
|
||||
Ok(obj.get("result").unwrap().to_owned())
|
||||
}
|
||||
_ => {
|
||||
error!("{:?}", msg.as_object().unwrap().keys().collect::<Vec<_>>());
|
||||
Err(format!("unexpected response: {:?}", msg).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,10 +7,11 @@ use http::{header, Method};
|
|||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use serial_test::serial;
|
||||
use test_log::test;
|
||||
|
||||
use crate::common::{PASS, USER};
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -52,7 +53,7 @@ async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -133,14 +134,14 @@ async fn bearer_auth() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn client_ip_extractor() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// TODO: test the client IP extractor
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -184,7 +185,7 @@ async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -196,7 +197,7 @@ async fn health_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -269,7 +270,7 @@ async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -303,7 +304,7 @@ async fn rpc_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -372,7 +373,7 @@ async fn signin_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -425,13 +426,17 @@ async fn signup_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
assert_eq!(res.status(), 200, "body: {}", res.text().await?);
|
||||
|
||||
let body: serde_json::Value = serde_json::from_str(&res.text().await?).unwrap();
|
||||
assert!(!body["token"].as_str().unwrap().to_string().is_empty(), "body: {}", body);
|
||||
assert!(
|
||||
body["token"].as_str().unwrap().starts_with("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9"),
|
||||
"body: {}",
|
||||
body
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -543,7 +548,7 @@ async fn sql_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -577,7 +582,7 @@ async fn sync_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn version_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -619,7 +624,7 @@ async fn seed_table(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -696,7 +701,7 @@ async fn key_endpoint_select_all() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -759,7 +764,7 @@ async fn key_endpoint_create_all() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -829,7 +834,7 @@ async fn key_endpoint_update_all() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -899,7 +904,7 @@ async fn key_endpoint_modify_all() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -953,7 +958,7 @@ async fn key_endpoint_delete_all() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -994,7 +999,7 @@ async fn key_endpoint_select_one() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -1091,7 +1096,7 @@ async fn key_endpoint_create_one() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -1164,7 +1169,7 @@ async fn key_endpoint_update_one() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
@ -1242,7 +1247,7 @@ async fn key_endpoint_modify_one() -> Result<(), Box<dyn std::error::Error>> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test(tokio::test)]
|
||||
#[serial]
|
||||
async fn key_endpoint_delete_one() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server(true, false, true).await.unwrap();
|
||||
|
|
1109
tests/ws_integration.rs
Normal file
1109
tests/ws_integration.rs
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue