[rpc] Add WebSocket metrics (#2413)
This commit is contained in:
parent
4288d9f188
commit
3b26ad2a44
16 changed files with 1295 additions and 1062 deletions
|
@ -2,7 +2,7 @@
|
||||||
version: "3"
|
version: "3"
|
||||||
services:
|
services:
|
||||||
grafana:
|
grafana:
|
||||||
image: "grafana/grafana-oss:latest"
|
image: "grafana/grafana-oss:main"
|
||||||
expose:
|
expose:
|
||||||
- "3000"
|
- "3000"
|
||||||
ports:
|
ports:
|
||||||
|
|
|
@ -11,6 +11,7 @@ use crate::net::{self, client_ip::ClientIp};
|
||||||
use crate::node;
|
use crate::node;
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
use ipnet::IpNet;
|
use ipnet::IpNet;
|
||||||
|
use opentelemetry::Context as TelemetryContext;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
@ -123,6 +124,9 @@ pub async fn init(
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
// Initialize opentelemetry and logging
|
// Initialize opentelemetry and logging
|
||||||
crate::telemetry::builder().with_filter(log).init();
|
crate::telemetry::builder().with_filter(log).init();
|
||||||
|
// Start metrics subsystem
|
||||||
|
crate::telemetry::metrics::init(&TelemetryContext::current())
|
||||||
|
.expect("failed to initialize metrics");
|
||||||
|
|
||||||
// Check if a banner should be outputted
|
// Check if a banner should be outputted
|
||||||
if !no_banner {
|
if !no_banner {
|
||||||
|
|
|
@ -129,7 +129,7 @@ pub async fn init() -> Result<(), Error> {
|
||||||
.merge(key::router())
|
.merge(key::router())
|
||||||
.layer(service);
|
.layer(service);
|
||||||
|
|
||||||
// Setup the graceful shutdown with no timeout
|
// Setup the graceful shutdown
|
||||||
let handle = Handle::new();
|
let handle = Handle::new();
|
||||||
let shutdown_handler = graceful_shutdown(handle.clone());
|
let shutdown_handler = graceful_shutdown(handle.clone());
|
||||||
|
|
||||||
|
@ -159,9 +159,6 @@ pub async fn init() -> Result<(), Error> {
|
||||||
// Wait for the shutdown to finish
|
// Wait for the shutdown to finish
|
||||||
let _ = shutdown_handler.await;
|
let _ = shutdown_handler.await;
|
||||||
|
|
||||||
// Flush all telemetry data
|
|
||||||
opentelemetry::global::shutdown_tracer_provider();
|
|
||||||
|
|
||||||
info!(target: LOG, "Web server stopped. Bye!");
|
info!(target: LOG, "Web server stopped. Bye!");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
899
src/net/rpc.rs
899
src/net/rpc.rs
|
@ -1,58 +1,17 @@
|
||||||
use crate::cnf::MAX_CONCURRENT_CALLS;
|
use crate::rpc::connection::Connection;
|
||||||
use crate::cnf::PKG_NAME;
|
|
||||||
use crate::cnf::PKG_VERSION;
|
|
||||||
use crate::cnf::WEBSOCKET_PING_FREQUENCY;
|
|
||||||
use crate::dbs::DB;
|
|
||||||
use crate::err::Error;
|
|
||||||
use crate::rpc::args::Take;
|
|
||||||
use crate::rpc::paths::{ID, METHOD, PARAMS};
|
|
||||||
use crate::rpc::res;
|
|
||||||
use crate::rpc::res::Data;
|
|
||||||
use crate::rpc::res::Failure;
|
|
||||||
use crate::rpc::res::IntoRpcResponse;
|
|
||||||
use crate::rpc::res::OutputFormat;
|
|
||||||
use crate::rpc::CONN_CLOSED_ERR;
|
|
||||||
use crate::telemetry::traces::rpc::span_for_request;
|
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::Extension;
|
use axum::Extension;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use futures::{SinkExt, StreamExt};
|
|
||||||
use futures_util::stream::SplitSink;
|
|
||||||
use futures_util::stream::SplitStream;
|
|
||||||
use http_body::Body as HttpBody;
|
use http_body::Body as HttpBody;
|
||||||
use once_cell::sync::Lazy;
|
use surrealdb::dbs::Session;
|
||||||
use std::collections::BTreeMap;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use surrealdb::channel;
|
|
||||||
use surrealdb::channel::{Receiver, Sender};
|
|
||||||
use surrealdb::dbs::{QueryType, Response, Session};
|
|
||||||
use surrealdb::sql::serde::deserialize;
|
|
||||||
use surrealdb::sql::Array;
|
|
||||||
use surrealdb::sql::Object;
|
|
||||||
use surrealdb::sql::Strand;
|
|
||||||
use surrealdb::sql::Value;
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
use tokio::task::JoinSet;
|
|
||||||
use tokio_util::sync::CancellationToken;
|
|
||||||
use tower_http::request_id::RequestId;
|
use tower_http::request_id::RequestId;
|
||||||
use tracing::Span;
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::ws::{Message, WebSocket, WebSocketUpgrade},
|
extract::ws::{WebSocket, WebSocketUpgrade},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Mapping of WebSocketID to WebSocket
|
|
||||||
pub(crate) struct WebSocketRef(pub(crate) Sender<Message>, pub(crate) CancellationToken);
|
|
||||||
type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>;
|
|
||||||
// Mapping of LiveQueryID to WebSocketID
|
|
||||||
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
|
|
||||||
|
|
||||||
pub(super) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
|
||||||
static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
|
|
||||||
|
|
||||||
pub(super) fn router<S, B>() -> Router<S, B>
|
pub(super) fn router<S, B>() -> Router<S, B>
|
||||||
where
|
where
|
||||||
B: HttpBody + Send + 'static,
|
B: HttpBody + Send + 'static,
|
||||||
|
@ -72,853 +31,13 @@ async fn handler(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
|
async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) {
|
||||||
let rpc = Rpc::new(sess);
|
let rpc = Connection::new(sess);
|
||||||
|
|
||||||
// If the request ID is a valid UUID and is not already in use, use it as the WebSocket ID
|
// Update the WebSocket ID with the Request ID
|
||||||
match req_id.header_value().to_str().map(Uuid::parse_str) {
|
if let Ok(Ok(req_id)) = req_id.header_value().to_str().map(Uuid::parse_str) {
|
||||||
Ok(Ok(req_id)) if !WEBSOCKETS.read().await.contains_key(&req_id) => {
|
// If the ID couldn't be updated, ignore the error and keep the default ID
|
||||||
rpc.write().await.ws_id = req_id
|
let _ = rpc.write().await.update_ws_id(req_id).await;
|
||||||
}
|
|
||||||
_ => (),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Rpc::serve(rpc, ws).await;
|
Connection::serve(rpc, ws).await;
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Rpc {
|
|
||||||
session: Session,
|
|
||||||
format: OutputFormat,
|
|
||||||
ws_id: Uuid,
|
|
||||||
vars: BTreeMap<String, Value>,
|
|
||||||
graceful_shutdown: CancellationToken,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Rpc {
|
|
||||||
/// Instantiate a new RPC
|
|
||||||
pub fn new(mut session: Session) -> Arc<RwLock<Rpc>> {
|
|
||||||
// Create a new RPC variables store
|
|
||||||
let vars = BTreeMap::new();
|
|
||||||
// Set the default output format
|
|
||||||
let format = OutputFormat::Json;
|
|
||||||
// Enable real-time mode
|
|
||||||
session.rt = true;
|
|
||||||
// Create and store the Rpc connection
|
|
||||||
Arc::new(RwLock::new(Rpc {
|
|
||||||
session,
|
|
||||||
format,
|
|
||||||
ws_id: Uuid::new_v4(),
|
|
||||||
vars,
|
|
||||||
graceful_shutdown: CancellationToken::new(),
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Serve the RPC endpoint
|
|
||||||
pub async fn serve(rpc: Arc<RwLock<Rpc>>, ws: WebSocket) {
|
|
||||||
// Split the socket into send and recv
|
|
||||||
let (sender, receiver) = ws.split();
|
|
||||||
// Create an internal channel between the receiver and the sender
|
|
||||||
let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS);
|
|
||||||
|
|
||||||
let ws_id = rpc.read().await.ws_id;
|
|
||||||
|
|
||||||
// Store this WebSocket in the list of WebSockets
|
|
||||||
WEBSOCKETS.write().await.insert(
|
|
||||||
ws_id,
|
|
||||||
WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()),
|
|
||||||
);
|
|
||||||
|
|
||||||
trace!("WebSocket {} connected", ws_id);
|
|
||||||
|
|
||||||
// Wait until all tasks finish
|
|
||||||
tokio::join!(
|
|
||||||
Self::ping(rpc.clone(), internal_sender.clone()),
|
|
||||||
Self::read(rpc.clone(), receiver, internal_sender.clone()),
|
|
||||||
Self::write(rpc.clone(), sender, internal_receiver.clone()),
|
|
||||||
Self::lq_notifications(rpc.clone()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Remove all live queries
|
|
||||||
LIVE_QUERIES.write().await.retain(|key, value| {
|
|
||||||
if value == &ws_id {
|
|
||||||
trace!("Removing live query: {}", key);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
true
|
|
||||||
});
|
|
||||||
|
|
||||||
// Remove this WebSocket from the list of WebSockets
|
|
||||||
WEBSOCKETS.write().await.remove(&ws_id);
|
|
||||||
|
|
||||||
trace!("WebSocket {} disconnected", ws_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send Ping messages to the client
|
|
||||||
async fn ping(rpc: Arc<RwLock<Rpc>>, internal_sender: Sender<Message>) {
|
|
||||||
// Create the interval ticker
|
|
||||||
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
|
|
||||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
|
||||||
loop {
|
|
||||||
let is_shutdown = cancel_token.cancelled();
|
|
||||||
tokio::select! {
|
|
||||||
_ = interval.tick() => {
|
|
||||||
let msg = Message::Ping(vec![]);
|
|
||||||
|
|
||||||
// Send the message to the client and close the WebSocket connection if it fails
|
|
||||||
if internal_sender.send(msg).await.is_err() {
|
|
||||||
rpc.read().await.graceful_shutdown.cancel();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ = is_shutdown => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Read messages sent from the client
|
|
||||||
async fn read(
|
|
||||||
rpc: Arc<RwLock<Rpc>>,
|
|
||||||
mut receiver: SplitStream<WebSocket>,
|
|
||||||
internal_sender: Sender<Message>,
|
|
||||||
) {
|
|
||||||
// Collect all spawned tasks so we can wait for them at the end
|
|
||||||
let mut tasks = JoinSet::new();
|
|
||||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
|
||||||
loop {
|
|
||||||
let is_shutdown = cancel_token.cancelled();
|
|
||||||
tokio::select! {
|
|
||||||
msg = receiver.next() => {
|
|
||||||
if let Some(msg) = msg {
|
|
||||||
match msg {
|
|
||||||
// We've received a message from the client
|
|
||||||
// Ping/Pong is automatically handled by the WebSocket library
|
|
||||||
Ok(msg) => match msg {
|
|
||||||
Message::Text(_) => {
|
|
||||||
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
|
||||||
}
|
|
||||||
Message::Binary(_) => {
|
|
||||||
tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
|
||||||
}
|
|
||||||
Message::Close(_) => {
|
|
||||||
// Respond with a close message
|
|
||||||
if let Err(err) = internal_sender.send(Message::Close(None)).await {
|
|
||||||
trace!("WebSocket error when replying to the Close frame: {:?}", err);
|
|
||||||
};
|
|
||||||
// Start the graceful shutdown of the WebSocket and close the channels
|
|
||||||
rpc.read().await.graceful_shutdown.cancel();
|
|
||||||
let _ = internal_sender.close();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Ignore everything else
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(err) => {
|
|
||||||
trace!("WebSocket error: {:?}", err);
|
|
||||||
// Start the graceful shutdown of the WebSocket and close the channels
|
|
||||||
rpc.read().await.graceful_shutdown.cancel();
|
|
||||||
let _ = internal_sender.close();
|
|
||||||
// Exit out of the loop
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = is_shutdown => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for all tasks to finish
|
|
||||||
while let Some(res) = tasks.join_next().await {
|
|
||||||
if let Err(err) = res {
|
|
||||||
error!("Error while handling RPC message: {}", err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Write messages to the client
|
|
||||||
async fn write(
|
|
||||||
rpc: Arc<RwLock<Rpc>>,
|
|
||||||
mut sender: SplitSink<WebSocket, Message>,
|
|
||||||
mut internal_receiver: Receiver<Message>,
|
|
||||||
) {
|
|
||||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
|
||||||
loop {
|
|
||||||
let is_shutdown = cancel_token.cancelled();
|
|
||||||
tokio::select! {
|
|
||||||
// Wait for the next message to send
|
|
||||||
msg = internal_receiver.next() => {
|
|
||||||
if let Some(res) = msg {
|
|
||||||
// Send the message to the client
|
|
||||||
if let Err(err) = sender.send(res).await {
|
|
||||||
if err.to_string() != CONN_CLOSED_ERR {
|
|
||||||
debug!("WebSocket error: {:?}", err);
|
|
||||||
}
|
|
||||||
// Close the WebSocket connection
|
|
||||||
rpc.read().await.graceful_shutdown.cancel();
|
|
||||||
// Exit out of the loop
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ = is_shutdown => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Send live query notifications to the client
|
|
||||||
async fn lq_notifications(rpc: Arc<RwLock<Rpc>>) {
|
|
||||||
if let Some(channel) = DB.get().unwrap().notifications() {
|
|
||||||
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
msg = channel.recv() => {
|
|
||||||
if let Ok(notification) = msg {
|
|
||||||
// Find which WebSocket the notification belongs to
|
|
||||||
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
|
||||||
// Check to see if the WebSocket exists
|
|
||||||
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
|
|
||||||
// Serialize the message to send
|
|
||||||
let message = res::success(None, notification);
|
|
||||||
// Get the current output format
|
|
||||||
let format = rpc.read().await.format.clone();
|
|
||||||
// Send the notification to the client
|
|
||||||
message.send(format, ws.clone()).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ = cancel_token.cancelled() => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Handle individual WebSocket messages
|
|
||||||
async fn handle_msg(rpc: Arc<RwLock<Rpc>>, msg: Message, chn: Sender<Message>) {
|
|
||||||
// Get the current output format
|
|
||||||
let mut out_fmt = rpc.read().await.format.clone();
|
|
||||||
let span = span_for_request(&rpc.read().await.ws_id);
|
|
||||||
let _enter = span.enter();
|
|
||||||
// Parse the request
|
|
||||||
match Self::parse_request(msg).await {
|
|
||||||
Ok((id, method, params, _out_fmt)) => {
|
|
||||||
span.record(
|
|
||||||
"rpc.jsonrpc.request_id",
|
|
||||||
id.clone().map(|v| v.as_string()).unwrap_or(String::new()),
|
|
||||||
);
|
|
||||||
if let Some(_out_fmt) = _out_fmt {
|
|
||||||
out_fmt = _out_fmt;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the request
|
|
||||||
let res = Self::process_request(rpc.clone(), &method, params).await;
|
|
||||||
|
|
||||||
// Process the response
|
|
||||||
res.into_response(id).send(out_fmt, chn).await
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
// Process the response
|
|
||||||
res::failure(None, err).send(out_fmt, chn).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn parse_request(
|
|
||||||
msg: Message,
|
|
||||||
) -> Result<(Option<Value>, String, Array, Option<OutputFormat>), Failure> {
|
|
||||||
let mut out_fmt = None;
|
|
||||||
let req = match msg {
|
|
||||||
// This is a binary message
|
|
||||||
Message::Binary(val) => {
|
|
||||||
// Use binary output
|
|
||||||
out_fmt = Some(OutputFormat::Full);
|
|
||||||
|
|
||||||
match deserialize(&val) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(_) => {
|
|
||||||
debug!("Error when trying to deserialize the request");
|
|
||||||
return Err(Failure::PARSE_ERROR);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// This is a text message
|
|
||||||
Message::Text(ref val) => {
|
|
||||||
// Parse the SurrealQL object
|
|
||||||
match surrealdb::sql::value(val) {
|
|
||||||
// The SurrealQL message parsed ok
|
|
||||||
Ok(v) => v,
|
|
||||||
// The SurrealQL message failed to parse
|
|
||||||
_ => return Err(Failure::PARSE_ERROR),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Unsupported message type
|
|
||||||
_ => {
|
|
||||||
debug!("Unsupported message type: {:?}", msg);
|
|
||||||
return Err(res::Failure::custom("Unsupported message type"));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
// Fetch the 'id' argument
|
|
||||||
let id = match req.pick(&*ID) {
|
|
||||||
v if v.is_none() => None,
|
|
||||||
v if v.is_null() => Some(v),
|
|
||||||
v if v.is_uuid() => Some(v),
|
|
||||||
v if v.is_number() => Some(v),
|
|
||||||
v if v.is_strand() => Some(v),
|
|
||||||
v if v.is_datetime() => Some(v),
|
|
||||||
_ => return Err(Failure::INVALID_REQUEST),
|
|
||||||
};
|
|
||||||
// Fetch the 'method' argument
|
|
||||||
let method = match req.pick(&*METHOD) {
|
|
||||||
Value::Strand(v) => v.to_raw(),
|
|
||||||
_ => return Err(Failure::INVALID_REQUEST),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Now that we know the method, we can update the span
|
|
||||||
Span::current().record("rpc.method", &method);
|
|
||||||
Span::current().record("otel.name", format!("surrealdb.rpc/{}", method));
|
|
||||||
|
|
||||||
// Fetch the 'params' argument
|
|
||||||
let params = match req.pick(&*PARAMS) {
|
|
||||||
Value::Array(v) => v,
|
|
||||||
_ => Array::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((id, method, params, out_fmt))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn process_request(
|
|
||||||
rpc: Arc<RwLock<Rpc>>,
|
|
||||||
method: &str,
|
|
||||||
params: Array,
|
|
||||||
) -> Result<Data, Failure> {
|
|
||||||
info!("Process RPC request");
|
|
||||||
|
|
||||||
// Match the method to a function
|
|
||||||
match method {
|
|
||||||
// Handle a surrealdb ping message
|
|
||||||
//
|
|
||||||
// This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough.
|
|
||||||
// For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios.
|
|
||||||
// By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent.
|
|
||||||
//
|
|
||||||
"ping" => Ok(Value::None.into()),
|
|
||||||
// Retrieve the current auth record
|
|
||||||
"info" => match params.len() {
|
|
||||||
0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into),
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Switch to a specific namespace and database
|
|
||||||
"use" => match params.needs_two() {
|
|
||||||
Ok((ns, db)) => {
|
|
||||||
rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Signup to a specific authentication scope
|
|
||||||
"signup" => match params.needs_one() {
|
|
||||||
Ok(Value::Object(v)) => {
|
|
||||||
rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Signin as a root, namespace, database or scope user
|
|
||||||
"signin" => match params.needs_one() {
|
|
||||||
Ok(Value::Object(v)) => {
|
|
||||||
rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Invalidate the current authentication session
|
|
||||||
"invalidate" => match params.len() {
|
|
||||||
0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into),
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Authenticate using an authentication token
|
|
||||||
"authenticate" => match params.needs_one() {
|
|
||||||
Ok(Value::Strand(v)) => {
|
|
||||||
rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Kill a live query using a query id
|
|
||||||
"kill" => match params.needs_one() {
|
|
||||||
Ok(v) => rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into),
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Setup a live query on a specific table
|
|
||||||
"live" => match params.needs_one_or_two() {
|
|
||||||
Ok((v, d)) if v.is_table() => {
|
|
||||||
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
Ok((v, d)) if v.is_strand() => {
|
|
||||||
rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Specify a connection-wide parameter
|
|
||||||
"let" | "set" => match params.needs_one_or_two() {
|
|
||||||
Ok((Value::Strand(s), v)) => {
|
|
||||||
rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Unset and clear a connection-wide parameter
|
|
||||||
"unset" => match params.needs_one() {
|
|
||||||
Ok(Value::Strand(s)) => {
|
|
||||||
rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Select a value or values from the database
|
|
||||||
"select" => match params.needs_one() {
|
|
||||||
Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into),
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Insert a value or values in the database
|
|
||||||
"insert" => match params.needs_one_or_two() {
|
|
||||||
Ok((v, o)) => {
|
|
||||||
rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Create a value or values in the database
|
|
||||||
"create" => match params.needs_one_or_two() {
|
|
||||||
Ok((v, o)) => {
|
|
||||||
rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Update a value or values in the database using `CONTENT`
|
|
||||||
"update" => match params.needs_one_or_two() {
|
|
||||||
Ok((v, o)) => {
|
|
||||||
rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Update a value or values in the database using `MERGE`
|
|
||||||
"change" | "merge" => match params.needs_one_or_two() {
|
|
||||||
Ok((v, o)) => {
|
|
||||||
rpc.read().await.change(v, o).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Update a value or values in the database using `PATCH`
|
|
||||||
"modify" | "patch" => match params.needs_one_or_two() {
|
|
||||||
Ok((v, o)) => {
|
|
||||||
rpc.read().await.modify(v, o).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Delete a value or values from the database
|
|
||||||
"delete" => match params.needs_one() {
|
|
||||||
Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into),
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Specify the output format for text requests
|
|
||||||
"format" => match params.needs_one() {
|
|
||||||
Ok(Value::Strand(v)) => {
|
|
||||||
rpc.write().await.format(v).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Get the current server version
|
|
||||||
"version" => match params.len() {
|
|
||||||
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
// Run a full SurrealQL query against the database
|
|
||||||
"query" => match params.needs_one_or_two() {
|
|
||||||
Ok((Value::Strand(s), o)) if o.is_none_or_null() => {
|
|
||||||
rpc.read().await.query(s).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
Ok((Value::Strand(s), Value::Object(o))) => {
|
|
||||||
rpc.read().await.query_with(s, o).await.map(Into::into).map_err(Into::into)
|
|
||||||
}
|
|
||||||
_ => Err(Failure::INVALID_PARAMS),
|
|
||||||
},
|
|
||||||
_ => Err(Failure::METHOD_NOT_FOUND),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for authentication
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
|
||||||
match out.as_str() {
|
|
||||||
"json" | "application/json" => self.format = OutputFormat::Json,
|
|
||||||
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
|
|
||||||
"pack" | "application/pack" => self.format = OutputFormat::Pack,
|
|
||||||
_ => return Err(Error::InvalidType),
|
|
||||||
};
|
|
||||||
Ok(Value::None)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
|
||||||
if let Value::Strand(ns) = ns {
|
|
||||||
self.session.ns = Some(ns.0);
|
|
||||||
}
|
|
||||||
if let Value::Strand(db) = db {
|
|
||||||
self.session.db = Some(db.0);
|
|
||||||
}
|
|
||||||
Ok(Value::None)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
|
|
||||||
.await
|
|
||||||
.map(Into::into)
|
|
||||||
.map_err(Into::into)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
|
|
||||||
.await
|
|
||||||
.map(Into::into)
|
|
||||||
.map_err(Into::into)
|
|
||||||
}
|
|
||||||
async fn invalidate(&mut self) -> Result<Value, Error> {
|
|
||||||
surrealdb::iam::clear::clear(&mut self.session)?;
|
|
||||||
Ok(Value::None)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
|
|
||||||
Ok(Value::None)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for identification
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn info(&self) -> Result<Value, Error> {
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "SELECT * FROM $auth";
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, None).await?;
|
|
||||||
// Extract the first value from the result
|
|
||||||
let res = res.remove(0).result?.first();
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for setting variables
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
|
|
||||||
match val {
|
|
||||||
// Remove the variable if undefined
|
|
||||||
Value::None => self.vars.remove(&key.0),
|
|
||||||
// Store the variable if defined
|
|
||||||
v => self.vars.insert(key.0, v),
|
|
||||||
};
|
|
||||||
Ok(Value::Null)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
|
|
||||||
self.vars.remove(&key.0);
|
|
||||||
Ok(Value::Null)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for live queries
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn kill(&self, id: Value) -> Result<Value, Error> {
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "KILL $id";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = map! {
|
|
||||||
String::from("id") => id, // NOTE: id can be parameter
|
|
||||||
=> &self.vars
|
|
||||||
};
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let response = res.remove(0);
|
|
||||||
match response.result {
|
|
||||||
Ok(v) => Ok(v),
|
|
||||||
Err(e) => Err(Error::from(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = match diff.is_true() {
|
|
||||||
true => "LIVE SELECT DIFF FROM $tb",
|
|
||||||
false => "LIVE SELECT * FROM $tb",
|
|
||||||
};
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = map! {
|
|
||||||
String::from("tb") => tb.could_be_table(),
|
|
||||||
=> &self.vars
|
|
||||||
};
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let response = res.remove(0);
|
|
||||||
match response.result {
|
|
||||||
Ok(v) => Ok(v),
|
|
||||||
Err(e) => Err(Error::from(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for selecting
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn select(&self, what: Value) -> Result<Value, Error> {
|
|
||||||
// Return a single result?
|
|
||||||
let one = what.is_thing();
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "SELECT * FROM $what";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(map! {
|
|
||||||
String::from("what") => what.could_be_table(),
|
|
||||||
=> &self.vars
|
|
||||||
});
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, var).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let res = match one {
|
|
||||||
true => res.remove(0).result?.first(),
|
|
||||||
false => res.remove(0).result?,
|
|
||||||
};
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for inserting
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
|
|
||||||
// Return a single result?
|
|
||||||
let one = what.is_thing();
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "INSERT INTO $what $data RETURN AFTER";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(map! {
|
|
||||||
String::from("what") => what.could_be_table(),
|
|
||||||
String::from("data") => data,
|
|
||||||
=> &self.vars
|
|
||||||
});
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, var).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let res = match one {
|
|
||||||
true => res.remove(0).result?.first(),
|
|
||||||
false => res.remove(0).result?,
|
|
||||||
};
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for creating
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
|
|
||||||
// Return a single result?
|
|
||||||
let one = what.is_thing();
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "CREATE $what CONTENT $data RETURN AFTER";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(map! {
|
|
||||||
String::from("what") => what.could_be_table(),
|
|
||||||
String::from("data") => data,
|
|
||||||
=> &self.vars
|
|
||||||
});
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, var).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let res = match one {
|
|
||||||
true => res.remove(0).result?.first(),
|
|
||||||
false => res.remove(0).result?,
|
|
||||||
};
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for updating
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
|
|
||||||
// Return a single result?
|
|
||||||
let one = what.is_thing();
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "UPDATE $what CONTENT $data RETURN AFTER";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(map! {
|
|
||||||
String::from("what") => what.could_be_table(),
|
|
||||||
String::from("data") => data,
|
|
||||||
=> &self.vars
|
|
||||||
});
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, var).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let res = match one {
|
|
||||||
true => res.remove(0).result?.first(),
|
|
||||||
false => res.remove(0).result?,
|
|
||||||
};
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for changing
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn change(&self, what: Value, data: Value) -> Result<Value, Error> {
|
|
||||||
// Return a single result?
|
|
||||||
let one = what.is_thing();
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "UPDATE $what MERGE $data RETURN AFTER";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(map! {
|
|
||||||
String::from("what") => what.could_be_table(),
|
|
||||||
String::from("data") => data,
|
|
||||||
=> &self.vars
|
|
||||||
});
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, var).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let res = match one {
|
|
||||||
true => res.remove(0).result?.first(),
|
|
||||||
false => res.remove(0).result?,
|
|
||||||
};
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for modifying
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> {
|
|
||||||
// Return a single result?
|
|
||||||
let one = what.is_thing();
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "UPDATE $what PATCH $data RETURN DIFF";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(map! {
|
|
||||||
String::from("what") => what.could_be_table(),
|
|
||||||
String::from("data") => data,
|
|
||||||
=> &self.vars
|
|
||||||
});
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, var).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let res = match one {
|
|
||||||
true => res.remove(0).result?.first(),
|
|
||||||
false => res.remove(0).result?,
|
|
||||||
};
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for deleting
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn delete(&self, what: Value) -> Result<Value, Error> {
|
|
||||||
// Return a single result?
|
|
||||||
let one = what.is_thing();
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the SQL query string
|
|
||||||
let sql = "DELETE $what RETURN BEFORE";
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(map! {
|
|
||||||
String::from("what") => what.could_be_table(),
|
|
||||||
=> &self.vars
|
|
||||||
});
|
|
||||||
// Execute the query on the database
|
|
||||||
let mut res = kvs.execute(sql, &self.session, var).await?;
|
|
||||||
// Extract the first query result
|
|
||||||
let res = match one {
|
|
||||||
true => res.remove(0).result?.first(),
|
|
||||||
false => res.remove(0).result?,
|
|
||||||
};
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Methods for querying
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> {
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(self.vars.clone());
|
|
||||||
// Execute the query on the database
|
|
||||||
let res = kvs.execute(&sql, &self.session, var).await?;
|
|
||||||
// Post-process hooks for web layer
|
|
||||||
for response in &res {
|
|
||||||
self.handle_live_query_results(response).await;
|
|
||||||
}
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> {
|
|
||||||
// Get a database reference
|
|
||||||
let kvs = DB.get().unwrap();
|
|
||||||
// Specify the query parameters
|
|
||||||
let var = Some(mrg! { vars.0, &self.vars });
|
|
||||||
// Execute the query on the database
|
|
||||||
let res = kvs.execute(&sql, &self.session, var).await?;
|
|
||||||
// Post-process hooks for web layer
|
|
||||||
for response in &res {
|
|
||||||
self.handle_live_query_results(response).await;
|
|
||||||
}
|
|
||||||
// Return the result to the client
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ------------------------------
|
|
||||||
// Private methods
|
|
||||||
// ------------------------------
|
|
||||||
|
|
||||||
async fn handle_live_query_results(&self, res: &Response) {
|
|
||||||
match &res.query_type {
|
|
||||||
QueryType::Live => {
|
|
||||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
|
||||||
// Match on Uuid type
|
|
||||||
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
|
|
||||||
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
QueryType::Kill => {
|
|
||||||
if let Ok(Value::Uuid(lqid)) = &res.result {
|
|
||||||
let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0);
|
|
||||||
if let Some(ws_id) = ws_id {
|
|
||||||
trace!("Unregistered live query {} on websocket {}", lqid, ws_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,7 @@
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use axum_server::Handle;
|
use axum_server::Handle;
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
use crate::{
|
use crate::{err::Error, rpc, telemetry};
|
||||||
err::Error,
|
|
||||||
net::rpc::{WebSocketRef, WEBSOCKETS},
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Start a graceful shutdown:
|
/// Start a graceful shutdown:
|
||||||
/// * Signal the Axum Handle when a shutdown signal is received.
|
/// * Signal the Axum Handle when a shutdown signal is received.
|
||||||
|
@ -24,15 +19,12 @@ pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> {
|
||||||
// First stop accepting new HTTP requests
|
// First stop accepting new HTTP requests
|
||||||
http_handle.graceful_shutdown(None);
|
http_handle.graceful_shutdown(None);
|
||||||
|
|
||||||
// Close all WebSocket connections. Queued messages will still be processed.
|
rpc::graceful_shutdown().await;
|
||||||
for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() {
|
|
||||||
cancel_token.cancel();
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wait for all existing WebSocket connections to gracefully close
|
// Flush all telemetry data
|
||||||
while WEBSOCKETS.read().await.len() > 0 {
|
if let Err(err) = telemetry::shutdown() {
|
||||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
error!("Failed to flush telemetry data: {}", err);
|
||||||
};
|
}
|
||||||
} => (),
|
} => (),
|
||||||
// Force an immediate shutdown if a second signal is received
|
// Force an immediate shutdown if a second signal is received
|
||||||
_ = async {
|
_ = async {
|
||||||
|
@ -46,9 +38,7 @@ pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> {
|
||||||
http_handle.shutdown();
|
http_handle.shutdown();
|
||||||
|
|
||||||
// Close all WebSocket connections immediately
|
// Close all WebSocket connections immediately
|
||||||
if let Ok(mut writer) = WEBSOCKETS.try_write() {
|
rpc::shutdown();
|
||||||
writer.drain();
|
|
||||||
}
|
|
||||||
} => (),
|
} => (),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
302
src/rpc/connection.rs
Normal file
302
src/rpc/connection.rs
Normal file
|
@ -0,0 +1,302 @@
|
||||||
|
use axum::extract::ws::{Message, WebSocket};
|
||||||
|
use futures_util::stream::{SplitSink, SplitStream};
|
||||||
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use opentelemetry::trace::FutureExt;
|
||||||
|
use opentelemetry::Context as TelemetryContext;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use surrealdb::channel::{self, Receiver, Sender};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use surrealdb::dbs::Session;
|
||||||
|
use tokio::task::JoinSet;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::cnf::{MAX_CONCURRENT_CALLS, WEBSOCKET_PING_FREQUENCY};
|
||||||
|
use crate::dbs::DB;
|
||||||
|
use crate::rpc::res::success;
|
||||||
|
use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS};
|
||||||
|
use crate::telemetry;
|
||||||
|
use crate::telemetry::metrics::ws::RequestContext;
|
||||||
|
use crate::telemetry::traces::rpc::span_for_request;
|
||||||
|
|
||||||
|
use super::processor::Processor;
|
||||||
|
use super::request::parse_request;
|
||||||
|
use super::res::{failure, IntoRpcResponse, OutputFormat};
|
||||||
|
|
||||||
|
pub struct Connection {
|
||||||
|
ws_id: Uuid,
|
||||||
|
processor: Processor,
|
||||||
|
graceful_shutdown: CancellationToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connection {
|
||||||
|
/// Instantiate a new RPC
|
||||||
|
pub fn new(mut session: Session) -> Arc<RwLock<Connection>> {
|
||||||
|
// Create a new RPC variables store
|
||||||
|
let vars = BTreeMap::new();
|
||||||
|
// Set the default output format
|
||||||
|
let format = OutputFormat::Json;
|
||||||
|
// Enable real-time mode
|
||||||
|
session.rt = true;
|
||||||
|
|
||||||
|
// Create a new RPC processor
|
||||||
|
let processor = Processor::new(session, format, vars);
|
||||||
|
|
||||||
|
// Create and store the RPC connection
|
||||||
|
Arc::new(RwLock::new(Connection {
|
||||||
|
ws_id: processor.ws_id,
|
||||||
|
processor,
|
||||||
|
graceful_shutdown: CancellationToken::new(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the WebSocket ID. If the ID already exists, do not update it.
|
||||||
|
pub async fn update_ws_id(&mut self, ws_id: Uuid) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
if WEBSOCKETS.read().await.contains_key(&ws_id) {
|
||||||
|
trace!("WebSocket ID '{}' is in use by another connection. Do not update it.", &ws_id);
|
||||||
|
return Err("websocket ID is in use".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ws_id = ws_id;
|
||||||
|
self.processor.ws_id = ws_id;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serve the RPC endpoint
|
||||||
|
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
|
||||||
|
// Split the socket into send and recv
|
||||||
|
let (sender, receiver) = ws.split();
|
||||||
|
// Create an internal channel between the receiver and the sender
|
||||||
|
let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS);
|
||||||
|
|
||||||
|
let ws_id = rpc.read().await.ws_id;
|
||||||
|
|
||||||
|
trace!("WebSocket {} connected", ws_id);
|
||||||
|
|
||||||
|
if let Err(err) = telemetry::metrics::ws::on_connect() {
|
||||||
|
error!("Error running metrics::ws::on_connect hook: {}", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add this WebSocket to the list
|
||||||
|
WEBSOCKETS.write().await.insert(
|
||||||
|
ws_id,
|
||||||
|
WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Remove all live queries
|
||||||
|
LIVE_QUERIES.write().await.retain(|key, value| {
|
||||||
|
if value == &ws_id {
|
||||||
|
trace!("Removing live query: {}", key);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
true
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut tasks = JoinSet::new();
|
||||||
|
tasks.spawn(Self::ping(rpc.clone(), internal_sender.clone()));
|
||||||
|
tasks.spawn(Self::read(rpc.clone(), receiver, internal_sender.clone()));
|
||||||
|
tasks.spawn(Self::write(rpc.clone(), sender, internal_receiver.clone()));
|
||||||
|
tasks.spawn(Self::lq_notifications(rpc.clone()));
|
||||||
|
|
||||||
|
// Wait until all tasks finish
|
||||||
|
while let Some(res) = tasks.join_next().await {
|
||||||
|
if let Err(err) = res {
|
||||||
|
error!("Error handling RPC connection: {}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove this WebSocket from the list
|
||||||
|
WEBSOCKETS.write().await.remove(&ws_id);
|
||||||
|
|
||||||
|
trace!("WebSocket {} disconnected", ws_id);
|
||||||
|
|
||||||
|
if let Err(err) = telemetry::metrics::ws::on_disconnect() {
|
||||||
|
error!("Error running metrics::ws::on_disconnect hook: {}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send Ping messages to the client
|
||||||
|
async fn ping(rpc: Arc<RwLock<Connection>>, internal_sender: Sender<Message>) {
|
||||||
|
// Create the interval ticker
|
||||||
|
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
let is_shutdown = cancel_token.cancelled();
|
||||||
|
tokio::select! {
|
||||||
|
_ = interval.tick() => {
|
||||||
|
let msg = Message::Ping(vec![]);
|
||||||
|
|
||||||
|
// Send the message to the client and close the WebSocket connection if it fails
|
||||||
|
if internal_sender.send(msg).await.is_err() {
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ = is_shutdown => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read messages sent from the client
|
||||||
|
async fn read(
|
||||||
|
rpc: Arc<RwLock<Connection>>,
|
||||||
|
mut receiver: SplitStream<WebSocket>,
|
||||||
|
internal_sender: Sender<Message>,
|
||||||
|
) {
|
||||||
|
// Collect all spawned tasks so we can wait for them at the end
|
||||||
|
let mut tasks = JoinSet::new();
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
let is_shutdown = cancel_token.cancelled();
|
||||||
|
tokio::select! {
|
||||||
|
msg = receiver.next() => {
|
||||||
|
if let Some(msg) = msg {
|
||||||
|
match msg {
|
||||||
|
// We've received a message from the client
|
||||||
|
// Ping/Pong is automatically handled by the WebSocket library
|
||||||
|
Ok(msg) => match msg {
|
||||||
|
Message::Text(_) => {
|
||||||
|
tasks.spawn(Connection::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
||||||
|
}
|
||||||
|
Message::Binary(_) => {
|
||||||
|
tasks.spawn(Connection::handle_msg(rpc.clone(), msg, internal_sender.clone()));
|
||||||
|
}
|
||||||
|
Message::Close(_) => {
|
||||||
|
// Respond with a close message
|
||||||
|
if let Err(err) = internal_sender.send(Message::Close(None)).await {
|
||||||
|
trace!("WebSocket error when replying to the Close frame: {:?}", err);
|
||||||
|
};
|
||||||
|
// Start the graceful shutdown of the WebSocket and close the channels
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
let _ = internal_sender.close();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Ignore everything else
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(err) => {
|
||||||
|
trace!("WebSocket error: {:?}", err);
|
||||||
|
// Start the graceful shutdown of the WebSocket and close the channels
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
let _ = internal_sender.close();
|
||||||
|
// Exit out of the loop
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = is_shutdown => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all tasks to finish
|
||||||
|
while let Some(res) = tasks.join_next().await {
|
||||||
|
if let Err(err) = res {
|
||||||
|
error!("Error while handling RPC message: {}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write messages to the client
|
||||||
|
async fn write(
|
||||||
|
rpc: Arc<RwLock<Connection>>,
|
||||||
|
mut sender: SplitSink<WebSocket, Message>,
|
||||||
|
mut internal_receiver: Receiver<Message>,
|
||||||
|
) {
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
let is_shutdown = cancel_token.cancelled();
|
||||||
|
tokio::select! {
|
||||||
|
// Wait for the next message to send
|
||||||
|
msg = internal_receiver.next() => {
|
||||||
|
if let Some(res) = msg {
|
||||||
|
// Send the message to the client
|
||||||
|
if let Err(err) = sender.send(res).await {
|
||||||
|
if err.to_string() != CONN_CLOSED_ERR {
|
||||||
|
debug!("WebSocket error: {:?}", err);
|
||||||
|
}
|
||||||
|
// Close the WebSocket connection
|
||||||
|
rpc.read().await.graceful_shutdown.cancel();
|
||||||
|
// Exit out of the loop
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ = is_shutdown => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send live query notifications to the client
|
||||||
|
async fn lq_notifications(rpc: Arc<RwLock<Connection>>) {
|
||||||
|
if let Some(channel) = DB.get().unwrap().notifications() {
|
||||||
|
let cancel_token = rpc.read().await.graceful_shutdown.clone();
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
msg = channel.recv() => {
|
||||||
|
if let Ok(notification) = msg {
|
||||||
|
// Find which WebSocket the notification belongs to
|
||||||
|
if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) {
|
||||||
|
// Check to see if the WebSocket exists
|
||||||
|
if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) {
|
||||||
|
// Serialize the message to send
|
||||||
|
let message = success(None, notification);
|
||||||
|
// Get the current output format
|
||||||
|
let format = rpc.read().await.processor.format.clone();
|
||||||
|
// Send the notification to the client
|
||||||
|
message.send(format, ws.clone()).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ = cancel_token.cancelled() => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle individual WebSocket messages
|
||||||
|
async fn handle_msg(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) {
|
||||||
|
// Get the current output format
|
||||||
|
let mut out_fmt = rpc.read().await.processor.format.clone();
|
||||||
|
// Prepare Span and Otel context
|
||||||
|
let span = span_for_request(&rpc.read().await.ws_id);
|
||||||
|
let _enter = span.enter();
|
||||||
|
let req_cx = RequestContext::default();
|
||||||
|
let otel_cx = TelemetryContext::current_with_value(req_cx.clone());
|
||||||
|
|
||||||
|
// Parse the request
|
||||||
|
match parse_request(msg).await {
|
||||||
|
Ok(req) => {
|
||||||
|
if let Some(_out_fmt) = req.out_fmt {
|
||||||
|
out_fmt = _out_fmt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we know the method, we can update the span and create otel context
|
||||||
|
span.record("rpc.method", &req.method);
|
||||||
|
span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
|
||||||
|
span.record(
|
||||||
|
"rpc.jsonrpc.request_id",
|
||||||
|
req.id.clone().map(|v| v.as_string()).unwrap_or(String::new()),
|
||||||
|
);
|
||||||
|
let otel_cx = TelemetryContext::current_with_value(
|
||||||
|
req_cx.with_method(&req.method).with_size(req.size),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
let res =
|
||||||
|
rpc.write().await.processor.process_request(&req.method, req.params).await;
|
||||||
|
|
||||||
|
// Process the response
|
||||||
|
res.into_response(req.id).send(out_fmt, chn).with_context(otel_cx).await
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
// Process the response
|
||||||
|
failure(None, err).send(out_fmt, chn).with_context(otel_cx.clone()).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,44 @@
|
||||||
pub mod args;
|
pub mod args;
|
||||||
pub mod paths;
|
pub mod connection;
|
||||||
|
pub mod processor;
|
||||||
|
pub mod request;
|
||||||
pub mod res;
|
pub mod res;
|
||||||
|
|
||||||
pub(crate) static CONN_CLOSED_ERR: &str = "Connection closed normally";
|
use std::{collections::HashMap, time::Duration};
|
||||||
|
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use surrealdb::channel::Sender;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
static CONN_CLOSED_ERR: &str = "Connection closed normally";
|
||||||
|
|
||||||
|
// Mapping of WebSocketID to WebSocket
|
||||||
|
pub struct WebSocketRef(Sender<Message>, CancellationToken);
|
||||||
|
type WebSockets = RwLock<HashMap<Uuid, WebSocketRef>>;
|
||||||
|
// Mapping of LiveQueryID to WebSocketID
|
||||||
|
type LiveQueries = RwLock<HashMap<Uuid, Uuid>>;
|
||||||
|
|
||||||
|
pub(crate) static WEBSOCKETS: Lazy<WebSockets> = Lazy::new(WebSockets::default);
|
||||||
|
pub(crate) static LIVE_QUERIES: Lazy<LiveQueries> = Lazy::new(LiveQueries::default);
|
||||||
|
|
||||||
|
pub(crate) async fn graceful_shutdown() {
|
||||||
|
// Close all WebSocket connections. Queued messages will still be processed.
|
||||||
|
for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() {
|
||||||
|
cancel_token.cancel();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all existing WebSocket connections to gracefully close
|
||||||
|
while WEBSOCKETS.read().await.len() > 0 {
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn shutdown() {
|
||||||
|
// Close all WebSocket connections immediately
|
||||||
|
if let Ok(mut writer) = WEBSOCKETS.try_write() {
|
||||||
|
writer.drain();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
use surrealdb::sql::Part;
|
|
||||||
|
|
||||||
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
|
|
||||||
|
|
||||||
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
|
|
||||||
|
|
||||||
pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]);
|
|
547
src/rpc/processor.rs
Normal file
547
src/rpc/processor.rs
Normal file
|
@ -0,0 +1,547 @@
|
||||||
|
use crate::cnf::PKG_NAME;
|
||||||
|
use crate::cnf::PKG_VERSION;
|
||||||
|
use crate::dbs::DB;
|
||||||
|
use crate::err::Error;
|
||||||
|
use crate::rpc::args::Take;
|
||||||
|
use crate::rpc::LIVE_QUERIES;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use surrealdb::dbs::QueryType;
|
||||||
|
use surrealdb::dbs::Response;
|
||||||
|
use surrealdb::sql::Object;
|
||||||
|
use surrealdb::sql::Strand;
|
||||||
|
use surrealdb::sql::Value;
|
||||||
|
use surrealdb::{dbs::Session, sql::Array};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use super::res::{Data, Failure, OutputFormat};
|
||||||
|
|
||||||
|
pub struct Processor {
|
||||||
|
pub ws_id: Uuid,
|
||||||
|
session: Session,
|
||||||
|
pub format: OutputFormat,
|
||||||
|
vars: BTreeMap<String, Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Processor {
|
||||||
|
pub fn new(session: Session, format: OutputFormat, vars: BTreeMap<String, Value>) -> Self {
|
||||||
|
Self {
|
||||||
|
ws_id: Uuid::new_v4(),
|
||||||
|
session,
|
||||||
|
format,
|
||||||
|
vars,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn process_request(&mut self, method: &str, params: Array) -> Result<Data, Failure> {
|
||||||
|
info!("Process RPC request");
|
||||||
|
|
||||||
|
// Match the method to a function
|
||||||
|
match method {
|
||||||
|
// Handle a surrealdb ping message
|
||||||
|
//
|
||||||
|
// This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough.
|
||||||
|
// For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios.
|
||||||
|
// By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent.
|
||||||
|
//
|
||||||
|
"ping" => Ok(Value::None.into()),
|
||||||
|
// Retrieve the current auth record
|
||||||
|
"info" => match params.len() {
|
||||||
|
0 => self.info().await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Switch to a specific namespace and database
|
||||||
|
"use" => match params.needs_two() {
|
||||||
|
Ok((ns, db)) => self.yuse(ns, db).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Signup to a specific authentication scope
|
||||||
|
"signup" => match params.needs_one() {
|
||||||
|
Ok(Value::Object(v)) => self.signup(v).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Signin as a root, namespace, database or scope user
|
||||||
|
"signin" => match params.needs_one() {
|
||||||
|
Ok(Value::Object(v)) => self.signin(v).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Invalidate the current authentication session
|
||||||
|
"invalidate" => match params.len() {
|
||||||
|
0 => self.invalidate().await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Authenticate using an authentication token
|
||||||
|
"authenticate" => match params.needs_one() {
|
||||||
|
Ok(Value::Strand(v)) => {
|
||||||
|
self.authenticate(v).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Kill a live query using a query id
|
||||||
|
"kill" => match params.needs_one() {
|
||||||
|
Ok(v) => self.kill(v).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Setup a live query on a specific table
|
||||||
|
"live" => match params.needs_one_or_two() {
|
||||||
|
Ok((v, d)) if v.is_table() => {
|
||||||
|
self.live(v, d).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
Ok((v, d)) if v.is_strand() => {
|
||||||
|
self.live(v, d).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Specify a connection-wide parameter
|
||||||
|
"let" | "set" => match params.needs_one_or_two() {
|
||||||
|
Ok((Value::Strand(s), v)) => {
|
||||||
|
self.set(s, v).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Unset and clear a connection-wide parameter
|
||||||
|
"unset" => match params.needs_one() {
|
||||||
|
Ok(Value::Strand(s)) => self.unset(s).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Select a value or values from the database
|
||||||
|
"select" => match params.needs_one() {
|
||||||
|
Ok(v) => self.select(v).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Insert a value or values in the database
|
||||||
|
"insert" => match params.needs_one_or_two() {
|
||||||
|
Ok((v, o)) => self.insert(v, o).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Create a value or values in the database
|
||||||
|
"create" => match params.needs_one_or_two() {
|
||||||
|
Ok((v, o)) => self.create(v, o).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Update a value or values in the database using `CONTENT`
|
||||||
|
"update" => match params.needs_one_or_two() {
|
||||||
|
Ok((v, o)) => self.update(v, o).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Update a value or values in the database using `MERGE`
|
||||||
|
"change" | "merge" => match params.needs_one_or_two() {
|
||||||
|
Ok((v, o)) => self.change(v, o).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Update a value or values in the database using `PATCH`
|
||||||
|
"modify" | "patch" => match params.needs_one_or_two() {
|
||||||
|
Ok((v, o)) => self.modify(v, o).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Delete a value or values from the database
|
||||||
|
"delete" => match params.needs_one() {
|
||||||
|
Ok(v) => self.delete(v).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Specify the output format for text requests
|
||||||
|
"format" => match params.needs_one() {
|
||||||
|
Ok(Value::Strand(v)) => self.format(v).await.map(Into::into).map_err(Into::into),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Get the current server version
|
||||||
|
"version" => match params.len() {
|
||||||
|
0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()),
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
// Run a full SurrealQL query against the database
|
||||||
|
"query" => match params.needs_one_or_two() {
|
||||||
|
Ok((Value::Strand(s), o)) if o.is_none_or_null() => {
|
||||||
|
self.query(s).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
Ok((Value::Strand(s), Value::Object(o))) => {
|
||||||
|
self.query_with(s, o).await.map(Into::into).map_err(Into::into)
|
||||||
|
}
|
||||||
|
_ => Err(Failure::INVALID_PARAMS),
|
||||||
|
},
|
||||||
|
_ => Err(Failure::METHOD_NOT_FOUND),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for authentication
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn format(&mut self, out: Strand) -> Result<Value, Error> {
|
||||||
|
match out.as_str() {
|
||||||
|
"json" | "application/json" => self.format = OutputFormat::Json,
|
||||||
|
"cbor" | "application/cbor" => self.format = OutputFormat::Cbor,
|
||||||
|
"pack" | "application/pack" => self.format = OutputFormat::Pack,
|
||||||
|
_ => return Err(Error::InvalidType),
|
||||||
|
};
|
||||||
|
Ok(Value::None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn yuse(&mut self, ns: Value, db: Value) -> Result<Value, Error> {
|
||||||
|
if let Value::Strand(ns) = ns {
|
||||||
|
self.session.ns = Some(ns.0);
|
||||||
|
}
|
||||||
|
if let Value::Strand(db) = db {
|
||||||
|
self.session.db = Some(db.0);
|
||||||
|
}
|
||||||
|
Ok(Value::None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn signup(&mut self, vars: Object) -> Result<Value, Error> {
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
surrealdb::iam::signup::signup(kvs, &mut self.session, vars)
|
||||||
|
.await
|
||||||
|
.map(Into::into)
|
||||||
|
.map_err(Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn signin(&mut self, vars: Object) -> Result<Value, Error> {
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
surrealdb::iam::signin::signin(kvs, &mut self.session, vars)
|
||||||
|
.await
|
||||||
|
.map(Into::into)
|
||||||
|
.map_err(Into::into)
|
||||||
|
}
|
||||||
|
async fn invalidate(&mut self) -> Result<Value, Error> {
|
||||||
|
surrealdb::iam::clear::clear(&mut self.session)?;
|
||||||
|
Ok(Value::None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn authenticate(&mut self, token: Strand) -> Result<Value, Error> {
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?;
|
||||||
|
Ok(Value::None)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for identification
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn info(&self) -> Result<Value, Error> {
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "SELECT * FROM $auth";
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, None).await?;
|
||||||
|
// Extract the first value from the result
|
||||||
|
let res = res.remove(0).result?.first();
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for setting variables
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn set(&mut self, key: Strand, val: Value) -> Result<Value, Error> {
|
||||||
|
match val {
|
||||||
|
// Remove the variable if undefined
|
||||||
|
Value::None => self.vars.remove(&key.0),
|
||||||
|
// Store the variable if defined
|
||||||
|
v => self.vars.insert(key.0, v),
|
||||||
|
};
|
||||||
|
Ok(Value::Null)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn unset(&mut self, key: Strand) -> Result<Value, Error> {
|
||||||
|
self.vars.remove(&key.0);
|
||||||
|
Ok(Value::Null)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for live queries
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn kill(&self, id: Value) -> Result<Value, Error> {
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "KILL $id";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = map! {
|
||||||
|
String::from("id") => id, // NOTE: id can be parameter
|
||||||
|
=> &self.vars
|
||||||
|
};
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let response = res.remove(0);
|
||||||
|
match response.result {
|
||||||
|
Ok(v) => Ok(v),
|
||||||
|
Err(e) => Err(Error::from(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn live(&self, tb: Value, diff: Value) -> Result<Value, Error> {
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = match diff.is_true() {
|
||||||
|
true => "LIVE SELECT DIFF FROM $tb",
|
||||||
|
false => "LIVE SELECT * FROM $tb",
|
||||||
|
};
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = map! {
|
||||||
|
String::from("tb") => tb.could_be_table(),
|
||||||
|
=> &self.vars
|
||||||
|
};
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let response = res.remove(0);
|
||||||
|
match response.result {
|
||||||
|
Ok(v) => Ok(v),
|
||||||
|
Err(e) => Err(Error::from(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for selecting
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn select(&self, what: Value) -> Result<Value, Error> {
|
||||||
|
// Return a single result?
|
||||||
|
let one = what.is_thing();
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "SELECT * FROM $what";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(map! {
|
||||||
|
String::from("what") => what.could_be_table(),
|
||||||
|
=> &self.vars
|
||||||
|
});
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, var).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let res = match one {
|
||||||
|
true => res.remove(0).result?.first(),
|
||||||
|
false => res.remove(0).result?,
|
||||||
|
};
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for inserting
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn insert(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
|
// Return a single result?
|
||||||
|
let one = what.is_thing();
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "INSERT INTO $what $data RETURN AFTER";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(map! {
|
||||||
|
String::from("what") => what.could_be_table(),
|
||||||
|
String::from("data") => data,
|
||||||
|
=> &self.vars
|
||||||
|
});
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, var).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let res = match one {
|
||||||
|
true => res.remove(0).result?.first(),
|
||||||
|
false => res.remove(0).result?,
|
||||||
|
};
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for creating
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn create(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
|
// Return a single result?
|
||||||
|
let one = what.is_thing();
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "CREATE $what CONTENT $data RETURN AFTER";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(map! {
|
||||||
|
String::from("what") => what.could_be_table(),
|
||||||
|
String::from("data") => data,
|
||||||
|
=> &self.vars
|
||||||
|
});
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, var).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let res = match one {
|
||||||
|
true => res.remove(0).result?.first(),
|
||||||
|
false => res.remove(0).result?,
|
||||||
|
};
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for updating
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn update(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
|
// Return a single result?
|
||||||
|
let one = what.is_thing();
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "UPDATE $what CONTENT $data RETURN AFTER";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(map! {
|
||||||
|
String::from("what") => what.could_be_table(),
|
||||||
|
String::from("data") => data,
|
||||||
|
=> &self.vars
|
||||||
|
});
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, var).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let res = match one {
|
||||||
|
true => res.remove(0).result?.first(),
|
||||||
|
false => res.remove(0).result?,
|
||||||
|
};
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for changing
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn change(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
|
// Return a single result?
|
||||||
|
let one = what.is_thing();
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "UPDATE $what MERGE $data RETURN AFTER";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(map! {
|
||||||
|
String::from("what") => what.could_be_table(),
|
||||||
|
String::from("data") => data,
|
||||||
|
=> &self.vars
|
||||||
|
});
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, var).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let res = match one {
|
||||||
|
true => res.remove(0).result?.first(),
|
||||||
|
false => res.remove(0).result?,
|
||||||
|
};
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for modifying
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn modify(&self, what: Value, data: Value) -> Result<Value, Error> {
|
||||||
|
// Return a single result?
|
||||||
|
let one = what.is_thing();
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "UPDATE $what PATCH $data RETURN DIFF";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(map! {
|
||||||
|
String::from("what") => what.could_be_table(),
|
||||||
|
String::from("data") => data,
|
||||||
|
=> &self.vars
|
||||||
|
});
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, var).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let res = match one {
|
||||||
|
true => res.remove(0).result?.first(),
|
||||||
|
false => res.remove(0).result?,
|
||||||
|
};
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for deleting
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn delete(&self, what: Value) -> Result<Value, Error> {
|
||||||
|
// Return a single result?
|
||||||
|
let one = what.is_thing();
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the SQL query string
|
||||||
|
let sql = "DELETE $what RETURN BEFORE";
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(map! {
|
||||||
|
String::from("what") => what.could_be_table(),
|
||||||
|
=> &self.vars
|
||||||
|
});
|
||||||
|
// Execute the query on the database
|
||||||
|
let mut res = kvs.execute(sql, &self.session, var).await?;
|
||||||
|
// Extract the first query result
|
||||||
|
let res = match one {
|
||||||
|
true => res.remove(0).result?.first(),
|
||||||
|
false => res.remove(0).result?,
|
||||||
|
};
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Methods for querying
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn query(&self, sql: Strand) -> Result<Vec<Response>, Error> {
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(self.vars.clone());
|
||||||
|
// Execute the query on the database
|
||||||
|
let res = kvs.execute(&sql, &self.session, var).await?;
|
||||||
|
// Post-process hooks for web layer
|
||||||
|
for response in &res {
|
||||||
|
self.handle_live_query_results(response).await;
|
||||||
|
}
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn query_with(&self, sql: Strand, mut vars: Object) -> Result<Vec<Response>, Error> {
|
||||||
|
// Get a database reference
|
||||||
|
let kvs = DB.get().unwrap();
|
||||||
|
// Specify the query parameters
|
||||||
|
let var = Some(mrg! { vars.0, &self.vars });
|
||||||
|
// Execute the query on the database
|
||||||
|
let res = kvs.execute(&sql, &self.session, var).await?;
|
||||||
|
// Post-process hooks for web layer
|
||||||
|
for response in &res {
|
||||||
|
self.handle_live_query_results(response).await;
|
||||||
|
}
|
||||||
|
// Return the result to the client
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------
|
||||||
|
// Private methods
|
||||||
|
// ------------------------------
|
||||||
|
|
||||||
|
async fn handle_live_query_results(&self, res: &Response) {
|
||||||
|
match &res.query_type {
|
||||||
|
QueryType::Live => {
|
||||||
|
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||||
|
// Match on Uuid type
|
||||||
|
LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id);
|
||||||
|
trace!("Registered live query {} on websocket {}", lqid, self.ws_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
QueryType::Kill => {
|
||||||
|
if let Ok(Value::Uuid(lqid)) = &res.result {
|
||||||
|
let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0);
|
||||||
|
if let Some(ws_id) = ws_id {
|
||||||
|
trace!("Unregistered live query {} on websocket {}", lqid, ws_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
84
src/rpc/request.rs
Normal file
84
src/rpc/request.rs
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
use axum::extract::ws::Message;
|
||||||
|
use surrealdb::sql::{serde::deserialize, Array, Value};
|
||||||
|
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use surrealdb::sql::Part;
|
||||||
|
|
||||||
|
use super::res::{Failure, OutputFormat};
|
||||||
|
|
||||||
|
pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
|
||||||
|
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
|
||||||
|
pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]);
|
||||||
|
|
||||||
|
pub struct Request {
|
||||||
|
pub id: Option<Value>,
|
||||||
|
pub method: String,
|
||||||
|
pub params: Array,
|
||||||
|
pub size: usize,
|
||||||
|
pub out_fmt: Option<OutputFormat>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse the RPC request
|
||||||
|
pub async fn parse_request(msg: Message) -> Result<Request, Failure> {
|
||||||
|
let mut out_fmt = None;
|
||||||
|
let (req, size) = match msg {
|
||||||
|
// This is a binary message
|
||||||
|
Message::Binary(val) => {
|
||||||
|
// Use binary output
|
||||||
|
out_fmt = Some(OutputFormat::Full);
|
||||||
|
|
||||||
|
match deserialize(&val) {
|
||||||
|
Ok(v) => (v, val.len()),
|
||||||
|
Err(_) => {
|
||||||
|
debug!("Error when trying to deserialize the request");
|
||||||
|
return Err(Failure::PARSE_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// This is a text message
|
||||||
|
Message::Text(ref val) => {
|
||||||
|
// Parse the SurrealQL object
|
||||||
|
match surrealdb::sql::value(val) {
|
||||||
|
// The SurrealQL message parsed ok
|
||||||
|
Ok(v) => (v, val.len()),
|
||||||
|
// The SurrealQL message failed to parse
|
||||||
|
_ => return Err(Failure::PARSE_ERROR),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Unsupported message type
|
||||||
|
_ => {
|
||||||
|
debug!("Unsupported message type: {:?}", msg);
|
||||||
|
return Err(Failure::custom("Unsupported message type"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fetch the 'id' argument
|
||||||
|
let id = match req.pick(&*ID) {
|
||||||
|
v if v.is_none() => None,
|
||||||
|
v if v.is_null() => Some(v),
|
||||||
|
v if v.is_uuid() => Some(v),
|
||||||
|
v if v.is_number() => Some(v),
|
||||||
|
v if v.is_strand() => Some(v),
|
||||||
|
v if v.is_datetime() => Some(v),
|
||||||
|
_ => return Err(Failure::INVALID_REQUEST),
|
||||||
|
};
|
||||||
|
// Fetch the 'method' argument
|
||||||
|
let method = match req.pick(&*METHOD) {
|
||||||
|
Value::Strand(v) => v.to_raw(),
|
||||||
|
_ => return Err(Failure::INVALID_REQUEST),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fetch the 'params' argument
|
||||||
|
let params = match req.pick(&*PARAMS) {
|
||||||
|
Value::Array(v) => v,
|
||||||
|
_ => Array::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Request {
|
||||||
|
id,
|
||||||
|
method,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
out_fmt,
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,4 +1,5 @@
|
||||||
use axum::extract::ws::Message;
|
use axum::extract::ws::Message;
|
||||||
|
use opentelemetry::Context as TelemetryContext;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::{json, Value as Json};
|
use serde_json::{json, Value as Json};
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
@ -11,6 +12,7 @@ use tracing::Span;
|
||||||
|
|
||||||
use crate::err;
|
use crate::err;
|
||||||
use crate::rpc::CONN_CLOSED_ERR;
|
use crate::rpc::CONN_CLOSED_ERR;
|
||||||
|
use crate::telemetry::metrics::ws::record_rpc;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum OutputFormat {
|
pub enum OutputFormat {
|
||||||
|
@ -96,6 +98,7 @@ impl Response {
|
||||||
|
|
||||||
info!("Process RPC response");
|
info!("Process RPC response");
|
||||||
|
|
||||||
|
let is_error = self.result.is_err();
|
||||||
if let Err(err) = &self.result {
|
if let Err(err) = &self.result {
|
||||||
span.record("otel.status_code", "Error");
|
span.record("otel.status_code", "Error");
|
||||||
span.record(
|
span.record(
|
||||||
|
@ -106,30 +109,33 @@ impl Response {
|
||||||
span.record("rpc.jsonrpc.error_message", err.message.as_ref());
|
span.record("rpc.jsonrpc.error_message", err.message.as_ref());
|
||||||
}
|
}
|
||||||
|
|
||||||
let message = match out {
|
let (res_size, message) = match out {
|
||||||
OutputFormat::Json => {
|
OutputFormat::Json => {
|
||||||
let res = serde_json::to_string(&self.simplify()).unwrap();
|
let res = serde_json::to_string(&self.simplify()).unwrap();
|
||||||
Message::Text(res)
|
(res.len(), Message::Text(res))
|
||||||
}
|
}
|
||||||
OutputFormat::Cbor => {
|
OutputFormat::Cbor => {
|
||||||
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
let res = serde_cbor::to_vec(&self.simplify()).unwrap();
|
||||||
Message::Binary(res)
|
(res.len(), Message::Binary(res))
|
||||||
}
|
}
|
||||||
OutputFormat::Pack => {
|
OutputFormat::Pack => {
|
||||||
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
let res = serde_pack::to_vec(&self.simplify()).unwrap();
|
||||||
Message::Binary(res)
|
(res.len(), Message::Binary(res))
|
||||||
}
|
}
|
||||||
OutputFormat::Full => {
|
OutputFormat::Full => {
|
||||||
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
let res = surrealdb::sql::serde::serialize(&self).unwrap();
|
||||||
Message::Binary(res)
|
(res.len(), Message::Binary(res))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(err) = chn.send(message).await {
|
if let Err(err) = chn.send(message).await {
|
||||||
if err.to_string() != CONN_CLOSED_ERR {
|
if err.to_string() != CONN_CLOSED_ERR {
|
||||||
error!("Error sending response: {}", err);
|
error!("Error sending response: {}", err);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
record_rpc(&TelemetryContext::current(), res_size, is_error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,87 +1,15 @@
|
||||||
pub(super) mod tower_layer;
|
pub(super) mod tower_layer;
|
||||||
|
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use opentelemetry::{
|
use opentelemetry::metrics::{Histogram, MetricsError, ObservableUpDownCounter, Unit};
|
||||||
metrics::{Histogram, Meter, MeterProvider, ObservableUpDownCounter, Unit},
|
use opentelemetry::Context as TelemetryContext;
|
||||||
runtime,
|
|
||||||
sdk::{
|
|
||||||
export::metrics::aggregation,
|
|
||||||
metrics::{
|
|
||||||
controllers::{self, BasicController},
|
|
||||||
processors, selectors,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Context,
|
|
||||||
};
|
|
||||||
use opentelemetry_otlp::MetricsExporterBuilder;
|
|
||||||
|
|
||||||
use crate::telemetry::OTEL_DEFAULT_RESOURCE;
|
use self::tower_layer::HttpCallMetricTracker;
|
||||||
|
|
||||||
// Histogram buckets in milliseconds
|
use super::{METER_DURATION, METER_SIZE};
|
||||||
static HTTP_DURATION_MS_HISTOGRAM_BUCKETS: &[f64] = &[
|
|
||||||
5.0, 10.0, 20.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 500.0, 750.0, 1000.0, 1500.0,
|
|
||||||
2000.0, 2500.0, 5000.0, 10000.0, 15000.0, 30000.0,
|
|
||||||
];
|
|
||||||
|
|
||||||
const KB: f64 = 1024.0;
|
|
||||||
const MB: f64 = 1024.0 * KB;
|
|
||||||
|
|
||||||
const HTTP_SIZE_HISTOGRAM_BUCKETS: &[f64] = &[
|
|
||||||
1.0 * KB, // 1 KB
|
|
||||||
2.0 * KB, // 2 KB
|
|
||||||
5.0 * KB, // 5 KB
|
|
||||||
10.0 * KB, // 10 KB
|
|
||||||
100.0 * KB, // 100 KB
|
|
||||||
500.0 * KB, // 500 KB
|
|
||||||
1.0 * MB, // 1 MB
|
|
||||||
2.5 * MB, // 2 MB
|
|
||||||
5.0 * MB, // 5 MB
|
|
||||||
10.0 * MB, // 10 MB
|
|
||||||
25.0 * MB, // 25 MB
|
|
||||||
50.0 * MB, // 50 MB
|
|
||||||
100.0 * MB, // 100 MB
|
|
||||||
];
|
|
||||||
|
|
||||||
static METER_PROVIDER_HTTP_DURATION: Lazy<BasicController> = Lazy::new(|| {
|
|
||||||
let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic())
|
|
||||||
.build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector()))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let builder = controllers::basic(processors::factory(
|
|
||||||
selectors::simple::histogram(HTTP_DURATION_MS_HISTOGRAM_BUCKETS),
|
|
||||||
aggregation::cumulative_temporality_selector(),
|
|
||||||
))
|
|
||||||
.with_exporter(exporter)
|
|
||||||
.with_resource(OTEL_DEFAULT_RESOURCE.clone());
|
|
||||||
|
|
||||||
let controller = builder.build();
|
|
||||||
controller.start(&Context::current(), runtime::Tokio).unwrap();
|
|
||||||
controller
|
|
||||||
});
|
|
||||||
|
|
||||||
static METER_PROVIDER_HTTP_SIZE: Lazy<BasicController> = Lazy::new(|| {
|
|
||||||
let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic())
|
|
||||||
.build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector()))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let builder = controllers::basic(processors::factory(
|
|
||||||
selectors::simple::histogram(HTTP_SIZE_HISTOGRAM_BUCKETS),
|
|
||||||
aggregation::cumulative_temporality_selector(),
|
|
||||||
))
|
|
||||||
.with_exporter(exporter)
|
|
||||||
.with_resource(OTEL_DEFAULT_RESOURCE.clone());
|
|
||||||
|
|
||||||
let controller = builder.build();
|
|
||||||
controller.start(&Context::current(), runtime::Tokio).unwrap();
|
|
||||||
controller
|
|
||||||
});
|
|
||||||
|
|
||||||
static HTTP_DURATION_METER: Lazy<Meter> =
|
|
||||||
Lazy::new(|| METER_PROVIDER_HTTP_DURATION.meter("http_duration"));
|
|
||||||
static HTTP_SIZE_METER: Lazy<Meter> = Lazy::new(|| METER_PROVIDER_HTTP_SIZE.meter("http_size"));
|
|
||||||
|
|
||||||
pub static HTTP_SERVER_DURATION: Lazy<Histogram<u64>> = Lazy::new(|| {
|
pub static HTTP_SERVER_DURATION: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
HTTP_DURATION_METER
|
METER_DURATION
|
||||||
.u64_histogram("http.server.duration")
|
.u64_histogram("http.server.duration")
|
||||||
.with_description("The HTTP server duration in milliseconds.")
|
.with_description("The HTTP server duration in milliseconds.")
|
||||||
.with_unit(Unit::new("ms"))
|
.with_unit(Unit::new("ms"))
|
||||||
|
@ -89,14 +17,14 @@ pub static HTTP_SERVER_DURATION: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static HTTP_SERVER_ACTIVE_REQUESTS: Lazy<ObservableUpDownCounter<i64>> = Lazy::new(|| {
|
pub static HTTP_SERVER_ACTIVE_REQUESTS: Lazy<ObservableUpDownCounter<i64>> = Lazy::new(|| {
|
||||||
HTTP_DURATION_METER
|
METER_DURATION
|
||||||
.i64_observable_up_down_counter("http.server.active_requests")
|
.i64_observable_up_down_counter("http.server.active_requests")
|
||||||
.with_description("The number of active HTTP requests.")
|
.with_description("The number of active HTTP requests.")
|
||||||
.init()
|
.init()
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static HTTP_SERVER_REQUEST_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
pub static HTTP_SERVER_REQUEST_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
HTTP_SIZE_METER
|
METER_SIZE
|
||||||
.u64_histogram("http.server.request.size")
|
.u64_histogram("http.server.request.size")
|
||||||
.with_description("Measures the size of HTTP request messages.")
|
.with_description("Measures the size of HTTP request messages.")
|
||||||
.with_unit(Unit::new("mb"))
|
.with_unit(Unit::new("mb"))
|
||||||
|
@ -104,9 +32,49 @@ pub static HTTP_SERVER_REQUEST_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
});
|
});
|
||||||
|
|
||||||
pub static HTTP_SERVER_RESPONSE_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
pub static HTTP_SERVER_RESPONSE_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
HTTP_SIZE_METER
|
METER_SIZE
|
||||||
.u64_histogram("http.server.response.size")
|
.u64_histogram("http.server.response.size")
|
||||||
.with_description("Measures the size of HTTP response messages.")
|
.with_description("Measures the size of HTTP response messages.")
|
||||||
.with_unit(Unit::new("mb"))
|
.with_unit(Unit::new("mb"))
|
||||||
.init()
|
.init()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
fn observe_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||||
|
observe_active_request(1, tracker)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn observe_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||||
|
observe_active_request(-1, tracker)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn observe_active_request(value: i64, tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||||
|
let attrs = tracker.active_req_attrs();
|
||||||
|
|
||||||
|
METER_DURATION
|
||||||
|
.register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, value, &attrs))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_request_duration(tracker: &HttpCallMetricTracker) {
|
||||||
|
// Record the duration of the request.
|
||||||
|
HTTP_SERVER_DURATION.record(
|
||||||
|
&TelemetryContext::current(),
|
||||||
|
tracker.duration().as_millis() as u64,
|
||||||
|
&tracker.request_duration_attrs(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_request_size(tracker: &HttpCallMetricTracker, size: u64) {
|
||||||
|
HTTP_SERVER_REQUEST_SIZE.record(
|
||||||
|
&TelemetryContext::current(),
|
||||||
|
size,
|
||||||
|
&tracker.request_size_attrs(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_response_size(tracker: &HttpCallMetricTracker, size: u64) {
|
||||||
|
HTTP_SERVER_RESPONSE_SIZE.record(
|
||||||
|
&TelemetryContext::current(),
|
||||||
|
size,
|
||||||
|
&tracker.response_size_attrs(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use axum::extract::MatchedPath;
|
use axum::extract::MatchedPath;
|
||||||
use opentelemetry::{metrics::MetricsError, Context as TelemetryContext, KeyValue};
|
use opentelemetry::{metrics::MetricsError, KeyValue};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
cell::Cell,
|
cell::Cell,
|
||||||
|
@ -13,11 +13,6 @@ use futures::Future;
|
||||||
use http::{Request, Response, StatusCode, Version};
|
use http::{Request, Response, StatusCode, Version};
|
||||||
use tower::{Layer, Service};
|
use tower::{Layer, Service};
|
||||||
|
|
||||||
use super::{
|
|
||||||
HTTP_DURATION_METER, HTTP_SERVER_ACTIVE_REQUESTS, HTTP_SERVER_DURATION,
|
|
||||||
HTTP_SERVER_REQUEST_SIZE, HTTP_SERVER_RESPONSE_SIZE,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct HttpMetricsLayer;
|
pub struct HttpMetricsLayer;
|
||||||
|
|
||||||
|
@ -168,7 +163,7 @@ impl HttpCallMetricTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Follows the OpenTelemetry semantic conventions for HTTP metrics define here: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/metrics/semantic_conventions/http-metrics.md
|
// Follows the OpenTelemetry semantic conventions for HTTP metrics define here: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/metrics/semantic_conventions/http-metrics.md
|
||||||
fn olel_common_attrs(&self) -> Vec<KeyValue> {
|
fn otel_common_attrs(&self) -> Vec<KeyValue> {
|
||||||
let mut res = vec![
|
let mut res = vec![
|
||||||
KeyValue::new("http.request.method", self.method.as_str().to_owned()),
|
KeyValue::new("http.request.method", self.method.as_str().to_owned()),
|
||||||
KeyValue::new("network.protocol.name", "http".to_owned()),
|
KeyValue::new("network.protocol.name", "http".to_owned()),
|
||||||
|
@ -186,11 +181,11 @@ impl HttpCallMetricTracker {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn active_req_attrs(&self) -> Vec<KeyValue> {
|
pub(super) fn active_req_attrs(&self) -> Vec<KeyValue> {
|
||||||
self.olel_common_attrs()
|
self.otel_common_attrs()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn request_duration_attrs(&self) -> Vec<KeyValue> {
|
pub(super) fn request_duration_attrs(&self) -> Vec<KeyValue> {
|
||||||
let mut res = self.olel_common_attrs();
|
let mut res = self.otel_common_attrs();
|
||||||
|
|
||||||
res.push(KeyValue::new(
|
res.push(KeyValue::new(
|
||||||
"http.response.status_code",
|
"http.response.status_code",
|
||||||
|
@ -247,64 +242,25 @@ impl Drop for HttpCallMetricTracker {
|
||||||
|
|
||||||
pub fn on_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
pub fn on_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||||
// Setup the active_requests observer
|
// Setup the active_requests observer
|
||||||
observe_active_request_start(tracker)
|
super::observe_request_start(tracker)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn on_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
pub fn on_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
||||||
// Setup the active_requests observer
|
// Setup the active_requests observer
|
||||||
observe_active_request_finish(tracker)?;
|
super::observe_request_finish(tracker)?;
|
||||||
|
|
||||||
// Record the duration of the request.
|
// Record the duration of the request.
|
||||||
record_request_duration(tracker);
|
super::record_request_duration(tracker);
|
||||||
|
|
||||||
// Record the request size if known
|
// Record the request size if known
|
||||||
if let Some(size) = tracker.request_size {
|
if let Some(size) = tracker.request_size {
|
||||||
record_request_size(tracker, size)
|
super::record_request_size(tracker, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record the response size if known
|
// Record the response size if known
|
||||||
if let Some(size) = tracker.response_size {
|
if let Some(size) = tracker.response_size {
|
||||||
record_response_size(tracker, size)
|
super::record_response_size(tracker, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn observe_active_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
|
||||||
let attrs = tracker.active_req_attrs();
|
|
||||||
// Setup the callback to observe the active requests.
|
|
||||||
HTTP_DURATION_METER
|
|
||||||
.register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, 1, &attrs))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn observe_active_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> {
|
|
||||||
let attrs = tracker.active_req_attrs();
|
|
||||||
// Setup the callback to observe the active requests.
|
|
||||||
HTTP_DURATION_METER
|
|
||||||
.register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, -1, &attrs))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_request_duration(tracker: &HttpCallMetricTracker) {
|
|
||||||
// Record the duration of the request.
|
|
||||||
HTTP_SERVER_DURATION.record(
|
|
||||||
&TelemetryContext::current(),
|
|
||||||
tracker.duration().as_millis() as u64,
|
|
||||||
&tracker.request_duration_attrs(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_request_size(tracker: &HttpCallMetricTracker, size: u64) {
|
|
||||||
HTTP_SERVER_REQUEST_SIZE.record(
|
|
||||||
&TelemetryContext::current(),
|
|
||||||
size,
|
|
||||||
&tracker.request_size_attrs(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn record_response_size(tracker: &HttpCallMetricTracker, size: u64) {
|
|
||||||
HTTP_SERVER_RESPONSE_SIZE.record(
|
|
||||||
&TelemetryContext::current(),
|
|
||||||
size,
|
|
||||||
&tracker.response_size_attrs(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,3 +1,97 @@
|
||||||
pub mod http;
|
pub mod http;
|
||||||
|
pub mod ws;
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use opentelemetry::Context as TelemetryContext;
|
||||||
|
use opentelemetry::{
|
||||||
|
metrics::{Meter, MeterProvider, MetricsError},
|
||||||
|
runtime,
|
||||||
|
sdk::{
|
||||||
|
export::metrics::aggregation,
|
||||||
|
metrics::{
|
||||||
|
controllers::{self, BasicController},
|
||||||
|
processors, selectors,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use opentelemetry_otlp::MetricsExporterBuilder;
|
||||||
|
|
||||||
pub use self::http::tower_layer::HttpMetricsLayer;
|
pub use self::http::tower_layer::HttpMetricsLayer;
|
||||||
|
use self::ws::observe_active_connection;
|
||||||
|
|
||||||
|
use super::OTEL_DEFAULT_RESOURCE;
|
||||||
|
|
||||||
|
// Histogram buckets in milliseconds
|
||||||
|
static HISTOGRAM_BUCKETS_MS: &[f64] = &[
|
||||||
|
5.0, 10.0, 20.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 500.0, 750.0, 1000.0, 1500.0,
|
||||||
|
2000.0, 2500.0, 5000.0, 10000.0, 15000.0, 30000.0,
|
||||||
|
];
|
||||||
|
|
||||||
|
// Histogram buckets in bytes
|
||||||
|
const KB: f64 = 1024.0;
|
||||||
|
const MB: f64 = 1024.0 * KB;
|
||||||
|
const HISTOGRAM_BUCKETS_BYTES: &[f64] = &[
|
||||||
|
1.0 * KB, // 1 KB
|
||||||
|
2.0 * KB, // 2 KB
|
||||||
|
5.0 * KB, // 5 KB
|
||||||
|
10.0 * KB, // 10 KB
|
||||||
|
100.0 * KB, // 100 KB
|
||||||
|
500.0 * KB, // 500 KB
|
||||||
|
1.0 * MB, // 1 MB
|
||||||
|
2.5 * MB, // 2 MB
|
||||||
|
5.0 * MB, // 5 MB
|
||||||
|
10.0 * MB, // 10 MB
|
||||||
|
25.0 * MB, // 25 MB
|
||||||
|
50.0 * MB, // 50 MB
|
||||||
|
100.0 * MB, // 100 MB
|
||||||
|
];
|
||||||
|
|
||||||
|
fn build_controller(boundaries: &'static [f64]) -> BasicController {
|
||||||
|
let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic())
|
||||||
|
.build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let builder = controllers::basic(processors::factory(
|
||||||
|
selectors::simple::histogram(boundaries),
|
||||||
|
aggregation::cumulative_temporality_selector(),
|
||||||
|
))
|
||||||
|
.with_push_timeout(Duration::from_secs(5))
|
||||||
|
.with_collect_period(Duration::from_secs(5))
|
||||||
|
.with_exporter(exporter)
|
||||||
|
.with_resource(OTEL_DEFAULT_RESOURCE.clone());
|
||||||
|
|
||||||
|
builder.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
static METER_PROVIDER_DURATION: Lazy<BasicController> =
|
||||||
|
Lazy::new(|| build_controller(HISTOGRAM_BUCKETS_MS));
|
||||||
|
|
||||||
|
static METER_PROVIDER_SIZE: Lazy<BasicController> =
|
||||||
|
Lazy::new(|| build_controller(HISTOGRAM_BUCKETS_BYTES));
|
||||||
|
|
||||||
|
static METER_DURATION: Lazy<Meter> = Lazy::new(|| METER_PROVIDER_DURATION.meter("duration"));
|
||||||
|
static METER_SIZE: Lazy<Meter> = Lazy::new(|| METER_PROVIDER_SIZE.meter("size"));
|
||||||
|
|
||||||
|
/// Initialize the metrics subsystem
|
||||||
|
pub fn init(cx: &TelemetryContext) -> Result<(), MetricsError> {
|
||||||
|
METER_PROVIDER_DURATION.start(cx, runtime::Tokio)?;
|
||||||
|
METER_PROVIDER_SIZE.start(cx, runtime::Tokio)?;
|
||||||
|
|
||||||
|
observe_active_connection(0)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Shutdown the metrics providers
|
||||||
|
//
|
||||||
|
pub fn shutdown(cx: &TelemetryContext) -> Result<(), MetricsError> {
|
||||||
|
METER_PROVIDER_DURATION.stop(cx)?;
|
||||||
|
METER_PROVIDER_DURATION.collect(cx)?;
|
||||||
|
METER_PROVIDER_SIZE.stop(cx)?;
|
||||||
|
METER_PROVIDER_SIZE.collect(cx)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
126
src/telemetry/metrics/ws/mod.rs
Normal file
126
src/telemetry/metrics/ws/mod.rs
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use opentelemetry::KeyValue;
|
||||||
|
use opentelemetry::{
|
||||||
|
metrics::{Histogram, MetricsError, ObservableUpDownCounter, Unit},
|
||||||
|
Context as TelemetryContext,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{METER_DURATION, METER_SIZE};
|
||||||
|
|
||||||
|
pub static RPC_SERVER_DURATION: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
|
METER_DURATION
|
||||||
|
.u64_histogram("rpc.server.duration")
|
||||||
|
.with_description("Measures duration of inbound RPC requests in milliseconds.")
|
||||||
|
.with_unit(Unit::new("ms"))
|
||||||
|
.init()
|
||||||
|
});
|
||||||
|
|
||||||
|
pub static RPC_SERVER_ACTIVE_CONNECTIONS: Lazy<ObservableUpDownCounter<i64>> = Lazy::new(|| {
|
||||||
|
METER_DURATION
|
||||||
|
.i64_observable_up_down_counter("rpc.server.active_connections")
|
||||||
|
.with_description("The number of active WebSocket connections.")
|
||||||
|
.init()
|
||||||
|
});
|
||||||
|
|
||||||
|
pub static RPC_SERVER_REQUEST_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
|
METER_SIZE
|
||||||
|
.u64_histogram("rpc.server.request.size")
|
||||||
|
.with_description("Measures the size of HTTP request messages.")
|
||||||
|
.with_unit(Unit::new("mb"))
|
||||||
|
.init()
|
||||||
|
});
|
||||||
|
|
||||||
|
pub static RPC_SERVER_RESPONSE_SIZE: Lazy<Histogram<u64>> = Lazy::new(|| {
|
||||||
|
METER_SIZE
|
||||||
|
.u64_histogram("rpc.server.response.size")
|
||||||
|
.with_description("Measures the size of HTTP response messages.")
|
||||||
|
.with_unit(Unit::new("mb"))
|
||||||
|
.init()
|
||||||
|
});
|
||||||
|
|
||||||
|
fn otel_common_attrs() -> Vec<KeyValue> {
|
||||||
|
vec![KeyValue::new("rpc.system", "jsonrpc"), KeyValue::new("rpc.service", "surrealdb")]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Registers the callback that increases the number of active RPC connections.
|
||||||
|
pub fn on_connect() -> Result<(), MetricsError> {
|
||||||
|
observe_active_connection(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Registers the callback that increases the number of active RPC connections.
|
||||||
|
pub fn on_disconnect() -> Result<(), MetricsError> {
|
||||||
|
observe_active_connection(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn observe_active_connection(value: i64) -> Result<(), MetricsError> {
|
||||||
|
let attrs = otel_common_attrs();
|
||||||
|
|
||||||
|
METER_DURATION
|
||||||
|
.register_callback(move |cx| RPC_SERVER_ACTIVE_CONNECTIONS.observe(cx, value, &attrs))
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Record an RPC command
|
||||||
|
//
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
|
pub struct RequestContext {
|
||||||
|
start: Instant,
|
||||||
|
pub method: String,
|
||||||
|
pub size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RequestContext {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
start: Instant::now(),
|
||||||
|
method: "unknown".to_string(),
|
||||||
|
size: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestContext {
|
||||||
|
pub fn with_method(self, method: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
method: method.to_string(),
|
||||||
|
..self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_size(self, size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
size,
|
||||||
|
..self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Updates the request and response metrics for an RPC method.
|
||||||
|
pub fn record_rpc(cx: &TelemetryContext, res_size: usize, is_error: bool) {
|
||||||
|
let mut attrs = otel_common_attrs();
|
||||||
|
let mut duration = 0;
|
||||||
|
let mut req_size = 0;
|
||||||
|
|
||||||
|
if let Some(cx) = cx.get::<RequestContext>() {
|
||||||
|
attrs.extend_from_slice(&[
|
||||||
|
KeyValue::new("rpc.method", cx.method.clone()),
|
||||||
|
KeyValue::new("rpc.error", is_error),
|
||||||
|
]);
|
||||||
|
duration = cx.start.elapsed().as_millis() as u64;
|
||||||
|
req_size = cx.size as u64;
|
||||||
|
} else {
|
||||||
|
// If a bug causes the RequestContent to be empty, we still want to record the metrics to avoid a silent failure.
|
||||||
|
warn!("record_rpc: no request context found, resulting metrics will be invalid");
|
||||||
|
attrs.extend_from_slice(&[
|
||||||
|
KeyValue::new("rpc.method", "unknown"),
|
||||||
|
KeyValue::new("rpc.error", is_error),
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
|
||||||
|
RPC_SERVER_DURATION.record(cx, duration, &attrs);
|
||||||
|
RPC_SERVER_REQUEST_SIZE.record(cx, req_size, &attrs);
|
||||||
|
RPC_SERVER_RESPONSE_SIZE.record(cx, res_size as u64, &attrs);
|
||||||
|
}
|
|
@ -6,11 +6,12 @@ use std::time::Duration;
|
||||||
|
|
||||||
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
use opentelemetry::metrics::MetricsError;
|
||||||
use opentelemetry::sdk::resource::{
|
use opentelemetry::sdk::resource::{
|
||||||
EnvResourceDetector, SdkProvidedResourceDetector, TelemetryResourceDetector,
|
EnvResourceDetector, SdkProvidedResourceDetector, TelemetryResourceDetector,
|
||||||
};
|
};
|
||||||
use opentelemetry::sdk::Resource;
|
use opentelemetry::sdk::Resource;
|
||||||
use opentelemetry::KeyValue;
|
use opentelemetry::{Context as TelemetryContext, KeyValue};
|
||||||
use tracing::{Level, Subscriber};
|
use tracing::{Level, Subscriber};
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
@ -86,10 +87,18 @@ impl Builder {
|
||||||
|
|
||||||
/// Install the tracing dispatcher globally
|
/// Install the tracing dispatcher globally
|
||||||
pub fn init(self) {
|
pub fn init(self) {
|
||||||
self.build().init()
|
self.build().init();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn shutdown() -> Result<(), MetricsError> {
|
||||||
|
// Flush all telemetry data
|
||||||
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
|
metrics::shutdown(&TelemetryContext::current())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Create an EnvFilter from the given value. If the value is not a valid log level, it will be treated as EnvFilter directives.
|
/// Create an EnvFilter from the given value. If the value is not a valid log level, it will be treated as EnvFilter directives.
|
||||||
pub fn filter_from_value(v: &str) -> Result<EnvFilter, tracing_subscriber::filter::ParseError> {
|
pub fn filter_from_value(v: &str) -> Result<EnvFilter, tracing_subscriber::filter::ParseError> {
|
||||||
match v {
|
match v {
|
||||||
|
|
Loading…
Reference in a new issue