diff --git a/dev/docker/compose.yaml b/dev/docker/compose.yaml index 5bdc9f0a..bccb6e7b 100644 --- a/dev/docker/compose.yaml +++ b/dev/docker/compose.yaml @@ -2,7 +2,7 @@ version: "3" services: grafana: - image: "grafana/grafana-oss:latest" + image: "grafana/grafana-oss:main" expose: - "3000" ports: diff --git a/src/cli/start.rs b/src/cli/start.rs index 51626b67..09ecf423 100644 --- a/src/cli/start.rs +++ b/src/cli/start.rs @@ -11,6 +11,7 @@ use crate::net::{self, client_ip::ClientIp}; use crate::node; use clap::Args; use ipnet::IpNet; +use opentelemetry::Context as TelemetryContext; use std::net::SocketAddr; use std::path::PathBuf; use std::time::Duration; @@ -123,6 +124,9 @@ pub async fn init( ) -> Result<(), Error> { // Initialize opentelemetry and logging crate::telemetry::builder().with_filter(log).init(); + // Start metrics subsystem + crate::telemetry::metrics::init(&TelemetryContext::current()) + .expect("failed to initialize metrics"); // Check if a banner should be outputted if !no_banner { diff --git a/src/net/mod.rs b/src/net/mod.rs index 8caa6e0a..922e4cbc 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -129,7 +129,7 @@ pub async fn init() -> Result<(), Error> { .merge(key::router()) .layer(service); - // Setup the graceful shutdown with no timeout + // Setup the graceful shutdown let handle = Handle::new(); let shutdown_handler = graceful_shutdown(handle.clone()); @@ -159,9 +159,6 @@ pub async fn init() -> Result<(), Error> { // Wait for the shutdown to finish let _ = shutdown_handler.await; - // Flush all telemetry data - opentelemetry::global::shutdown_tracer_provider(); - info!(target: LOG, "Web server stopped. Bye!"); Ok(()) diff --git a/src/net/rpc.rs b/src/net/rpc.rs index fdab6eba..cb806918 100644 --- a/src/net/rpc.rs +++ b/src/net/rpc.rs @@ -1,58 +1,17 @@ -use crate::cnf::MAX_CONCURRENT_CALLS; -use crate::cnf::PKG_NAME; -use crate::cnf::PKG_VERSION; -use crate::cnf::WEBSOCKET_PING_FREQUENCY; -use crate::dbs::DB; -use crate::err::Error; -use crate::rpc::args::Take; -use crate::rpc::paths::{ID, METHOD, PARAMS}; -use crate::rpc::res; -use crate::rpc::res::Data; -use crate::rpc::res::Failure; -use crate::rpc::res::IntoRpcResponse; -use crate::rpc::res::OutputFormat; -use crate::rpc::CONN_CLOSED_ERR; -use crate::telemetry::traces::rpc::span_for_request; +use crate::rpc::connection::Connection; use axum::routing::get; use axum::Extension; use axum::Router; -use futures::{SinkExt, StreamExt}; -use futures_util::stream::SplitSink; -use futures_util::stream::SplitStream; use http_body::Body as HttpBody; -use once_cell::sync::Lazy; -use std::collections::BTreeMap; -use std::collections::HashMap; -use std::sync::Arc; -use surrealdb::channel; -use surrealdb::channel::{Receiver, Sender}; -use surrealdb::dbs::{QueryType, Response, Session}; -use surrealdb::sql::serde::deserialize; -use surrealdb::sql::Array; -use surrealdb::sql::Object; -use surrealdb::sql::Strand; -use surrealdb::sql::Value; -use tokio::sync::RwLock; -use tokio::task::JoinSet; -use tokio_util::sync::CancellationToken; +use surrealdb::dbs::Session; use tower_http::request_id::RequestId; -use tracing::Span; use uuid::Uuid; use axum::{ - extract::ws::{Message, WebSocket, WebSocketUpgrade}, + extract::ws::{WebSocket, WebSocketUpgrade}, response::IntoResponse, }; -// Mapping of WebSocketID to WebSocket -pub(crate) struct WebSocketRef(pub(crate) Sender, pub(crate) CancellationToken); -type WebSockets = RwLock>; -// Mapping of LiveQueryID to WebSocketID -type LiveQueries = RwLock>; - -pub(super) static WEBSOCKETS: Lazy = Lazy::new(WebSockets::default); -static LIVE_QUERIES: Lazy = Lazy::new(LiveQueries::default); - pub(super) fn router() -> Router where B: HttpBody + Send + 'static, @@ -72,853 +31,13 @@ async fn handler( } async fn handle_socket(ws: WebSocket, sess: Session, req_id: RequestId) { - let rpc = Rpc::new(sess); + let rpc = Connection::new(sess); - // If the request ID is a valid UUID and is not already in use, use it as the WebSocket ID - match req_id.header_value().to_str().map(Uuid::parse_str) { - Ok(Ok(req_id)) if !WEBSOCKETS.read().await.contains_key(&req_id) => { - rpc.write().await.ws_id = req_id - } - _ => (), + // Update the WebSocket ID with the Request ID + if let Ok(Ok(req_id)) = req_id.header_value().to_str().map(Uuid::parse_str) { + // If the ID couldn't be updated, ignore the error and keep the default ID + let _ = rpc.write().await.update_ws_id(req_id).await; } - Rpc::serve(rpc, ws).await; -} - -pub struct Rpc { - session: Session, - format: OutputFormat, - ws_id: Uuid, - vars: BTreeMap, - graceful_shutdown: CancellationToken, -} - -impl Rpc { - /// Instantiate a new RPC - pub fn new(mut session: Session) -> Arc> { - // Create a new RPC variables store - let vars = BTreeMap::new(); - // Set the default output format - let format = OutputFormat::Json; - // Enable real-time mode - session.rt = true; - // Create and store the Rpc connection - Arc::new(RwLock::new(Rpc { - session, - format, - ws_id: Uuid::new_v4(), - vars, - graceful_shutdown: CancellationToken::new(), - })) - } - - /// Serve the RPC endpoint - pub async fn serve(rpc: Arc>, ws: WebSocket) { - // Split the socket into send and recv - let (sender, receiver) = ws.split(); - // Create an internal channel between the receiver and the sender - let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS); - - let ws_id = rpc.read().await.ws_id; - - // Store this WebSocket in the list of WebSockets - WEBSOCKETS.write().await.insert( - ws_id, - WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()), - ); - - trace!("WebSocket {} connected", ws_id); - - // Wait until all tasks finish - tokio::join!( - Self::ping(rpc.clone(), internal_sender.clone()), - Self::read(rpc.clone(), receiver, internal_sender.clone()), - Self::write(rpc.clone(), sender, internal_receiver.clone()), - Self::lq_notifications(rpc.clone()), - ); - - // Remove all live queries - LIVE_QUERIES.write().await.retain(|key, value| { - if value == &ws_id { - trace!("Removing live query: {}", key); - return false; - } - true - }); - - // Remove this WebSocket from the list of WebSockets - WEBSOCKETS.write().await.remove(&ws_id); - - trace!("WebSocket {} disconnected", ws_id); - } - - /// Send Ping messages to the client - async fn ping(rpc: Arc>, internal_sender: Sender) { - // Create the interval ticker - let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY); - let cancel_token = rpc.read().await.graceful_shutdown.clone(); - loop { - let is_shutdown = cancel_token.cancelled(); - tokio::select! { - _ = interval.tick() => { - let msg = Message::Ping(vec![]); - - // Send the message to the client and close the WebSocket connection if it fails - if internal_sender.send(msg).await.is_err() { - rpc.read().await.graceful_shutdown.cancel(); - break; - } - }, - _ = is_shutdown => break, - } - } - } - - /// Read messages sent from the client - async fn read( - rpc: Arc>, - mut receiver: SplitStream, - internal_sender: Sender, - ) { - // Collect all spawned tasks so we can wait for them at the end - let mut tasks = JoinSet::new(); - let cancel_token = rpc.read().await.graceful_shutdown.clone(); - loop { - let is_shutdown = cancel_token.cancelled(); - tokio::select! { - msg = receiver.next() => { - if let Some(msg) = msg { - match msg { - // We've received a message from the client - // Ping/Pong is automatically handled by the WebSocket library - Ok(msg) => match msg { - Message::Text(_) => { - tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone())); - } - Message::Binary(_) => { - tasks.spawn(Rpc::handle_msg(rpc.clone(), msg, internal_sender.clone())); - } - Message::Close(_) => { - // Respond with a close message - if let Err(err) = internal_sender.send(Message::Close(None)).await { - trace!("WebSocket error when replying to the Close frame: {:?}", err); - }; - // Start the graceful shutdown of the WebSocket and close the channels - rpc.read().await.graceful_shutdown.cancel(); - let _ = internal_sender.close(); - break; - } - _ => { - // Ignore everything else - } - }, - Err(err) => { - trace!("WebSocket error: {:?}", err); - // Start the graceful shutdown of the WebSocket and close the channels - rpc.read().await.graceful_shutdown.cancel(); - let _ = internal_sender.close(); - // Exit out of the loop - break; - } - } - } - } - _ = is_shutdown => break, - } - } - - // Wait for all tasks to finish - while let Some(res) = tasks.join_next().await { - if let Err(err) = res { - error!("Error while handling RPC message: {}", err); - } - } - } - - /// Write messages to the client - async fn write( - rpc: Arc>, - mut sender: SplitSink, - mut internal_receiver: Receiver, - ) { - let cancel_token = rpc.read().await.graceful_shutdown.clone(); - loop { - let is_shutdown = cancel_token.cancelled(); - tokio::select! { - // Wait for the next message to send - msg = internal_receiver.next() => { - if let Some(res) = msg { - // Send the message to the client - if let Err(err) = sender.send(res).await { - if err.to_string() != CONN_CLOSED_ERR { - debug!("WebSocket error: {:?}", err); - } - // Close the WebSocket connection - rpc.read().await.graceful_shutdown.cancel(); - // Exit out of the loop - break; - } - } - }, - _ = is_shutdown => break, - } - } - } - - /// Send live query notifications to the client - async fn lq_notifications(rpc: Arc>) { - if let Some(channel) = DB.get().unwrap().notifications() { - let cancel_token = rpc.read().await.graceful_shutdown.clone(); - loop { - tokio::select! { - msg = channel.recv() => { - if let Ok(notification) = msg { - // Find which WebSocket the notification belongs to - if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) { - // Check to see if the WebSocket exists - if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) { - // Serialize the message to send - let message = res::success(None, notification); - // Get the current output format - let format = rpc.read().await.format.clone(); - // Send the notification to the client - message.send(format, ws.clone()).await - } - } - } - }, - _ = cancel_token.cancelled() => break, - } - } - } - } - - /// Handle individual WebSocket messages - async fn handle_msg(rpc: Arc>, msg: Message, chn: Sender) { - // Get the current output format - let mut out_fmt = rpc.read().await.format.clone(); - let span = span_for_request(&rpc.read().await.ws_id); - let _enter = span.enter(); - // Parse the request - match Self::parse_request(msg).await { - Ok((id, method, params, _out_fmt)) => { - span.record( - "rpc.jsonrpc.request_id", - id.clone().map(|v| v.as_string()).unwrap_or(String::new()), - ); - if let Some(_out_fmt) = _out_fmt { - out_fmt = _out_fmt; - } - - // Process the request - let res = Self::process_request(rpc.clone(), &method, params).await; - - // Process the response - res.into_response(id).send(out_fmt, chn).await - } - Err(err) => { - // Process the response - res::failure(None, err).send(out_fmt, chn).await - } - } - } - - async fn parse_request( - msg: Message, - ) -> Result<(Option, String, Array, Option), Failure> { - let mut out_fmt = None; - let req = match msg { - // This is a binary message - Message::Binary(val) => { - // Use binary output - out_fmt = Some(OutputFormat::Full); - - match deserialize(&val) { - Ok(v) => v, - Err(_) => { - debug!("Error when trying to deserialize the request"); - return Err(Failure::PARSE_ERROR); - } - } - } - // This is a text message - Message::Text(ref val) => { - // Parse the SurrealQL object - match surrealdb::sql::value(val) { - // The SurrealQL message parsed ok - Ok(v) => v, - // The SurrealQL message failed to parse - _ => return Err(Failure::PARSE_ERROR), - } - } - // Unsupported message type - _ => { - debug!("Unsupported message type: {:?}", msg); - return Err(res::Failure::custom("Unsupported message type")); - } - }; - // Fetch the 'id' argument - let id = match req.pick(&*ID) { - v if v.is_none() => None, - v if v.is_null() => Some(v), - v if v.is_uuid() => Some(v), - v if v.is_number() => Some(v), - v if v.is_strand() => Some(v), - v if v.is_datetime() => Some(v), - _ => return Err(Failure::INVALID_REQUEST), - }; - // Fetch the 'method' argument - let method = match req.pick(&*METHOD) { - Value::Strand(v) => v.to_raw(), - _ => return Err(Failure::INVALID_REQUEST), - }; - - // Now that we know the method, we can update the span - Span::current().record("rpc.method", &method); - Span::current().record("otel.name", format!("surrealdb.rpc/{}", method)); - - // Fetch the 'params' argument - let params = match req.pick(&*PARAMS) { - Value::Array(v) => v, - _ => Array::new(), - }; - - Ok((id, method, params, out_fmt)) - } - - async fn process_request( - rpc: Arc>, - method: &str, - params: Array, - ) -> Result { - info!("Process RPC request"); - - // Match the method to a function - match method { - // Handle a surrealdb ping message - // - // This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough. - // For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios. - // By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent. - // - "ping" => Ok(Value::None.into()), - // Retrieve the current auth record - "info" => match params.len() { - 0 => rpc.read().await.info().await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Switch to a specific namespace and database - "use" => match params.needs_two() { - Ok((ns, db)) => { - rpc.write().await.yuse(ns, db).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Signup to a specific authentication scope - "signup" => match params.needs_one() { - Ok(Value::Object(v)) => { - rpc.write().await.signup(v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Signin as a root, namespace, database or scope user - "signin" => match params.needs_one() { - Ok(Value::Object(v)) => { - rpc.write().await.signin(v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Invalidate the current authentication session - "invalidate" => match params.len() { - 0 => rpc.write().await.invalidate().await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Authenticate using an authentication token - "authenticate" => match params.needs_one() { - Ok(Value::Strand(v)) => { - rpc.write().await.authenticate(v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Kill a live query using a query id - "kill" => match params.needs_one() { - Ok(v) => rpc.read().await.kill(v).await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Setup a live query on a specific table - "live" => match params.needs_one_or_two() { - Ok((v, d)) if v.is_table() => { - rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into) - } - Ok((v, d)) if v.is_strand() => { - rpc.read().await.live(v, d).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Specify a connection-wide parameter - "let" | "set" => match params.needs_one_or_two() { - Ok((Value::Strand(s), v)) => { - rpc.write().await.set(s, v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Unset and clear a connection-wide parameter - "unset" => match params.needs_one() { - Ok(Value::Strand(s)) => { - rpc.write().await.unset(s).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Select a value or values from the database - "select" => match params.needs_one() { - Ok(v) => rpc.read().await.select(v).await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Insert a value or values in the database - "insert" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.insert(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Create a value or values in the database - "create" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.create(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Update a value or values in the database using `CONTENT` - "update" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.update(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Update a value or values in the database using `MERGE` - "change" | "merge" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.change(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Update a value or values in the database using `PATCH` - "modify" | "patch" => match params.needs_one_or_two() { - Ok((v, o)) => { - rpc.read().await.modify(v, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Delete a value or values from the database - "delete" => match params.needs_one() { - Ok(v) => rpc.read().await.delete(v).await.map(Into::into).map_err(Into::into), - _ => Err(Failure::INVALID_PARAMS), - }, - // Specify the output format for text requests - "format" => match params.needs_one() { - Ok(Value::Strand(v)) => { - rpc.write().await.format(v).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - // Get the current server version - "version" => match params.len() { - 0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()), - _ => Err(Failure::INVALID_PARAMS), - }, - // Run a full SurrealQL query against the database - "query" => match params.needs_one_or_two() { - Ok((Value::Strand(s), o)) if o.is_none_or_null() => { - rpc.read().await.query(s).await.map(Into::into).map_err(Into::into) - } - Ok((Value::Strand(s), Value::Object(o))) => { - rpc.read().await.query_with(s, o).await.map(Into::into).map_err(Into::into) - } - _ => Err(Failure::INVALID_PARAMS), - }, - _ => Err(Failure::METHOD_NOT_FOUND), - } - } - - // ------------------------------ - // Methods for authentication - // ------------------------------ - - async fn format(&mut self, out: Strand) -> Result { - match out.as_str() { - "json" | "application/json" => self.format = OutputFormat::Json, - "cbor" | "application/cbor" => self.format = OutputFormat::Cbor, - "pack" | "application/pack" => self.format = OutputFormat::Pack, - _ => return Err(Error::InvalidType), - }; - Ok(Value::None) - } - - async fn yuse(&mut self, ns: Value, db: Value) -> Result { - if let Value::Strand(ns) = ns { - self.session.ns = Some(ns.0); - } - if let Value::Strand(db) = db { - self.session.db = Some(db.0); - } - Ok(Value::None) - } - - async fn signup(&mut self, vars: Object) -> Result { - let kvs = DB.get().unwrap(); - surrealdb::iam::signup::signup(kvs, &mut self.session, vars) - .await - .map(Into::into) - .map_err(Into::into) - } - - async fn signin(&mut self, vars: Object) -> Result { - let kvs = DB.get().unwrap(); - surrealdb::iam::signin::signin(kvs, &mut self.session, vars) - .await - .map(Into::into) - .map_err(Into::into) - } - async fn invalidate(&mut self) -> Result { - surrealdb::iam::clear::clear(&mut self.session)?; - Ok(Value::None) - } - - async fn authenticate(&mut self, token: Strand) -> Result { - let kvs = DB.get().unwrap(); - surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?; - Ok(Value::None) - } - - // ------------------------------ - // Methods for identification - // ------------------------------ - - async fn info(&self) -> Result { - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "SELECT * FROM $auth"; - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, None).await?; - // Extract the first value from the result - let res = res.remove(0).result?.first(); - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for setting variables - // ------------------------------ - - async fn set(&mut self, key: Strand, val: Value) -> Result { - match val { - // Remove the variable if undefined - Value::None => self.vars.remove(&key.0), - // Store the variable if defined - v => self.vars.insert(key.0, v), - }; - Ok(Value::Null) - } - - async fn unset(&mut self, key: Strand) -> Result { - self.vars.remove(&key.0); - Ok(Value::Null) - } - - // ------------------------------ - // Methods for live queries - // ------------------------------ - - async fn kill(&self, id: Value) -> Result { - // Specify the SQL query string - let sql = "KILL $id"; - // Specify the query parameters - let var = map! { - String::from("id") => id, // NOTE: id can be parameter - => &self.vars - }; - // Execute the query on the database - let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?; - // Extract the first query result - let response = res.remove(0); - match response.result { - Ok(v) => Ok(v), - Err(e) => Err(Error::from(e)), - } - } - - async fn live(&self, tb: Value, diff: Value) -> Result { - // Specify the SQL query string - let sql = match diff.is_true() { - true => "LIVE SELECT DIFF FROM $tb", - false => "LIVE SELECT * FROM $tb", - }; - // Specify the query parameters - let var = map! { - String::from("tb") => tb.could_be_table(), - => &self.vars - }; - // Execute the query on the database - let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?; - // Extract the first query result - let response = res.remove(0); - match response.result { - Ok(v) => Ok(v), - Err(e) => Err(Error::from(e)), - } - } - - // ------------------------------ - // Methods for selecting - // ------------------------------ - - async fn select(&self, what: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "SELECT * FROM $what"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for inserting - // ------------------------------ - - async fn insert(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "INSERT INTO $what $data RETURN AFTER"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for creating - // ------------------------------ - - async fn create(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "CREATE $what CONTENT $data RETURN AFTER"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for updating - // ------------------------------ - - async fn update(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "UPDATE $what CONTENT $data RETURN AFTER"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for changing - // ------------------------------ - - async fn change(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "UPDATE $what MERGE $data RETURN AFTER"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for modifying - // ------------------------------ - - async fn modify(&self, what: Value, data: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "UPDATE $what PATCH $data RETURN DIFF"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - String::from("data") => data, - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for deleting - // ------------------------------ - - async fn delete(&self, what: Value) -> Result { - // Return a single result? - let one = what.is_thing(); - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the SQL query string - let sql = "DELETE $what RETURN BEFORE"; - // Specify the query parameters - let var = Some(map! { - String::from("what") => what.could_be_table(), - => &self.vars - }); - // Execute the query on the database - let mut res = kvs.execute(sql, &self.session, var).await?; - // Extract the first query result - let res = match one { - true => res.remove(0).result?.first(), - false => res.remove(0).result?, - }; - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Methods for querying - // ------------------------------ - - async fn query(&self, sql: Strand) -> Result, Error> { - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the query parameters - let var = Some(self.vars.clone()); - // Execute the query on the database - let res = kvs.execute(&sql, &self.session, var).await?; - // Post-process hooks for web layer - for response in &res { - self.handle_live_query_results(response).await; - } - // Return the result to the client - Ok(res) - } - - async fn query_with(&self, sql: Strand, mut vars: Object) -> Result, Error> { - // Get a database reference - let kvs = DB.get().unwrap(); - // Specify the query parameters - let var = Some(mrg! { vars.0, &self.vars }); - // Execute the query on the database - let res = kvs.execute(&sql, &self.session, var).await?; - // Post-process hooks for web layer - for response in &res { - self.handle_live_query_results(response).await; - } - // Return the result to the client - Ok(res) - } - - // ------------------------------ - // Private methods - // ------------------------------ - - async fn handle_live_query_results(&self, res: &Response) { - match &res.query_type { - QueryType::Live => { - if let Ok(Value::Uuid(lqid)) = &res.result { - // Match on Uuid type - LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id); - trace!("Registered live query {} on websocket {}", lqid, self.ws_id); - } - } - QueryType::Kill => { - if let Ok(Value::Uuid(lqid)) = &res.result { - let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0); - if let Some(ws_id) = ws_id { - trace!("Unregistered live query {} on websocket {}", lqid, ws_id); - } - } - } - _ => {} - } - } + Connection::serve(rpc, ws).await; } diff --git a/src/net/signals.rs b/src/net/signals.rs index a14f9b96..0b418b59 100644 --- a/src/net/signals.rs +++ b/src/net/signals.rs @@ -1,12 +1,7 @@ -use std::time::Duration; - use axum_server::Handle; use tokio::task::JoinHandle; -use crate::{ - err::Error, - net::rpc::{WebSocketRef, WEBSOCKETS}, -}; +use crate::{err::Error, rpc, telemetry}; /// Start a graceful shutdown: /// * Signal the Axum Handle when a shutdown signal is received. @@ -24,15 +19,12 @@ pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> { // First stop accepting new HTTP requests http_handle.graceful_shutdown(None); - // Close all WebSocket connections. Queued messages will still be processed. - for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() { - cancel_token.cancel(); - }; + rpc::graceful_shutdown().await; - // Wait for all existing WebSocket connections to gracefully close - while WEBSOCKETS.read().await.len() > 0 { - tokio::time::sleep(Duration::from_millis(100)).await; - }; + // Flush all telemetry data + if let Err(err) = telemetry::shutdown() { + error!("Failed to flush telemetry data: {}", err); + } } => (), // Force an immediate shutdown if a second signal is received _ = async { @@ -46,9 +38,7 @@ pub fn graceful_shutdown(http_handle: Handle) -> JoinHandle<()> { http_handle.shutdown(); // Close all WebSocket connections immediately - if let Ok(mut writer) = WEBSOCKETS.try_write() { - writer.drain(); - } + rpc::shutdown(); } => (), } }) diff --git a/src/rpc/connection.rs b/src/rpc/connection.rs new file mode 100644 index 00000000..c6ee7b83 --- /dev/null +++ b/src/rpc/connection.rs @@ -0,0 +1,302 @@ +use axum::extract::ws::{Message, WebSocket}; +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::{SinkExt, StreamExt}; +use opentelemetry::trace::FutureExt; +use opentelemetry::Context as TelemetryContext; +use std::collections::BTreeMap; +use std::sync::Arc; +use surrealdb::channel::{self, Receiver, Sender}; +use tokio::sync::RwLock; + +use surrealdb::dbs::Session; +use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; +use uuid::Uuid; + +use crate::cnf::{MAX_CONCURRENT_CALLS, WEBSOCKET_PING_FREQUENCY}; +use crate::dbs::DB; +use crate::rpc::res::success; +use crate::rpc::{WebSocketRef, CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS}; +use crate::telemetry; +use crate::telemetry::metrics::ws::RequestContext; +use crate::telemetry::traces::rpc::span_for_request; + +use super::processor::Processor; +use super::request::parse_request; +use super::res::{failure, IntoRpcResponse, OutputFormat}; + +pub struct Connection { + ws_id: Uuid, + processor: Processor, + graceful_shutdown: CancellationToken, +} + +impl Connection { + /// Instantiate a new RPC + pub fn new(mut session: Session) -> Arc> { + // Create a new RPC variables store + let vars = BTreeMap::new(); + // Set the default output format + let format = OutputFormat::Json; + // Enable real-time mode + session.rt = true; + + // Create a new RPC processor + let processor = Processor::new(session, format, vars); + + // Create and store the RPC connection + Arc::new(RwLock::new(Connection { + ws_id: processor.ws_id, + processor, + graceful_shutdown: CancellationToken::new(), + })) + } + + /// Update the WebSocket ID. If the ID already exists, do not update it. + pub async fn update_ws_id(&mut self, ws_id: Uuid) -> Result<(), Box> { + if WEBSOCKETS.read().await.contains_key(&ws_id) { + trace!("WebSocket ID '{}' is in use by another connection. Do not update it.", &ws_id); + return Err("websocket ID is in use".into()); + } + + self.ws_id = ws_id; + self.processor.ws_id = ws_id; + Ok(()) + } + + /// Serve the RPC endpoint + pub async fn serve(rpc: Arc>, ws: WebSocket) { + // Split the socket into send and recv + let (sender, receiver) = ws.split(); + // Create an internal channel between the receiver and the sender + let (internal_sender, internal_receiver) = channel::new(MAX_CONCURRENT_CALLS); + + let ws_id = rpc.read().await.ws_id; + + trace!("WebSocket {} connected", ws_id); + + if let Err(err) = telemetry::metrics::ws::on_connect() { + error!("Error running metrics::ws::on_connect hook: {}", err); + } + + // Add this WebSocket to the list + WEBSOCKETS.write().await.insert( + ws_id, + WebSocketRef(internal_sender.clone(), rpc.read().await.graceful_shutdown.clone()), + ); + + // Remove all live queries + LIVE_QUERIES.write().await.retain(|key, value| { + if value == &ws_id { + trace!("Removing live query: {}", key); + return false; + } + true + }); + + let mut tasks = JoinSet::new(); + tasks.spawn(Self::ping(rpc.clone(), internal_sender.clone())); + tasks.spawn(Self::read(rpc.clone(), receiver, internal_sender.clone())); + tasks.spawn(Self::write(rpc.clone(), sender, internal_receiver.clone())); + tasks.spawn(Self::lq_notifications(rpc.clone())); + + // Wait until all tasks finish + while let Some(res) = tasks.join_next().await { + if let Err(err) = res { + error!("Error handling RPC connection: {}", err); + } + } + + // Remove this WebSocket from the list + WEBSOCKETS.write().await.remove(&ws_id); + + trace!("WebSocket {} disconnected", ws_id); + + if let Err(err) = telemetry::metrics::ws::on_disconnect() { + error!("Error running metrics::ws::on_disconnect hook: {}", err); + } + } + + /// Send Ping messages to the client + async fn ping(rpc: Arc>, internal_sender: Sender) { + // Create the interval ticker + let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY); + let cancel_token = rpc.read().await.graceful_shutdown.clone(); + loop { + let is_shutdown = cancel_token.cancelled(); + tokio::select! { + _ = interval.tick() => { + let msg = Message::Ping(vec![]); + + // Send the message to the client and close the WebSocket connection if it fails + if internal_sender.send(msg).await.is_err() { + rpc.read().await.graceful_shutdown.cancel(); + break; + } + }, + _ = is_shutdown => break, + } + } + } + + /// Read messages sent from the client + async fn read( + rpc: Arc>, + mut receiver: SplitStream, + internal_sender: Sender, + ) { + // Collect all spawned tasks so we can wait for them at the end + let mut tasks = JoinSet::new(); + let cancel_token = rpc.read().await.graceful_shutdown.clone(); + loop { + let is_shutdown = cancel_token.cancelled(); + tokio::select! { + msg = receiver.next() => { + if let Some(msg) = msg { + match msg { + // We've received a message from the client + // Ping/Pong is automatically handled by the WebSocket library + Ok(msg) => match msg { + Message::Text(_) => { + tasks.spawn(Connection::handle_msg(rpc.clone(), msg, internal_sender.clone())); + } + Message::Binary(_) => { + tasks.spawn(Connection::handle_msg(rpc.clone(), msg, internal_sender.clone())); + } + Message::Close(_) => { + // Respond with a close message + if let Err(err) = internal_sender.send(Message::Close(None)).await { + trace!("WebSocket error when replying to the Close frame: {:?}", err); + }; + // Start the graceful shutdown of the WebSocket and close the channels + rpc.read().await.graceful_shutdown.cancel(); + let _ = internal_sender.close(); + break; + } + _ => { + // Ignore everything else + } + }, + Err(err) => { + trace!("WebSocket error: {:?}", err); + // Start the graceful shutdown of the WebSocket and close the channels + rpc.read().await.graceful_shutdown.cancel(); + let _ = internal_sender.close(); + // Exit out of the loop + break; + } + } + } + } + _ = is_shutdown => break, + } + } + + // Wait for all tasks to finish + while let Some(res) = tasks.join_next().await { + if let Err(err) = res { + error!("Error while handling RPC message: {}", err); + } + } + } + + /// Write messages to the client + async fn write( + rpc: Arc>, + mut sender: SplitSink, + mut internal_receiver: Receiver, + ) { + let cancel_token = rpc.read().await.graceful_shutdown.clone(); + loop { + let is_shutdown = cancel_token.cancelled(); + tokio::select! { + // Wait for the next message to send + msg = internal_receiver.next() => { + if let Some(res) = msg { + // Send the message to the client + if let Err(err) = sender.send(res).await { + if err.to_string() != CONN_CLOSED_ERR { + debug!("WebSocket error: {:?}", err); + } + // Close the WebSocket connection + rpc.read().await.graceful_shutdown.cancel(); + // Exit out of the loop + break; + } + } + }, + _ = is_shutdown => break, + } + } + } + + /// Send live query notifications to the client + async fn lq_notifications(rpc: Arc>) { + if let Some(channel) = DB.get().unwrap().notifications() { + let cancel_token = rpc.read().await.graceful_shutdown.clone(); + loop { + tokio::select! { + msg = channel.recv() => { + if let Ok(notification) = msg { + // Find which WebSocket the notification belongs to + if let Some(ws_id) = LIVE_QUERIES.read().await.get(¬ification.id) { + // Check to see if the WebSocket exists + if let Some(WebSocketRef(ws, _)) = WEBSOCKETS.read().await.get(ws_id) { + // Serialize the message to send + let message = success(None, notification); + // Get the current output format + let format = rpc.read().await.processor.format.clone(); + // Send the notification to the client + message.send(format, ws.clone()).await + } + } + } + }, + _ = cancel_token.cancelled() => break, + } + } + } + } + + /// Handle individual WebSocket messages + async fn handle_msg(rpc: Arc>, msg: Message, chn: Sender) { + // Get the current output format + let mut out_fmt = rpc.read().await.processor.format.clone(); + // Prepare Span and Otel context + let span = span_for_request(&rpc.read().await.ws_id); + let _enter = span.enter(); + let req_cx = RequestContext::default(); + let otel_cx = TelemetryContext::current_with_value(req_cx.clone()); + + // Parse the request + match parse_request(msg).await { + Ok(req) => { + if let Some(_out_fmt) = req.out_fmt { + out_fmt = _out_fmt; + } + + // Now that we know the method, we can update the span and create otel context + span.record("rpc.method", &req.method); + span.record("otel.name", format!("surrealdb.rpc/{}", req.method)); + span.record( + "rpc.jsonrpc.request_id", + req.id.clone().map(|v| v.as_string()).unwrap_or(String::new()), + ); + let otel_cx = TelemetryContext::current_with_value( + req_cx.with_method(&req.method).with_size(req.size), + ); + + // Process the request + let res = + rpc.write().await.processor.process_request(&req.method, req.params).await; + + // Process the response + res.into_response(req.id).send(out_fmt, chn).with_context(otel_cx).await + } + Err(err) => { + // Process the response + failure(None, err).send(out_fmt, chn).with_context(otel_cx.clone()).await + } + } + } +} diff --git a/src/rpc/mod.rs b/src/rpc/mod.rs index c11136eb..d201d01b 100644 --- a/src/rpc/mod.rs +++ b/src/rpc/mod.rs @@ -1,5 +1,44 @@ pub mod args; -pub mod paths; +pub mod connection; +pub mod processor; +pub mod request; pub mod res; -pub(crate) static CONN_CLOSED_ERR: &str = "Connection closed normally"; +use std::{collections::HashMap, time::Duration}; + +use axum::extract::ws::Message; +use once_cell::sync::Lazy; +use surrealdb::channel::Sender; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; +use uuid::Uuid; + +static CONN_CLOSED_ERR: &str = "Connection closed normally"; + +// Mapping of WebSocketID to WebSocket +pub struct WebSocketRef(Sender, CancellationToken); +type WebSockets = RwLock>; +// Mapping of LiveQueryID to WebSocketID +type LiveQueries = RwLock>; + +pub(crate) static WEBSOCKETS: Lazy = Lazy::new(WebSockets::default); +pub(crate) static LIVE_QUERIES: Lazy = Lazy::new(LiveQueries::default); + +pub(crate) async fn graceful_shutdown() { + // Close all WebSocket connections. Queued messages will still be processed. + for (_, WebSocketRef(_, cancel_token)) in WEBSOCKETS.read().await.iter() { + cancel_token.cancel(); + } + + // Wait for all existing WebSocket connections to gracefully close + while WEBSOCKETS.read().await.len() > 0 { + tokio::time::sleep(Duration::from_millis(100)).await; + } +} + +pub(crate) fn shutdown() { + // Close all WebSocket connections immediately + if let Ok(mut writer) = WEBSOCKETS.try_write() { + writer.drain(); + } +} diff --git a/src/rpc/paths.rs b/src/rpc/paths.rs deleted file mode 100644 index b20301f1..00000000 --- a/src/rpc/paths.rs +++ /dev/null @@ -1,8 +0,0 @@ -use once_cell::sync::Lazy; -use surrealdb::sql::Part; - -pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]); - -pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]); - -pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]); diff --git a/src/rpc/processor.rs b/src/rpc/processor.rs new file mode 100644 index 00000000..359df67c --- /dev/null +++ b/src/rpc/processor.rs @@ -0,0 +1,547 @@ +use crate::cnf::PKG_NAME; +use crate::cnf::PKG_VERSION; +use crate::dbs::DB; +use crate::err::Error; +use crate::rpc::args::Take; +use crate::rpc::LIVE_QUERIES; +use std::collections::BTreeMap; + +use surrealdb::dbs::QueryType; +use surrealdb::dbs::Response; +use surrealdb::sql::Object; +use surrealdb::sql::Strand; +use surrealdb::sql::Value; +use surrealdb::{dbs::Session, sql::Array}; +use uuid::Uuid; + +use super::res::{Data, Failure, OutputFormat}; + +pub struct Processor { + pub ws_id: Uuid, + session: Session, + pub format: OutputFormat, + vars: BTreeMap, +} + +impl Processor { + pub fn new(session: Session, format: OutputFormat, vars: BTreeMap) -> Self { + Self { + ws_id: Uuid::new_v4(), + session, + format, + vars, + } + } + + pub async fn process_request(&mut self, method: &str, params: Array) -> Result { + info!("Process RPC request"); + + // Match the method to a function + match method { + // Handle a surrealdb ping message + // + // This is used to keep the WebSocket connection alive in environments where the WebSocket protocol is not enough. + // For example, some browsers will wait for the TCP protocol to timeout before triggering an on_close event. This may take several seconds or even minutes in certain scenarios. + // By sending a ping message every few seconds from the client, we can force a connection check and trigger a an on_close event if the ping can't be sent. + // + "ping" => Ok(Value::None.into()), + // Retrieve the current auth record + "info" => match params.len() { + 0 => self.info().await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Switch to a specific namespace and database + "use" => match params.needs_two() { + Ok((ns, db)) => self.yuse(ns, db).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Signup to a specific authentication scope + "signup" => match params.needs_one() { + Ok(Value::Object(v)) => self.signup(v).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Signin as a root, namespace, database or scope user + "signin" => match params.needs_one() { + Ok(Value::Object(v)) => self.signin(v).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Invalidate the current authentication session + "invalidate" => match params.len() { + 0 => self.invalidate().await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Authenticate using an authentication token + "authenticate" => match params.needs_one() { + Ok(Value::Strand(v)) => { + self.authenticate(v).await.map(Into::into).map_err(Into::into) + } + _ => Err(Failure::INVALID_PARAMS), + }, + // Kill a live query using a query id + "kill" => match params.needs_one() { + Ok(v) => self.kill(v).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Setup a live query on a specific table + "live" => match params.needs_one_or_two() { + Ok((v, d)) if v.is_table() => { + self.live(v, d).await.map(Into::into).map_err(Into::into) + } + Ok((v, d)) if v.is_strand() => { + self.live(v, d).await.map(Into::into).map_err(Into::into) + } + _ => Err(Failure::INVALID_PARAMS), + }, + // Specify a connection-wide parameter + "let" | "set" => match params.needs_one_or_two() { + Ok((Value::Strand(s), v)) => { + self.set(s, v).await.map(Into::into).map_err(Into::into) + } + _ => Err(Failure::INVALID_PARAMS), + }, + // Unset and clear a connection-wide parameter + "unset" => match params.needs_one() { + Ok(Value::Strand(s)) => self.unset(s).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Select a value or values from the database + "select" => match params.needs_one() { + Ok(v) => self.select(v).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Insert a value or values in the database + "insert" => match params.needs_one_or_two() { + Ok((v, o)) => self.insert(v, o).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Create a value or values in the database + "create" => match params.needs_one_or_two() { + Ok((v, o)) => self.create(v, o).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Update a value or values in the database using `CONTENT` + "update" => match params.needs_one_or_two() { + Ok((v, o)) => self.update(v, o).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Update a value or values in the database using `MERGE` + "change" | "merge" => match params.needs_one_or_two() { + Ok((v, o)) => self.change(v, o).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Update a value or values in the database using `PATCH` + "modify" | "patch" => match params.needs_one_or_two() { + Ok((v, o)) => self.modify(v, o).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Delete a value or values from the database + "delete" => match params.needs_one() { + Ok(v) => self.delete(v).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Specify the output format for text requests + "format" => match params.needs_one() { + Ok(Value::Strand(v)) => self.format(v).await.map(Into::into).map_err(Into::into), + _ => Err(Failure::INVALID_PARAMS), + }, + // Get the current server version + "version" => match params.len() { + 0 => Ok(format!("{PKG_NAME}-{}", *PKG_VERSION).into()), + _ => Err(Failure::INVALID_PARAMS), + }, + // Run a full SurrealQL query against the database + "query" => match params.needs_one_or_two() { + Ok((Value::Strand(s), o)) if o.is_none_or_null() => { + self.query(s).await.map(Into::into).map_err(Into::into) + } + Ok((Value::Strand(s), Value::Object(o))) => { + self.query_with(s, o).await.map(Into::into).map_err(Into::into) + } + _ => Err(Failure::INVALID_PARAMS), + }, + _ => Err(Failure::METHOD_NOT_FOUND), + } + } + + // ------------------------------ + // Methods for authentication + // ------------------------------ + + async fn format(&mut self, out: Strand) -> Result { + match out.as_str() { + "json" | "application/json" => self.format = OutputFormat::Json, + "cbor" | "application/cbor" => self.format = OutputFormat::Cbor, + "pack" | "application/pack" => self.format = OutputFormat::Pack, + _ => return Err(Error::InvalidType), + }; + Ok(Value::None) + } + + async fn yuse(&mut self, ns: Value, db: Value) -> Result { + if let Value::Strand(ns) = ns { + self.session.ns = Some(ns.0); + } + if let Value::Strand(db) = db { + self.session.db = Some(db.0); + } + Ok(Value::None) + } + + async fn signup(&mut self, vars: Object) -> Result { + let kvs = DB.get().unwrap(); + surrealdb::iam::signup::signup(kvs, &mut self.session, vars) + .await + .map(Into::into) + .map_err(Into::into) + } + + async fn signin(&mut self, vars: Object) -> Result { + let kvs = DB.get().unwrap(); + surrealdb::iam::signin::signin(kvs, &mut self.session, vars) + .await + .map(Into::into) + .map_err(Into::into) + } + async fn invalidate(&mut self) -> Result { + surrealdb::iam::clear::clear(&mut self.session)?; + Ok(Value::None) + } + + async fn authenticate(&mut self, token: Strand) -> Result { + let kvs = DB.get().unwrap(); + surrealdb::iam::verify::token(kvs, &mut self.session, &token.0).await?; + Ok(Value::None) + } + + // ------------------------------ + // Methods for identification + // ------------------------------ + + async fn info(&self) -> Result { + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "SELECT * FROM $auth"; + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, None).await?; + // Extract the first value from the result + let res = res.remove(0).result?.first(); + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for setting variables + // ------------------------------ + + async fn set(&mut self, key: Strand, val: Value) -> Result { + match val { + // Remove the variable if undefined + Value::None => self.vars.remove(&key.0), + // Store the variable if defined + v => self.vars.insert(key.0, v), + }; + Ok(Value::Null) + } + + async fn unset(&mut self, key: Strand) -> Result { + self.vars.remove(&key.0); + Ok(Value::Null) + } + + // ------------------------------ + // Methods for live queries + // ------------------------------ + + async fn kill(&self, id: Value) -> Result { + // Specify the SQL query string + let sql = "KILL $id"; + // Specify the query parameters + let var = map! { + String::from("id") => id, // NOTE: id can be parameter + => &self.vars + }; + // Execute the query on the database + let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?; + // Extract the first query result + let response = res.remove(0); + match response.result { + Ok(v) => Ok(v), + Err(e) => Err(Error::from(e)), + } + } + + async fn live(&self, tb: Value, diff: Value) -> Result { + // Specify the SQL query string + let sql = match diff.is_true() { + true => "LIVE SELECT DIFF FROM $tb", + false => "LIVE SELECT * FROM $tb", + }; + // Specify the query parameters + let var = map! { + String::from("tb") => tb.could_be_table(), + => &self.vars + }; + // Execute the query on the database + let mut res = self.query_with(Strand::from(sql), Object::from(var)).await?; + // Extract the first query result + let response = res.remove(0); + match response.result { + Ok(v) => Ok(v), + Err(e) => Err(Error::from(e)), + } + } + + // ------------------------------ + // Methods for selecting + // ------------------------------ + + async fn select(&self, what: Value) -> Result { + // Return a single result? + let one = what.is_thing(); + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "SELECT * FROM $what"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + => &self.vars + }); + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for inserting + // ------------------------------ + + async fn insert(&self, what: Value, data: Value) -> Result { + // Return a single result? + let one = what.is_thing(); + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "INSERT INTO $what $data RETURN AFTER"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars + }); + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for creating + // ------------------------------ + + async fn create(&self, what: Value, data: Value) -> Result { + // Return a single result? + let one = what.is_thing(); + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "CREATE $what CONTENT $data RETURN AFTER"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars + }); + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for updating + // ------------------------------ + + async fn update(&self, what: Value, data: Value) -> Result { + // Return a single result? + let one = what.is_thing(); + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "UPDATE $what CONTENT $data RETURN AFTER"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars + }); + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for changing + // ------------------------------ + + async fn change(&self, what: Value, data: Value) -> Result { + // Return a single result? + let one = what.is_thing(); + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "UPDATE $what MERGE $data RETURN AFTER"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars + }); + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for modifying + // ------------------------------ + + async fn modify(&self, what: Value, data: Value) -> Result { + // Return a single result? + let one = what.is_thing(); + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "UPDATE $what PATCH $data RETURN DIFF"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + String::from("data") => data, + => &self.vars + }); + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for deleting + // ------------------------------ + + async fn delete(&self, what: Value) -> Result { + // Return a single result? + let one = what.is_thing(); + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the SQL query string + let sql = "DELETE $what RETURN BEFORE"; + // Specify the query parameters + let var = Some(map! { + String::from("what") => what.could_be_table(), + => &self.vars + }); + // Execute the query on the database + let mut res = kvs.execute(sql, &self.session, var).await?; + // Extract the first query result + let res = match one { + true => res.remove(0).result?.first(), + false => res.remove(0).result?, + }; + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Methods for querying + // ------------------------------ + + async fn query(&self, sql: Strand) -> Result, Error> { + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the query parameters + let var = Some(self.vars.clone()); + // Execute the query on the database + let res = kvs.execute(&sql, &self.session, var).await?; + // Post-process hooks for web layer + for response in &res { + self.handle_live_query_results(response).await; + } + // Return the result to the client + Ok(res) + } + + async fn query_with(&self, sql: Strand, mut vars: Object) -> Result, Error> { + // Get a database reference + let kvs = DB.get().unwrap(); + // Specify the query parameters + let var = Some(mrg! { vars.0, &self.vars }); + // Execute the query on the database + let res = kvs.execute(&sql, &self.session, var).await?; + // Post-process hooks for web layer + for response in &res { + self.handle_live_query_results(response).await; + } + // Return the result to the client + Ok(res) + } + + // ------------------------------ + // Private methods + // ------------------------------ + + async fn handle_live_query_results(&self, res: &Response) { + match &res.query_type { + QueryType::Live => { + if let Ok(Value::Uuid(lqid)) = &res.result { + // Match on Uuid type + LIVE_QUERIES.write().await.insert(lqid.0, self.ws_id); + trace!("Registered live query {} on websocket {}", lqid, self.ws_id); + } + } + QueryType::Kill => { + if let Ok(Value::Uuid(lqid)) = &res.result { + let ws_id = LIVE_QUERIES.write().await.remove(&lqid.0); + if let Some(ws_id) = ws_id { + trace!("Unregistered live query {} on websocket {}", lqid, ws_id); + } + } + } + _ => {} + } + } +} diff --git a/src/rpc/request.rs b/src/rpc/request.rs new file mode 100644 index 00000000..69a35da9 --- /dev/null +++ b/src/rpc/request.rs @@ -0,0 +1,84 @@ +use axum::extract::ws::Message; +use surrealdb::sql::{serde::deserialize, Array, Value}; + +use once_cell::sync::Lazy; +use surrealdb::sql::Part; + +use super::res::{Failure, OutputFormat}; + +pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]); +pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]); +pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]); + +pub struct Request { + pub id: Option, + pub method: String, + pub params: Array, + pub size: usize, + pub out_fmt: Option, +} + +/// Parse the RPC request +pub async fn parse_request(msg: Message) -> Result { + let mut out_fmt = None; + let (req, size) = match msg { + // This is a binary message + Message::Binary(val) => { + // Use binary output + out_fmt = Some(OutputFormat::Full); + + match deserialize(&val) { + Ok(v) => (v, val.len()), + Err(_) => { + debug!("Error when trying to deserialize the request"); + return Err(Failure::PARSE_ERROR); + } + } + } + // This is a text message + Message::Text(ref val) => { + // Parse the SurrealQL object + match surrealdb::sql::value(val) { + // The SurrealQL message parsed ok + Ok(v) => (v, val.len()), + // The SurrealQL message failed to parse + _ => return Err(Failure::PARSE_ERROR), + } + } + // Unsupported message type + _ => { + debug!("Unsupported message type: {:?}", msg); + return Err(Failure::custom("Unsupported message type")); + } + }; + + // Fetch the 'id' argument + let id = match req.pick(&*ID) { + v if v.is_none() => None, + v if v.is_null() => Some(v), + v if v.is_uuid() => Some(v), + v if v.is_number() => Some(v), + v if v.is_strand() => Some(v), + v if v.is_datetime() => Some(v), + _ => return Err(Failure::INVALID_REQUEST), + }; + // Fetch the 'method' argument + let method = match req.pick(&*METHOD) { + Value::Strand(v) => v.to_raw(), + _ => return Err(Failure::INVALID_REQUEST), + }; + + // Fetch the 'params' argument + let params = match req.pick(&*PARAMS) { + Value::Array(v) => v, + _ => Array::new(), + }; + + Ok(Request { + id, + method, + params, + size, + out_fmt, + }) +} diff --git a/src/rpc/res.rs b/src/rpc/res.rs index ebf1b047..94d3ca43 100644 --- a/src/rpc/res.rs +++ b/src/rpc/res.rs @@ -1,4 +1,5 @@ use axum::extract::ws::Message; +use opentelemetry::Context as TelemetryContext; use serde::Serialize; use serde_json::{json, Value as Json}; use std::borrow::Cow; @@ -11,6 +12,7 @@ use tracing::Span; use crate::err; use crate::rpc::CONN_CLOSED_ERR; +use crate::telemetry::metrics::ws::record_rpc; #[derive(Debug, Clone)] pub enum OutputFormat { @@ -96,6 +98,7 @@ impl Response { info!("Process RPC response"); + let is_error = self.result.is_err(); if let Err(err) = &self.result { span.record("otel.status_code", "Error"); span.record( @@ -106,30 +109,33 @@ impl Response { span.record("rpc.jsonrpc.error_message", err.message.as_ref()); } - let message = match out { + let (res_size, message) = match out { OutputFormat::Json => { let res = serde_json::to_string(&self.simplify()).unwrap(); - Message::Text(res) + (res.len(), Message::Text(res)) } OutputFormat::Cbor => { let res = serde_cbor::to_vec(&self.simplify()).unwrap(); - Message::Binary(res) + (res.len(), Message::Binary(res)) } OutputFormat::Pack => { let res = serde_pack::to_vec(&self.simplify()).unwrap(); - Message::Binary(res) + (res.len(), Message::Binary(res)) } OutputFormat::Full => { let res = surrealdb::sql::serde::serialize(&self).unwrap(); - Message::Binary(res) + (res.len(), Message::Binary(res)) } }; if let Err(err) = chn.send(message).await { if err.to_string() != CONN_CLOSED_ERR { error!("Error sending response: {}", err); + return; } }; + + record_rpc(&TelemetryContext::current(), res_size, is_error); } } diff --git a/src/telemetry/metrics/http/mod.rs b/src/telemetry/metrics/http/mod.rs index 079f8e08..6847631d 100644 --- a/src/telemetry/metrics/http/mod.rs +++ b/src/telemetry/metrics/http/mod.rs @@ -1,87 +1,15 @@ pub(super) mod tower_layer; use once_cell::sync::Lazy; -use opentelemetry::{ - metrics::{Histogram, Meter, MeterProvider, ObservableUpDownCounter, Unit}, - runtime, - sdk::{ - export::metrics::aggregation, - metrics::{ - controllers::{self, BasicController}, - processors, selectors, - }, - }, - Context, -}; -use opentelemetry_otlp::MetricsExporterBuilder; +use opentelemetry::metrics::{Histogram, MetricsError, ObservableUpDownCounter, Unit}; +use opentelemetry::Context as TelemetryContext; -use crate::telemetry::OTEL_DEFAULT_RESOURCE; +use self::tower_layer::HttpCallMetricTracker; -// Histogram buckets in milliseconds -static HTTP_DURATION_MS_HISTOGRAM_BUCKETS: &[f64] = &[ - 5.0, 10.0, 20.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 500.0, 750.0, 1000.0, 1500.0, - 2000.0, 2500.0, 5000.0, 10000.0, 15000.0, 30000.0, -]; - -const KB: f64 = 1024.0; -const MB: f64 = 1024.0 * KB; - -const HTTP_SIZE_HISTOGRAM_BUCKETS: &[f64] = &[ - 1.0 * KB, // 1 KB - 2.0 * KB, // 2 KB - 5.0 * KB, // 5 KB - 10.0 * KB, // 10 KB - 100.0 * KB, // 100 KB - 500.0 * KB, // 500 KB - 1.0 * MB, // 1 MB - 2.5 * MB, // 2 MB - 5.0 * MB, // 5 MB - 10.0 * MB, // 10 MB - 25.0 * MB, // 25 MB - 50.0 * MB, // 50 MB - 100.0 * MB, // 100 MB -]; - -static METER_PROVIDER_HTTP_DURATION: Lazy = Lazy::new(|| { - let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic()) - .build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector())) - .unwrap(); - - let builder = controllers::basic(processors::factory( - selectors::simple::histogram(HTTP_DURATION_MS_HISTOGRAM_BUCKETS), - aggregation::cumulative_temporality_selector(), - )) - .with_exporter(exporter) - .with_resource(OTEL_DEFAULT_RESOURCE.clone()); - - let controller = builder.build(); - controller.start(&Context::current(), runtime::Tokio).unwrap(); - controller -}); - -static METER_PROVIDER_HTTP_SIZE: Lazy = Lazy::new(|| { - let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic()) - .build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector())) - .unwrap(); - - let builder = controllers::basic(processors::factory( - selectors::simple::histogram(HTTP_SIZE_HISTOGRAM_BUCKETS), - aggregation::cumulative_temporality_selector(), - )) - .with_exporter(exporter) - .with_resource(OTEL_DEFAULT_RESOURCE.clone()); - - let controller = builder.build(); - controller.start(&Context::current(), runtime::Tokio).unwrap(); - controller -}); - -static HTTP_DURATION_METER: Lazy = - Lazy::new(|| METER_PROVIDER_HTTP_DURATION.meter("http_duration")); -static HTTP_SIZE_METER: Lazy = Lazy::new(|| METER_PROVIDER_HTTP_SIZE.meter("http_size")); +use super::{METER_DURATION, METER_SIZE}; pub static HTTP_SERVER_DURATION: Lazy> = Lazy::new(|| { - HTTP_DURATION_METER + METER_DURATION .u64_histogram("http.server.duration") .with_description("The HTTP server duration in milliseconds.") .with_unit(Unit::new("ms")) @@ -89,14 +17,14 @@ pub static HTTP_SERVER_DURATION: Lazy> = Lazy::new(|| { }); pub static HTTP_SERVER_ACTIVE_REQUESTS: Lazy> = Lazy::new(|| { - HTTP_DURATION_METER + METER_DURATION .i64_observable_up_down_counter("http.server.active_requests") .with_description("The number of active HTTP requests.") .init() }); pub static HTTP_SERVER_REQUEST_SIZE: Lazy> = Lazy::new(|| { - HTTP_SIZE_METER + METER_SIZE .u64_histogram("http.server.request.size") .with_description("Measures the size of HTTP request messages.") .with_unit(Unit::new("mb")) @@ -104,9 +32,49 @@ pub static HTTP_SERVER_REQUEST_SIZE: Lazy> = Lazy::new(|| { }); pub static HTTP_SERVER_RESPONSE_SIZE: Lazy> = Lazy::new(|| { - HTTP_SIZE_METER + METER_SIZE .u64_histogram("http.server.response.size") .with_description("Measures the size of HTTP response messages.") .with_unit(Unit::new("mb")) .init() }); + +fn observe_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { + observe_active_request(1, tracker) +} + +fn observe_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { + observe_active_request(-1, tracker) +} + +fn observe_active_request(value: i64, tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { + let attrs = tracker.active_req_attrs(); + + METER_DURATION + .register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, value, &attrs)) +} + +fn record_request_duration(tracker: &HttpCallMetricTracker) { + // Record the duration of the request. + HTTP_SERVER_DURATION.record( + &TelemetryContext::current(), + tracker.duration().as_millis() as u64, + &tracker.request_duration_attrs(), + ); +} + +fn record_request_size(tracker: &HttpCallMetricTracker, size: u64) { + HTTP_SERVER_REQUEST_SIZE.record( + &TelemetryContext::current(), + size, + &tracker.request_size_attrs(), + ); +} + +fn record_response_size(tracker: &HttpCallMetricTracker, size: u64) { + HTTP_SERVER_RESPONSE_SIZE.record( + &TelemetryContext::current(), + size, + &tracker.response_size_attrs(), + ); +} diff --git a/src/telemetry/metrics/http/tower_layer.rs b/src/telemetry/metrics/http/tower_layer.rs index 16ec44aa..55dfe67f 100644 --- a/src/telemetry/metrics/http/tower_layer.rs +++ b/src/telemetry/metrics/http/tower_layer.rs @@ -1,5 +1,5 @@ use axum::extract::MatchedPath; -use opentelemetry::{metrics::MetricsError, Context as TelemetryContext, KeyValue}; +use opentelemetry::{metrics::MetricsError, KeyValue}; use pin_project_lite::pin_project; use std::{ cell::Cell, @@ -13,11 +13,6 @@ use futures::Future; use http::{Request, Response, StatusCode, Version}; use tower::{Layer, Service}; -use super::{ - HTTP_DURATION_METER, HTTP_SERVER_ACTIVE_REQUESTS, HTTP_SERVER_DURATION, - HTTP_SERVER_REQUEST_SIZE, HTTP_SERVER_RESPONSE_SIZE, -}; - #[derive(Clone, Default)] pub struct HttpMetricsLayer; @@ -168,7 +163,7 @@ impl HttpCallMetricTracker { } // Follows the OpenTelemetry semantic conventions for HTTP metrics define here: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/metrics/semantic_conventions/http-metrics.md - fn olel_common_attrs(&self) -> Vec { + fn otel_common_attrs(&self) -> Vec { let mut res = vec![ KeyValue::new("http.request.method", self.method.as_str().to_owned()), KeyValue::new("network.protocol.name", "http".to_owned()), @@ -186,11 +181,11 @@ impl HttpCallMetricTracker { } pub(super) fn active_req_attrs(&self) -> Vec { - self.olel_common_attrs() + self.otel_common_attrs() } pub(super) fn request_duration_attrs(&self) -> Vec { - let mut res = self.olel_common_attrs(); + let mut res = self.otel_common_attrs(); res.push(KeyValue::new( "http.response.status_code", @@ -247,64 +242,25 @@ impl Drop for HttpCallMetricTracker { pub fn on_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { // Setup the active_requests observer - observe_active_request_start(tracker) + super::observe_request_start(tracker) } pub fn on_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { // Setup the active_requests observer - observe_active_request_finish(tracker)?; + super::observe_request_finish(tracker)?; // Record the duration of the request. - record_request_duration(tracker); + super::record_request_duration(tracker); // Record the request size if known if let Some(size) = tracker.request_size { - record_request_size(tracker, size) + super::record_request_size(tracker, size) } // Record the response size if known if let Some(size) = tracker.response_size { - record_response_size(tracker, size) + super::record_response_size(tracker, size) } Ok(()) } - -fn observe_active_request_start(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { - let attrs = tracker.active_req_attrs(); - // Setup the callback to observe the active requests. - HTTP_DURATION_METER - .register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, 1, &attrs)) -} - -fn observe_active_request_finish(tracker: &HttpCallMetricTracker) -> Result<(), MetricsError> { - let attrs = tracker.active_req_attrs(); - // Setup the callback to observe the active requests. - HTTP_DURATION_METER - .register_callback(move |ctx| HTTP_SERVER_ACTIVE_REQUESTS.observe(ctx, -1, &attrs)) -} - -fn record_request_duration(tracker: &HttpCallMetricTracker) { - // Record the duration of the request. - HTTP_SERVER_DURATION.record( - &TelemetryContext::current(), - tracker.duration().as_millis() as u64, - &tracker.request_duration_attrs(), - ); -} - -pub fn record_request_size(tracker: &HttpCallMetricTracker, size: u64) { - HTTP_SERVER_REQUEST_SIZE.record( - &TelemetryContext::current(), - size, - &tracker.request_size_attrs(), - ); -} - -pub fn record_response_size(tracker: &HttpCallMetricTracker, size: u64) { - HTTP_SERVER_RESPONSE_SIZE.record( - &TelemetryContext::current(), - size, - &tracker.response_size_attrs(), - ); -} diff --git a/src/telemetry/metrics/mod.rs b/src/telemetry/metrics/mod.rs index b2b235a3..b35c74ac 100644 --- a/src/telemetry/metrics/mod.rs +++ b/src/telemetry/metrics/mod.rs @@ -1,3 +1,97 @@ pub mod http; +pub mod ws; + +use std::time::Duration; + +use once_cell::sync::Lazy; +use opentelemetry::Context as TelemetryContext; +use opentelemetry::{ + metrics::{Meter, MeterProvider, MetricsError}, + runtime, + sdk::{ + export::metrics::aggregation, + metrics::{ + controllers::{self, BasicController}, + processors, selectors, + }, + }, +}; +use opentelemetry_otlp::MetricsExporterBuilder; pub use self::http::tower_layer::HttpMetricsLayer; +use self::ws::observe_active_connection; + +use super::OTEL_DEFAULT_RESOURCE; + +// Histogram buckets in milliseconds +static HISTOGRAM_BUCKETS_MS: &[f64] = &[ + 5.0, 10.0, 20.0, 50.0, 75.0, 100.0, 150.0, 200.0, 250.0, 300.0, 500.0, 750.0, 1000.0, 1500.0, + 2000.0, 2500.0, 5000.0, 10000.0, 15000.0, 30000.0, +]; + +// Histogram buckets in bytes +const KB: f64 = 1024.0; +const MB: f64 = 1024.0 * KB; +const HISTOGRAM_BUCKETS_BYTES: &[f64] = &[ + 1.0 * KB, // 1 KB + 2.0 * KB, // 2 KB + 5.0 * KB, // 5 KB + 10.0 * KB, // 10 KB + 100.0 * KB, // 100 KB + 500.0 * KB, // 500 KB + 1.0 * MB, // 1 MB + 2.5 * MB, // 2 MB + 5.0 * MB, // 5 MB + 10.0 * MB, // 10 MB + 25.0 * MB, // 25 MB + 50.0 * MB, // 50 MB + 100.0 * MB, // 100 MB +]; + +fn build_controller(boundaries: &'static [f64]) -> BasicController { + let exporter = MetricsExporterBuilder::from(opentelemetry_otlp::new_exporter().tonic()) + .build_metrics_exporter(Box::new(aggregation::cumulative_temporality_selector())) + .unwrap(); + + let builder = controllers::basic(processors::factory( + selectors::simple::histogram(boundaries), + aggregation::cumulative_temporality_selector(), + )) + .with_push_timeout(Duration::from_secs(5)) + .with_collect_period(Duration::from_secs(5)) + .with_exporter(exporter) + .with_resource(OTEL_DEFAULT_RESOURCE.clone()); + + builder.build() +} + +static METER_PROVIDER_DURATION: Lazy = + Lazy::new(|| build_controller(HISTOGRAM_BUCKETS_MS)); + +static METER_PROVIDER_SIZE: Lazy = + Lazy::new(|| build_controller(HISTOGRAM_BUCKETS_BYTES)); + +static METER_DURATION: Lazy = Lazy::new(|| METER_PROVIDER_DURATION.meter("duration")); +static METER_SIZE: Lazy = Lazy::new(|| METER_PROVIDER_SIZE.meter("size")); + +/// Initialize the metrics subsystem +pub fn init(cx: &TelemetryContext) -> Result<(), MetricsError> { + METER_PROVIDER_DURATION.start(cx, runtime::Tokio)?; + METER_PROVIDER_SIZE.start(cx, runtime::Tokio)?; + + observe_active_connection(0)?; + + Ok(()) +} + +// +// Shutdown the metrics providers +// +pub fn shutdown(cx: &TelemetryContext) -> Result<(), MetricsError> { + METER_PROVIDER_DURATION.stop(cx)?; + METER_PROVIDER_DURATION.collect(cx)?; + METER_PROVIDER_SIZE.stop(cx)?; + METER_PROVIDER_SIZE.collect(cx)?; + + Ok(()) +} diff --git a/src/telemetry/metrics/ws/mod.rs b/src/telemetry/metrics/ws/mod.rs new file mode 100644 index 00000000..144a223e --- /dev/null +++ b/src/telemetry/metrics/ws/mod.rs @@ -0,0 +1,126 @@ +use std::time::Instant; + +use once_cell::sync::Lazy; +use opentelemetry::KeyValue; +use opentelemetry::{ + metrics::{Histogram, MetricsError, ObservableUpDownCounter, Unit}, + Context as TelemetryContext, +}; + +use super::{METER_DURATION, METER_SIZE}; + +pub static RPC_SERVER_DURATION: Lazy> = Lazy::new(|| { + METER_DURATION + .u64_histogram("rpc.server.duration") + .with_description("Measures duration of inbound RPC requests in milliseconds.") + .with_unit(Unit::new("ms")) + .init() +}); + +pub static RPC_SERVER_ACTIVE_CONNECTIONS: Lazy> = Lazy::new(|| { + METER_DURATION + .i64_observable_up_down_counter("rpc.server.active_connections") + .with_description("The number of active WebSocket connections.") + .init() +}); + +pub static RPC_SERVER_REQUEST_SIZE: Lazy> = Lazy::new(|| { + METER_SIZE + .u64_histogram("rpc.server.request.size") + .with_description("Measures the size of HTTP request messages.") + .with_unit(Unit::new("mb")) + .init() +}); + +pub static RPC_SERVER_RESPONSE_SIZE: Lazy> = Lazy::new(|| { + METER_SIZE + .u64_histogram("rpc.server.response.size") + .with_description("Measures the size of HTTP response messages.") + .with_unit(Unit::new("mb")) + .init() +}); + +fn otel_common_attrs() -> Vec { + vec![KeyValue::new("rpc.system", "jsonrpc"), KeyValue::new("rpc.service", "surrealdb")] +} + +/// Registers the callback that increases the number of active RPC connections. +pub fn on_connect() -> Result<(), MetricsError> { + observe_active_connection(1) +} + +/// Registers the callback that increases the number of active RPC connections. +pub fn on_disconnect() -> Result<(), MetricsError> { + observe_active_connection(-1) +} + +pub(super) fn observe_active_connection(value: i64) -> Result<(), MetricsError> { + let attrs = otel_common_attrs(); + + METER_DURATION + .register_callback(move |cx| RPC_SERVER_ACTIVE_CONNECTIONS.observe(cx, value, &attrs)) +} + +// +// Record an RPC command +// + +#[derive(Clone, Debug, PartialEq)] +pub struct RequestContext { + start: Instant, + pub method: String, + pub size: usize, +} + +impl Default for RequestContext { + fn default() -> Self { + Self { + start: Instant::now(), + method: "unknown".to_string(), + size: 0, + } + } +} + +impl RequestContext { + pub fn with_method(self, method: &str) -> Self { + Self { + method: method.to_string(), + ..self + } + } + + pub fn with_size(self, size: usize) -> Self { + Self { + size, + ..self + } + } +} + +/// Updates the request and response metrics for an RPC method. +pub fn record_rpc(cx: &TelemetryContext, res_size: usize, is_error: bool) { + let mut attrs = otel_common_attrs(); + let mut duration = 0; + let mut req_size = 0; + + if let Some(cx) = cx.get::() { + attrs.extend_from_slice(&[ + KeyValue::new("rpc.method", cx.method.clone()), + KeyValue::new("rpc.error", is_error), + ]); + duration = cx.start.elapsed().as_millis() as u64; + req_size = cx.size as u64; + } else { + // If a bug causes the RequestContent to be empty, we still want to record the metrics to avoid a silent failure. + warn!("record_rpc: no request context found, resulting metrics will be invalid"); + attrs.extend_from_slice(&[ + KeyValue::new("rpc.method", "unknown"), + KeyValue::new("rpc.error", is_error), + ]); + }; + + RPC_SERVER_DURATION.record(cx, duration, &attrs); + RPC_SERVER_REQUEST_SIZE.record(cx, req_size, &attrs); + RPC_SERVER_RESPONSE_SIZE.record(cx, res_size as u64, &attrs); +} diff --git a/src/telemetry/mod.rs b/src/telemetry/mod.rs index b1be2d1e..4dd24a41 100644 --- a/src/telemetry/mod.rs +++ b/src/telemetry/mod.rs @@ -6,11 +6,12 @@ use std::time::Duration; use crate::cli::validator::parser::env_filter::CustomEnvFilter; use once_cell::sync::Lazy; +use opentelemetry::metrics::MetricsError; use opentelemetry::sdk::resource::{ EnvResourceDetector, SdkProvidedResourceDetector, TelemetryResourceDetector, }; use opentelemetry::sdk::Resource; -use opentelemetry::KeyValue; +use opentelemetry::{Context as TelemetryContext, KeyValue}; use tracing::{Level, Subscriber}; use tracing_subscriber::prelude::*; use tracing_subscriber::util::SubscriberInitExt; @@ -86,10 +87,18 @@ impl Builder { /// Install the tracing dispatcher globally pub fn init(self) { - self.build().init() + self.build().init(); } } +pub fn shutdown() -> Result<(), MetricsError> { + // Flush all telemetry data + opentelemetry::global::shutdown_tracer_provider(); + metrics::shutdown(&TelemetryContext::current())?; + + Ok(()) +} + /// Create an EnvFilter from the given value. If the value is not a valid log level, it will be treated as EnvFilter directives. pub fn filter_from_value(v: &str) -> Result { match v {