Refactor ds tasks to single part of codebase ()

This commit is contained in:
Przemyslaw Hugh Kaznowski 2024-03-18 12:30:31 +00:00 committed by GitHub
parent 0728afd60c
commit 47a1589018
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 520 additions and 403 deletions

21
core/src/dbs/lifecycle.rs Normal file
View file

@ -0,0 +1,21 @@
/// LoggingLifecycle is used to create log messages upon creation, and log messages when it is dropped
#[doc(hidden)]
pub struct LoggingLifecycle {
identifier: String,
}
impl LoggingLifecycle {
#[doc(hidden)]
pub fn new(identifier: String) -> Self {
debug!("Started {}", identifier);
Self {
identifier,
}
}
}
impl Drop for LoggingLifecycle {
fn drop(&mut self) {
debug!("Stopped {}", self.identifier);
}
}

View file

@ -14,6 +14,7 @@ mod statement;
mod transaction;
mod variables;
pub use self::lifecycle::*;
pub use self::notification::*;
pub use self::options::*;
pub use self::response::*;
@ -30,6 +31,8 @@ pub use self::capabilities::Capabilities;
pub mod node;
mod group;
#[doc(hidden)]
pub mod lifecycle;
mod processor;
mod result;
mod store;

View file

@ -910,6 +910,10 @@ pub enum Error {
/// The session has an invalid expiration
#[error("The session has an invalid expiration")]
InvalidSessionExpiration,
/// A node task has failed
#[error("A node task has failed: {0}")]
NodeAgent(&'static str),
}
impl From<Error> for String {

View file

@ -1,3 +1,5 @@
use std::time::Duration;
/// Configuration for the engine behaviour
/// The defaults are optimal so please only modify these if you know deliberately why you are modifying them.
#[derive(Clone, Copy, Debug)]
@ -7,6 +9,7 @@ pub struct EngineOptions {
pub new_live_queries_per_transaction: u32,
/// The size of batches being requested per update in order to catch up a live query
pub live_query_catchup_size: u32,
pub tick_interval: Duration,
}
impl Default for EngineOptions {
@ -14,6 +17,7 @@ impl Default for EngineOptions {
Self {
new_live_queries_per_transaction: 100,
live_query_catchup_size: 1000,
tick_interval: Duration::from_secs(1),
}
}
}

View file

@ -5,20 +5,18 @@ use crate::api::conn::Param;
use crate::api::conn::Route;
use crate::api::conn::Router;
use crate::api::engine::local::Db;
use crate::api::engine::local::DEFAULT_TICK_INTERVAL;
use crate::api::opt::{Endpoint, EndpointKind};
use crate::api::ExtraFeatures;
use crate::api::OnceLockExt;
use crate::api::Result;
use crate::api::Surreal;
use crate::dbs::Options;
use crate::dbs::Session;
use crate::engine::IntervalStream;
use crate::fflags::FFLAGS;
use crate::engine::tasks::start_tasks;
use crate::iam::Level;
use crate::kvs::Datastore;
use crate::opt::auth::Root;
use crate::opt::WaitFor;
use crate::options::EngineOptions;
use flume::Receiver;
use flume::Sender;
use futures::future::Either;
@ -35,10 +33,7 @@ use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use std::task::Poll;
use std::time::Duration;
use tokio::sync::watch;
use tokio::time;
use tokio::time::MissedTickBehavior;
impl crate::api::Connection for Db {}
@ -156,9 +151,20 @@ pub(crate) fn router(
let mut live_queries = HashMap::new();
let mut session = Session::default().with_rt(true);
let (maintenance_tx, maintenance_rx) = flume::bounded::<()>(1);
let tick_interval = address.config.tick_interval.unwrap_or(DEFAULT_TICK_INTERVAL);
run_maintenance(kvs.clone(), tick_interval, maintenance_rx);
#[cfg(feature = "sql2")]
let opt = {
let tick_interval = address
.config
.tick_interval
.unwrap_or(crate::api::engine::local::DEFAULT_TICK_INTERVAL);
EngineOptions {
tick_interval,
..Default::default()
}
};
#[cfg(not(feature = "sql2"))]
let opt = EngineOptions::default();
let (tasks, task_chans) = start_tasks(&opt, kvs.clone());
let mut notifications = kvs.notifications();
let notification_stream = poll_fn(move |cx| match &mut notifications {
@ -207,65 +213,11 @@ pub(crate) fn router(
}
// Stop maintenance tasks
let _ = maintenance_tx.into_send_async(()).await;
});
}
fn run_maintenance(kvs: Arc<Datastore>, tick_interval: Duration, stop_signal: Receiver<()>) {
trace!("Starting maintenance");
// Some classic ownership shenanigans
let kvs_two = kvs.clone();
let stop_signal_two = stop_signal.clone();
// Spawn the ticker, which is used for tracking versionstamps and heartbeats across databases
tokio::spawn(async move {
let mut interval = time::interval(tick_interval);
// Don't bombard the database if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
// Delay sending the first tick
interval.tick().await;
let ticker = IntervalStream::new(interval);
let streams = (ticker.map(Some), stop_signal.into_stream().map(|_| None));
let mut stream = streams.merge();
while let Some(Some(_)) = stream.next().await {
match kvs.clone().tick().await {
Ok(()) => trace!("Node agent tick ran successfully"),
Err(error) => error!("Error running node agent tick: {error}"),
for chan in task_chans {
if let Err(e) = chan.send(()) {
error!("Error sending shutdown signal to task: {}", e);
}
}
tasks.resolve().await.unwrap();
});
if FFLAGS.change_feed_live_queries.enabled() {
trace!("Live queries v2 enabled");
// Spawn the live query change feed consumer, which is used for catching up on relevant change feeds
tokio::spawn(async move {
let kvs = kvs_two;
let stop_signal = stop_signal_two;
let mut interval = time::interval(tick_interval);
// Don't bombard the database if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
// Delay sending the first tick
interval.tick().await;
let ticker = IntervalStream::new(interval);
let streams = (ticker.map(Some), stop_signal.into_stream().map(|_| None));
let mut stream = streams.merge();
let opt = Options::default();
while let Some(Some(_)) = stream.next().await {
match kvs.process_lq_notifications(&opt).await {
Ok(()) => trace!("Live Query poll ran successfully"),
Err(error) => error!("Error running live query poll: {error}"),
}
}
});
} else {
trace!("Live queries v2 disabled")
}
}

View file

@ -13,12 +13,12 @@ use crate::api::Result;
use crate::api::Surreal;
use crate::dbs::Options;
use crate::dbs::Session;
use crate::engine::IntervalStream;
use crate::fflags::FFLAGS;
use crate::engine::tasks::start_tasks;
use crate::iam::Level;
use crate::kvs::Datastore;
use crate::opt::auth::Root;
use crate::opt::WaitFor;
use crate::options::EngineOptions;
use flume::Receiver;
use flume::Sender;
use futures::future::Either;
@ -35,11 +35,8 @@ use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use std::task::Poll;
use std::time::Duration;
use tokio::sync::watch;
use wasm_bindgen_futures::spawn_local;
use wasmtimer::tokio as time;
use wasmtimer::tokio::MissedTickBehavior;
impl crate::api::Connection for Db {}
@ -151,9 +148,12 @@ pub(crate) fn router(
let mut live_queries = HashMap::new();
let mut session = Session::default().with_rt(true);
let (maintenance_tx, maintenance_rx) = flume::bounded::<()>(1);
let tick_interval = address.config.tick_interval.unwrap_or(DEFAULT_TICK_INTERVAL);
run_maintenance(kvs.clone(), tick_interval, maintenance_rx);
let opt = EngineOptions {
tick_interval,
..Default::default()
};
let (tasks, task_chans) = start_tasks(&opt, kvs.clone());
let mut notifications = kvs.notifications();
let notification_stream = poll_fn(move |cx| match &mut notifications {
@ -202,61 +202,10 @@ pub(crate) fn router(
}
// Stop maintenance tasks
let _ = maintenance_tx.into_send_async(()).await;
});
}
fn run_maintenance(kvs: Arc<Datastore>, tick_interval: Duration, stop_signal: Receiver<()>) {
// Some classic ownership shenanigans
let kvs_two = kvs.clone();
let stop_signal_two = stop_signal.clone();
// Spawn the ticker, which is used for tracking versionstamps and heartbeats across databases
spawn_local(async move {
let mut interval = time::interval(tick_interval);
// Don't bombard the database if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
// Delay sending the first tick
interval.tick().await;
let ticker = IntervalStream::new(interval);
let streams = (ticker.map(Some), stop_signal.into_stream().map(|_| None));
let mut stream = streams.merge();
while let Some(Some(_)) = stream.next().await {
match kvs.tick().await {
Ok(()) => trace!("Node agent tick ran successfully"),
Err(error) => error!("Error running node agent tick: {error}"),
for chan in task_chans {
if let Err(e) = chan.send(()) {
error!("Error sending shutdown signal to maintenance task: {e}");
}
}
});
if FFLAGS.change_feed_live_queries.enabled() {
// Spawn the live query change feed consumer, which is used for catching up on relevant change feeds
spawn_local(async move {
let kvs = kvs_two;
let stop_signal = stop_signal_two;
let mut interval = time::interval(tick_interval);
// Don't bombard the database if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
// Delay sending the first tick
interval.tick().await;
let ticker = IntervalStream::new(interval);
let streams = (ticker.map(Some), stop_signal.into_stream().map(|_| None));
let mut stream = streams.merge();
let opt = Options::default();
while let Some(Some(_)) = stream.next().await {
match kvs.process_lq_notifications(&opt).await {
Ok(()) => trace!("Live Query poll ran successfully"),
Err(error) => error!("Error running live query poll: {error}"),
}
}
})
}
}

View file

@ -13,6 +13,8 @@ pub mod any;
pub mod local;
#[cfg(any(feature = "protocol-http", feature = "protocol-ws"))]
pub mod remote;
#[doc(hidden)]
pub mod tasks;
use crate::sql::statements::CreateStatement;
use crate::sql::statements::DeleteStatement;

212
lib/src/api/engine/tasks.rs Normal file
View file

@ -0,0 +1,212 @@
use flume::Sender;
use futures::StreamExt;
use futures_concurrency::stream::Merge;
#[cfg(target_arch = "wasm32")]
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
#[cfg(not(target_arch = "wasm32"))]
use tokio::task::JoinHandle;
use crate::dbs::Options;
use crate::fflags::FFLAGS;
use crate::kvs::Datastore;
use crate::options::EngineOptions;
use crate::engine::IntervalStream;
#[cfg(not(target_arch = "wasm32"))]
use crate::Error as RootError;
#[cfg(not(target_arch = "wasm32"))]
use tokio::spawn as spawn_future;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local as spawn_future;
#[cfg(not(target_arch = "wasm32"))]
type FutureTask = JoinHandle<()>;
#[cfg(target_arch = "wasm32")]
/// This will be true if a task has completed
type FutureTask = Arc<AtomicBool>;
pub struct Tasks {
pub nd: FutureTask,
pub lq: FutureTask,
}
impl Tasks {
#[cfg(not(target_arch = "wasm32"))]
pub async fn resolve(self) -> Result<(), RootError> {
self.nd.await.map_err(|e| {
error!("Node agent task failed: {}", e);
#[cfg(not(feature = "sql2"))]
let inner_err = surrealdb_core1::err::Error::Unreachable(
"This feature won't go live with sql1, so delete this branching",
);
#[cfg(feature = "sql2")]
let inner_err = surrealdb_core2::err::Error::NodeAgent("node task failed and has been logged");
RootError::Db(inner_err)
})?;
self.lq.await.map_err(|e| {
error!("Live query task failed: {}", e);
#[cfg(not(feature = "sql2"))]
let inner_err = surrealdb_core1::err::Error::Unreachable(
"This feature won't go live with sql1, so delete this branching",
);
#[cfg(feature = "sql2")]
let inner_err = surrealdb_core2::err::Error::NodeAgent(
"live query task failed and has been logged",
);
RootError::Db(inner_err)
})?;
Ok(())
}
}
/// Starts tasks that are required for the correct running of the engine
pub fn start_tasks(opt: &EngineOptions, dbs: Arc<Datastore>) -> (Tasks, [Sender<()>; 2]) {
let nd = init(opt, dbs.clone());
let lq = live_query_change_feed(opt, dbs);
let cancellation_channels = [nd.1, lq.1];
(
Tasks {
nd: nd.0,
lq: lq.0,
},
cancellation_channels,
)
}
// The init starts a long-running thread for periodically calling Datastore.tick.
// Datastore.tick is responsible for running garbage collection and other
// background tasks.
//
// This function needs to be called before after the dbs::init and before the net::init functions.
// It needs to be before net::init because the net::init function blocks until the web server stops.
fn init(opt: &EngineOptions, dbs: Arc<Datastore>) -> (FutureTask, Sender<()>) {
#[cfg(feature = "sql2")]
let _init = crate::dbs::LoggingLifecycle::new("node agent initialisation".to_string());
#[cfg(feature = "sql2")]
let tick_interval = opt.tick_interval;
#[cfg(not(feature = "sql2"))]
let tick_interval = Duration::from_secs(1);
trace!("Ticker interval is {:?}", tick_interval);
#[cfg(target_arch = "wasm32")]
let completed_status = Arc::new(AtomicBool::new(false));
#[cfg(target_arch = "wasm32")]
let ret_status = completed_status.clone();
// We create a channel that can be streamed that will indicate termination
let (tx, rx) = flume::bounded(1);
let _fut = spawn_future(async move {
#[cfg(feature = "sql2")]
let _lifecycle = crate::dbs::LoggingLifecycle::new("heartbeat task".to_string());
let ticker = interval_ticker(tick_interval).await;
let streams = (
ticker.map(|i| {
trace!("Node agent tick: {:?}", i);
Some(i)
}),
rx.into_stream().map(|_| None),
);
let mut streams = streams.merge();
while let Some(Some(_)) = streams.next().await {
if let Err(e) = dbs.tick().await {
error!("Error running node agent tick: {}", e);
break;
}
}
#[cfg(target_arch = "wasm32")]
completed_status.store(true, Ordering::Relaxed);
});
#[cfg(not(target_arch = "wasm32"))]
return (_fut, tx);
#[cfg(target_arch = "wasm32")]
return (ret_status, tx);
}
// Start live query on change feeds notification processing
fn live_query_change_feed(opt: &EngineOptions, dbs: Arc<Datastore>) -> (FutureTask, Sender<()>) {
#[cfg(feature = "sql2")]
let tick_interval = opt.tick_interval;
#[cfg(not(feature = "sql2"))]
let tick_interval = Duration::from_secs(1);
#[cfg(target_arch = "wasm32")]
let completed_status = Arc::new(AtomicBool::new(false));
#[cfg(target_arch = "wasm32")]
let ret_status = completed_status.clone();
// We create a channel that can be streamed that will indicate termination
let (tx, rx) = flume::bounded(1);
let _fut = spawn_future(async move {
#[cfg(feature = "sql2")]
let _lifecycle = crate::dbs::LoggingLifecycle::new("live query agent task".to_string());
if !FFLAGS.change_feed_live_queries.enabled() {
// TODO verify test fails since return without completion
#[cfg(target_arch = "wasm32")]
completed_status.store(true, Ordering::Relaxed);
return;
}
let ticker = interval_ticker(tick_interval).await;
let streams = (
ticker.map(|i| {
trace!("Live query agent tick: {:?}", i);
Some(i)
}),
rx.into_stream().map(|_| None),
);
let mut streams = streams.merge();
let opt = Options::default();
while let Some(Some(_)) = streams.next().await {
if let Err(e) = dbs.process_lq_notifications(&opt).await {
error!("Error running node agent tick: {}", e);
break;
}
}
#[cfg(target_arch = "wasm32")]
completed_status.store(true, Ordering::Relaxed);
});
#[cfg(not(target_arch = "wasm32"))]
return (_fut, tx);
#[cfg(target_arch = "wasm32")]
return (ret_status, tx);
}
async fn interval_ticker(interval: Duration) -> IntervalStream {
#[cfg(not(target_arch = "wasm32"))]
use tokio::{time, time::MissedTickBehavior};
#[cfg(target_arch = "wasm32")]
use wasmtimer::{tokio as time, tokio::MissedTickBehavior};
let mut interval = time::interval(interval);
// Don't bombard the database if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
interval.tick().await;
IntervalStream::new(interval)
}
#[cfg(test)]
#[cfg(feature = "kv-mem")]
mod test {
use crate::engine::tasks::start_tasks;
use crate::kvs::Datastore;
use crate::options::EngineOptions;
use std::sync::Arc;
#[test_log::test(tokio::test)]
pub async fn tasks_complete() {
let opt = EngineOptions::default();
let dbs = Arc::new(Datastore::new("memory").await.unwrap());
let (val, chans) = start_tasks(&opt, dbs.clone());
for chan in chans {
chan.send(()).unwrap();
}
val.resolve().await.unwrap();
}
}

View file

@ -137,20 +137,20 @@ async fn database_change_feeds() -> Result<(), Error> {
);
Some(&tmp)
.filter(|x| *x == &val)
.map(|v| ())
.map(|_v| ())
.ok_or(format!("Expected UPDATE value:\nleft: {}\nright: {}", tmp, val))?;
// DELETE
let tmp = res.remove(0).result?;
let val = Value::parse("[]");
Some(&tmp)
.filter(|x| *x == &val)
.map(|v| ())
.map(|_v| ())
.ok_or(format!("Expected DELETE value:\nleft: {}\nright: {}", tmp, val))?;
// SHOW CHANGES
let tmp = res.remove(0).result?;
Some(&tmp)
.filter(|x| *x == cf_val_arr)
.map(|v| ())
.map(|_v| ())
.ok_or(format!("Expected SHOW CHANGES value:\nleft: {}\nright: {}", tmp, cf_val_arr))?;
Ok(())
}

View file

@ -4,17 +4,17 @@ use crate::cli::validator::parser::env_filter::CustomEnvFilter;
use crate::cli::validator::parser::env_filter::CustomEnvFilterParser;
use crate::cnf::LOGO;
use crate::dbs;
use crate::dbs::StartCommandDbsOptions;
use crate::dbs::{StartCommandDbsOptions, DB};
use crate::env;
use crate::err::Error;
use crate::net::{self, client_ip::ClientIp};
use crate::node;
use clap::Args;
use opentelemetry::Context as TelemetryContext;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::time::Duration;
use surrealdb::engine::any::IntoEndpoint;
use surrealdb::engine::tasks::start_tasks;
use tokio_util::sync::CancellationToken;
#[derive(Args, Debug)]
@ -183,20 +183,20 @@ pub async fn init(
// Start the kvs server
dbs::init(dbs).await?;
// Start the node agent
// This is equivalent to run_maintenance in native/wasm drivers
let nd = node::init(ct.clone());
let lq = node::live_query_change_feed(ct.clone());
let (tasks, task_chans) = start_tasks(
&config::CF.get().unwrap().engine.unwrap_or_default(),
DB.get().unwrap().clone(),
);
// Start the web server
net::init(ct).await?;
// Wait for the node agent to stop
if let Err(e) = nd.await {
error!("Node agent failed while running: {}", e);
return Err(Error::NodeAgent);
}
if let Err(e) = lq.await {
error!("Live query change feed failed while running: {}", e);
return Err(Error::NodeAgent);
}
net::init(ct.clone()).await?;
// Shutdown and stop closed tasks
task_chans.into_iter().for_each(|chan| {
if let Err(e) = chan.send(()) {
error!("Failed to send shutdown signal to task: {}", e);
}
});
ct.cancel();
tasks.resolve().await?;
// All ok
Ok(())
}

View file

@ -1,12 +1,12 @@
use crate::cli::CF;
use crate::err::Error;
use clap::Args;
use std::sync::OnceLock;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use surrealdb::dbs::capabilities::{Capabilities, FuncTarget, NetTarget, Targets};
use surrealdb::kvs::Datastore;
pub static DB: OnceLock<Datastore> = OnceLock::new();
pub static DB: OnceLock<Arc<Datastore>> = OnceLock::new();
#[derive(Args, Debug)]
pub struct StartCommandDbsOptions {
@ -266,7 +266,7 @@ pub async fn init(
}
// Store database instance
let _ = DB.set(dbs);
let _ = DB.set(Arc::new(dbs));
// All ok
Ok(())

View file

@ -1,72 +1 @@
use std::time::Duration;
use surrealdb::dbs::Options;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use surrealdb::fflags::FFLAGS;
use crate::cli::CF;
const LOG: &str = "surrealdb::node";
// The init starts a long-running thread for periodically calling Datastore.tick.
// Datastore.tick is responsible for running garbage collection and other
// background tasks.
//
// This function needs to be called before after the dbs::init and before the net::init functions.
// It needs to be before net::init because the net::init function blocks until the web server stops.
pub fn init(ct: CancellationToken) -> JoinHandle<()> {
let opt = CF.get().unwrap();
let tick_interval = opt.tick_interval;
info!(target: LOG, "Started node agent");
// This requires the nodes::init function to be called after the dbs::init function.
let dbs = crate::dbs::DB.get().unwrap();
tokio::spawn(async move {
loop {
if let Err(e) = dbs.tick().await {
error!("Error running node agent tick: {}", e);
}
tokio::select! {
_ = ct.cancelled() => {
info!(target: LOG, "Gracefully stopping node agent");
break;
}
_ = tokio::time::sleep(tick_interval) => {}
}
}
info!(target: LOG, "Stopped node agent");
})
}
// Start live query on change feeds notification processing
pub fn live_query_change_feed(ct: CancellationToken) -> JoinHandle<()> {
tokio::spawn(async move {
if !FFLAGS.change_feed_live_queries.enabled() {
return;
}
// Spawn the live query change feed consumer, which is used for catching up on relevant change feeds
tokio::spawn(async move {
let kvs = crate::dbs::DB.get().unwrap();
let tick_interval = Duration::from_secs(1);
let opt = Options::default();
loop {
if let Err(e) = kvs.process_lq_notifications(&opt).await {
error!("Error running node agent live query tick: {}", e);
}
tokio::select! {
_ = ct.cancelled() => {
info!(target: LOG, "Gracefully stopping live query node agent");
break;
}
_ = tokio::time::sleep(tick_interval) => {}
}
}
info!("Stopped live query node agent")
});
})
}

View file

@ -9,6 +9,7 @@ mod cli_integration {
use serde_json::json;
use std::fs::File;
use std::time;
use std::time::Duration;
use surrealdb::fflags::FFLAGS;
use test_log::test;
use tokio::time::sleep;
@ -17,8 +18,8 @@ mod cli_integration {
use super::common::{self, StartServerArguments, PASS, USER};
const ONE_SEC: time::Duration = time::Duration::new(1, 0);
const TWO_SECS: time::Duration = time::Duration::new(2, 0);
const ONE_SEC: Duration = Duration::new(1, 0);
const TWO_SECS: Duration = Duration::new(2, 0);
#[test]
fn version_command() {
@ -257,7 +258,7 @@ mod cli_integration {
#[test(tokio::test)]
async fn with_root_auth() {
// Commands with credentials when auth is enabled, should succeed
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
let creds = format!("--user {USER} --pass {PASS}");
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
@ -298,13 +299,13 @@ mod cli_integration {
common::run(&args).output().unwrap_or_else(|_| panic!("failed to run import: {args}"));
}
server.finish()
server.finish().unwrap();
}
#[test(tokio::test)]
async fn with_auth_level() {
// Commands with credentials for different auth levels
let (addr, server) = common::start_server_with_auth_level().await.unwrap();
let (addr, mut server) = common::start_server_with_auth_level().await.unwrap();
let creds = format!("--user {USER} --pass {PASS}");
let ns = Ulid::new();
let db = Ulid::new();
@ -483,14 +484,14 @@ mod cli_integration {
output
);
}
server.finish();
server.finish().unwrap();
}
#[test(tokio::test)]
// TODO(gguillemas): Remove this test once the legacy authentication is deprecated in v2.0.0
async fn without_auth_level() {
// Commands with credentials for different auth levels
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
let creds = format!("--user {USER} --pass {PASS}");
// Prefix with 'a' so that we don't start with a number and cause a parsing error
let ns = format!("a{}", Ulid::new());
@ -551,13 +552,13 @@ mod cli_integration {
output
);
}
server.finish()
server.finish().unwrap();
}
#[test(tokio::test)]
async fn with_anon_auth() {
// Commands without credentials when auth is enabled, should fail
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
let creds = ""; // Anonymous user
let sql_args = format!("sql --conn http://{addr} --multi --pretty");
@ -606,14 +607,13 @@ mod cli_integration {
output
);
}
server.finish();
server.finish().unwrap();
}
#[test(tokio::test)]
async fn node() {
// Commands without credentials when auth is disabled, should succeed
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
auth: false,
tls: false,
wait_is_ready: true,
@ -666,15 +666,19 @@ mod cli_integration {
.output()
.unwrap();
let output = remove_debug_info(output).replace('\n', "");
// TODO: when enabling the feature flag, turn these to `create` not `update`
let allowed = [
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 1 }, { changes: [{ create: { id: thing:one } }], versionstamp: 2 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 1 }, { changes: [{ create: { id: thing:one } }], versionstamp: 3 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 2 }, { changes: [{ create: { id: thing:one } }], versionstamp: 3 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 2 }, { changes: [{ create: { id: thing:one } }], versionstamp: 4 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 65536 }, { changes: [{ update: { id: thing:one } }], versionstamp: 131072 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 65536 }, { changes: [{ update: { id: thing:one } }], versionstamp: 196608 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 131072 }, { changes: [{ update: { id: thing:one } }], versionstamp: 196608 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 131072 }, { changes: [{ update: { id: thing:one } }], versionstamp: 262144 }]]",
];
allowed
.into_iter()
.find(|case| *case == output)
.find(|case| {
println!("Comparing 2:\n{case}\n{output}");
*case == output
})
.ok_or(format!("Output didnt match an example output: {output}"))
.unwrap();
} else {
@ -684,10 +688,16 @@ mod cli_integration {
.unwrap();
let output = remove_debug_info(output).replace('\n', "");
let allowed = [
// Delete these
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 1 }, { changes: [{ update: { id: thing:one } }], versionstamp: 2 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 1 }, { changes: [{ update: { id: thing:one } }], versionstamp: 3 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 2 }, { changes: [{ update: { id: thing:one } }], versionstamp: 3 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 2 }, { changes: [{ update: { id: thing:one } }], versionstamp: 4 }]]",
// Keep these
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 65536 }, { changes: [{ update: { id: thing:one } }], versionstamp: 131072 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 65536 }, { changes: [{ update: { id: thing:one } }], versionstamp: 196608 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 131072 }, { changes: [{ update: { id: thing:one } }], versionstamp: 196608 }]]",
"[[{ changes: [{ define_table: { name: 'thing' } }], versionstamp: 131072 }, { changes: [{ update: { id: thing:one } }], versionstamp: 262144 }]]",
];
allowed
.into_iter()
@ -715,7 +725,7 @@ mod cli_integration {
let output = remove_debug_info(output);
assert_eq!(output, "[[]]\n\n".to_owned(), "failed to send sql: {args}");
}
server.finish()
server.finish().unwrap();
}
#[test]
@ -854,7 +864,7 @@ mod cli_integration {
let _ =
futures::future::join(async { send_future.await.unwrap_err() }, signal_send_fut).await;
server.finish()
server.finish().unwrap();
}
#[test(tokio::test)]
@ -863,7 +873,7 @@ mod cli_integration {
// Default capabilities only allow functions
info!("* When default capabilities");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
args: "".to_owned(),
..Default::default()
})
@ -890,13 +900,13 @@ mod cli_integration {
"unexpected output: {output:?}"
);
server.finish();
server.finish().unwrap();
}
// Deny all, denies all users to execute functions and access any network address
info!("* When all capabilities are denied");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
args: "--deny-all".to_owned(),
..Default::default()
})
@ -922,13 +932,13 @@ mod cli_integration {
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}"
);
server.finish()
server.finish().unwrap();
}
// When all capabilities are allowed, anyone (including non-authenticated users) can execute functions and access any network address
info!("* When all capabilities are allowed");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
args: "--allow-all".to_owned(),
..Default::default()
})
@ -948,12 +958,12 @@ mod cli_integration {
let output = common::run(&cmd).input(query).output().unwrap();
assert!(output.starts_with("['1']"), "unexpected output: {output:?}");
server.finish()
server.finish().unwrap();
}
info!("* When scripting is denied");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
args: "--deny-scripting".to_owned(),
..Default::default()
})
@ -972,12 +982,12 @@ mod cli_integration {
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}"
);
server.finish()
server.finish().unwrap();
}
info!("* When net is denied and function is enabled");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
args: "--deny-net 127.0.0.1 --allow-funcs http::get".to_owned(),
..Default::default()
})
@ -998,12 +1008,12 @@ mod cli_integration {
),
"unexpected output: {output:?}"
);
server.finish()
server.finish().unwrap();
}
info!("* When net is enabled for an IP and also denied for a specific port that doesn't match");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
args: "--allow-net 127.0.0.1 --deny-net 127.0.0.1:80 --allow-funcs http::get"
.to_owned(),
..Default::default()
@ -1019,12 +1029,12 @@ mod cli_integration {
let query = format!("RETURN http::get('http://{}/version');\n\n", addr);
let output = common::run(&cmd).input(&query).output().unwrap();
assert!(output.starts_with("['surrealdb"), "unexpected output: {output:?}");
server.finish()
server.finish().unwrap();
}
info!("* When a function family is denied");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
args: "--deny-funcs http".to_owned(),
..Default::default()
})
@ -1042,12 +1052,12 @@ mod cli_integration {
output.contains("Function 'http::get' is not allowed"),
"unexpected output: {output:?}"
);
server.finish()
server.finish().unwrap();
}
info!("* When auth is enabled and guest access is allowed");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
auth: true,
args: "--allow-guests".to_owned(),
..Default::default()
@ -1063,12 +1073,12 @@ mod cli_integration {
let query = "RETURN 1;\n\n";
let output = common::run(&cmd).input(query).output().unwrap();
assert!(output.contains("[1]"), "unexpected output: {output:?}");
server.finish()
server.finish().unwrap();
}
info!("* When auth is enabled and guest access is denied");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
auth: true,
args: "--deny-guests".to_owned(),
..Default::default()
@ -1087,12 +1097,12 @@ mod cli_integration {
output.contains("Not enough permissions to perform this action"),
"unexpected output: {output:?}"
);
server.finish()
server.finish().unwrap();
}
info!("* When auth is disabled, guest access is always allowed");
{
let (addr, server) = common::start_server(StartServerArguments {
let (addr, mut server) = common::start_server(StartServerArguments {
auth: false,
args: "--deny-guests".to_owned(),
..Default::default()
@ -1108,7 +1118,7 @@ mod cli_integration {
let query = "RETURN 1;\n\n";
let output = common::run(&cmd).input(query).output().unwrap();
assert!(output.contains("[1]"), "unexpected output: {output:?}");
server.finish()
server.finish().unwrap();
}
}
}

View file

@ -5,6 +5,7 @@ use std::path::Path;
use std::process::{Command, ExitStatus, Stdio};
use std::{env, fs};
use tokio::time;
use tokio_stream::StreamExt;
use tracing::{debug, error, info};
pub const USER: &str = "root";
@ -33,8 +34,13 @@ impl Child {
self
}
pub fn finish(mut self) {
self.inner.take().unwrap().kill().unwrap();
pub fn finish(&mut self) -> Result<&mut Self, String> {
let a = self
.inner
.as_mut()
.map(|child| child.kill().map_err(|e| format!("failed to kill: {}", e)))
.unwrap_or(Err("no inner".to_string()));
a.map(|_ok| self)
}
pub fn send_signal(&self, signal: nix::sys::signal::Signal) -> nix::Result<()> {
@ -58,8 +64,8 @@ impl Child {
/// Read the child's stdout concatenated with its stderr. Returns Ok if the child
/// returns successfully, Err otherwise.
pub fn output(mut self) -> Result<String, String> {
let status = self.inner.take().unwrap().wait().unwrap();
pub fn output(&mut self) -> Result<String, String> {
let status = self.inner.as_mut().map(|child| child.wait().unwrap()).unwrap();
let mut buf = self.stdout();
buf.push_str(&self.stderr());

View file

@ -8,7 +8,7 @@ use test_log::test;
#[test(tokio::test)]
async fn ping() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Send INFO command
@ -19,14 +19,14 @@ async fn ping() -> Result<(), Box<dyn std::error::Error>> {
let res = res.as_object().unwrap();
assert!(res.keys().all(|k| ["id", "result"].contains(&k.as_str())), "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn info() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -65,14 +65,14 @@ async fn info() -> Result<(), Box<dyn std::error::Error>> {
let res = res["result"].as_object().unwrap();
assert_eq!(res["user"], "user", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn signup() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -113,14 +113,14 @@ async fn signup() -> Result<(), Box<dyn std::error::Error>> {
let res = res["result"].as_str().unwrap();
assert!(res.starts_with("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9"), "result: {}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn signin() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -187,14 +187,14 @@ async fn signin() -> Result<(), Box<dyn std::error::Error>> {
let res = res["result"].as_str().unwrap();
assert!(res.starts_with("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9"), "result: {}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn invalidate() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -215,14 +215,14 @@ async fn invalidate() -> Result<(), Box<dyn std::error::Error>> {
res
);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn authenticate() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -237,14 +237,14 @@ async fn authenticate() -> Result<(), Box<dyn std::error::Error>> {
let res = socket.send_message_query("DEFINE NAMESPACE test").await?;
assert_eq!(res[0]["status"], "OK", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn letset() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -258,14 +258,14 @@ async fn letset() -> Result<(), Box<dyn std::error::Error>> {
// Verify the variables are set
let res = socket.send_message_query("SELECT * FROM $let_var, $set_var").await?;
assert_eq!(res[0]["result"], json!(["let_value", "set_value"]), "result: {:?}", res);
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn unset() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -283,14 +283,14 @@ async fn unset() -> Result<(), Box<dyn std::error::Error>> {
let res = socket.send_message_query("SELECT * FROM $let_var").await?;
assert_eq!(res[0]["result"], json!([null]), "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn select() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -308,14 +308,14 @@ async fn select() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res[0]["name"], "foo", "result: {:?}", res);
assert_eq!(res[0]["value"], "bar", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn insert() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -349,14 +349,14 @@ async fn insert() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res[0]["name"], "foo", "result: {:?}", res);
assert_eq!(res[0]["value"], "bar", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn create() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -387,14 +387,14 @@ async fn create() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res.len(), 1, "result: {:?}", res);
assert_eq!(res[0]["value"], "bar", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn update() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -428,14 +428,14 @@ async fn update() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res[0]["name"], json!(null), "result: {:?}", res);
assert_eq!(res[0]["value"], "bar", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn merge() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -470,14 +470,14 @@ async fn merge() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res[0]["name"], "foo", "result: {:?}", res);
assert_eq!(res[0]["value"], "bar", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn patch() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -517,14 +517,14 @@ async fn patch() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(res[0]["name"], json!(null), "result: {:?}", res);
assert_eq!(res[0]["value"], "bar", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn delete() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -554,14 +554,14 @@ async fn delete() -> Result<(), Box<dyn std::error::Error>> {
let res = res[0]["result"].as_array().unwrap();
assert_eq!(res.len(), 0, "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn query() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -581,14 +581,14 @@ async fn query() -> Result<(), Box<dyn std::error::Error>> {
let res = res[0]["result"].as_array().unwrap();
assert_eq!(res.len(), 1, "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn version() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Send version command
@ -597,7 +597,7 @@ async fn version() -> Result<(), Box<dyn std::error::Error>> {
let res = res["result"].as_str().unwrap();
assert!(res.starts_with("surrealdb-"), "result: {}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
@ -605,7 +605,7 @@ async fn version() -> Result<(), Box<dyn std::error::Error>> {
#[test(tokio::test)]
async fn concurrency() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -636,14 +636,14 @@ async fn concurrency() -> Result<(), Box<dyn std::error::Error>> {
assert!(res.iter().all(|v| v["error"].is_null()), "Unexpected error received: {:#?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn live() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -703,14 +703,14 @@ async fn live() -> Result<(), Box<dyn std::error::Error>> {
let res = res["result"].as_object().unwrap();
assert_eq!(res["id"], "tester:id", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn kill() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -806,14 +806,14 @@ async fn kill() -> Result<(), Box<dyn std::error::Error>> {
let msgs = socket.receive_all_other_messages(0, Duration::from_secs(1)).await?;
assert!(msgs.iter().all(|v| v["error"].is_null()), "Unexpected error received: {:?}", msgs);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn live_second_connection() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket1 = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -854,14 +854,14 @@ async fn live_second_connection() -> Result<(), Box<dyn std::error::Error>> {
let res = res["result"].as_object().unwrap();
assert_eq!(res["id"], "tester:id", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn variable_auth_live_query() -> Result<(), Box<dyn std::error::Error>> {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await?;
// Authenticate the connection
@ -913,14 +913,14 @@ async fn variable_auth_live_query() -> Result<(), Box<dyn std::error::Error>> {
let msgs = socket.receive_all_other_messages(0, Duration::from_secs(1)).await?;
assert!(msgs.iter().all(|v| v["error"].is_null()), "Unexpected error received: {:?}", msgs);
// Test passed
server.finish();
server.finish().unwrap();
Ok(())
}
#[test(tokio::test)]
async fn session_expiration() {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
// Authenticate the connection
@ -936,7 +936,8 @@ async fn session_expiration() {
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) )
;"#,
)
.await.unwrap();
.await
.unwrap();
// Create resource that requires a scope session to query
socket
.send_message_query(
@ -945,14 +946,16 @@ async fn session_expiration() {
PERMISSIONS FOR select, create, update, delete WHERE $scope = "scope"
;"#,
)
.await.unwrap();
.await
.unwrap();
socket
.send_message_query(
r#"
CREATE test:1 SET working = "yes"
;"#,
)
.await.unwrap();
.await
.unwrap();
// Send SIGNUP command
let res = socket
.send_request(
@ -991,7 +994,10 @@ async fn session_expiration() {
let res = res.unwrap();
assert!(res.is_object(), "result: {:?}", res);
let res = res.as_object().unwrap();
assert_eq!(res["error"], json!({"code": -32000, "message": "There was a problem with the database: The session has expired"}));
assert_eq!(
res["error"],
json!({"code": -32000, "message": "There was a problem with the database: The session has expired"})
);
// Sign in again using the same session
let res = socket
.send_request(
@ -1017,13 +1023,13 @@ async fn session_expiration() {
let res = socket.send_message_query("SELECT VALUE working FROM test:1").await.unwrap();
assert_eq!(res[0]["result"], json!(["yes"]), "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
}
#[test(tokio::test)]
async fn session_expiration_operations() {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
// Authenticate the connection
@ -1040,7 +1046,8 @@ async fn session_expiration_operations() {
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) )
;"#,
)
.await.unwrap();
.await
.unwrap();
// Create resource that requires a scope session to query
socket
.send_message_query(
@ -1049,14 +1056,16 @@ async fn session_expiration_operations() {
PERMISSIONS FOR select, create, update, delete WHERE $scope = "scope"
;"#,
)
.await.unwrap();
.await
.unwrap();
socket
.send_message_query(
r#"
CREATE test:1 SET working = "yes"
;"#,
)
.await.unwrap();
.await
.unwrap();
// Send SIGNUP command
let res = socket
.send_request(
@ -1095,72 +1104,70 @@ async fn session_expiration_operations() {
let res = res.unwrap();
assert!(res.is_object(), "result: {:?}", res);
let res = res.as_object().unwrap();
assert_eq!(res["error"], json!({"code": -32000, "message": "There was a problem with the database: The session has expired"}));
assert_eq!(
res["error"],
json!({"code": -32000, "message": "There was a problem with the database: The session has expired"})
);
// Test operations that SHOULD NOT work with an expired session
let operations_ko = vec![
socket.send_request("let", json!(["let_var", "let_value",])),
socket.send_request("set", json!(["set_var", "set_value",])),
socket.send_request("info", json!([])),
socket.send_request("select", json!(["tester",])),
socket
.send_request(
"insert",
json!([
"tester",
socket.send_request(
"insert",
json!([
"tester",
{
"name": "foo",
"value": "bar",
}
]),
),
socket.send_request(
"create",
json!([
"tester",
{
"value": "bar",
}
]),
),
socket.send_request(
"update",
json!([
"tester",
{
"value": "bar",
}
]),
),
socket.send_request(
"merge",
json!([
"tester",
{
"value": "bar",
}
]),
),
socket.send_request(
"patch",
json!([
"tester:id",
[
{
"name": "foo",
"value": "bar",
}
]),
),
socket
.send_request(
"create",
json!([
"tester",
"op": "add",
"path": "value",
"value": "bar"
},
{
"value": "bar",
"op": "remove",
"path": "name",
}
]),
),
socket
.send_request(
"update",
json!([
"tester",
{
"value": "bar",
}
]),
),
socket
.send_request(
"merge",
json!([
"tester",
{
"value": "bar",
}
]),
),
socket
.send_request(
"patch",
json!([
"tester:id",
[
{
"op": "add",
"path": "value",
"value": "bar"
},
{
"op": "remove",
"path": "name",
}
]
]),
),
]
]),
),
socket.send_request("delete", json!(["tester"])),
socket.send_request("live", json!(["tester"])),
socket.send_request("kill", json!(["tester"])),
@ -1172,8 +1179,11 @@ async fn session_expiration_operations() {
let res = res.unwrap();
assert!(res.is_object(), "result: {:?}", res);
let res = res.as_object().unwrap();
assert_eq!(res["error"], json!({"code": -32000, "message": "There was a problem with the database: The session has expired"}));
};
assert_eq!(
res["error"],
json!({"code": -32000, "message": "There was a problem with the database: The session has expired"})
);
}
// Test operations that SHOULD work with an expired session
let operations_ok = vec![
@ -1191,22 +1201,22 @@ async fn session_expiration_operations() {
let res = res.as_object().unwrap();
// Verify response contains no error
assert!(res.keys().all(|k| ["id", "result"].contains(&k.as_str())), "result: {:?}", res);
};
}
// Test operations that SHOULD work with an expired session
// These operations will refresh the session expiration
let res = socket
.send_request(
"signup",
json!([{
"ns": NS,
"db": DB,
"sc": "scope",
"email": "another@email.com",
"pass": "pass",
}]),
)
.await;
.send_request(
"signup",
json!([{
"ns": NS,
"db": DB,
"sc": "scope",
"email": "another@email.com",
"pass": "pass",
}]),
)
.await;
assert!(res.is_ok(), "result: {:?}", res);
let res = res.unwrap();
assert!(res.is_object(), "result: {:?}", res);
@ -1221,7 +1231,10 @@ async fn session_expiration_operations() {
let res = res.unwrap();
assert!(res.is_object(), "result: {:?}", res);
let res = res.as_object().unwrap();
assert_eq!(res["error"], json!({"code": -32000, "message": "There was a problem with the database: The session has expired"}));
assert_eq!(
res["error"],
json!({"code": -32000, "message": "There was a problem with the database: The session has expired"})
);
let res = socket
.send_request(
"signin",
@ -1250,7 +1263,10 @@ async fn session_expiration_operations() {
let res = res.unwrap();
assert!(res.is_object(), "result: {:?}", res);
let res = res.as_object().unwrap();
assert_eq!(res["error"], json!({"code": -32000, "message": "There was a problem with the database: The session has expired"}));
assert_eq!(
res["error"],
json!({"code": -32000, "message": "There was a problem with the database: The session has expired"})
);
// This needs to be last operation as the session will no longer expire afterwards
let res = socket.send_request("authenticate", json!([root_token,])).await;
@ -1262,13 +1278,13 @@ async fn session_expiration_operations() {
assert!(res.keys().all(|k| ["id", "result"].contains(&k.as_str())), "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
}
#[test(tokio::test)]
async fn session_reauthentication() {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
// Authenticate the connection and store the root level token
@ -1286,7 +1302,8 @@ async fn session_reauthentication() {
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) )
;"#,
)
.await.unwrap();
.await
.unwrap();
// Create resource that requires a scope session to query
socket
.send_message_query(
@ -1295,14 +1312,16 @@ async fn session_reauthentication() {
PERMISSIONS FOR select, create, update, delete WHERE $scope = "scope"
;"#,
)
.await.unwrap();
.await
.unwrap();
socket
.send_message_query(
r#"
CREATE test:1 SET working = "yes"
;"#,
)
.await.unwrap();
.await
.unwrap();
// Send SIGNUP command
let res = socket
.send_request(
@ -1347,13 +1366,13 @@ async fn session_reauthentication() {
let res = socket.send_message_query("INFO FOR ROOT").await.unwrap();
assert_eq!(res[0]["status"], "OK", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
}
#[test(tokio::test)]
async fn session_reauthentication_expired() {
// Setup database server
let (addr, server) = common::start_server_with_defaults().await.unwrap();
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
// Authenticate the connection and store the root level token
@ -1371,7 +1390,8 @@ async fn session_reauthentication_expired() {
SIGNIN ( SELECT * FROM user WHERE email = $email AND crypto::argon2::compare(pass, $pass) )
;"#,
)
.await.unwrap();
.await
.unwrap();
// Create resource that requires a scope session to query
socket
.send_message_query(
@ -1380,14 +1400,16 @@ async fn session_reauthentication_expired() {
PERMISSIONS FOR select, create, update, delete WHERE $scope = "scope"
;"#,
)
.await.unwrap();
.await
.unwrap();
socket
.send_message_query(
r#"
CREATE test:1 SET working = "yes"
;"#,
)
.await.unwrap();
.await
.unwrap();
// Send SIGNUP command
let res = socket
.send_request(
@ -1423,12 +1445,15 @@ async fn session_reauthentication_expired() {
let res = res.unwrap();
assert!(res.is_object(), "result: {:?}", res);
let res = res.as_object().unwrap();
assert_eq!(res["error"], json!({"code": -32000, "message": "There was a problem with the database: The session has expired"}));
assert_eq!(
res["error"],
json!({"code": -32000, "message": "There was a problem with the database: The session has expired"})
);
// Authenticate using the root token, which has not expired yet
socket.send_request("authenticate", json!([root_token,])).await.unwrap();
// Check that we have root access and the session is not expired
let res = socket.send_message_query("INFO FOR ROOT").await.unwrap();
assert_eq!(res[0]["status"], "OK", "result: {:?}", res);
// Test passed
server.finish();
server.finish().unwrap();
}