Remove global static Datastore (#4377)

This commit is contained in:
Sergii Glushchenko 2024-07-18 14:11:59 +02:00 committed by GitHub
parent 968b1714dc
commit c73435a881
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 102 additions and 65 deletions

View file

@ -4,7 +4,7 @@ use crate::cli::validator::parser::env_filter::CustomEnvFilter;
use crate::cli::validator::parser::env_filter::CustomEnvFilterParser; use crate::cli::validator::parser::env_filter::CustomEnvFilterParser;
use crate::cnf::LOGO; use crate::cnf::LOGO;
use crate::dbs; use crate::dbs;
use crate::dbs::{StartCommandDbsOptions, DB}; use crate::dbs::StartCommandDbsOptions;
use crate::env; use crate::env;
use crate::err::Error; use crate::err::Error;
use crate::net::{self, client_ip::ClientIp}; use crate::net::{self, client_ip::ClientIp};
@ -12,6 +12,7 @@ use clap::Args;
use opentelemetry::Context; use opentelemetry::Context;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use surrealdb::engine::any::IntoEndpoint; use surrealdb::engine::any::IntoEndpoint;
use surrealdb::engine::tasks::start_tasks; use surrealdb::engine::tasks::start_tasks;
@ -189,15 +190,13 @@ pub async fn init(
let ct = CancellationToken::new(); let ct = CancellationToken::new();
// Initiate environment // Initiate environment
env::init().await?; env::init().await?;
// Start the kvs server // Start the datastore
dbs::init(dbs).await?; let ds = Arc::new(dbs::init(dbs).await?);
// Start the node agent // Start the node agent
let (tasks, task_chans) = start_tasks( let (tasks, task_chans) =
&config::CF.get().unwrap().engine.unwrap_or_default(), start_tasks(&config::CF.get().unwrap().engine.unwrap_or_default(), ds.clone());
DB.get().unwrap().clone(),
);
// Start the web server // Start the web server
net::init(ct.clone()).await?; net::init(ds, ct.clone()).await?;
// Shutdown and stop closed tasks // Shutdown and stop closed tasks
task_chans.into_iter().for_each(|chan| { task_chans.into_iter().for_each(|chan| {
if chan.send(()).is_err() { if chan.send(()).is_err() {

View file

@ -2,13 +2,10 @@ use crate::cli::CF;
use crate::err::Error; use crate::err::Error;
use clap::Args; use clap::Args;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, OnceLock};
use std::time::Duration; use std::time::Duration;
use surrealdb::dbs::capabilities::{Capabilities, FuncTarget, NetTarget, Targets}; use surrealdb::dbs::capabilities::{Capabilities, FuncTarget, NetTarget, Targets};
use surrealdb::kvs::Datastore; use surrealdb::kvs::Datastore;
pub static DB: OnceLock<Arc<Datastore>> = OnceLock::new();
#[derive(Args, Debug)] #[derive(Args, Debug)]
pub struct StartCommandDbsOptions { pub struct StartCommandDbsOptions {
#[arg(help = "Whether strict mode is enabled on this database instance")] #[arg(help = "Whether strict mode is enabled on this database instance")]
@ -211,7 +208,7 @@ pub async fn init(
capabilities, capabilities,
temporary_directory, temporary_directory,
}: StartCommandDbsOptions, }: StartCommandDbsOptions,
) -> Result<(), Error> { ) -> Result<Datastore, Error> {
// Get local copy of options // Get local copy of options
let opt = CF.get().unwrap(); let opt = CF.get().unwrap();
// Convert the capabilities // Convert the capabilities
@ -248,10 +245,8 @@ pub async fn init(
} }
// Bootstrap the datastore // Bootstrap the datastore
dbs.bootstrap().await?; dbs.bootstrap().await?;
// Store database instance
let _ = DB.set(Arc::new(dbs));
// All ok // All ok
Ok(()) Ok(dbs)
} }
#[cfg(test)] #[cfg(test)]

View file

@ -15,7 +15,7 @@ use surrealdb::{
}; };
use tower_http::auth::AsyncAuthorizeRequest; use tower_http::auth::AsyncAuthorizeRequest;
use crate::{dbs::DB, err::Error}; use crate::err::Error;
use super::{ use super::{
client_ip::ExtractClientIP, client_ip::ExtractClientIP,
@ -75,8 +75,6 @@ where
} }
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> { async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
let kvs = DB.get().unwrap();
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await { let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
if !or.is_null() { if !or.is_null() {
Some(or.to_string()) Some(or.to_string())
@ -113,6 +111,8 @@ async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
Error::InvalidAuth Error::InvalidAuth
})?; })?;
let kvs = &state.datastore;
let ExtractClientIP(ip) = let ExtractClientIP(ip) =
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None)); parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));

View file

@ -1,4 +1,4 @@
use crate::dbs::DB; use super::AppState;
use crate::err::Error; use crate::err::Error;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::get; use axum::routing::get;
@ -21,9 +21,12 @@ where
Router::new().route("/export", get(handler)) Router::new().route("/export", get(handler))
} }
async fn handler(Extension(session): Extension<Session>) -> Result<impl IntoResponse, Error> { async fn handler(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>,
) -> Result<impl IntoResponse, Error> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Create a chunked response // Create a chunked response
let (mut chn, body) = Body::channel(); let (mut chn, body) = Body::channel();
// Ensure a NS and DB are set // Ensure a NS and DB are set

View file

@ -1,7 +1,8 @@
use crate::dbs::DB; use super::AppState;
use crate::err::Error; use crate::err::Error;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::get; use axum::routing::get;
use axum::Extension;
use axum::Router; use axum::Router;
use http_body::Body as HttpBody; use http_body::Body as HttpBody;
use surrealdb::kvs::{LockType::*, TransactionType::*}; use surrealdb::kvs::{LockType::*, TransactionType::*};
@ -14,9 +15,9 @@ where
Router::new().route("/health", get(handler)) Router::new().route("/health", get(handler))
} }
async fn handler() -> impl IntoResponse { async fn handler(Extension(state): Extension<AppState>) -> impl IntoResponse {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Attempt to open a transaction // Attempt to open a transaction
match db.transaction(Read, Optimistic).await { match db.transaction(Read, Optimistic).await {
// The transaction failed to start // The transaction failed to start

View file

@ -1,5 +1,5 @@
use super::headers::Accept; use super::headers::Accept;
use crate::dbs::DB; use super::AppState;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
@ -32,12 +32,13 @@ where
} }
async fn handler( async fn handler(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
sql: Bytes, sql: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Convert the body to a byte slice // Convert the body to a byte slice
let sql = bytes_to_utf8(&sql)?; let sql = bytes_to_utf8(&sql)?;
// Check the permissions level // Check the permissions level

View file

@ -1,4 +1,3 @@
use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
@ -18,6 +17,7 @@ use surrealdb::sql::Value;
use tower_http::limit::RequestBodyLimitLayer; use tower_http::limit::RequestBodyLimitLayer;
use super::headers::Accept; use super::headers::Accept;
use super::AppState;
const MAX: usize = 1024 * 16; // 16 KiB const MAX: usize = 1024 * 16; // 16 KiB
@ -68,13 +68,14 @@ where
// ------------------------------ // ------------------------------
async fn select_all( async fn select_all(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>, Path(table): Path<String>,
Query(query): Query<QueryOptions>, Query(query): Query<QueryOptions>,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Specify the request statement // Specify the request statement
@ -108,6 +109,7 @@ async fn select_all(
} }
async fn create_all( async fn create_all(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>, Path(table): Path<String>,
@ -115,7 +117,7 @@ async fn create_all(
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Convert the HTTP request body // Convert the HTTP request body
@ -152,6 +154,7 @@ async fn create_all(
} }
async fn update_all( async fn update_all(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>, Path(table): Path<String>,
@ -159,7 +162,7 @@ async fn update_all(
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Convert the HTTP request body // Convert the HTTP request body
@ -196,6 +199,7 @@ async fn update_all(
} }
async fn modify_all( async fn modify_all(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>, Path(table): Path<String>,
@ -203,7 +207,7 @@ async fn modify_all(
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Convert the HTTP request body // Convert the HTTP request body
@ -240,13 +244,14 @@ async fn modify_all(
} }
async fn delete_all( async fn delete_all(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Path(table): Path<String>, Path(table): Path<String>,
Query(params): Query<Params>, Query(params): Query<Params>,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Specify the request statement // Specify the request statement
@ -278,13 +283,14 @@ async fn delete_all(
// ------------------------------ // ------------------------------
async fn select_one( async fn select_one(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Path((table, id)): Path<(String, String)>, Path((table, id)): Path<(String, String)>,
Query(query): Query<QueryOptions>, Query(query): Query<QueryOptions>,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Specify the request statement // Specify the request statement
@ -321,6 +327,7 @@ async fn select_one(
} }
async fn create_one( async fn create_one(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Query(params): Query<Params>, Query(params): Query<Params>,
@ -328,7 +335,7 @@ async fn create_one(
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Convert the HTTP request body // Convert the HTTP request body
@ -371,6 +378,7 @@ async fn create_one(
} }
async fn update_one( async fn update_one(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Query(params): Query<Params>, Query(params): Query<Params>,
@ -378,7 +386,7 @@ async fn update_one(
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Convert the HTTP request body // Convert the HTTP request body
@ -421,6 +429,7 @@ async fn update_one(
} }
async fn modify_one( async fn modify_one(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Query(params): Query<Params>, Query(params): Query<Params>,
@ -428,7 +437,7 @@ async fn modify_one(
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Convert the HTTP request body // Convert the HTTP request body
@ -471,12 +480,13 @@ async fn modify_one(
} }
async fn delete_one( async fn delete_one(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
Path((table, id)): Path<(String, String)>, Path((table, id)): Path<(String, String)>,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let _ = check_ns_db(&session)?; let _ = check_ns_db(&session)?;
// Specify the request statement // Specify the request statement

View file

@ -1,5 +1,5 @@
//! This file defines the endpoints for the ML API for importing and exporting SurrealML models. //! This file defines the endpoints for the ML API for importing and exporting SurrealML models.
use crate::dbs::DB; use super::AppState;
use crate::err::Error; use crate::err::Error;
use crate::net::output; use crate::net::output;
use axum::extract::{BodyStream, DefaultBodyLimit, Path}; use axum::extract::{BodyStream, DefaultBodyLimit, Path};
@ -41,11 +41,12 @@ where
/// This endpoint allows the user to import a model into the database. /// This endpoint allows the user to import a model into the database.
async fn import( async fn import(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
mut stream: BodyStream, mut stream: BodyStream,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(&session)?; let (nsv, dbv) = check_ns_db(&session)?;
// Check the permissions level // Check the permissions level
@ -92,11 +93,12 @@ async fn import(
/// This endpoint allows the user to export a model from the database. /// This endpoint allows the user to export a model from the database.
async fn export( async fn export(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
Path((name, version)): Path<(String, String)>, Path((name, version)): Path<(String, String)>,
) -> Result<impl IntoResponse, Error> { ) -> Result<impl IntoResponse, Error> {
// Get the datastore reference // Get the datastore reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Ensure a NS and DB are set // Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(&session)?; let (nsv, dbv) = check_ns_db(&session)?;
// Check the permissions level // Check the permissions level

View file

@ -36,6 +36,7 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use surrealdb::headers::{AUTH_DB, AUTH_NS, DB, ID, NS}; use surrealdb::headers::{AUTH_DB, AUTH_NS, DB, ID, NS};
use surrealdb::kvs::Datastore;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_http::add_extension::AddExtensionLayer; use tower_http::add_extension::AddExtensionLayer;
@ -60,14 +61,16 @@ const LOG: &str = "surrealdb::net";
#[derive(Clone)] #[derive(Clone)]
struct AppState { struct AppState {
client_ip: client_ip::ClientIp, client_ip: client_ip::ClientIp,
datastore: Arc<Datastore>,
} }
pub async fn init(ct: CancellationToken) -> Result<(), Error> { pub async fn init(ds: Arc<Datastore>, ct: CancellationToken) -> Result<(), Error> {
// Get local copy of options // Get local copy of options
let opt = CF.get().unwrap(); let opt = CF.get().unwrap();
let app_state = AppState { let app_state = AppState {
client_ip: opt.client_ip, client_ip: opt.client_ip,
datastore: ds.clone(),
}; };
// Specify headers to be obfuscated from all requests/responses // Specify headers to be obfuscated from all requests/responses
@ -186,7 +189,7 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
let axum_app = axum_app.with_state(rpc_state.clone()); let axum_app = axum_app.with_state(rpc_state.clone());
// Spawn a task to handle notifications // Spawn a task to handle notifications
tokio::spawn(async move { notifications(rpc_state, ct.clone()).await }); tokio::spawn(async move { notifications(ds, rpc_state, ct.clone()).await });
// If a certificate and key are specified then setup TLS // If a certificate and key are specified then setup TLS
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) { if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
// Configure certificate and private key used by https // Configure certificate and private key used by https

View file

@ -3,7 +3,6 @@ use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use crate::cnf; use crate::cnf;
use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::rpc::connection::Connection; use crate::rpc::connection::Connection;
use crate::rpc::format::HttpFormat; use crate::rpc::format::HttpFormat;
@ -23,6 +22,7 @@ use bytes::Bytes;
use http::HeaderValue; use http::HeaderValue;
use http_body::Body as HttpBody; use http_body::Body as HttpBody;
use surrealdb::dbs::Session; use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore;
use surrealdb::rpc::format::Format; use surrealdb::rpc::format::Format;
use surrealdb::rpc::format::PROTOCOLS; use surrealdb::rpc::format::PROTOCOLS;
use surrealdb::rpc::method::Method; use surrealdb::rpc::method::Method;
@ -31,6 +31,7 @@ use uuid::Uuid;
use super::headers::Accept; use super::headers::Accept;
use super::headers::ContentType; use super::headers::ContentType;
use super::AppState;
use surrealdb::rpc::rpc_context::RpcContext; use surrealdb::rpc::rpc_context::RpcContext;
@ -45,6 +46,7 @@ where
async fn get_handler( async fn get_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(state): Extension<AppState>,
Extension(id): Extension<RequestId>, Extension(id): Extension<RequestId>,
Extension(sess): Extension<Session>, Extension(sess): Extension<Session>,
State(rpc_state): State<Arc<RpcState>>, State(rpc_state): State<Arc<RpcState>>,
@ -79,10 +81,18 @@ async fn get_handler(
// Set the maximum WebSocket message size // Set the maximum WebSocket message size
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE) .max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
// Handle the WebSocket upgrade and process messages // Handle the WebSocket upgrade and process messages
.on_upgrade(move |socket| handle_socket(rpc_state, socket, sess, id))) .on_upgrade(move |socket| {
handle_socket(state.datastore.clone(), rpc_state, socket, sess, id)
}))
} }
async fn handle_socket(state: Arc<RpcState>, ws: WebSocket, sess: Session, id: Uuid) { async fn handle_socket(
datastore: Arc<Datastore>,
state: Arc<RpcState>,
ws: WebSocket,
sess: Session,
id: Uuid,
) {
// Check if there is a WebSocket protocol specified // Check if there is a WebSocket protocol specified
let format = match ws.protocol().map(HeaderValue::to_str) { let format = match ws.protocol().map(HeaderValue::to_str) {
// Any selected protocol will always be a valie value // Any selected protocol will always be a valie value
@ -92,12 +102,13 @@ async fn handle_socket(state: Arc<RpcState>, ws: WebSocket, sess: Session, id: U
}; };
// Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here // Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here
// Create a new connection instance // Create a new connection instance
let rpc = Connection::new(state, id, sess, format); let rpc = Connection::new(datastore, state, id, sess, format);
// Serve the socket connection requests // Serve the socket connection requests
Connection::serve(rpc, ws).await; Connection::serve(rpc, ws).await;
} }
async fn post_handler( async fn post_handler(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
output: Option<TypedHeader<Accept>>, output: Option<TypedHeader<Accept>>,
content_type: TypedHeader<ContentType>, content_type: TypedHeader<ContentType>,
@ -114,7 +125,7 @@ async fn post_handler(
return Err(Error::InvalidType); return Err(Error::InvalidType);
} }
let mut rpc_ctx = PostRpcContext::new(DB.get().unwrap(), session, BTreeMap::new()); let mut rpc_ctx = PostRpcContext::new(&state.datastore, session, BTreeMap::new());
match fmt.req_http(body) { match fmt.req_http(body) {
Ok(req) => { Ok(req) => {

View file

@ -1,4 +1,3 @@
use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
@ -16,6 +15,7 @@ use surrealdb::sql::Value;
use tower_http::limit::RequestBodyLimitLayer; use tower_http::limit::RequestBodyLimitLayer;
use super::headers::Accept; use super::headers::Accept;
use super::AppState;
const MAX: usize = 1024; // 1 KiB const MAX: usize = 1024; // 1 KiB
@ -50,12 +50,13 @@ where
} }
async fn handler( async fn handler(
Extension(state): Extension<AppState>,
Extension(mut session): Extension<Session>, Extension(mut session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = &state.datastore;
// Convert the HTTP body into text // Convert the HTTP body into text
let data = bytes_to_utf8(&body)?; let data = bytes_to_utf8(&body)?;
// Parse the provided data as JSON // Parse the provided data as JSON

View file

@ -1,4 +1,3 @@
use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
@ -14,6 +13,7 @@ use surrealdb::sql::Value;
use tower_http::limit::RequestBodyLimitLayer; use tower_http::limit::RequestBodyLimitLayer;
use super::headers::Accept; use super::headers::Accept;
use super::AppState;
const MAX: usize = 1024; // 1 KiB const MAX: usize = 1024; // 1 KiB
@ -48,12 +48,13 @@ where
} }
async fn handler( async fn handler(
Extension(state): Extension<AppState>,
Extension(mut session): Extension<Session>, Extension(mut session): Extension<Session>,
accept: Option<TypedHeader<Accept>>, accept: Option<TypedHeader<Accept>>,
body: Bytes, body: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get a database reference // Get a database reference
let kvs = DB.get().unwrap(); let kvs = &state.datastore;
// Convert the HTTP body into text // Convert the HTTP body into text
let data = bytes_to_utf8(&body)?; let data = bytes_to_utf8(&body)?;
// Parse the provided data as JSON // Parse the provided data as JSON

View file

@ -1,4 +1,3 @@
use crate::dbs::DB;
use crate::err::Error; use crate::err::Error;
use crate::net::input::bytes_to_utf8; use crate::net::input::bytes_to_utf8;
use crate::net::output; use crate::net::output;
@ -20,6 +19,7 @@ use surrealdb::dbs::Session;
use tower_http::limit::RequestBodyLimitLayer; use tower_http::limit::RequestBodyLimitLayer;
use super::headers::Accept; use super::headers::Accept;
use super::AppState;
const MAX: usize = 1024 * 1024; // 1 MiB const MAX: usize = 1024 * 1024; // 1 MiB
@ -37,13 +37,14 @@ where
} }
async fn post_handler( async fn post_handler(
Extension(state): Extension<AppState>,
Extension(session): Extension<Session>, Extension(session): Extension<Session>,
output: Option<TypedHeader<Accept>>, output: Option<TypedHeader<Accept>>,
params: Query<Params>, params: Query<Params>,
sql: Bytes, sql: Bytes,
) -> Result<impl IntoResponse, impl IntoResponse> { ) -> Result<impl IntoResponse, impl IntoResponse> {
// Get a database reference // Get a database reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Convert the received sql query // Convert the received sql query
let sql = bytes_to_utf8(&sql)?; let sql = bytes_to_utf8(&sql)?;
// Execute the received sql query // Execute the received sql query
@ -65,12 +66,13 @@ async fn post_handler(
async fn ws_handler( async fn ws_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(state): Extension<AppState>,
Extension(sess): Extension<Session>, Extension(sess): Extension<Session>,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, sess)) ws.on_upgrade(move |socket| handle_socket(state, socket, sess))
} }
async fn handle_socket(ws: WebSocket, session: Session) { async fn handle_socket(state: AppState, ws: WebSocket, session: Session) {
// Split the WebSocket connection // Split the WebSocket connection
let (mut tx, mut rx) = ws.split(); let (mut tx, mut rx) = ws.split();
// Wait to receive the next message // Wait to receive the next message
@ -78,7 +80,7 @@ async fn handle_socket(ws: WebSocket, session: Session) {
if let Ok(msg) = res { if let Ok(msg) = res {
if let Ok(sql) = msg.to_text() { if let Ok(sql) = msg.to_text() {
// Get a database reference // Get a database reference
let db = DB.get().unwrap(); let db = &state.datastore;
// Execute the received sql query // Execute the received sql query
let _ = match db.execute(sql, &session, None).await { let _ = match db.execute(sql, &session, None).await {
// Convert the response to JSON // Convert the response to JSON

View file

@ -1,7 +1,6 @@
use crate::cnf::{ use crate::cnf::{
PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY, PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY,
}; };
use crate::dbs::DB;
use crate::rpc::failure::Failure; use crate::rpc::failure::Failure;
use crate::rpc::format::WsFormat; use crate::rpc::format::WsFormat;
use crate::rpc::response::{failure, IntoRpcResponse}; use crate::rpc::response::{failure, IntoRpcResponse};
@ -44,11 +43,13 @@ pub struct Connection {
pub(crate) canceller: CancellationToken, pub(crate) canceller: CancellationToken,
pub(crate) channels: (Sender<Message>, Receiver<Message>), pub(crate) channels: (Sender<Message>, Receiver<Message>),
pub(crate) state: Arc<RpcState>, pub(crate) state: Arc<RpcState>,
pub(crate) datastore: Arc<Datastore>,
} }
impl Connection { impl Connection {
/// Instantiate a new RPC /// Instantiate a new RPC
pub fn new( pub fn new(
datastore: Arc<Datastore>,
state: Arc<RpcState>, state: Arc<RpcState>,
id: Uuid, id: Uuid,
mut session: Session, mut session: Session,
@ -66,6 +67,7 @@ impl Connection {
canceller: CancellationToken::new(), canceller: CancellationToken::new(),
channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS), channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS),
state, state,
datastore,
})) }))
} }
@ -77,6 +79,8 @@ impl Connection {
let id = rpc_lock.id; let id = rpc_lock.id;
// Get the WebSocket state // Get the WebSocket state
let state = rpc_lock.state.clone(); let state = rpc_lock.state.clone();
// Get the Datastore
let ds = rpc_lock.datastore.clone();
// Log the succesful WebSocket connection // Log the succesful WebSocket connection
trace!("WebSocket {} connected", id); trace!("WebSocket {} connected", id);
// Split the socket into sending and receiving streams // Split the socket into sending and receiving streams
@ -125,7 +129,7 @@ impl Connection {
true true
}); });
if let Err(err) = DB.get().unwrap().delete_queries(gc).await { if let Err(err) = ds.delete_queries(gc).await {
error!("Error handling RPC connection: {}", err); error!("Error handling RPC connection: {}", err);
} }
@ -367,7 +371,7 @@ impl Connection {
impl RpcContext for Connection { impl RpcContext for Connection {
fn kvs(&self) -> &Datastore { fn kvs(&self) -> &Datastore {
DB.get().unwrap() &self.datastore
} }
fn session(&self) -> &Session { fn session(&self) -> &Session {
@ -410,7 +414,7 @@ impl RpcContext for Connection {
return Err(RpcError::InvalidParams); return Err(RpcError::InvalidParams);
}; };
let out: Result<Value, RpcError> = let out: Result<Value, RpcError> =
surrealdb::iam::signup::signup(DB.get().unwrap(), &mut self.session, v) surrealdb::iam::signup::signup(&self.datastore, &mut self.session, v)
.await .await
.map(Into::into) .map(Into::into)
.map_err(Into::into); .map_err(Into::into);
@ -423,7 +427,7 @@ impl RpcContext for Connection {
return Err(RpcError::InvalidParams); return Err(RpcError::InvalidParams);
}; };
let out: Result<Value, RpcError> = let out: Result<Value, RpcError> =
surrealdb::iam::signin::signin(DB.get().unwrap(), &mut self.session, v) surrealdb::iam::signin::signin(&self.datastore, &mut self.session, v)
.await .await
.map(Into::into) .map(Into::into)
.map_err(Into::into); .map_err(Into::into);
@ -434,7 +438,7 @@ impl RpcContext for Connection {
let Ok(Value::Strand(token)) = params.needs_one() else { let Ok(Value::Strand(token)) = params.needs_one() else {
return Err(RpcError::InvalidParams); return Err(RpcError::InvalidParams);
}; };
surrealdb::iam::verify::token(DB.get().unwrap(), &mut self.session, &token.0).await?; surrealdb::iam::verify::token(&self.datastore, &mut self.session, &token.0).await?;
Ok(Value::None) Ok(Value::None)
} }
} }

View file

@ -4,7 +4,6 @@ pub mod format;
pub mod post_context; pub mod post_context;
pub mod response; pub mod response;
use crate::dbs::DB;
use crate::rpc::connection::Connection; use crate::rpc::connection::Connection;
use crate::rpc::response::success; use crate::rpc::response::success;
use crate::telemetry::metrics::ws::NotificationContext; use crate::telemetry::metrics::ws::NotificationContext;
@ -12,6 +11,7 @@ use opentelemetry::Context as TelemetryContext;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use surrealdb::kvs::Datastore;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use uuid::Uuid; use uuid::Uuid;
@ -41,9 +41,13 @@ impl RpcState {
} }
/// Performs notification delivery to the WebSockets /// Performs notification delivery to the WebSockets
pub(crate) async fn notifications(state: Arc<RpcState>, canceller: CancellationToken) { pub(crate) async fn notifications(
ds: Arc<Datastore>,
state: Arc<RpcState>,
canceller: CancellationToken,
) {
// Listen to the notifications channel // Listen to the notifications channel
if let Some(channel) = DB.get().unwrap().notifications() { if let Some(channel) = ds.notifications() {
// Loop continuously // Loop continuously
loop { loop {
tokio::select! { tokio::select! {