Remove global static Datastore (#4377)
This commit is contained in:
parent
968b1714dc
commit
c73435a881
15 changed files with 102 additions and 65 deletions
|
@ -4,7 +4,7 @@ use crate::cli::validator::parser::env_filter::CustomEnvFilter;
|
|||
use crate::cli::validator::parser::env_filter::CustomEnvFilterParser;
|
||||
use crate::cnf::LOGO;
|
||||
use crate::dbs;
|
||||
use crate::dbs::{StartCommandDbsOptions, DB};
|
||||
use crate::dbs::StartCommandDbsOptions;
|
||||
use crate::env;
|
||||
use crate::err::Error;
|
||||
use crate::net::{self, client_ip::ClientIp};
|
||||
|
@ -12,6 +12,7 @@ use clap::Args;
|
|||
use opentelemetry::Context;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use surrealdb::engine::any::IntoEndpoint;
|
||||
use surrealdb::engine::tasks::start_tasks;
|
||||
|
@ -189,15 +190,13 @@ pub async fn init(
|
|||
let ct = CancellationToken::new();
|
||||
// Initiate environment
|
||||
env::init().await?;
|
||||
// Start the kvs server
|
||||
dbs::init(dbs).await?;
|
||||
// Start the datastore
|
||||
let ds = Arc::new(dbs::init(dbs).await?);
|
||||
// Start the node agent
|
||||
let (tasks, task_chans) = start_tasks(
|
||||
&config::CF.get().unwrap().engine.unwrap_or_default(),
|
||||
DB.get().unwrap().clone(),
|
||||
);
|
||||
let (tasks, task_chans) =
|
||||
start_tasks(&config::CF.get().unwrap().engine.unwrap_or_default(), ds.clone());
|
||||
// Start the web server
|
||||
net::init(ct.clone()).await?;
|
||||
net::init(ds, ct.clone()).await?;
|
||||
// Shutdown and stop closed tasks
|
||||
task_chans.into_iter().for_each(|chan| {
|
||||
if chan.send(()).is_err() {
|
||||
|
|
|
@ -2,13 +2,10 @@ use crate::cli::CF;
|
|||
use crate::err::Error;
|
||||
use clap::Args;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::Duration;
|
||||
use surrealdb::dbs::capabilities::{Capabilities, FuncTarget, NetTarget, Targets};
|
||||
use surrealdb::kvs::Datastore;
|
||||
|
||||
pub static DB: OnceLock<Arc<Datastore>> = OnceLock::new();
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
pub struct StartCommandDbsOptions {
|
||||
#[arg(help = "Whether strict mode is enabled on this database instance")]
|
||||
|
@ -211,7 +208,7 @@ pub async fn init(
|
|||
capabilities,
|
||||
temporary_directory,
|
||||
}: StartCommandDbsOptions,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<Datastore, Error> {
|
||||
// Get local copy of options
|
||||
let opt = CF.get().unwrap();
|
||||
// Convert the capabilities
|
||||
|
@ -248,10 +245,8 @@ pub async fn init(
|
|||
}
|
||||
// Bootstrap the datastore
|
||||
dbs.bootstrap().await?;
|
||||
// Store database instance
|
||||
let _ = DB.set(Arc::new(dbs));
|
||||
// All ok
|
||||
Ok(())
|
||||
Ok(dbs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -15,7 +15,7 @@ use surrealdb::{
|
|||
};
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
|
||||
use crate::{dbs::DB, err::Error};
|
||||
use crate::err::Error;
|
||||
|
||||
use super::{
|
||||
client_ip::ExtractClientIP,
|
||||
|
@ -75,8 +75,6 @@ where
|
|||
}
|
||||
|
||||
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
|
||||
let kvs = DB.get().unwrap();
|
||||
|
||||
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
|
||||
if !or.is_null() {
|
||||
Some(or.to_string())
|
||||
|
@ -113,6 +111,8 @@ async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
|
|||
Error::InvalidAuth
|
||||
})?;
|
||||
|
||||
let kvs = &state.datastore;
|
||||
|
||||
let ExtractClientIP(ip) =
|
||||
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::dbs::DB;
|
||||
use super::AppState;
|
||||
use crate::err::Error;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
|
@ -21,9 +21,12 @@ where
|
|||
Router::new().route("/export", get(handler))
|
||||
}
|
||||
|
||||
async fn handler(Extension(session): Extension<Session>) -> Result<impl IntoResponse, Error> {
|
||||
async fn handler(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
) -> Result<impl IntoResponse, Error> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Create a chunked response
|
||||
let (mut chn, body) = Body::channel();
|
||||
// Ensure a NS and DB are set
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use crate::dbs::DB;
|
||||
use super::AppState;
|
||||
use crate::err::Error;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::Extension;
|
||||
use axum::Router;
|
||||
use http_body::Body as HttpBody;
|
||||
use surrealdb::kvs::{LockType::*, TransactionType::*};
|
||||
|
@ -14,9 +15,9 @@ where
|
|||
Router::new().route("/health", get(handler))
|
||||
}
|
||||
|
||||
async fn handler() -> impl IntoResponse {
|
||||
async fn handler(Extension(state): Extension<AppState>) -> impl IntoResponse {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Attempt to open a transaction
|
||||
match db.transaction(Read, Optimistic).await {
|
||||
// The transaction failed to start
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::headers::Accept;
|
||||
use crate::dbs::DB;
|
||||
use super::AppState;
|
||||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
|
@ -32,12 +32,13 @@ where
|
|||
}
|
||||
|
||||
async fn handler(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
sql: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Convert the body to a byte slice
|
||||
let sql = bytes_to_utf8(&sql)?;
|
||||
// Check the permissions level
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
|
@ -18,6 +17,7 @@ use surrealdb::sql::Value;
|
|||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
use super::headers::Accept;
|
||||
use super::AppState;
|
||||
|
||||
const MAX: usize = 1024 * 16; // 16 KiB
|
||||
|
||||
|
@ -68,13 +68,14 @@ where
|
|||
// ------------------------------
|
||||
|
||||
async fn select_all(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
Query(query): Query<QueryOptions>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Specify the request statement
|
||||
|
@ -108,6 +109,7 @@ async fn select_all(
|
|||
}
|
||||
|
||||
async fn create_all(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
|
@ -115,7 +117,7 @@ async fn create_all(
|
|||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Convert the HTTP request body
|
||||
|
@ -152,6 +154,7 @@ async fn create_all(
|
|||
}
|
||||
|
||||
async fn update_all(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
|
@ -159,7 +162,7 @@ async fn update_all(
|
|||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Convert the HTTP request body
|
||||
|
@ -196,6 +199,7 @@ async fn update_all(
|
|||
}
|
||||
|
||||
async fn modify_all(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
|
@ -203,7 +207,7 @@ async fn modify_all(
|
|||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Convert the HTTP request body
|
||||
|
@ -240,13 +244,14 @@ async fn modify_all(
|
|||
}
|
||||
|
||||
async fn delete_all(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<Params>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Specify the request statement
|
||||
|
@ -278,13 +283,14 @@ async fn delete_all(
|
|||
// ------------------------------
|
||||
|
||||
async fn select_one(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Path((table, id)): Path<(String, String)>,
|
||||
Query(query): Query<QueryOptions>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Specify the request statement
|
||||
|
@ -321,6 +327,7 @@ async fn select_one(
|
|||
}
|
||||
|
||||
async fn create_one(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Query(params): Query<Params>,
|
||||
|
@ -328,7 +335,7 @@ async fn create_one(
|
|||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Convert the HTTP request body
|
||||
|
@ -371,6 +378,7 @@ async fn create_one(
|
|||
}
|
||||
|
||||
async fn update_one(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Query(params): Query<Params>,
|
||||
|
@ -378,7 +386,7 @@ async fn update_one(
|
|||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Convert the HTTP request body
|
||||
|
@ -421,6 +429,7 @@ async fn update_one(
|
|||
}
|
||||
|
||||
async fn modify_one(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Query(params): Query<Params>,
|
||||
|
@ -428,7 +437,7 @@ async fn modify_one(
|
|||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Convert the HTTP request body
|
||||
|
@ -471,12 +480,13 @@ async fn modify_one(
|
|||
}
|
||||
|
||||
async fn delete_one(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
Path((table, id)): Path<(String, String)>,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let _ = check_ns_db(&session)?;
|
||||
// Specify the request statement
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
//! This file defines the endpoints for the ML API for importing and exporting SurrealML models.
|
||||
use crate::dbs::DB;
|
||||
use super::AppState;
|
||||
use crate::err::Error;
|
||||
use crate::net::output;
|
||||
use axum::extract::{BodyStream, DefaultBodyLimit, Path};
|
||||
|
@ -41,11 +41,12 @@ where
|
|||
|
||||
/// This endpoint allows the user to import a model into the database.
|
||||
async fn import(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
mut stream: BodyStream,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let (nsv, dbv) = check_ns_db(&session)?;
|
||||
// Check the permissions level
|
||||
|
@ -92,11 +93,12 @@ async fn import(
|
|||
|
||||
/// This endpoint allows the user to export a model from the database.
|
||||
async fn export(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
Path((name, version)): Path<(String, String)>,
|
||||
) -> Result<impl IntoResponse, Error> {
|
||||
// Get the datastore reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Ensure a NS and DB are set
|
||||
let (nsv, dbv) = check_ns_db(&session)?;
|
||||
// Check the permissions level
|
||||
|
|
|
@ -36,6 +36,7 @@ use std::net::SocketAddr;
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use surrealdb::headers::{AUTH_DB, AUTH_NS, DB, ID, NS};
|
||||
use surrealdb::kvs::Datastore;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::add_extension::AddExtensionLayer;
|
||||
|
@ -60,14 +61,16 @@ const LOG: &str = "surrealdb::net";
|
|||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
client_ip: client_ip::ClientIp,
|
||||
datastore: Arc<Datastore>,
|
||||
}
|
||||
|
||||
pub async fn init(ct: CancellationToken) -> Result<(), Error> {
|
||||
pub async fn init(ds: Arc<Datastore>, ct: CancellationToken) -> Result<(), Error> {
|
||||
// Get local copy of options
|
||||
let opt = CF.get().unwrap();
|
||||
|
||||
let app_state = AppState {
|
||||
client_ip: opt.client_ip,
|
||||
datastore: ds.clone(),
|
||||
};
|
||||
|
||||
// Specify headers to be obfuscated from all requests/responses
|
||||
|
@ -186,7 +189,7 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
|
|||
let axum_app = axum_app.with_state(rpc_state.clone());
|
||||
|
||||
// Spawn a task to handle notifications
|
||||
tokio::spawn(async move { notifications(rpc_state, ct.clone()).await });
|
||||
tokio::spawn(async move { notifications(ds, rpc_state, ct.clone()).await });
|
||||
// If a certificate and key are specified then setup TLS
|
||||
if let (Some(cert), Some(key)) = (&opt.crt, &opt.key) {
|
||||
// Configure certificate and private key used by https
|
||||
|
|
|
@ -3,7 +3,6 @@ use std::ops::Deref;
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::cnf;
|
||||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::rpc::connection::Connection;
|
||||
use crate::rpc::format::HttpFormat;
|
||||
|
@ -23,6 +22,7 @@ use bytes::Bytes;
|
|||
use http::HeaderValue;
|
||||
use http_body::Body as HttpBody;
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::kvs::Datastore;
|
||||
use surrealdb::rpc::format::Format;
|
||||
use surrealdb::rpc::format::PROTOCOLS;
|
||||
use surrealdb::rpc::method::Method;
|
||||
|
@ -31,6 +31,7 @@ use uuid::Uuid;
|
|||
|
||||
use super::headers::Accept;
|
||||
use super::headers::ContentType;
|
||||
use super::AppState;
|
||||
|
||||
use surrealdb::rpc::rpc_context::RpcContext;
|
||||
|
||||
|
@ -45,6 +46,7 @@ where
|
|||
|
||||
async fn get_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(id): Extension<RequestId>,
|
||||
Extension(sess): Extension<Session>,
|
||||
State(rpc_state): State<Arc<RpcState>>,
|
||||
|
@ -79,10 +81,18 @@ async fn get_handler(
|
|||
// Set the maximum WebSocket message size
|
||||
.max_message_size(*cnf::WEBSOCKET_MAX_MESSAGE_SIZE)
|
||||
// Handle the WebSocket upgrade and process messages
|
||||
.on_upgrade(move |socket| handle_socket(rpc_state, socket, sess, id)))
|
||||
.on_upgrade(move |socket| {
|
||||
handle_socket(state.datastore.clone(), rpc_state, socket, sess, id)
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_socket(state: Arc<RpcState>, ws: WebSocket, sess: Session, id: Uuid) {
|
||||
async fn handle_socket(
|
||||
datastore: Arc<Datastore>,
|
||||
state: Arc<RpcState>,
|
||||
ws: WebSocket,
|
||||
sess: Session,
|
||||
id: Uuid,
|
||||
) {
|
||||
// Check if there is a WebSocket protocol specified
|
||||
let format = match ws.protocol().map(HeaderValue::to_str) {
|
||||
// Any selected protocol will always be a valie value
|
||||
|
@ -92,12 +102,13 @@ async fn handle_socket(state: Arc<RpcState>, ws: WebSocket, sess: Session, id: U
|
|||
};
|
||||
// Format::Unsupported is not in the PROTOCOLS list so cannot be the value of format here
|
||||
// Create a new connection instance
|
||||
let rpc = Connection::new(state, id, sess, format);
|
||||
let rpc = Connection::new(datastore, state, id, sess, format);
|
||||
// Serve the socket connection requests
|
||||
Connection::serve(rpc, ws).await;
|
||||
}
|
||||
|
||||
async fn post_handler(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
output: Option<TypedHeader<Accept>>,
|
||||
content_type: TypedHeader<ContentType>,
|
||||
|
@ -114,7 +125,7 @@ async fn post_handler(
|
|||
return Err(Error::InvalidType);
|
||||
}
|
||||
|
||||
let mut rpc_ctx = PostRpcContext::new(DB.get().unwrap(), session, BTreeMap::new());
|
||||
let mut rpc_ctx = PostRpcContext::new(&state.datastore, session, BTreeMap::new());
|
||||
|
||||
match fmt.req_http(body) {
|
||||
Ok(req) => {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
|
@ -16,6 +15,7 @@ use surrealdb::sql::Value;
|
|||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
use super::headers::Accept;
|
||||
use super::AppState;
|
||||
|
||||
const MAX: usize = 1024; // 1 KiB
|
||||
|
||||
|
@ -50,12 +50,13 @@ where
|
|||
}
|
||||
|
||||
async fn handler(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(mut session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
let kvs = &state.datastore;
|
||||
// Convert the HTTP body into text
|
||||
let data = bytes_to_utf8(&body)?;
|
||||
// Parse the provided data as JSON
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
|
@ -14,6 +13,7 @@ use surrealdb::sql::Value;
|
|||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
use super::headers::Accept;
|
||||
use super::AppState;
|
||||
|
||||
const MAX: usize = 1024; // 1 KiB
|
||||
|
||||
|
@ -48,12 +48,13 @@ where
|
|||
}
|
||||
|
||||
async fn handler(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(mut session): Extension<Session>,
|
||||
accept: Option<TypedHeader<Accept>>,
|
||||
body: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get a database reference
|
||||
let kvs = DB.get().unwrap();
|
||||
let kvs = &state.datastore;
|
||||
// Convert the HTTP body into text
|
||||
let data = bytes_to_utf8(&body)?;
|
||||
// Parse the provided data as JSON
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::dbs::DB;
|
||||
use crate::err::Error;
|
||||
use crate::net::input::bytes_to_utf8;
|
||||
use crate::net::output;
|
||||
|
@ -20,6 +19,7 @@ use surrealdb::dbs::Session;
|
|||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
|
||||
use super::headers::Accept;
|
||||
use super::AppState;
|
||||
|
||||
const MAX: usize = 1024 * 1024; // 1 MiB
|
||||
|
||||
|
@ -37,13 +37,14 @@ where
|
|||
}
|
||||
|
||||
async fn post_handler(
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(session): Extension<Session>,
|
||||
output: Option<TypedHeader<Accept>>,
|
||||
params: Query<Params>,
|
||||
sql: Bytes,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Get a database reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Convert the received sql query
|
||||
let sql = bytes_to_utf8(&sql)?;
|
||||
// Execute the received sql query
|
||||
|
@ -65,12 +66,13 @@ async fn post_handler(
|
|||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(sess): Extension<Session>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, sess))
|
||||
ws.on_upgrade(move |socket| handle_socket(state, socket, sess))
|
||||
}
|
||||
|
||||
async fn handle_socket(ws: WebSocket, session: Session) {
|
||||
async fn handle_socket(state: AppState, ws: WebSocket, session: Session) {
|
||||
// Split the WebSocket connection
|
||||
let (mut tx, mut rx) = ws.split();
|
||||
// Wait to receive the next message
|
||||
|
@ -78,7 +80,7 @@ async fn handle_socket(ws: WebSocket, session: Session) {
|
|||
if let Ok(msg) = res {
|
||||
if let Ok(sql) = msg.to_text() {
|
||||
// Get a database reference
|
||||
let db = DB.get().unwrap();
|
||||
let db = &state.datastore;
|
||||
// Execute the received sql query
|
||||
let _ = match db.execute(sql, &session, None).await {
|
||||
// Convert the response to JSON
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::cnf::{
|
||||
PKG_NAME, PKG_VERSION, WEBSOCKET_MAX_CONCURRENT_REQUESTS, WEBSOCKET_PING_FREQUENCY,
|
||||
};
|
||||
use crate::dbs::DB;
|
||||
use crate::rpc::failure::Failure;
|
||||
use crate::rpc::format::WsFormat;
|
||||
use crate::rpc::response::{failure, IntoRpcResponse};
|
||||
|
@ -44,11 +43,13 @@ pub struct Connection {
|
|||
pub(crate) canceller: CancellationToken,
|
||||
pub(crate) channels: (Sender<Message>, Receiver<Message>),
|
||||
pub(crate) state: Arc<RpcState>,
|
||||
pub(crate) datastore: Arc<Datastore>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Instantiate a new RPC
|
||||
pub fn new(
|
||||
datastore: Arc<Datastore>,
|
||||
state: Arc<RpcState>,
|
||||
id: Uuid,
|
||||
mut session: Session,
|
||||
|
@ -66,6 +67,7 @@ impl Connection {
|
|||
canceller: CancellationToken::new(),
|
||||
channels: channel::bounded(*WEBSOCKET_MAX_CONCURRENT_REQUESTS),
|
||||
state,
|
||||
datastore,
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -77,6 +79,8 @@ impl Connection {
|
|||
let id = rpc_lock.id;
|
||||
// Get the WebSocket state
|
||||
let state = rpc_lock.state.clone();
|
||||
// Get the Datastore
|
||||
let ds = rpc_lock.datastore.clone();
|
||||
// Log the succesful WebSocket connection
|
||||
trace!("WebSocket {} connected", id);
|
||||
// Split the socket into sending and receiving streams
|
||||
|
@ -125,7 +129,7 @@ impl Connection {
|
|||
true
|
||||
});
|
||||
|
||||
if let Err(err) = DB.get().unwrap().delete_queries(gc).await {
|
||||
if let Err(err) = ds.delete_queries(gc).await {
|
||||
error!("Error handling RPC connection: {}", err);
|
||||
}
|
||||
|
||||
|
@ -367,7 +371,7 @@ impl Connection {
|
|||
|
||||
impl RpcContext for Connection {
|
||||
fn kvs(&self) -> &Datastore {
|
||||
DB.get().unwrap()
|
||||
&self.datastore
|
||||
}
|
||||
|
||||
fn session(&self) -> &Session {
|
||||
|
@ -410,7 +414,7 @@ impl RpcContext for Connection {
|
|||
return Err(RpcError::InvalidParams);
|
||||
};
|
||||
let out: Result<Value, RpcError> =
|
||||
surrealdb::iam::signup::signup(DB.get().unwrap(), &mut self.session, v)
|
||||
surrealdb::iam::signup::signup(&self.datastore, &mut self.session, v)
|
||||
.await
|
||||
.map(Into::into)
|
||||
.map_err(Into::into);
|
||||
|
@ -423,7 +427,7 @@ impl RpcContext for Connection {
|
|||
return Err(RpcError::InvalidParams);
|
||||
};
|
||||
let out: Result<Value, RpcError> =
|
||||
surrealdb::iam::signin::signin(DB.get().unwrap(), &mut self.session, v)
|
||||
surrealdb::iam::signin::signin(&self.datastore, &mut self.session, v)
|
||||
.await
|
||||
.map(Into::into)
|
||||
.map_err(Into::into);
|
||||
|
@ -434,7 +438,7 @@ impl RpcContext for Connection {
|
|||
let Ok(Value::Strand(token)) = params.needs_one() else {
|
||||
return Err(RpcError::InvalidParams);
|
||||
};
|
||||
surrealdb::iam::verify::token(DB.get().unwrap(), &mut self.session, &token.0).await?;
|
||||
surrealdb::iam::verify::token(&self.datastore, &mut self.session, &token.0).await?;
|
||||
Ok(Value::None)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ pub mod format;
|
|||
pub mod post_context;
|
||||
pub mod response;
|
||||
|
||||
use crate::dbs::DB;
|
||||
use crate::rpc::connection::Connection;
|
||||
use crate::rpc::response::success;
|
||||
use crate::telemetry::metrics::ws::NotificationContext;
|
||||
|
@ -12,6 +11,7 @@ use opentelemetry::Context as TelemetryContext;
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use surrealdb::kvs::Datastore;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
@ -41,9 +41,13 @@ impl RpcState {
|
|||
}
|
||||
|
||||
/// Performs notification delivery to the WebSockets
|
||||
pub(crate) async fn notifications(state: Arc<RpcState>, canceller: CancellationToken) {
|
||||
pub(crate) async fn notifications(
|
||||
ds: Arc<Datastore>,
|
||||
state: Arc<RpcState>,
|
||||
canceller: CancellationToken,
|
||||
) {
|
||||
// Listen to the notifications channel
|
||||
if let Some(channel) = DB.get().unwrap().notifications() {
|
||||
if let Some(channel) = ds.notifications() {
|
||||
// Loop continuously
|
||||
loop {
|
||||
tokio::select! {
|
||||
|
|
Loading…
Reference in a new issue