surrealpatch/src/rpc/connection.rs
2024-07-18 12:11:59 +00:00

444 lines
13 KiB
Rust

use crate::cnf::{
PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY,
};
use crate::rpc::failure::Failure;
use crate::rpc::format::WsFormat;
use crate::rpc::response::{failure, IntoRpcResponse};
use crate::rpc::CONN_CLOSED_ERR;
use crate::telemetry;
use crate::telemetry::metrics::ws::RequestContext;
use crate::telemetry::traces::rpc::span_for_request;
use axum::extract::ws::{Message, WebSocket};
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use opentelemetry::trace::FutureExt;
use opentelemetry::Context as TelemetryContext;
use std::collections::BTreeMap;
use std::sync::Arc;
use surrealdb::channel::{self, Receiver, Sender};
use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore;
use surrealdb::rpc::args::Take;
use surrealdb::rpc::format::Format;
use surrealdb::rpc::method::Method;
use surrealdb::rpc::RpcContext;
use surrealdb::rpc::{Data, RpcError};
use surrealdb::sql::Array;
use surrealdb::sql::Value;
use tokio::sync::{RwLock, Semaphore};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
use tracing::Span;
use uuid::Uuid;
use super::RpcState;
pub struct Connection {
pub(crate) id: Uuid,
pub(crate) format: Format,
pub(crate) session: Session,
pub(crate) vars: BTreeMap<String, Value>,
pub(crate) limiter: Arc<Semaphore>,
pub(crate) canceller: CancellationToken,
pub(crate) channels: (Sender<Message>, Receiver<Message>),
pub(crate) state: Arc<RpcState>,
pub(crate) datastore: Arc<Datastore>,
}
impl Connection {
/// Instantiate a new RPC
pub fn new(
datastore: Arc<Datastore>,
state: Arc<RpcState>,
id: Uuid,
mut session: Session,
format: Format,
) -> Arc<RwLock<Connection>> {
// Enable real-time mode
session.rt = true;
// Create and store the RPC connection
Arc::new(RwLock::new(Connection {
id,
format,
session,
vars: BTreeMap::new(),
limiter: Arc::new(Semaphore::new(*WEBSOCKET_MAX_CONCURRENT_REQUESTS)),
canceller: CancellationToken::new(),
channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS),
state,
datastore,
}))
}
/// Serve the RPC endpoint
pub async fn serve(rpc: Arc<RwLock<Connection>>, ws: WebSocket) {
// Get the RPC lock
let rpc_lock = rpc.read().await;
// Get the WebSocket id
let id = rpc_lock.id;
// Get the WebSocket state
let state = rpc_lock.state.clone();
// Get the Datastore
let ds = rpc_lock.datastore.clone();
// Log the succesful WebSocket connection
trace!("WebSocket {} connected", id);
// Split the socket into sending and receiving streams
let (sender, receiver) = ws.split();
// Create an internal channel for sending and receiving
let internal_sender = rpc_lock.channels.0.clone();
let internal_receiver = rpc_lock.channels.1.clone();
// Drop the lock early so rpc is free to be written to.
std::mem::drop(rpc_lock);
if let Err(err) = telemetry::metrics::ws::on_connect() {
error!("Error running metrics::ws::on_connect hook: {}", err);
}
// Add this WebSocket to the list
state.web_sockets.write().await.insert(id, rpc.clone());
// Spawn async tasks for the WebSocket
let mut tasks = JoinSet::new();
tasks.spawn(Self::ping(rpc.clone(), internal_sender.clone()));
tasks.spawn(Self::read(rpc.clone(), receiver, internal_sender.clone()));
tasks.spawn(Self::write(rpc.clone(), sender, internal_receiver.clone()));
// Wait until all tasks finish
while let Some(res) = tasks.join_next().await {
if let Err(err) = res {
error!("Error handling RPC connection: {}", err);
}
}
internal_sender.close();
trace!("WebSocket {} disconnected", id);
// Remove this WebSocket from the list
state.web_sockets.write().await.remove(&id);
// Remove all live queries
let mut gc = Vec::new();
state.live_queries.write().await.retain(|key, value| {
if value == &id {
trace!("Removing live query: {}", key);
gc.push(*key);
return false;
}
true
});
if let Err(err) = ds.delete_queries(gc).await {
error!("Error handling RPC connection: {}", err);
}
if let Err(err) = telemetry::metrics::ws::on_disconnect() {
error!("Error running metrics::ws::on_disconnect hook: {}", err);
}
}
/// Send Ping messages to the client
async fn ping(rpc: Arc<RwLock<Connection>>, internal_sender: Sender<Message>) {
// Create the interval ticker
let mut interval = tokio::time::interval(WEBSOCKET_PING_FREQUENCY);
// Clone the WebSocket cancellation token
let canceller = rpc.read().await.canceller.clone();
// Loop, and listen for messages to write
loop {
tokio::select! {
//
biased;
// Check if this has shutdown
_ = canceller.cancelled() => break,
// Send a regular ping message
_ = interval.tick() => {
// Create a new ping message
let msg = Message::Ping(vec![]);
// Close the connection if the message fails
if internal_sender.send(msg).await.is_err() {
// Cancel the WebSocket tasks
rpc.read().await.canceller.cancel();
// Exit out of the loop
break;
}
},
}
}
}
/// Write messages to the client
async fn write(
rpc: Arc<RwLock<Connection>>,
mut sender: SplitSink<WebSocket, Message>,
mut internal_receiver: Receiver<Message>,
) {
// Clone the WebSocket cancellation token
let canceller = rpc.read().await.canceller.clone();
// Loop, and listen for messages to write
loop {
tokio::select! {
//
biased;
// Check if this has shutdown
_ = canceller.cancelled() => break,
// Wait for the next message to send
Some(res) = internal_receiver.next() => {
// Send the message to the client
if let Err(err) = sender.send(res).await {
// Output any errors if not a close error
if err.to_string() != CONN_CLOSED_ERR {
debug!("WebSocket error: {:?}", err);
}
// Cancel the WebSocket tasks
rpc.read().await.canceller.cancel();
// Exit out of the loop
break;
}
},
}
}
}
/// Read messages sent from the client
async fn read(
rpc: Arc<RwLock<Connection>>,
mut receiver: SplitStream<WebSocket>,
internal_sender: Sender<Message>,
) {
// Store spawned tasks so we can wait for them
let mut tasks = JoinSet::new();
// Clone the WebSocket cancellation token
let canceller = rpc.read().await.canceller.clone();
// Loop, and listen for messages to write
loop {
tokio::select! {
//
biased;
// Check if this has shutdown
_ = canceller.cancelled() => break,
// Remove any completed tasks
Some(out) = tasks.join_next() => match out {
// The task completed successfully
Ok(_) => continue,
// There was an uncaught panic in the task
Err(err) => {
// There was an error with the task
trace!("WebSocket request error: {:?}", err);
// Cancel the WebSocket tasks
rpc.read().await.canceller.cancel();
// Exit out of the loop
break;
}
},
// Wait for the next received message
Some(msg) = receiver.next() => match msg {
// We've received a message from the client
Ok(msg) => match msg {
Message::Text(_) => {
tasks.spawn(Connection::handle_message(rpc.clone(), msg, internal_sender.clone()));
}
Message::Binary(_) => {
tasks.spawn(Connection::handle_message(rpc.clone(), msg, internal_sender.clone()));
}
Message::Close(_) => {
// Respond with a close message
if let Err(err) = internal_sender.send(Message::Close(None)).await {
trace!("WebSocket error when replying to the Close frame: {:?}", err);
};
// Cancel the WebSocket tasks
rpc.read().await.canceller.cancel();
// Exit out of the loop
break;
}
_ => {
// Ignore everything else
}
},
Err(err) => {
// There was an error with the WebSocket
trace!("WebSocket error: {:?}", err);
// Cancel the WebSocket tasks
rpc.read().await.canceller.cancel();
// Exit out of the loop
break;
}
}
}
}
// Wait for all tasks to finish
while let Some(res) = tasks.join_next().await {
if let Err(err) = res {
// There was an error with the task
trace!("WebSocket request error: {:?}", err);
}
}
// Abort all tasks
tasks.shutdown().await;
}
/// Handle individual WebSocket messages
async fn handle_message(rpc: Arc<RwLock<Connection>>, msg: Message, chn: Sender<Message>) {
// Get the current output format
let mut fmt = rpc.read().await.format;
// Prepare Span and Otel context
let span = span_for_request(&rpc.read().await.id);
// Acquire concurrent request rate limiter
let permit = rpc.read().await.limiter.clone().acquire_owned().await.unwrap();
// Calculate the length of the message
let len = match msg {
Message::Text(ref msg) => {
// If no format was specified, default to JSON
if fmt.is_none() {
fmt = Format::Json;
rpc.write().await.format = fmt;
}
// Retrieve the length of the message
msg.len()
}
Message::Binary(ref msg) => {
// If no format was specified, default to Bincode
if fmt.is_none() {
fmt = Format::Bincode;
rpc.write().await.format = fmt;
}
// Retrieve the length of the message
msg.len()
}
_ => unreachable!(),
};
// Parse the request
async move {
let span = Span::current();
let req_cx = RequestContext::default();
let otel_cx = Arc::new(TelemetryContext::new().with_value(req_cx.clone()));
// Parse the RPC request structure
match fmt.req_ws(msg) {
Ok(req) => {
// Now that we know the method, we can update the span and create otel context
span.record("rpc.method", &req.method);
span.record("otel.name", format!("surrealdb.rpc/{}", req.method));
span.record(
"rpc.request_id",
req.id.clone().map(Value::as_string).unwrap_or_default(),
);
let otel_cx = Arc::new(TelemetryContext::current_with_value(
req_cx.with_method(&req.method).with_size(len),
));
// Process the message
let res =
Connection::process_message(rpc.clone(), &req.method, req.params).await;
// Process the response
res.into_response(req.id)
.send(otel_cx.clone(), fmt, &chn)
.with_context(otel_cx.as_ref().clone())
.await
}
Err(err) => {
// Process the response
failure(None, err)
.send(otel_cx.clone(), fmt, &chn)
.with_context(otel_cx.as_ref().clone())
.await
}
}
}
.instrument(span)
.await;
// Drop the rate limiter permit
drop(permit);
}
pub async fn process_message(
rpc: Arc<RwLock<Connection>>,
method: &str,
params: Array,
) -> Result<Data, Failure> {
debug!("Process RPC request");
let method = Method::parse(method);
if !method.is_valid() {
return Err(Failure::METHOD_NOT_FOUND);
}
// if the write lock is a bottleneck then execute could be refactored into execute_mut and execute
// rpc.write().await.execute(method, params).await.map_err(Into::into)
match method.needs_mut() {
true => rpc.write().await.execute(method, params).await.map_err(Into::into),
false => rpc.read().await.execute_immut(method, params).await.map_err(Into::into),
}
}
}
impl RpcContext for Connection {
fn kvs(&self) -> &Datastore {
&self.datastore
}
fn session(&self) -> &Session {
&self.session
}
fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
fn vars(&self) -> &BTreeMap<String, Value> {
&self.vars
}
fn vars_mut(&mut self) -> &mut BTreeMap<String, Value> {
&mut self.vars
}
fn version_data(&self) -> impl Into<Data> {
format!("{PKG_NAME}-{}", *PKG_VERSION)
}
const LQ_SUPPORT: bool = true;
async fn handle_live(&self, lqid: &Uuid) {
self.state.live_queries.write().await.insert(*lqid, self.id);
trace!("Registered live query {} on websocket {}", lqid, self.id);
}
async fn handle_kill(&self, lqid: &Uuid) {
if let Some(id) = self.state.live_queries.write().await.remove(lqid) {
trace!("Unregistered live query {} on websocket {}", lqid, id);
}
}
// reimplimentaions
async fn signup(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
surrealdb::iam::signup::signup(&self.datastore, &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn signin(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Object(v)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
let out: Result<Value, RpcError> =
surrealdb::iam::signin::signin(&self.datastore, &mut self.session, v)
.await
.map(Into::into)
.map_err(Into::into);
out
}
async fn authenticate(&mut self, params: Array) -> Result<impl Into<Data>, RpcError> {
let Ok(Value::Strand(token)) = params.needs_one() else {
return Err(RpcError::InvalidParams);
};
surrealdb::iam::verify::token(&self.datastore, &mut self.session, &token.0).await?;
Ok(Value::None)
}
}