diff --git a/src/cli/start.rs b/src/cli/start.rs index 941d7c0b..5b388570 100644 --- a/src/cli/start.rs +++ b/src/cli/start.rs @@ -4,7 +4,7 @@ use crate::cli::validator::parser::env_filter::CustomEnvFilter; use crate::cli::validator::parser::env_filter::CustomEnvFilterParser; use crate::cnf::LOGO; use crate::dbs; -use crate::dbs::{StartCommandDbsOptions, DB}; +use crate::dbs::StartCommandDbsOptions; use crate::env; use crate::err::Error; use crate::net::{self, client_ip::ClientIp}; @@ -12,6 +12,7 @@ use clap::Args; use opentelemetry::Context; use std::net::SocketAddr; use std::path::PathBuf; +use std::sync::Arc; use std::time::Duration; use surrealdb::engine::any::IntoEndpoint; use surrealdb::engine::tasks::start_tasks; @@ -189,15 +190,13 @@ pub async fn init( let ct = CancellationToken::new(); // Initiate environment env::init().await?; - // Start the kvs server - dbs::init(dbs).await?; + // Start the datastore + let ds = Arc::new(dbs::init(dbs).await?); // Start the node agent - let (tasks, task_chans) = start_tasks( - &config::CF.get().unwrap().engine.unwrap_or_default(), - DB.get().unwrap().clone(), - ); + let (tasks, task_chans) = + start_tasks(&config::CF.get().unwrap().engine.unwrap_or_default(), ds.clone()); // Start the web server - net::init(ct.clone()).await?; + net::init(ds, ct.clone()).await?; // Shutdown and stop closed tasks task_chans.into_iter().for_each(|chan| { if chan.send(()).is_err() { diff --git a/src/dbs/mod.rs b/src/dbs/mod.rs index 2303a88e..aacc783d 100644 --- a/src/dbs/mod.rs +++ b/src/dbs/mod.rs @@ -2,13 +2,10 @@ use crate::cli::CF; use crate::err::Error; use clap::Args; use std::path::PathBuf; -use std::sync::{Arc, OnceLock}; use std::time::Duration; use surrealdb::dbs::capabilities::{Capabilities, FuncTarget, NetTarget, Targets}; use surrealdb::kvs::Datastore; -pub static DB: OnceLock> = OnceLock::new(); - #[derive(Args, Debug)] pub struct StartCommandDbsOptions { #[arg(help = "Whether strict mode is enabled on this database instance")] @@ -211,7 +208,7 @@ pub async fn init( capabilities, temporary_directory, }: StartCommandDbsOptions, -) -> Result<(), Error> { +) -> Result { // Get local copy of options let opt = CF.get().unwrap(); // Convert the capabilities @@ -248,10 +245,8 @@ pub async fn init( } // Bootstrap the datastore dbs.bootstrap().await?; - // Store database instance - let _ = DB.set(Arc::new(dbs)); // All ok - Ok(()) + Ok(dbs) } #[cfg(test)] diff --git a/src/net/auth.rs b/src/net/auth.rs index dbc859f5..6a2eb190 100644 --- a/src/net/auth.rs +++ b/src/net/auth.rs @@ -15,7 +15,7 @@ use surrealdb::{ }; use tower_http::auth::AsyncAuthorizeRequest; -use crate::{dbs::DB, err::Error}; +use crate::err::Error; use super::{ client_ip::ExtractClientIP, @@ -75,8 +75,6 @@ where } async fn check_auth(parts: &mut Parts) -> Result { - let kvs = DB.get().unwrap(); - let or = if let Ok(or) = parts.extract::>().await { if !or.is_null() { Some(or.to_string()) @@ -113,6 +111,8 @@ async fn check_auth(parts: &mut Parts) -> Result { Error::InvalidAuth })?; + let kvs = &state.datastore; + let ExtractClientIP(ip) = parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None)); diff --git a/src/net/export.rs b/src/net/export.rs index ed12d6bb..3157be0e 100644 --- a/src/net/export.rs +++ b/src/net/export.rs @@ -1,4 +1,4 @@ -use crate::dbs::DB; +use super::AppState; use crate::err::Error; use axum::response::IntoResponse; use axum::routing::get; @@ -21,9 +21,12 @@ where Router::new().route("/export", get(handler)) } -async fn handler(Extension(session): Extension) -> Result { +async fn handler( + Extension(state): Extension, + Extension(session): Extension, +) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Create a chunked response let (mut chn, body) = Body::channel(); // Ensure a NS and DB are set diff --git a/src/net/health.rs b/src/net/health.rs index 9182bc7f..c4e1463f 100644 --- a/src/net/health.rs +++ b/src/net/health.rs @@ -1,7 +1,8 @@ -use crate::dbs::DB; +use super::AppState; use crate::err::Error; use axum::response::IntoResponse; use axum::routing::get; +use axum::Extension; use axum::Router; use http_body::Body as HttpBody; use surrealdb::kvs::{LockType::*, TransactionType::*}; @@ -14,9 +15,9 @@ where Router::new().route("/health", get(handler)) } -async fn handler() -> impl IntoResponse { +async fn handler(Extension(state): Extension) -> impl IntoResponse { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Attempt to open a transaction match db.transaction(Read, Optimistic).await { // The transaction failed to start diff --git a/src/net/import.rs b/src/net/import.rs index 1902b534..40a89edb 100644 --- a/src/net/import.rs +++ b/src/net/import.rs @@ -1,5 +1,5 @@ use super::headers::Accept; -use crate::dbs::DB; +use super::AppState; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; @@ -32,12 +32,13 @@ where } async fn handler( + Extension(state): Extension, Extension(session): Extension, accept: Option>, sql: Bytes, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Convert the body to a byte slice let sql = bytes_to_utf8(&sql)?; // Check the permissions level diff --git a/src/net/key.rs b/src/net/key.rs index 975f9c04..d82f0475 100644 --- a/src/net/key.rs +++ b/src/net/key.rs @@ -1,4 +1,3 @@ -use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; @@ -18,6 +17,7 @@ use surrealdb::sql::Value; use tower_http::limit::RequestBodyLimitLayer; use super::headers::Accept; +use super::AppState; const MAX: usize = 1024 * 16; // 16 KiB @@ -68,13 +68,14 @@ where // ------------------------------ async fn select_all( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Path(table): Path, Query(query): Query, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Specify the request statement @@ -108,6 +109,7 @@ async fn select_all( } async fn create_all( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Path(table): Path, @@ -115,7 +117,7 @@ async fn create_all( body: Bytes, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Convert the HTTP request body @@ -152,6 +154,7 @@ async fn create_all( } async fn update_all( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Path(table): Path, @@ -159,7 +162,7 @@ async fn update_all( body: Bytes, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Convert the HTTP request body @@ -196,6 +199,7 @@ async fn update_all( } async fn modify_all( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Path(table): Path, @@ -203,7 +207,7 @@ async fn modify_all( body: Bytes, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Convert the HTTP request body @@ -240,13 +244,14 @@ async fn modify_all( } async fn delete_all( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Path(table): Path, Query(params): Query, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Specify the request statement @@ -278,13 +283,14 @@ async fn delete_all( // ------------------------------ async fn select_one( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Path((table, id)): Path<(String, String)>, Query(query): Query, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Specify the request statement @@ -321,6 +327,7 @@ async fn select_one( } async fn create_one( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Query(params): Query, @@ -328,7 +335,7 @@ async fn create_one( body: Bytes, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Convert the HTTP request body @@ -371,6 +378,7 @@ async fn create_one( } async fn update_one( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Query(params): Query, @@ -378,7 +386,7 @@ async fn update_one( body: Bytes, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Convert the HTTP request body @@ -421,6 +429,7 @@ async fn update_one( } async fn modify_one( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Query(params): Query, @@ -428,7 +437,7 @@ async fn modify_one( body: Bytes, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Convert the HTTP request body @@ -471,12 +480,13 @@ async fn modify_one( } async fn delete_one( + Extension(state): Extension, Extension(session): Extension, accept: Option>, Path((table, id)): Path<(String, String)>, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let _ = check_ns_db(&session)?; // Specify the request statement diff --git a/src/net/ml.rs b/src/net/ml.rs index 411ed95b..a65dd51b 100644 --- a/src/net/ml.rs +++ b/src/net/ml.rs @@ -1,5 +1,5 @@ //! This file defines the endpoints for the ML API for importing and exporting SurrealML models. -use crate::dbs::DB; +use super::AppState; use crate::err::Error; use crate::net::output; use axum::extract::{BodyStream, DefaultBodyLimit, Path}; @@ -41,11 +41,12 @@ where /// This endpoint allows the user to import a model into the database. async fn import( + Extension(state): Extension, Extension(session): Extension, mut stream: BodyStream, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let (nsv, dbv) = check_ns_db(&session)?; // Check the permissions level @@ -92,11 +93,12 @@ async fn import( /// This endpoint allows the user to export a model from the database. async fn export( + Extension(state): Extension, Extension(session): Extension, Path((name, version)): Path<(String, String)>, ) -> Result { // Get the datastore reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Ensure a NS and DB are set let (nsv, dbv) = check_ns_db(&session)?; // Check the permissions level diff --git a/src/net/mod.rs b/src/net/mod.rs index fcb1eaba..c2640f20 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -36,6 +36,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use surrealdb::headers::{AUTH_DB, AUTH_NS, DB, ID, NS}; +use surrealdb::kvs::Datastore; use tokio_util::sync::CancellationToken; use tower::ServiceBuilder; use tower_http::add_extension::AddExtensionLayer; @@ -60,14 +61,16 @@ const LOG: &str = "surrealdb::net"; #[derive(Clone)] struct AppState { client_ip: client_ip::ClientIp, + datastore: Arc, } -pub async fn init(ct: CancellationToken) -> Result<(), Error> { +pub async fn init(ds: Arc, ct: CancellationToken) -> Result<(), Error> { // Get local copy of options let opt = CF.get().unwrap(); let app_state = AppState { client_ip: opt.client_ip, + datastore: ds.clone(), }; // Specify headers to be obfuscated from all requests/responses @@ -186,7 +189,7 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> { let axum_app = axum_app.with_state(rpc_state.clone()); // Spawn a task to handle notifications - tokio::spawn(async move { notifications(rpc_state, ct.clone()).await }); + tokio::spawn(async move { notifications(ds, rpc_state, ct.clone()).await }); // If a certificate and key are specified then setup TLS if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) { // Configure certificate and private key used by https diff --git a/src/net/rpc.rs b/src/net/rpc.rs index 52b2452d..b04d8fcf 100644 --- a/src/net/rpc.rs +++ b/src/net/rpc.rs @@ -3,7 +3,6 @@ use std::ops::Deref; use std::sync::Arc; use crate::cnf; -use crate::dbs::DB; use crate::err::Error; use crate::rpc::connection::Connection; use crate::rpc::format::HttpFormat; @@ -23,6 +22,7 @@ use bytes::Bytes; use http::HeaderValue; use http_body::Body as HttpBody; use surrealdb::dbs::Session; +use surrealdb::kvs::Datastore; use surrealdb::rpc::format::Format; use surrealdb::rpc::format::PROTOCOLS; use surrealdb::rpc::method::Method; @@ -31,6 +31,7 @@ use uuid::Uuid; use super::headers::Accept; use super::headers::ContentType; +use super::AppState; use surrealdb::rpc::rpc_context::RpcContext; @@ -45,6 +46,7 @@ where async fn get_handler( ws: WebSocketUpgrade, + Extension(state): Extension, Extension(id): Extension, Extension(sess): Extension, State(rpc_state): State>, @@ -79,10 +81,18 @@ async fn get_handler( // Set the maximum WebSocket message size .max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE) // Handle the WebSocket upgrade and process messages - .on_upgrade(move |socket| handle_socket(rpc_state, socket, sess, id))) + .on_upgrade(move |socket| { + handle_socket(state.datastore.clone(), rpc_state, socket, sess, id) + })) } -async fn handle_socket(state: Arc, ws: WebSocket, sess: Session, id: Uuid) { +async fn handle_socket( + datastore: Arc, + state: Arc, + ws: WebSocket, + sess: Session, + id: Uuid, +) { // Check if there is a WebSocket protocol specified let format = match ws.protocol().map(HeaderValue::to_str) { // Any selected protocol will always be a valie value @@ -92,12 +102,13 @@ async fn handle_socket(state: Arc, ws: WebSocket, sess: Session, id: U }; // Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here // Create a new connection instance - let rpc = Connection::new(state, id, sess, format); + let rpc = Connection::new(datastore, state, id, sess, format); // Serve the socket connection requests Connection::serve(rpc, ws).await; } async fn post_handler( + Extension(state): Extension, Extension(session): Extension, output: Option>, content_type: TypedHeader, @@ -114,7 +125,7 @@ async fn post_handler( return Err(Error::InvalidType); } - let mut rpc_ctx = PostRpcContext::new(DB.get().unwrap(), session, BTreeMap::new()); + let mut rpc_ctx = PostRpcContext::new(&state.datastore, session, BTreeMap::new()); match fmt.req_http(body) { Ok(req) => { diff --git a/src/net/signin.rs b/src/net/signin.rs index e129a21e..c52a48d8 100644 --- a/src/net/signin.rs +++ b/src/net/signin.rs @@ -1,4 +1,3 @@ -use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; @@ -16,6 +15,7 @@ use surrealdb::sql::Value; use tower_http::limit::RequestBodyLimitLayer; use super::headers::Accept; +use super::AppState; const MAX: usize = 1024; // 1 KiB @@ -50,12 +50,13 @@ where } async fn handler( + Extension(state): Extension, Extension(mut session): Extension, accept: Option>, body: Bytes, ) -> Result { // Get a database reference - let kvs = DB.get().unwrap(); + let kvs = &state.datastore; // Convert the HTTP body into text let data = bytes_to_utf8(&body)?; // Parse the provided data as JSON diff --git a/src/net/signup.rs b/src/net/signup.rs index a1a07410..41943e5a 100644 --- a/src/net/signup.rs +++ b/src/net/signup.rs @@ -1,4 +1,3 @@ -use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; @@ -14,6 +13,7 @@ use surrealdb::sql::Value; use tower_http::limit::RequestBodyLimitLayer; use super::headers::Accept; +use super::AppState; const MAX: usize = 1024; // 1 KiB @@ -48,12 +48,13 @@ where } async fn handler( + Extension(state): Extension, Extension(mut session): Extension, accept: Option>, body: Bytes, ) -> Result { // Get a database reference - let kvs = DB.get().unwrap(); + let kvs = &state.datastore; // Convert the HTTP body into text let data = bytes_to_utf8(&body)?; // Parse the provided data as JSON diff --git a/src/net/sql.rs b/src/net/sql.rs index 50ed614f..8e9f6c76 100644 --- a/src/net/sql.rs +++ b/src/net/sql.rs @@ -1,4 +1,3 @@ -use crate::dbs::DB; use crate::err::Error; use crate::net::input::bytes_to_utf8; use crate::net::output; @@ -20,6 +19,7 @@ use surrealdb::dbs::Session; use tower_http::limit::RequestBodyLimitLayer; use super::headers::Accept; +use super::AppState; const MAX: usize = 1024 * 1024; // 1 MiB @@ -37,13 +37,14 @@ where } async fn post_handler( + Extension(state): Extension, Extension(session): Extension, output: Option>, params: Query, sql: Bytes, ) -> Result { // Get a database reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Convert the received sql query let sql = bytes_to_utf8(&sql)?; // Execute the received sql query @@ -65,12 +66,13 @@ async fn post_handler( async fn ws_handler( ws: WebSocketUpgrade, + Extension(state): Extension, Extension(sess): Extension, ) -> impl IntoResponse { - ws.on_upgrade(move |socket| handle_socket(socket, sess)) + ws.on_upgrade(move |socket| handle_socket(state, socket, sess)) } -async fn handle_socket(ws: WebSocket, session: Session) { +async fn handle_socket(state: AppState, ws: WebSocket, session: Session) { // Split the WebSocket connection let (mut tx, mut rx) = ws.split(); // Wait to receive the next message @@ -78,7 +80,7 @@ async fn handle_socket(ws: WebSocket, session: Session) { if let Ok(msg) = res { if let Ok(sql) = msg.to_text() { // Get a database reference - let db = DB.get().unwrap(); + let db = &state.datastore; // Execute the received sql query let _ = match db.execute(sql, &session, None).await { // Convert the response to JSON diff --git a/src/rpc/connection.rs b/src/rpc/connection.rs index fff75afb..addbc856 100644 --- a/src/rpc/connection.rs +++ b/src/rpc/connection.rs @@ -1,7 +1,6 @@ use crate::cnf::{ PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY, }; -use crate::dbs::DB; use crate::rpc::failure::Failure; use crate::rpc::format::WsFormat; use crate::rpc::response::{failure, IntoRpcResponse}; @@ -44,11 +43,13 @@ pub struct Connection { pub(crate) canceller: CancellationToken, pub(crate) channels: (Sender, Receiver), pub(crate) state: Arc, + pub(crate) datastore: Arc, } impl Connection { /// Instantiate a new RPC pub fn new( + datastore: Arc, state: Arc, id: Uuid, mut session: Session, @@ -66,6 +67,7 @@ impl Connection { canceller: CancellationToken::new(), channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS), state, + datastore, })) } @@ -77,6 +79,8 @@ impl Connection { let id = rpc_lock.id; // Get the WebSocket state let state = rpc_lock.state.clone(); + // Get the Datastore + let ds = rpc_lock.datastore.clone(); // Log the succesful WebSocket connection trace!("WebSocket {} connected", id); // Split the socket into sending and receiving streams @@ -125,7 +129,7 @@ impl Connection { true }); - if let Err(err) = DB.get().unwrap().delete_queries(gc).await { + if let Err(err) = ds.delete_queries(gc).await { error!("Error handling RPC connection: {}", err); } @@ -367,7 +371,7 @@ impl Connection { impl RpcContext for Connection { fn kvs(&self) -> &Datastore { - DB.get().unwrap() + &self.datastore } fn session(&self) -> &Session { @@ -410,7 +414,7 @@ impl RpcContext for Connection { return Err(RpcError::InvalidParams); }; let out: Result = - surrealdb::iam::signup::signup(DB.get().unwrap(), &mut self.session, v) + surrealdb::iam::signup::signup(&self.datastore, &mut self.session, v) .await .map(Into::into) .map_err(Into::into); @@ -423,7 +427,7 @@ impl RpcContext for Connection { return Err(RpcError::InvalidParams); }; let out: Result = - surrealdb::iam::signin::signin(DB.get().unwrap(), &mut self.session, v) + surrealdb::iam::signin::signin(&self.datastore, &mut self.session, v) .await .map(Into::into) .map_err(Into::into); @@ -434,7 +438,7 @@ impl RpcContext for Connection { let Ok(Value::Strand(token)) = params.needs_one() else { return Err(RpcError::InvalidParams); }; - surrealdb::iam::verify::token(DB.get().unwrap(), &mut self.session, &token.0).await?; + surrealdb::iam::verify::token(&self.datastore, &mut self.session, &token.0).await?; Ok(Value::None) } } diff --git a/src/rpc/mod.rs b/src/rpc/mod.rs index a1eba119..f7416659 100644 --- a/src/rpc/mod.rs +++ b/src/rpc/mod.rs @@ -4,7 +4,6 @@ pub mod format; pub mod post_context; pub mod response; -use crate::dbs::DB; use crate::rpc::connection::Connection; use crate::rpc::response::success; use crate::telemetry::metrics::ws::NotificationContext; @@ -12,6 +11,7 @@ use opentelemetry::Context as TelemetryContext; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; +use surrealdb::kvs::Datastore; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -41,9 +41,13 @@ impl RpcState { } /// Performs notification delivery to the WebSockets -pub(crate) async fn notifications(state: Arc, canceller: CancellationToken) { +pub(crate) async fn notifications( + ds: Arc, + state: Arc, + canceller: CancellationToken, +) { // Listen to the notifications channel - if let Some(channel) = DB.get().unwrap().notifications() { + if let Some(channel) = ds.notifications() { // Loop continuously loop { tokio::select! {