use crate::cnf::{ PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY, }; use crate::dbs::DB; use crate::rpc::failure::Failure; use crate::rpc::format::Format; use crate::rpc::response::{failure, IntoRpcResponse}; use crate::rpc::{CONN_CLOSED_ERR, LIVE_QUERIES, WEBSOCKETS}; use crate::telemetry; use crate::telemetry::metrics::ws::RequestContext; use crate::telemetry::traces::rpc::span_for_request; use axum::extract::ws::{Message, WebSocket}; use futures_util::stream::{SplitSink, SplitStream}; use futures_util::{SinkExt, StreamExt}; use opentelemetry::trace::FutureExt; use opentelemetry::Context as TelemetryContext; use std::collections::BTreeMap; use std::sync::Arc; use surrealdb::channel::{self, Receiver, Sender}; use surrealdb::dbs::Session; use surrealdb::kvs::Datastore; use surrealdb::rpc::args::Take; use surrealdb::rpc::method::Method; use surrealdb::rpc::RpcContext; use surrealdb::rpc::{Data, RpcError}; use surrealdb::sql::Array; use surrealdb::sql::Value; use tokio::sync::{RwLock, Semaphore}; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::Instrument; use tracing::Span; use uuid::Uuid; pub struct Connection { pub(crate) id: Uuid, pub(crate) format: Format, pub(crate) session: Session, pub(crate) vars: BTreeMap, pub(crate) limiter: Arc, pub(crate) canceller: CancellationToken, pub(crate) channels: (Sender, Receiver), } impl Connection { /// Instantiate a new RPC pub fn new(id: Uuid, mut session: Session, format: Format) -> Arc> { // Enable real-time mode session.rt = true; // Create and store the RPC connection Arc::new(RwLock::new(Connection { id, format, session, vars: BTreeMap::new(), limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)), canceller: CancellationToken::new(), channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS), })) } /// Serve the RPC endpoint pub async fn serve(rpc: Arc>, ws: WebSocket) { // Get the WebSocket ID let id = rpc.read().await.id; // Split the socket into sending and receiving streams let (sender, receiver) = ws.split(); // Create an internal channel for sending and receiving let internal_sender = rpc.read().await.channels.0.clone(); let internal_receiver = rpc.read().await.channels.1.clone(); trace!("WebSocket {} connected", id); if let Err(err) = telemetry::metrics::ws::on_connect() { error!("Error running metrics::ws::on_connect hook: {}", err); } // Add this WebSocket to the list WEBSOCKETS.write().await.insert(id, rpc.clone()); // Spawn async tasks for the WebSocket let mut tasks = JoinSet::new(); tasks.spawn(Self::ping(rpc.clone(), internal_sender.clone())); tasks.spawn(Self::read(rpc.clone(), receiver, internal_sender.clone())); tasks.spawn(Self::write(rpc.clone(), sender, internal_receiver.clone())); // Wait until all tasks finish while let Some(res) = tasks.join_next().await { if let Err(err) = res { error!("Error handling RPC connection: {}", err); } } internal_sender.close(); trace!("WebSocket {} disconnected", id); // Remove this WebSocket from the list WEBSOCKETS.write().await.remove(&id); // Remove all live queries let mut gc = Vec::new(); LIVE_QUERIES.write().await.retain(|key, value| { if value == &id { trace!("Removing live query: {}", key); gc.push(*key); return false; } true }); // Garbage collect queries if let Err(e) = DB.get().unwrap().garbage_collect_dead_session(gc.as_slice()).await { error!("Failed to garbage collect dead sessions: {:?}", e); } if let Err(err) = telemetry::metrics::ws::on_disconnect() { error!("Error running metrics::ws::on_disconnect hook: {}", err); } } /// Send Ping messages to the client async fn ping(rpc: Arc>, internal_sender: Sender) { // Create the interval ticker let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY); // Clone the WebSocket cancellation token let canceller = rpc.read().await.canceller.clone(); // Loop, and listen for messages to write loop { tokio::select! { // biased; // Check if this has shutdown _ = canceller.cancelled() => break, // Send a regular ping message _ = interval.tick() => { // Create a new ping message let msg = Message::Ping(vec![]); // Close the connection if the message fails if internal_sender.send(msg).await.is_err() { // Cancel the WebSocket tasks rpc.read().await.canceller.cancel(); // Exit out of the loop break; } }, } } } /// Write messages to the client async fn write( rpc: Arc>, mut sender: SplitSink, mut internal_receiver: Receiver, ) { // Clone the WebSocket cancellation token let canceller = rpc.read().await.canceller.clone(); // Loop, and listen for messages to write loop { tokio::select! { // biased; // Check if this has shutdown _ = canceller.cancelled() => break, // Wait for the next message to send Some(res) = internal_receiver.next() => { // Send the message to the client if let Err(err) = sender.send(res).await { // Output any errors if not a close error if err.to_string() != CONN_CLOSED_ERR { debug!("WebSocket error: {:?}", err); } // Cancel the WebSocket tasks rpc.read().await.canceller.cancel(); // Exit out of the loop break; } }, } } } /// Read messages sent from the client async fn read( rpc: Arc>, mut receiver: SplitStream, internal_sender: Sender, ) { // Store spawned tasks so we can wait for them let mut tasks = JoinSet::new(); // Clone the WebSocket cancellation token let canceller = rpc.read().await.canceller.clone(); // Loop, and listen for messages to write loop { tokio::select! { // biased; // Check if this has shutdown _ = canceller.cancelled() => break, // Remove any completed tasks Some(out) = tasks.join_next() => match out { // The task completed successfully Ok(_) => continue, // There was an uncaught panic in the task Err(err) => { // There was an error with the task trace!("WebSocket request error: {:?}", err); // Cancel the WebSocket tasks rpc.read().await.canceller.cancel(); // Exit out of the loop break; } }, // Wait for the next received message Some(msg) = receiver.next() => match msg { // We've received a message from the client Ok(msg) => match msg { Message::Text(_) => { tasks.spawn(Connection::handle_message(rpc.clone(), msg, internal_sender.clone())); } Message::Binary(_) => { tasks.spawn(Connection::handle_message(rpc.clone(), msg, internal_sender.clone())); } Message::Close(_) => { // Respond with a close message if let Err(err) = internal_sender.send(Message::Close(None)).await { trace!("WebSocket error when replying to the Close frame: {:?}", err); }; // Cancel the WebSocket tasks rpc.read().await.canceller.cancel(); // Exit out of the loop break; } _ => { // Ignore everything else } }, Err(err) => { // There was an error with the WebSocket trace!("WebSocket error: {:?}", err); // Cancel the WebSocket tasks rpc.read().await.canceller.cancel(); // Exit out of the loop break; } } } } // Wait for all tasks to finish while let Some(res) = tasks.join_next().await { if let Err(err) = res { // There was an error with the task trace!("WebSocket request error: {:?}", err); } } // Abort all tasks tasks.shutdown().await; } /// Handle individual WebSocket messages async fn handle_message(rpc: Arc>, msg: Message, chn: Sender) { // Get the current output format let mut fmt = rpc.read().await.format; // Prepare Span and Otel context let span = span_for_request(&rpc.read().await.id); // Acquire concurrent request rate limiter let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap(); // Calculate the length of the message let len = match msg { Message::Text(ref msg) => { // If no format was specified, default to JSON if fmt.is_none() { fmt = Format::Json; rpc.write().await.format = fmt; } // Retrieve the length of the message msg.len() } Message::Binary(ref msg) => { // If no format was specified, default to Bincode if fmt.is_none() { fmt = Format::Bincode; rpc.write().await.format = fmt; } // Retrieve the length of the message msg.len() } _ => unreachable!(), }; // Parse the request async move { let span = Span::current(); let req_cx = RequestContext::default(); let otel_cx = Arc::new(TelemetryContext::new().with_value(req_cx.clone())); // Parse the RPC request structure match fmt.req_ws(msg) { Ok(req) => { // Now that we know the method, we can update the span and create otel context span.record("rpc.method", &req.method); span.record("otel.name", format!("surrealdb.rpc/{}", req.method)); span.record( "rpc.request_id", req.id.clone().map(Value::as_string).unwrap_or_default(), ); let otel_cx = Arc::new(TelemetryContext::current_with_value( req_cx.with_method(&req.method).with_size(len), )); // Process the message let res = Connection::process_message(rpc.clone(), &req.method, req.params).await; // Process the response res.into_response(req.id) .send(otel_cx.clone(), fmt, &chn) .with_context(otel_cx.as_ref().clone()) .await } Err(err) => { // Process the response failure(None, err) .send(otel_cx.clone(), fmt, &chn) .with_context(otel_cx.as_ref().clone()) .await } } } .instrument(span) .await; // Drop the rate limiter permit drop(permit); } pub async fn process_message( rpc: Arc>, method: &str, params: Array, ) -> Result { debug!("Process RPC request"); let method = Method::parse(method); if !method.is_valid() { return Err(Failure::METHOD_NOT_FOUND); } // if the write lock is a bottleneck then execute could be refactored into execute_mut and execute // rpc.write().await.execute(method, params).await.map_err(Into::into) match method.needs_mut() { true => rpc.write().await.execute(method, params).await.map_err(Into::into), false => rpc.read().await.execute_immut(method, params).await.map_err(Into::into), } } } impl RpcContext for Connection { fn kvs(&self) -> &Datastore { DB.get().unwrap() } fn session(&self) -> &Session { &self.session } fn session_mut(&mut self) -> &mut Session { &mut self.session } fn vars(&self) -> &BTreeMap { &self.vars } fn vars_mut(&mut self) -> &mut BTreeMap { &mut self.vars } fn version_data(&self) -> impl Into { format!("{PKG_NAME}-{}", *PKG_VERSION) } const LQ_SUPPORT: bool = true; async fn handle_live(&self, lqid: &Uuid) { LIVE_QUERIES.write().await.insert(*lqid, self.id); trace!("Registered live query {} on websocket {}", lqid, self.id); } async fn handle_kill(&self, lqid: &Uuid) { if let Some(id) = LIVE_QUERIES.write().await.remove(lqid) { trace!("Unregistered live query {} on websocket {}", lqid, id); } } // reimplimentaions async fn signup(&mut self, params: Array) -> Result, RpcError> { let Ok(Value::Object(v)) = params.needs_one() else { return Err(RpcError::InvalidParams); }; let out: Result = surrealdb::iam::signup::signup(DB.get().unwrap(), &mut self.session, v) .await .map(Into::into) .map_err(Into::into); out } async fn signin(&mut self, params: Array) -> Result, RpcError> { let Ok(Value::Object(v)) = params.needs_one() else { return Err(RpcError::InvalidParams); }; let out: Result = surrealdb::iam::signin::signin(DB.get().unwrap(), &mut self.session, v) .await .map(Into::into) .map_err(Into::into); out } async fn authenticate(&mut self, params: Array) -> Result, RpcError> { let Ok(Value::Strand(token)) = params.needs_one() else { return Err(RpcError::InvalidParams); }; surrealdb::iam::verify::token(DB.get().unwrap(), &mut self.session, &token.0).await?; Ok(Value::None) } }