diff --git a/Cargo.lock b/Cargo.lock index 5be3f3d3..b7d4be75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2177,19 +2177,6 @@ dependencies = [ "futures-sink", ] -[[package]] -name = "futures-concurrency" -version = "7.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b590a729e1cbaf9ae3ec294143ea034d93cbb1de01c884d04bcd0af8b613d02" -dependencies = [ - "bitvec", - "futures-core", - "pin-project", - "slab", - "smallvec", -] - [[package]] name = "futures-core" version = "0.3.30" @@ -5952,7 +5939,6 @@ dependencies = [ "flate2", "flume", "futures", - "futures-concurrency", "geo 0.27.0", "hashbrown 0.14.5", "indexmap 2.2.6", diff --git a/Makefile.ci.toml b/Makefile.ci.toml index 06ac562b..3907ae01 100644 --- a/Makefile.ci.toml +++ b/Makefile.ci.toml @@ -10,13 +10,11 @@ args = ["check", "--locked", "--workspace"] [tasks.ci-check-wasm] category = "CI - CHECK" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["check", "--locked", "--package", "surrealdb", "--features", "protocol-ws,protocol-http,kv-mem,kv-indxdb,http,jwks", "--target", "wasm32-unknown-unknown"] [tasks.ci-clippy] category = "CI - CHECK" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["clippy", "--all-targets", "--features", "storage-mem,storage-rocksdb,storage-tikv,storage-fdb,scripting,http,jwks", "--tests", "--benches", "--examples", "--bins", "--", "-D", "warnings"] # @@ -26,31 +24,43 @@ args = ["clippy", "--all-targets", "--features", "storage-mem,storage-rocksdb,st [tasks.ci-cli-integration] category = "CI - INTEGRATION TESTS" command = "cargo" -env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "cli_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } +env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "cli_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,storage-surrealkv,http,scripting,jwks", "--workspace", "--test", "cli_integration", "--", "cli_integration"] [tasks.ci-http-integration] category = "CI - INTEGRATION TESTS" command = "cargo" -env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "http_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } +env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "http_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression,jwks", "--workspace", "--test", "http_integration", "--", "http_integration"] [tasks.ci-ws-integration] category = "WS - INTEGRATION TESTS" command = "cargo" -env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } +env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } } args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration"] [tasks.ci-ml-integration] category = "ML - INTEGRATION TESTS" command = "cargo" -env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } } +env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } } args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"] + +[tasks.ci-test-workspace] +category = "CI - INTEGRATION TESTS" +command = "cargo" +args = [ + "test", "--locked", "--no-default-features", "--features", "storage-mem,scripting,http,jwks", "--workspace", "--", + "--skip", "api_integration", + "--skip", "cli_integration", + "--skip", "http_integration", + "--skip", "ws_integration", + "--skip", "database_upgrade" +] + [tasks.ci-workspace-coverage] category = "CI - INTEGRATION TESTS" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = [ "llvm-cov", "--html", "--locked", "--no-default-features", "--features", "storage-mem,scripting,http,jwks", "--workspace", "--", "--skip", "api_integration", @@ -63,7 +73,6 @@ args = [ [tasks.test-workspace-coverage-complete] category = "CI - INTEGRATION TESTS" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["llvm-cov", "--html", "--locked", "--no-default-features", "--features", "protocol-ws,protocol-http,kv-mem,kv-rocksdb", "--workspace"] [tasks.ci-workspace-coverage-complete] @@ -83,25 +92,25 @@ run_task = { name = ["test-database-upgrade"], fork = true } [tasks.test-kvs] private = true command = "cargo" -env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable" } +env = { RUST_BACKTRACE = 1 } args = ["test", "--locked", "--package", "surrealdb", "--no-default-features", "--features", "${_TEST_FEATURES}", "--lib", "kvs"] [tasks.test-api-integration] private = true command = "cargo" -env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable" } +env = { RUST_BACKTRACE = 1 } args = ["test", "--locked", "--package", "surrealdb", "--no-default-features", "--features", "${_TEST_FEATURES}", "--test", "api", "api_integration::${_TEST_API_ENGINE}"] [tasks.ci-api-integration] -env = { RUST_BACKTRACE = 1, _START_SURREALDB_PATH = "memory", RUSTFLAGS = "--cfg surrealdb_unstable" } +env = { RUST_BACKTRACE = 1, _START_SURREALDB_PATH = "memory" } private = true run_task = { name = ["start-surrealdb", "test-api-integration", "stop-surrealdb"], fork = true } [tasks.test-database-upgrade] private = true command = "cargo" -env = { RUST_BACKTRACE = 1, RUST_LOG = "info", RUSTFLAGS = "--cfg surrealdb_unstable" } +env = { RUST_BACKTRACE = 1, RUST_LOG = "info" } args = ["test", "--locked", "--no-default-features", "--features", "${_TEST_FEATURES}", "--workspace", "--test", "database_upgrade", "--", "database_upgrade", "--show-output"] @@ -111,17 +120,17 @@ args = ["test", "--locked", "--no-default-features", "--features", "${_TEST_FEAT [tasks.ci-api-integration-http] category = "CI - INTEGRATION TESTS" -env = { _TEST_API_ENGINE = "http", _TEST_FEATURES = "protocol-http", RUSTFLAGS = "--cfg surrealdb_unstable" } +env = { _TEST_API_ENGINE = "http", _TEST_FEATURES = "protocol-http" } run_task = "ci-api-integration" [tasks.ci-api-integration-ws] category = "CI - INTEGRATION TESTS" -env = { _TEST_API_ENGINE = "ws", _TEST_FEATURES = "protocol-ws", RUSTFLAGS = "--cfg surrealdb_unstable" } +env = { _TEST_API_ENGINE = "ws", _TEST_FEATURES = "protocol-ws" } run_task = "ci-api-integration" [tasks.ci-api-integration-any] category = "CI - INTEGRATION TESTS" -env = { _TEST_API_ENGINE = "any", _TEST_FEATURES = "protocol-http", RUSTFLAGS = "--cfg surrealdb_unstable" } +env = { _TEST_API_ENGINE = "any", _TEST_FEATURES = "protocol-http" } run_task = "ci-api-integration" # @@ -256,7 +265,6 @@ ${HOME}/.tiup/bin/tiup clean --all [tasks.build-surrealdb] category = "CI - BUILD" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["build", "--locked", "--no-default-features", "--features", "storage-mem"] # @@ -265,7 +273,6 @@ args = ["build", "--locked", "--no-default-features", "--features", "storage-mem [tasks.ci-bench] category = "CI - BENCHMARK" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["bench", "--quiet", "--package", "surrealdb", "--no-default-features", "--features", "kv-mem,scripting,http,jwks", "${@}"] # @@ -277,13 +284,11 @@ BENCH_NUM_OPS = { value = "1000", condition = { env_not_set = ["BENCH_NUM_OPS"] BENCH_DURATION = { value = "30", condition = { env_not_set = ["BENCH_DURATION"] } } BENCH_SAMPLE_SIZE = { value = "10", condition = { env_not_set = ["BENCH_SAMPLE_SIZE"] } } BENCH_FEATURES = { value = "protocol-ws,kv-mem,kv-rocksdb,kv-fdb-7_1,kv-surrealkv", condition = { env_not_set = ["BENCH_FEATURES"] } } -RUSTFLAGS = "--cfg surrealdb_unstable" [tasks.bench-target] private = true category = "CI - BENCHMARK - SurrealDB Target" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["bench", "--package", "surrealdb", "--bench", "sdb", "--no-default-features", "--features", "${BENCH_FEATURES}", "${@}"] [tasks.bench-lib-mem] diff --git a/Makefile.local.toml b/Makefile.local.toml index 0122e7cd..235f63de 100644 --- a/Makefile.local.toml +++ b/Makefile.local.toml @@ -17,21 +17,19 @@ dependencies = ["cargo-upgrade", "cargo-update"] [tasks.docs] category = "LOCAL USAGE" command = "cargo" -env = { RUSTDOCFLAGS = "--cfg surrealdb_unstable" } args = ["doc", "--open", "--no-deps", "--package", "surrealdb", "--features", "rustls,native-tls,protocol-ws,protocol-http,kv-mem,kv-rocksdb,kv-tikv,http,scripting,jwks"] # Test [tasks.test] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable", RUSTDOCFLAGS = "--cfg surrealdb_unstable", RUST_MIN_STACK={ value = "4194304", condition = { env_not_set = ["RUST_MIN_STACK"] } } } +env = { RUST_MIN_STACK={ value = "4194304", condition = { env_not_set = ["RUST_MIN_STACK"] } } } args = ["test", "--workspace", "--no-fail-fast"] # Check [tasks.cargo-check] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["check", "--workspace", "--features", "${DEV_FEATURES}"] [tasks.cargo-fmt] @@ -50,7 +48,6 @@ script = """ [tasks.cargo-clippy] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["clippy", "--all-targets", "--all-features", "--", "-D", "warnings"] [tasks.check] @@ -71,28 +68,24 @@ args = ["clean"] [tasks.bench] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["bench", "--package", "surrealdb", "--no-default-features", "--features", "kv-mem,http,scripting,jwks", "--", "${@}"] # Run [tasks.run] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "${@}"] # Serve [tasks.serve] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "start", "--allow-all", "${@}"] # SQL [tasks.sql] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "sql", "--pretty", "${@}"] # Quick @@ -105,7 +98,6 @@ args = ["build", "${@}"] [tasks.build] category = "LOCAL USAGE" command = "cargo" -env = { RUSTFLAGS = "--cfg surrealdb_unstable" } args = ["build", "--release", "${@}"] # Default diff --git a/core/src/dbs/iterator.rs b/core/src/dbs/iterator.rs index 84ab3e96..2db0d25c 100644 --- a/core/src/dbs/iterator.rs +++ b/core/src/dbs/iterator.rs @@ -16,7 +16,9 @@ use crate::sql::range::Range; use crate::sql::table::Table; use crate::sql::thing::Thing; use crate::sql::value::Value; -use reblessive::{tree::Stk, TreeStack}; +use reblessive::tree::Stk; +#[cfg(not(target_arch = "wasm32"))] +use reblessive::TreeStack; use std::mem; #[derive(Clone)] diff --git a/core/src/rpc/request.rs b/core/src/rpc/request.rs index 4a742508..22ff8180 100644 --- a/core/src/rpc/request.rs +++ b/core/src/rpc/request.rs @@ -9,6 +9,7 @@ pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]); pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]); pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]); +#[derive(Debug)] pub struct Request { pub id: Option, pub method: String, diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 8b3c5131..7dc9a25a 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -79,7 +79,6 @@ chrono = { version = "0.4.31", features = ["serde"] } dmp = "0.2.0" flume = "0.11.0" futures = "0.3.29" -futures-concurrency = "7.4.3" geo = { version = "0.27.0", features = ["use-serde"] } indexmap = { version = "2.1.0", features = ["serde"] } native-tls = { version = "0.2.11", optional = true } diff --git a/lib/src/api/conn.rs b/lib/src/api/conn.rs index 86c8f64b..61a92d99 100644 --- a/lib/src/api/conn.rs +++ b/lib/src/api/conn.rs @@ -31,7 +31,7 @@ pub(crate) struct Route { /// Message router #[derive(Debug)] pub struct Router { - pub(crate) sender: Sender>, + pub(crate) sender: Sender, pub(crate) last_id: AtomicI64, pub(crate) features: HashSet, } @@ -42,12 +42,6 @@ impl Router { } } -impl Drop for Router { - fn drop(&mut self) { - let _res = self.sender.send(None); - } -} - /// The query method #[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash)] #[serde(rename_all = "lowercase")] diff --git a/lib/src/api/engine/any/native.rs b/lib/src/api/engine/any/native.rs index d55c8a37..1b088f05 100644 --- a/lib/src/api/engine/any/native.rs +++ b/lib/src/api/engine/any/native.rs @@ -68,7 +68,7 @@ impl Connection for Any { { features.insert(ExtraFeatures::Backup); features.insert(ExtraFeatures::LiveQueries); - engine::local::native::router(address, conn_tx, route_rx); + tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await?? } @@ -83,7 +83,7 @@ impl Connection for Any { { features.insert(ExtraFeatures::Backup); features.insert(ExtraFeatures::LiveQueries); - engine::local::native::router(address, conn_tx, route_rx); + tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await?? } @@ -98,7 +98,7 @@ impl Connection for Any { { features.insert(ExtraFeatures::Backup); features.insert(ExtraFeatures::LiveQueries); - engine::local::native::router(address, conn_tx, route_rx); + tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await?? } @@ -114,7 +114,7 @@ impl Connection for Any { { features.insert(ExtraFeatures::Backup); features.insert(ExtraFeatures::LiveQueries); - engine::local::native::router(address, conn_tx, route_rx); + tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await?? } @@ -129,7 +129,7 @@ impl Connection for Any { { features.insert(ExtraFeatures::Backup); features.insert(ExtraFeatures::LiveQueries); - engine::local::native::router(address, conn_tx, route_rx); + tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await?? } @@ -162,7 +162,9 @@ impl Connection for Any { client.get(base_url.join(Method::Health.as_str())?), ) .await?; - engine::remote::http::native::router(base_url, client, route_rx); + tokio::spawn(engine::remote::http::native::run_router( + base_url, client, route_rx, + )); } #[cfg(not(feature = "protocol-http"))] @@ -195,14 +197,14 @@ impl Connection for Any { maybe_connector.clone(), ) .await?; - engine::remote::ws::native::router( + tokio::spawn(engine::remote::ws::native::run_router( endpoint, maybe_connector, capacity, config, socket, route_rx, - ); + )); } #[cfg(not(feature = "protocol-ws"))] @@ -237,7 +239,7 @@ impl Connection for Any { request: (self.id, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } diff --git a/lib/src/api/engine/any/wasm.rs b/lib/src/api/engine/any/wasm.rs index d5cc6d77..aef4058e 100644 --- a/lib/src/api/engine/any/wasm.rs +++ b/lib/src/api/engine/any/wasm.rs @@ -18,12 +18,12 @@ use crate::opt::WaitFor; use flume::Receiver; use std::collections::HashSet; use std::future::Future; -use std::marker::PhantomData; use std::pin::Pin; use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::sync::OnceLock; use tokio::sync::watch; +use wasm_bindgen_futures::spawn_local; impl crate::api::Connection for Any {} @@ -54,7 +54,7 @@ impl Connection for Any { #[cfg(feature = "kv-fdb")] { features.insert(ExtraFeatures::LiveQueries); - engine::local::wasm::router(address, conn_tx, route_rx); + spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; } @@ -68,7 +68,7 @@ impl Connection for Any { #[cfg(feature = "kv-indxdb")] { features.insert(ExtraFeatures::LiveQueries); - engine::local::wasm::router(address, conn_tx, route_rx); + spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; } @@ -82,7 +82,7 @@ impl Connection for Any { #[cfg(feature = "kv-mem")] { features.insert(ExtraFeatures::LiveQueries); - engine::local::wasm::router(address, conn_tx, route_rx); + spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; } @@ -96,7 +96,7 @@ impl Connection for Any { #[cfg(feature = "kv-rocksdb")] { features.insert(ExtraFeatures::LiveQueries); - engine::local::wasm::router(address, conn_tx, route_rx); + spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; } @@ -111,7 +111,7 @@ impl Connection for Any { #[cfg(feature = "kv-surrealkv")] { features.insert(ExtraFeatures::LiveQueries); - engine::local::wasm::router(address, conn_tx, route_rx); + spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; } @@ -126,7 +126,7 @@ impl Connection for Any { #[cfg(feature = "kv-tikv")] { features.insert(ExtraFeatures::LiveQueries); - engine::local::wasm::router(address, conn_tx, route_rx); + spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; } @@ -139,7 +139,9 @@ impl Connection for Any { EndpointKind::Http | EndpointKind::Https => { #[cfg(feature = "protocol-http")] { - engine::remote::http::wasm::router(address, conn_tx, route_rx); + spawn_local(engine::remote::http::wasm::run_router( + address, conn_tx, route_rx, + )); } #[cfg(not(feature = "protocol-http"))] @@ -155,7 +157,9 @@ impl Connection for Any { features.insert(ExtraFeatures::LiveQueries); let mut endpoint = address; endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?; - engine::remote::ws::wasm::router(endpoint, capacity, conn_tx, route_rx); + spawn_local(engine::remote::ws::wasm::run_router( + endpoint, capacity, conn_tx, route_rx, + )); conn_rx.into_recv_async().await??; } @@ -192,7 +196,7 @@ impl Connection for Any { request: (self.id, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } diff --git a/lib/src/api/engine/local/native.rs b/lib/src/api/engine/local/native.rs index fdabd4a5..66ddb09a 100644 --- a/lib/src/api/engine/local/native.rs +++ b/lib/src/api/engine/local/native.rs @@ -19,10 +19,8 @@ use crate::opt::WaitFor; use crate::options::EngineOptions; use flume::Receiver; use flume::Sender; -use futures::future::Either; use futures::stream::poll_fn; use futures::StreamExt; -use futures_concurrency::stream::Merge as _; use std::collections::BTreeMap; use std::collections::HashMap; use std::collections::HashSet; @@ -55,7 +53,7 @@ impl Connection for Db { let (conn_tx, conn_rx) = flume::bounded(1); - router(address, conn_tx, route_rx); + tokio::spawn(run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; @@ -85,144 +83,142 @@ impl Connection for Db { request: (0, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } } -pub(crate) fn router( +pub(crate) async fn run_router( address: Endpoint, conn_tx: Sender>, - route_rx: Receiver>, + route_rx: Receiver, ) { - tokio::spawn(async move { - let configured_root = match address.config.auth { - Level::Root => Some(Root { - username: &address.config.username, - password: &address.config.password, - }), - _ => None, - }; + let configured_root = match address.config.auth { + Level::Root => Some(Root { + username: &address.config.username, + password: &address.config.password, + }), + _ => None, + }; - let endpoint = match EndpointKind::from(address.url.scheme()) { - EndpointKind::TiKv => address.url.as_str(), - _ => &address.path, - }; + let endpoint = match EndpointKind::from(address.url.scheme()) { + EndpointKind::TiKv => address.url.as_str(), + _ => &address.path, + }; - let kvs = match Datastore::new(endpoint).await { - Ok(kvs) => { - if let Err(error) = kvs.bootstrap().await { - let _ = conn_tx.into_send_async(Err(error.into())).await; - return; - } - // If a root user is specified, setup the initial datastore credentials - if let Some(root) = configured_root { - if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await - { - let _ = conn_tx.into_send_async(Err(error.into())).await; - return; - } - } - let _ = conn_tx.into_send_async(Ok(())).await; - kvs.with_auth_enabled(configured_root.is_some()) - } - Err(error) => { + let kvs = match Datastore::new(endpoint).await { + Ok(kvs) => { + if let Err(error) = kvs.bootstrap().await { let _ = conn_tx.into_send_async(Err(error.into())).await; return; } - }; - - let kvs = match address.config.capabilities.allows_live_query_notifications() { - true => kvs.with_notifications(), - false => kvs, - }; - - let kvs = kvs - .with_strict_mode(address.config.strict) - .with_query_timeout(address.config.query_timeout) - .with_transaction_timeout(address.config.transaction_timeout) - .with_capabilities(address.config.capabilities); - - #[cfg(any( - feature = "kv-mem", - feature = "kv-surrealkv", - feature = "kv-rocksdb", - feature = "kv-fdb", - feature = "kv-tikv", - ))] - let kvs = match address.config.temporary_directory { - Some(tmp_dir) => kvs.with_temporary_directory(tmp_dir), - _ => kvs, - }; - - let kvs = Arc::new(kvs); - let mut vars = BTreeMap::new(); - let mut live_queries = HashMap::new(); - let mut session = Session::default().with_rt(true); - - let opt = { - let mut engine_options = EngineOptions::default(); - engine_options.tick_interval = address - .config - .tick_interval - .unwrap_or(crate::api::engine::local::DEFAULT_TICK_INTERVAL); - engine_options - }; - let (tasks, task_chans) = start_tasks(&opt, kvs.clone()); - - let mut notifications = kvs.notifications(); - let notification_stream = poll_fn(move |cx| match &mut notifications { - Some(rx) => rx.poll_next_unpin(cx), - None => Poll::Ready(None), - }); - - let streams = (route_rx.stream().map(Either::Left), notification_stream.map(Either::Right)); - let mut merged = streams.merge(); - - while let Some(either) = merged.next().await { - match either { - Either::Left(None) => break, // Received a shutdown signal - Either::Left(Some(route)) => { - match super::router( - route.request, - &kvs, - &mut session, - &mut vars, - &mut live_queries, - ) - .await - { - Ok(value) => { - let _ = route.response.into_send_async(Ok(value)).await; - } - Err(error) => { - let _ = route.response.into_send_async(Err(error)).await; - } - } - } - Either::Right(notification) => { - let id = notification.id; - if let Some(sender) = live_queries.get(&id) { - if sender.send(notification).await.is_err() { - live_queries.remove(&id); - if let Err(error) = - super::kill_live_query(&kvs, id, &session, vars.clone()).await - { - warn!("Failed to kill live query '{id}'; {error}"); - } - } - } + // If a root user is specified, setup the initial datastore credentials + if let Some(root) = configured_root { + if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await { + let _ = conn_tx.into_send_async(Err(error.into())).await; + return; } } + let _ = conn_tx.into_send_async(Ok(())).await; + kvs.with_auth_enabled(configured_root.is_some()) } + Err(error) => { + let _ = conn_tx.into_send_async(Err(error.into())).await; + return; + } + }; - // Stop maintenance tasks - for chan in task_chans { - if let Err(_empty_tuple) = chan.send(()) { - error!("Error sending shutdown signal to task"); - } - } - tasks.resolve().await.unwrap(); + let kvs = match address.config.capabilities.allows_live_query_notifications() { + true => kvs.with_notifications(), + false => kvs, + }; + + let kvs = kvs + .with_strict_mode(address.config.strict) + .with_query_timeout(address.config.query_timeout) + .with_transaction_timeout(address.config.transaction_timeout) + .with_capabilities(address.config.capabilities); + + #[cfg(any( + feature = "kv-mem", + feature = "kv-surrealkv", + feature = "kv-rocksdb", + feature = "kv-fdb", + feature = "kv-tikv", + ))] + let kvs = match address.config.temporary_directory { + Some(tmp_dir) => kvs.with_temporary_directory(tmp_dir), + _ => kvs, + }; + + let kvs = Arc::new(kvs); + let mut vars = BTreeMap::new(); + let mut live_queries = HashMap::new(); + let mut session = Session::default().with_rt(true); + + let opt = { + let mut engine_options = EngineOptions::default(); + engine_options.tick_interval = address + .config + .tick_interval + .unwrap_or(crate::api::engine::local::DEFAULT_TICK_INTERVAL); + engine_options + }; + let (tasks, task_chans) = start_tasks(&opt, kvs.clone()); + + let mut notifications = kvs.notifications(); + let mut notification_stream = poll_fn(move |cx| match &mut notifications { + Some(rx) => rx.poll_next_unpin(cx), + // return poll pending so that this future is never woken up again and therefore not + // constantly polled. + None => Poll::Pending, }); + let mut route_stream = route_rx.into_stream(); + + loop { + tokio::select! { + route = route_stream.next() => { + let Some(route) = route else { + break + }; + match super::router(route.request, &kvs, &mut session, &mut vars, &mut live_queries) + .await + { + Ok(value) => { + let _ = route.response.into_send_async(Ok(value)).await; + } + Err(error) => { + let _ = route.response.into_send_async(Err(error)).await; + } + } + } + notification = notification_stream.next() => { + let Some(notification) = notification else { + // TODO: Maybe we should do something more then ignore a closed notifications + // channel? + continue + }; + let id = notification.id; + if let Some(sender) = live_queries.get(&id) { + if sender.send(notification).await.is_err() { + live_queries.remove(&id); + if let Err(error) = + super::kill_live_query(&kvs, id, &session, vars.clone()).await + { + warn!("Failed to kill live query '{id}'; {error}"); + } + } + } + } + } + } + + // Stop maintenance tasks + for chan in task_chans { + if chan.send(()).is_err() { + error!("Error sending shutdown signal to task"); + } + } + tasks.resolve().await.unwrap(); } diff --git a/lib/src/api/engine/local/wasm.rs b/lib/src/api/engine/local/wasm.rs index 70d134b9..73a850c6 100644 --- a/lib/src/api/engine/local/wasm.rs +++ b/lib/src/api/engine/local/wasm.rs @@ -20,15 +20,13 @@ use crate::opt::WaitFor; use crate::options::EngineOptions; use flume::Receiver; use flume::Sender; -use futures::future::Either; use futures::stream::poll_fn; +use futures::FutureExt; use futures::StreamExt; -use futures_concurrency::stream::Merge as _; use std::collections::BTreeMap; use std::collections::HashMap; use std::collections::HashSet; use std::future::Future; -use std::marker::PhantomData; use std::pin::Pin; use std::sync::atomic::AtomicI64; use std::sync::Arc; @@ -58,7 +56,7 @@ impl Connection for Db { let (conn_tx, conn_rx) = flume::bounded(1); - router(address, conn_tx, route_rx); + spawn_local(run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; @@ -87,120 +85,127 @@ impl Connection for Db { request: (0, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } } -pub(crate) fn router( +pub(crate) async fn run_router( address: Endpoint, conn_tx: Sender>, - route_rx: Receiver>, + route_rx: Receiver, ) { - spawn_local(async move { - let configured_root = match address.config.auth { - Level::Root => Some(Root { - username: &address.config.username, - password: &address.config.password, - }), - _ => None, - }; + let configured_root = match address.config.auth { + Level::Root => Some(Root { + username: &address.config.username, + password: &address.config.password, + }), + _ => None, + }; - let kvs = match Datastore::new(&address.path).await { - Ok(kvs) => { - if let Err(error) = kvs.bootstrap().await { - let _ = conn_tx.into_send_async(Err(error.into())).await; - return; - } - // If a root user is specified, setup the initial datastore credentials - if let Some(root) = configured_root { - if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await - { - let _ = conn_tx.into_send_async(Err(error.into())).await; - return; - } - } - let _ = conn_tx.into_send_async(Ok(())).await; - kvs.with_auth_enabled(configured_root.is_some()) - } - Err(error) => { + let kvs = match Datastore::new(&address.path).await { + Ok(kvs) => { + if let Err(error) = kvs.bootstrap().await { let _ = conn_tx.into_send_async(Err(error.into())).await; return; } - }; - - let kvs = match address.config.capabilities.allows_live_query_notifications() { - true => kvs.with_notifications(), - false => kvs, - }; - - let kvs = kvs - .with_strict_mode(address.config.strict) - .with_query_timeout(address.config.query_timeout) - .with_transaction_timeout(address.config.transaction_timeout) - .with_capabilities(address.config.capabilities); - - let kvs = Arc::new(kvs); - let mut vars = BTreeMap::new(); - let mut live_queries = HashMap::new(); - let mut session = Session::default().with_rt(true); - - let mut opt = EngineOptions::default(); - opt.tick_interval = address.config.tick_interval.unwrap_or(DEFAULT_TICK_INTERVAL); - let (_tasks, task_chans) = start_tasks(&opt, kvs.clone()); - - let mut notifications = kvs.notifications(); - let notification_stream = poll_fn(move |cx| match &mut notifications { - Some(rx) => rx.poll_next_unpin(cx), - None => Poll::Ready(None), - }); - - let streams = (route_rx.stream().map(Either::Left), notification_stream.map(Either::Right)); - let mut merged = streams.merge(); - - while let Some(either) = merged.next().await { - match either { - Either::Left(None) => break, // Received a shutdown signal - Either::Left(Some(route)) => { - match super::router( - route.request, - &kvs, - &mut session, - &mut vars, - &mut live_queries, - ) - .await - { - Ok(value) => { - let _ = route.response.into_send_async(Ok(value)).await; - } - Err(error) => { - let _ = route.response.into_send_async(Err(error)).await; - } - } - } - Either::Right(notification) => { - let id = notification.id; - if let Some(sender) = live_queries.get(&id) { - if sender.send(notification).await.is_err() { - live_queries.remove(&id); - if let Err(error) = - super::kill_live_query(&kvs, id, &session, vars.clone()).await - { - warn!("Failed to kill live query '{id}'; {error}"); - } - } - } + // If a root user is specified, setup the initial datastore credentials + if let Some(root) = configured_root { + if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await { + let _ = conn_tx.into_send_async(Err(error.into())).await; + return; } } + let _ = conn_tx.into_send_async(Ok(())).await; + kvs.with_auth_enabled(configured_root.is_some()) } + Err(error) => { + let _ = conn_tx.into_send_async(Err(error.into())).await; + return; + } + }; - // Stop maintenance tasks - for chan in task_chans { - if let Err(_empty_tuple) = chan.send(()) { - error!("Error sending shutdown signal to maintenance task"); - } - } + let kvs = match address.config.capabilities.allows_live_query_notifications() { + true => kvs.with_notifications(), + false => kvs, + }; + + let kvs = kvs + .with_strict_mode(address.config.strict) + .with_query_timeout(address.config.query_timeout) + .with_transaction_timeout(address.config.transaction_timeout) + .with_capabilities(address.config.capabilities); + + let kvs = Arc::new(kvs); + let mut vars = BTreeMap::new(); + let mut live_queries = HashMap::new(); + let mut session = Session::default().with_rt(true); + + let mut opt = EngineOptions::default(); + opt.tick_interval = address.config.tick_interval.unwrap_or(DEFAULT_TICK_INTERVAL); + let (_tasks, task_chans) = start_tasks(&opt, kvs.clone()); + + let mut notifications = kvs.notifications(); + let mut notification_stream = poll_fn(move |cx| match &mut notifications { + Some(rx) => rx.poll_next_unpin(cx), + None => Poll::Pending, }); + + let mut route_stream = route_rx.into_stream(); + + loop { + // use the less ergonomic futures::select as tokio::select is not available. + futures::select! { + route = route_stream.next().fuse() => { + let Some(route) = route else { + // termination requested + break + }; + + match super::router( + route.request, + &kvs, + &mut session, + &mut vars, + &mut live_queries, + ) + .await + { + Ok(value) => { + let _ = route.response.into_send_async(Ok(value)).await; + } + Err(error) => { + let _ = route.response.into_send_async(Err(error)).await; + } + } + } + notification = notification_stream.next().fuse() => { + let Some(notification) = notification else { + // TODO: maybe do something else then ignore a disconnected notification + // channel. + continue; + }; + + let id = notification.id; + if let Some(sender) = live_queries.get(&id) { + if sender.send(notification).await.is_err() { + live_queries.remove(&id); + if let Err(error) = + super::kill_live_query(&kvs, id, &session, vars.clone()).await + { + warn!("Failed to kill live query '{id}'; {error}"); + } + } + } + } + } + } + + // Stop maintenance tasks + for chan in task_chans { + if chan.send(()).is_err() { + error!("Error sending shutdown signal to maintenance task"); + } + } } diff --git a/lib/src/api/engine/remote/http/native.rs b/lib/src/api/engine/remote/http/native.rs index 951a2a1f..feb57f1e 100644 --- a/lib/src/api/engine/remote/http/native.rs +++ b/lib/src/api/engine/remote/http/native.rs @@ -67,7 +67,7 @@ impl Connection for Client { capacity => flume::bounded(capacity), }; - router(base_url, client, route_rx); + tokio::spawn(run_router(base_url, client, route_rx)); let mut features = HashSet::new(); features.insert(ExtraFeatures::Backup); @@ -94,30 +94,22 @@ impl Connection for Client { request: (0, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } } -pub(crate) fn router(base_url: Url, client: reqwest::Client, route_rx: Receiver>) { - tokio::spawn(async move { - let mut headers = HeaderMap::new(); - let mut vars = IndexMap::new(); - let mut auth = None; - let mut stream = route_rx.into_stream(); +pub(crate) async fn run_router(base_url: Url, client: reqwest::Client, route_rx: Receiver) { + let mut headers = HeaderMap::new(); + let mut vars = IndexMap::new(); + let mut auth = None; + let mut stream = route_rx.into_stream(); - while let Some(Some(route)) = stream.next().await { - let result = super::router( - route.request, - &base_url, - &client, - &mut headers, - &mut vars, - &mut auth, - ) - .await; - let _ = route.response.into_send_async(result).await; - } - }); + while let Some(route) = stream.next().await { + let result = + super::router(route.request, &base_url, &client, &mut headers, &mut vars, &mut auth) + .await; + let _ = route.response.into_send_async(result).await; + } } diff --git a/lib/src/api/engine/remote/http/wasm.rs b/lib/src/api/engine/remote/http/wasm.rs index 913c3936..ce621185 100644 --- a/lib/src/api/engine/remote/http/wasm.rs +++ b/lib/src/api/engine/remote/http/wasm.rs @@ -18,7 +18,6 @@ use reqwest::header::HeaderMap; use reqwest::ClientBuilder; use std::collections::HashSet; use std::future::Future; -use std::marker::PhantomData; use std::pin::Pin; use std::sync::atomic::AtomicI64; use std::sync::Arc; @@ -48,7 +47,7 @@ impl Connection for Client { let (conn_tx, conn_rx) = flume::bounded(1); - router(address, conn_tx, route_rx); + spawn_local(run_router(address, conn_tx, route_rx)); conn_rx.into_recv_async().await??; @@ -75,7 +74,7 @@ impl Connection for Client { request: (0, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } @@ -90,48 +89,39 @@ async fn client(base_url: &Url) -> Result { Ok(client) } -pub(crate) fn router( +pub(crate) async fn run_router( address: Endpoint, conn_tx: Sender>, - route_rx: Receiver>, + route_rx: Receiver, ) { - spawn_local(async move { - let base_url = address.url; + let base_url = address.url; - let client = match client(&base_url).await { - Ok(client) => { - let _ = conn_tx.into_send_async(Ok(())).await; - client + let client = match client(&base_url).await { + Ok(client) => { + let _ = conn_tx.into_send_async(Ok(())).await; + client + } + Err(error) => { + let _ = conn_tx.into_send_async(Err(error)).await; + return; + } + }; + + let mut headers = HeaderMap::new(); + let mut vars = IndexMap::new(); + let mut auth = None; + let mut stream = route_rx.into_stream(); + + while let Some(route) = stream.next().await { + match super::router(route.request, &base_url, &client, &mut headers, &mut vars, &mut auth) + .await + { + Ok(value) => { + let _ = route.response.into_send_async(Ok(value)).await; } Err(error) => { - let _ = conn_tx.into_send_async(Err(error)).await; - return; - } - }; - - let mut headers = HeaderMap::new(); - let mut vars = IndexMap::new(); - let mut auth = None; - let mut stream = route_rx.into_stream(); - - while let Some(Some(route)) = stream.next().await { - match super::router( - route.request, - &base_url, - &client, - &mut headers, - &mut vars, - &mut auth, - ) - .await - { - Ok(value) => { - let _ = route.response.into_send_async(Ok(value)).await; - } - Err(error) => { - let _ = route.response.into_send_async(Err(error)).await; - } + let _ = route.response.into_send_async(Err(error)).await; } } - }); + } } diff --git a/lib/src/api/engine/remote/ws/mod.rs b/lib/src/api/engine/remote/ws/mod.rs index 3a86eb88..97fc6d66 100644 --- a/lib/src/api/engine/remote/ws/mod.rs +++ b/lib/src/api/engine/remote/ws/mod.rs @@ -20,20 +20,160 @@ use crate::dbs::Status; use crate::method::Stats; use crate::opt::IntoEndpoint; use crate::sql::Value; +use bincode::Options as _; +use flume::Sender; use indexmap::IndexMap; use revision::revisioned; use revision::Revisioned; use serde::de::DeserializeOwned; +use serde::ser::SerializeMap; use serde::Deserialize; +use serde::Serialize; +use std::collections::HashMap; use std::io::Read; use std::marker::PhantomData; use std::time::Duration; +use surrealdb_core::dbs::Notification as CoreNotification; +use trice::Instant; +use uuid::Uuid; pub(crate) const PATH: &str = "rpc"; const PING_INTERVAL: Duration = Duration::from_secs(5); const PING_METHOD: &str = "ping"; const REVISION_HEADER: &str = "revision"; +/// A struct which will be serialized as a map to behave like the previously used BTreeMap. +/// +/// This struct serializes as if it is a surrealdb_core::sql::Value::Object. +#[derive(Debug)] +struct RouterRequest { + id: Option, + method: Value, + params: Option, +} + +impl Serialize for RouterRequest { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + struct InnerRequest<'a>(&'a RouterRequest); + + impl Serialize for InnerRequest<'_> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let size = 1 + self.0.id.is_some() as usize + self.0.params.is_some() as usize; + let mut map = serializer.serialize_map(Some(size))?; + if let Some(id) = self.0.id.as_ref() { + map.serialize_entry("id", id)?; + } + map.serialize_entry("method", &self.0.method)?; + if let Some(params) = self.0.params.as_ref() { + map.serialize_entry("params", params)?; + } + map.end() + } + } + + serializer.serialize_newtype_variant("Value", 9, "Object", &InnerRequest(self)) + } +} + +impl Revisioned for RouterRequest { + fn revision() -> u16 { + 1 + } + + fn serialize_revisioned( + &self, + w: &mut W, + ) -> std::result::Result<(), revision::Error> { + // version + Revisioned::serialize_revisioned(&1u32, w)?; + // object variant + Revisioned::serialize_revisioned(&9u32, w)?; + // object wrapper version + Revisioned::serialize_revisioned(&1u32, w)?; + + let size = 1 + self.id.is_some() as usize + self.params.is_some() as usize; + size.serialize_revisioned(w)?; + + let serializer = bincode::options() + .with_no_limit() + .with_little_endian() + .with_varint_encoding() + .reject_trailing_bytes(); + + if let Some(x) = self.id.as_ref() { + serializer + .serialize_into(&mut *w, "id") + .map_err(|err| revision::Error::Serialize(err.to_string()))?; + x.serialize_revisioned(w)?; + } + serializer + .serialize_into(&mut *w, "method") + .map_err(|err| revision::Error::Serialize(err.to_string()))?; + self.method.serialize_revisioned(w)?; + + if let Some(x) = self.params.as_ref() { + serializer + .serialize_into(&mut *w, "params") + .map_err(|err| revision::Error::Serialize(err.to_string()))?; + x.serialize_revisioned(w)?; + } + + Ok(()) + } + + fn deserialize_revisioned(_: &mut R) -> std::result::Result + where + Self: Sized, + { + panic!("deliberately unimplemented"); + } +} + +struct RouterState { + var_stash: IndexMap, + /// Vars currently set by the set method, + vars: IndexMap, + /// Messages which aught to be replayed on a reconnect. + replay: IndexMap, + /// Pending live queries + live_queries: HashMap>, + + routes: HashMap>)>, + + last_activity: Instant, + + sink: Sink, + stream: Stream, +} + +impl RouterState { + pub fn new(sink: Sink, stream: Stream) -> Self { + RouterState { + var_stash: IndexMap::new(), + vars: IndexMap::new(), + replay: IndexMap::new(), + live_queries: HashMap::new(), + routes: HashMap::new(), + last_activity: Instant::now(), + sink, + stream, + } + } +} + +enum HandleResult { + /// Socket disconnected, should continue to reconnect + Disconnected, + /// Nothing wrong continue as normal. + Ok, +} + /// The WS scheme used to connect to `ws://` endpoints #[derive(Debug)] pub struct Ws; @@ -156,7 +296,10 @@ pub(crate) struct Response { pub(crate) result: ServerResult, } -fn serialize(value: &Value, revisioned: bool) -> Result> { +fn serialize(value: &V, revisioned: bool) -> Result> +where + V: serde::Serialize + Revisioned, +{ if revisioned { let mut buf = Vec::new(); value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?; @@ -177,3 +320,67 @@ where bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?; crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into())) } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use revision::Revisioned; + use surrealdb_core::sql::Value; + + use super::RouterRequest; + + fn assert_converts(req: &RouterRequest, s: S, d: D) + where + S: FnOnce(&RouterRequest) -> I, + D: FnOnce(I) -> Value, + { + let ser = s(req); + let val = d(ser); + let Value::Object(obj) = val else { + panic!("not an object"); + }; + assert_eq!(obj.get("id").cloned(), req.id); + assert_eq!(obj.get("method").unwrap().clone(), req.method); + assert_eq!(obj.get("params").cloned(), req.params); + } + + #[test] + fn router_request_value_conversion() { + let request = RouterRequest { + id: Some(Value::from(1234i64)), + method: Value::from("request"), + params: Some(vec![Value::from(1234i64), Value::from("request")].into()), + }; + + println!("test convert bincode"); + + assert_converts( + &request, + |i| bincode::serialize(i).unwrap(), + |b| bincode::deserialize(&b).unwrap(), + ); + + println!("test convert json"); + + assert_converts( + &request, + |i| serde_json::to_string(i).unwrap(), + |b| serde_json::from_str(&b).unwrap(), + ); + + println!("test convert revisioned"); + + assert_converts( + &request, + |i| { + let mut buf = Vec::new(); + i.serialize_revisioned(&mut Cursor::new(&mut buf)).unwrap(); + buf + }, + |b| Value::deserialize_revisioned(&mut Cursor::new(b)).unwrap(), + ); + + println!("done"); + } +} diff --git a/lib/src/api/engine/remote/ws/native.rs b/lib/src/api/engine/remote/ws/native.rs index 819a085e..891c228a 100644 --- a/lib/src/api/engine/remote/ws/native.rs +++ b/lib/src/api/engine/remote/ws/native.rs @@ -1,5 +1,6 @@ use super::PATH; use super::{deserialize, serialize}; +use super::{HandleResult, RouterRequest}; use crate::api::conn::Connection; use crate::api::conn::DbResponse; use crate::api::conn::Method; @@ -23,16 +24,12 @@ use crate::engine::IntervalStream; use crate::opt::WaitFor; use crate::sql::Value; use flume::Receiver; -use futures::stream::SplitSink; +use futures::stream::{SplitSink, SplitStream}; use futures::SinkExt; use futures::StreamExt; -use futures_concurrency::stream::Merge as _; -use indexmap::IndexMap; use revision::revisioned; use serde::Deserialize; use std::collections::hash_map::Entry; -use std::collections::BTreeMap; -use std::collections::HashMap; use std::collections::HashSet; use std::future::Future; use std::mem; @@ -55,19 +52,15 @@ use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; use trice::Instant; -type WsResult = std::result::Result; - pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; // 64 MiB pub(crate) const MAX_FRAME_SIZE: usize = 16 << 20; // 16 MiB pub(crate) const WRITE_BUFFER_SIZE: usize = 128000; // tungstenite default pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = WRITE_BUFFER_SIZE + MAX_MESSAGE_SIZE; // Recommended max according to tungstenite docs pub(crate) const NAGLE_ALG: bool = false; -pub(crate) enum Either { - Request(Option), - Response(WsResult), - Ping, -} +type MessageSink = SplitSink>, Message>; +type MessageStream = SplitStream>>; +type RouterState = super::RouterState; #[cfg(any(feature = "native-tls", feature = "rustls"))] impl From for Connector { @@ -144,7 +137,7 @@ impl Connection for Client { capacity => flume::bounded(capacity), }; - router(address, maybe_connector, capacity, config, socket, route_rx); + tokio::spawn(run_router(address, maybe_connector, capacity, config, socket, route_rx)); let mut features = HashSet::new(); features.insert(ExtraFeatures::LiveQueries); @@ -172,384 +165,392 @@ impl Connection for Client { request: (self.id, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } } -#[allow(clippy::too_many_lines)] -pub(crate) fn router( - endpoint: Endpoint, - maybe_connector: Option, - capacity: usize, - config: WebSocketConfig, - mut socket: WebSocketStream>, - route_rx: Receiver>, -) { - tokio::spawn(async move { - let ping = { - let mut request = BTreeMap::new(); - request.insert("method".to_owned(), PING_METHOD.into()); - let value = Value::from(request); - let value = serialize(&value, endpoint.supports_revision).unwrap(); - Message::Binary(value) - }; - - let mut var_stash = IndexMap::new(); - let mut vars = IndexMap::new(); - let mut replay = IndexMap::new(); - - 'router: loop { - let (socket_sink, socket_stream) = socket.split(); - let mut socket_sink = Socket(Some(socket_sink)); - - if let Socket(Some(socket_sink)) = &mut socket_sink { - let mut routes = match capacity { - 0 => HashMap::new(), - capacity => HashMap::with_capacity(capacity), - }; - let mut live_queries = HashMap::new(); - - let mut interval = time::interval(PING_INTERVAL); - // don't bombard the server with pings if we miss some ticks - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - - let pinger = IntervalStream::new(interval); - - let streams = ( - socket_stream.map(Either::Response), - route_rx.stream().map(Either::Request), - pinger.map(|_| Either::Ping), - ); - - let mut merged = streams.merge(); - let mut last_activity = Instant::now(); - - while let Some(either) = merged.next().await { - match either { - Either::Request(Some(Route { - request, - response, - })) => { - let (id, method, param) = request; - let params = match param.query { - Some((query, bindings)) => { - vec![query.into(), bindings.into()] - } - None => param.other, - }; - match method { - Method::Set => { - if let [Value::Strand(key), value] = ¶ms[..2] { - var_stash.insert(id, (key.0.clone(), value.clone())); - } - } - Method::Unset => { - if let [Value::Strand(key)] = ¶ms[..1] { - vars.swap_remove(&key.0); - } - } - Method::Live => { - if let Some(sender) = param.notification_sender { - if let [Value::Uuid(id)] = ¶ms[..1] { - live_queries.insert(*id, sender); - } - } - if response - .into_send_async(Ok(DbResponse::Other(Value::None))) - .await - .is_err() - { - trace!("Receiver dropped"); - } - // There is nothing to send to the server here - continue; - } - Method::Kill => { - if let [Value::Uuid(id)] = ¶ms[..1] { - live_queries.remove(id); - } - } - _ => {} - } - let method_str = match method { - Method::Health => PING_METHOD, - _ => method.as_str(), - }; - let message = { - let mut request = BTreeMap::new(); - request.insert("id".to_owned(), Value::from(id)); - request.insert("method".to_owned(), method_str.into()); - if !params.is_empty() { - request.insert("params".to_owned(), params.into()); - } - let payload = Value::from(request); - trace!("Request {payload}"); - let payload = - serialize(&payload, endpoint.supports_revision).unwrap(); - Message::Binary(payload) - }; - if let Method::Authenticate - | Method::Invalidate - | Method::Signin - | Method::Signup - | Method::Use = method - { - replay.insert(method, message.clone()); - } - match socket_sink.send(message).await { - Ok(..) => { - last_activity = Instant::now(); - match routes.entry(id) { - Entry::Vacant(entry) => { - // Register query route - entry.insert((method, response)); - } - Entry::Occupied(..) => { - let error = Error::DuplicateRequestId(id); - if response - .into_send_async(Err(error.into())) - .await - .is_err() - { - trace!("Receiver dropped"); - } - } - } - } - Err(error) => { - let error = Error::Ws(error.to_string()); - if response.into_send_async(Err(error.into())).await.is_err() { - trace!("Receiver dropped"); - } - break; - } - } - } - Either::Response(result) => { - last_activity = Instant::now(); - match result { - Ok(message) => { - match Response::try_from(&message, endpoint.supports_revision) { - Ok(option) => { - // We are only interested in responses that are not empty - if let Some(response) = option { - trace!("{response:?}"); - match response.id { - // If `id` is set this is a normal response - Some(id) => { - if let Ok(id) = id.coerce_to_i64() { - // We can only route responses with IDs - if let Some((method, sender)) = - routes.remove(&id) - { - if matches!(method, Method::Set) { - if let Some((key, value)) = - var_stash.swap_remove(&id) - { - vars.insert(key, value); - } - } - // Send the response back to the caller - let mut response = response.result; - if matches!(method, Method::Insert) - { - // For insert, we need to flatten single responses in an array - if let Ok(Data::Other( - Value::Array(value), - )) = &mut response - { - if let [value] = - &mut value.0[..] - { - response = - Ok(Data::Other( - mem::take( - value, - ), - )); - } - } - } - let _res = sender - .into_send_async( - DbResponse::from(response), - ) - .await; - } - } - } - // If `id` is not set, this may be a live query notification - None => match response.result { - Ok(Data::Live(notification)) => { - let live_query_id = notification.id; - // Check if this live query is registered - if let Some(sender) = - live_queries.get(&live_query_id) - { - // Send the notification back to the caller or kill live query if the receiver is already dropped - if sender - .send(notification) - .await - .is_err() - { - live_queries - .remove(&live_query_id); - let kill = { - let mut request = - BTreeMap::new(); - request.insert( - "method".to_owned(), - Method::Kill - .as_str() - .into(), - ); - request.insert( - "params".to_owned(), - vec![Value::from( - live_query_id, - )] - .into(), - ); - let value = - Value::from(request); - let value = serialize( - &value, - endpoint - .supports_revision, - ) - .unwrap(); - Message::Binary(value) - }; - if let Err(error) = - socket_sink.send(kill).await - { - trace!("failed to send kill query to the server; {error:?}"); - break; - } - } - } - } - Ok(..) => { /* Ignored responses like pings */ - } - Err(error) => error!("{error:?}"), - }, - } - } - } - Err(error) => { - #[revisioned(revision = 1)] - #[derive(Deserialize)] - struct Response { - id: Option, - } - - // Let's try to find out the ID of the response that failed to deserialise - if let Message::Binary(binary) = message { - if let Ok(Response { - id, - }) = deserialize( - &mut &binary[..], - endpoint.supports_revision, - ) { - // Return an error if an ID was returned - if let Some(Ok(id)) = - id.map(Value::coerce_to_i64) - { - if let Some((_method, sender)) = - routes.remove(&id) - { - let _res = sender - .into_send_async(Err(error)) - .await; - } - } - } else { - // Unfortunately, we don't know which response failed to deserialize - warn!( - "Failed to deserialise message; {error:?}" - ); - } - } - } - } - } - Err(error) => { - match error { - WsError::ConnectionClosed => { - trace!("Connection successfully closed on the server"); - } - error => { - trace!("{error}"); - } - } - break; - } - } - } - Either::Ping => { - // only ping if we haven't talked to the server recently - if last_activity.elapsed() >= PING_INTERVAL { - trace!("Pinging the server"); - if let Err(error) = socket_sink.send(ping.clone()).await { - trace!("failed to ping the server; {error:?}"); - break; - } - } - } - // Close connection request received - Either::Request(None) => { - match socket_sink.send(Message::Close(None)).await { - Ok(..) => trace!("Connection closed successfully"), - Err(error) => { - warn!("Failed to close database connection; {error}") - } - } - break 'router; - } - } +async fn router_handle_route( + Route { + request, + response, + }: Route, + state: &mut RouterState, + endpoint: &Endpoint, +) -> HandleResult { + let (id, method, param) = request; + let params = match param.query { + Some((query, bindings)) => { + vec![query.into(), bindings.into()] + } + None => param.other, + }; + match method { + Method::Set => { + if let [Value::Strand(key), value] = ¶ms[..2] { + state.var_stash.insert(id, (key.0.clone(), value.clone())); + } + } + Method::Unset => { + if let [Value::Strand(key)] = ¶ms[..1] { + state.vars.swap_remove(&key.0); + } + } + Method::Live => { + if let Some(sender) = param.notification_sender { + if let [Value::Uuid(id)] = ¶ms[..1] { + state.live_queries.insert(id.0, sender); } } + if response.clone().into_send_async(Ok(DbResponse::Other(Value::None))).await.is_err() { + trace!("Receiver dropped"); + } + // There is nothing to send to the server here + } + Method::Kill => { + if let [Value::Uuid(id)] = ¶ms[..1] { + state.live_queries.remove(id); + } + } + _ => {} + } + let method_str = match method { + Method::Health => PING_METHOD, + _ => method.as_str(), + }; + let message = { + let request = RouterRequest { + id: Some(Value::from(id)), + method: method_str.into(), + params: (!params.is_empty()).then(|| params.into()), + }; - 'reconnect: loop { - trace!("Reconnecting..."); - match connect(&endpoint, Some(config), maybe_connector.clone()).await { - Ok(s) => { - socket = s; - for (_, message) in &replay { - if let Err(error) = socket.send(message.clone()).await { - trace!("{error}"); - time::sleep(time::Duration::from_secs(1)).await; - continue 'reconnect; - } - } - for (key, value) in &vars { - let mut request = BTreeMap::new(); - request.insert("method".to_owned(), Method::Set.as_str().into()); - request.insert( - "params".to_owned(), - vec![key.as_str().into(), value.clone()].into(), - ); - let payload = Value::from(request); - trace!("Request {payload}"); - if let Err(error) = socket.send(Message::Binary(payload.into())).await { - trace!("{error}"); - time::sleep(time::Duration::from_secs(1)).await; - continue 'reconnect; - } - } - trace!("Reconnected successfully"); - break; - } - Err(error) => { - trace!("Failed to reconnect; {error}"); - time::sleep(time::Duration::from_secs(1)).await; + trace!("Request {:?}", request); + let payload = serialize(&request, endpoint.supports_revision).unwrap(); + Message::Binary(payload) + }; + if let Method::Authenticate + | Method::Invalidate + | Method::Signin + | Method::Signup + | Method::Use = method + { + state.replay.insert(method, message.clone()); + } + match state.sink.send(message).await { + Ok(_) => { + state.last_activity = Instant::now(); + match state.routes.entry(id) { + Entry::Vacant(entry) => { + // Register query route + entry.insert((method, response)); + } + Entry::Occupied(..) => { + let error = Error::DuplicateRequestId(id); + if response.into_send_async(Err(error.into())).await.is_err() { + trace!("Receiver dropped"); } } } } - }); + Err(error) => { + let error = Error::Ws(error.to_string()); + if response.into_send_async(Err(error.into())).await.is_err() { + trace!("Receiver dropped"); + } + return HandleResult::Disconnected; + } + } + HandleResult::Ok +} + +async fn router_handle_response( + response: Message, + state: &mut RouterState, + endpoint: &Endpoint, +) -> HandleResult { + match Response::try_from(&response, endpoint.supports_revision) { + Ok(option) => { + // We are only interested in responses that are not empty + if let Some(response) = option { + trace!("{response:?}"); + match response.id { + // If `id` is set this is a normal response + Some(id) => { + if let Ok(id) = id.coerce_to_i64() { + // We can only route responses with IDs + if let Some((method, sender)) = state.routes.remove(&id) { + if matches!(method, Method::Set) { + if let Some((key, value)) = state.var_stash.swap_remove(&id) { + state.vars.insert(key, value); + } + } + // Send the response back to the caller + let mut response = response.result; + if matches!(method, Method::Insert) { + // For insert, we need to flatten single responses in an array + if let Ok(Data::Other(Value::Array(value))) = &mut response { + if let [value] = &mut value.0[..] { + response = Ok(Data::Other(mem::take(value))); + } + } + } + let _res = sender.into_send_async(DbResponse::from(response)).await; + } + } + } + // If `id` is not set, this may be a live query notification + None => { + match response.result { + Ok(Data::Live(notification)) => { + let live_query_id = notification.id; + // Check if this live query is registered + if let Some(sender) = state.live_queries.get(&live_query_id) { + // Send the notification back to the caller or kill live query if the receiver is already dropped + if sender.send(notification).await.is_err() { + state.live_queries.remove(&live_query_id); + let kill = { + let request = RouterRequest { + id: None, + method: Method::Kill.as_str().into(), + params: Some( + vec![Value::from(live_query_id)].into(), + ), + }; + let value = + serialize(&request, endpoint.supports_revision) + .unwrap(); + Message::Binary(value) + }; + if let Err(error) = state.sink.send(kill).await { + trace!("failed to send kill query to the server; {error:?}"); + return HandleResult::Disconnected; + } + } + } + } + Ok(..) => { /* Ignored responses like pings */ } + Err(error) => error!("{error:?}"), + } + } + } + } + } + Err(error) => { + #[revisioned(revision = 1)] + #[derive(Deserialize)] + struct Response { + id: Option, + } + + // Let's try to find out the ID of the response that failed to deserialise + if let Message::Binary(binary) = response { + if let Ok(Response { + id, + }) = deserialize(&mut &binary[..], endpoint.supports_revision) + { + // Return an error if an ID was returned + if let Some(Ok(id)) = id.map(Value::coerce_to_i64) { + if let Some((_method, sender)) = state.routes.remove(&id) { + let _res = sender.into_send_async(Err(error)).await; + } + } + } else { + // Unfortunately, we don't know which response failed to deserialize + warn!("Failed to deserialise message; {error:?}"); + } + } + } + } + HandleResult::Ok +} + +async fn router_reconnect( + maybe_connector: &Option, + config: &WebSocketConfig, + state: &mut RouterState, + endpoint: &Endpoint, +) { + loop { + trace!("Reconnecting..."); + match connect(endpoint, Some(*config), maybe_connector.clone()).await { + Ok(s) => { + let (new_sink, new_stream) = s.split(); + state.sink = new_sink; + state.stream = new_stream; + for (_, message) in &state.replay { + if let Err(error) = state.sink.send(message.clone()).await { + trace!("{error}"); + time::sleep(time::Duration::from_secs(1)).await; + continue; + } + } + for (key, value) in &state.vars { + let request = RouterRequest { + id: None, + method: Method::Set.as_str().into(), + params: Some(vec![key.as_str().into(), value.clone()].into()), + }; + trace!("Request {:?}", request); + let payload = serialize(&request, endpoint.supports_revision).unwrap(); + if let Err(error) = state.sink.send(Message::Binary(payload)).await { + trace!("{error}"); + time::sleep(time::Duration::from_secs(1)).await; + continue; + } + } + trace!("Reconnected successfully"); + break; + } + Err(error) => { + trace!("Failed to reconnect; {error}"); + time::sleep(time::Duration::from_secs(1)).await; + } + } + } +} + +pub(crate) async fn run_router( + endpoint: Endpoint, + maybe_connector: Option, + _capacity: usize, + config: WebSocketConfig, + socket: WebSocketStream>, + route_rx: Receiver, +) { + let ping = { + let request = RouterRequest { + id: None, + method: PING_METHOD.into(), + params: None, + }; + let value = serialize(&request, endpoint.supports_revision).unwrap(); + Message::Binary(value) + }; + + let (socket_sink, socket_stream) = socket.split(); + let mut state = RouterState::new(socket_sink, socket_stream); + let mut route_stream = route_rx.into_stream(); + + 'router: loop { + let mut interval = time::interval(PING_INTERVAL); + // don't bombard the server with pings if we miss some ticks + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + let mut pinger = IntervalStream::new(interval); + // Turn into a stream instead of calling recv_async + // The stream seems to be able to keep some state which would otherwise need to be + // recreated with each next. + + state.last_activity = Instant::now(); + state.live_queries.clear(); + state.routes.clear(); + + loop { + tokio::select! { + route = route_stream.next() => { + // handle incoming route + + let Some(response) = route else { + // route returned none, frontend dropped the channel, meaning the router + // should quit. + match state.sink.send(Message::Close(None)).await { + Ok(..) => trace!("Connection closed successfully"), + Err(error) => { + warn!("Failed to close database connection; {error}") + } + } + break 'router; + }; + + match router_handle_route(response, &mut state, &endpoint).await { + HandleResult::Ok => {}, + HandleResult::Disconnected => { + router_reconnect( + &maybe_connector, + &config, + &mut state, + &endpoint, + ) + .await; + continue 'router; + } + } + } + result = state.stream.next() => { + // Handle result from database. + + let Some(result) = result else { + // stream returned none meaning the connection dropped, try to reconnect. + router_reconnect( + &maybe_connector, + &config, + &mut state, + &endpoint, + ) + .await; + continue 'router; + }; + + state.last_activity = Instant::now(); + match result { + Ok(message) => { + match router_handle_response(message, &mut state, &endpoint).await { + HandleResult::Ok => continue, + HandleResult::Disconnected => { + router_reconnect( + &maybe_connector, + &config, + &mut state, + &endpoint, + ) + .await; + continue 'router; + } + } + } + Err(error) => { + match error { + WsError::ConnectionClosed => { + trace!("Connection successfully closed on the server"); + } + error => { + trace!("{error}"); + } + } + router_reconnect( + &maybe_connector, + &config, + &mut state, + &endpoint, + ) + .await; + continue 'router; + } + } + } + _ = pinger.next() => { + // only ping if we haven't talked to the server recently + if state.last_activity.elapsed() >= PING_INTERVAL { + trace!("Pinging the server"); + if let Err(error) = state.sink.send(ping.clone()).await { + trace!("failed to ping the server; {error:?}"); + router_reconnect( + &maybe_connector, + &config, + &mut state, + &endpoint, + ) + .await; + continue 'router; + } + } + + } + } + } + } } impl Response { @@ -588,8 +589,6 @@ impl Response { } } -pub struct Socket(Option>, Message>>); - #[cfg(test)] mod tests { use super::serialize; diff --git a/lib/src/api/engine/remote/ws/wasm.rs b/lib/src/api/engine/remote/ws/wasm.rs index e8396f1d..f6813b40 100644 --- a/lib/src/api/engine/remote/ws/wasm.rs +++ b/lib/src/api/engine/remote/ws/wasm.rs @@ -1,5 +1,5 @@ -use super::PATH; use super::{deserialize, serialize}; +use super::{HandleResult, PATH}; use crate::api::conn::Connection; use crate::api::conn::DbResponse; use crate::api::conn::Method; @@ -16,27 +16,26 @@ use crate::api::ExtraFeatures; use crate::api::OnceLockExt; use crate::api::Result; use crate::api::Surreal; -use crate::engine::remote::ws::Data; +use crate::engine::remote::ws::{Data, RouterRequest}; use crate::engine::IntervalStream; use crate::opt::WaitFor; use crate::sql::Value; use flume::Receiver; use flume::Sender; +use futures::stream::{SplitSink, SplitStream}; +use futures::FutureExt; use futures::SinkExt; use futures::StreamExt; -use futures_concurrency::stream::Merge as _; -use indexmap::IndexMap; use pharos::Channel; +use pharos::Events; use pharos::Observable; use pharos::ObserveConfig; use revision::revisioned; use serde::Deserialize; use std::collections::hash_map::Entry; use std::collections::BTreeMap; -use std::collections::HashMap; use std::collections::HashSet; use std::future::Future; -use std::marker::PhantomData; use std::mem; use std::pin::Pin; use std::sync::atomic::AtomicI64; @@ -48,16 +47,13 @@ use trice::Instant; use wasm_bindgen_futures::spawn_local; use wasmtimer::tokio as time; use wasmtimer::tokio::MissedTickBehavior; -use ws_stream_wasm::WsEvent; use ws_stream_wasm::WsMessage as Message; use ws_stream_wasm::WsMeta; +use ws_stream_wasm::{WsEvent, WsStream}; -pub(crate) enum Either { - Request(Option), - Response(Message), - Event(WsEvent), - Ping, -} +type MessageStream = SplitStream; +type MessageSink = SplitSink; +type RouterState = super::RouterState; impl crate::api::Connection for Client {} @@ -83,7 +79,7 @@ impl Connection for Client { let (conn_tx, conn_rx) = flume::bounded(1); - router(address, capacity, conn_tx, route_rx); + spawn_local(run_router(address, capacity, conn_tx, route_rx)); conn_rx.into_recv_async().await??; @@ -113,297 +109,367 @@ impl Connection for Client { request: (self.id, self.method, param), response: sender, }; - router.sender.send_async(Some(route)).await?; + router.sender.send_async(route).await?; Ok(receiver) }) } } -pub(crate) fn router( - endpoint: Endpoint, +async fn router_handle_request( + Route { + request, + response, + }: Route, + state: &mut RouterState, + endpoint: &Endpoint, +) -> HandleResult { + let (id, method, param) = request; + let params = match param.query { + Some((query, bindings)) => { + vec![query.into(), bindings.into()] + } + None => param.other, + }; + match method { + Method::Set => { + if let [Value::Strand(key), value] = ¶ms[..2] { + state.var_stash.insert(id, (key.0.clone(), value.clone())); + } + } + Method::Unset => { + if let [Value::Strand(key)] = ¶ms[..1] { + state.vars.swap_remove(&key.0); + } + } + Method::Live => { + if let Some(sender) = param.notification_sender { + if let [Value::Uuid(id)] = ¶ms[..1] { + state.live_queries.insert(id.0, sender); + } + } + if response.into_send_async(Ok(DbResponse::Other(Value::None))).await.is_err() { + trace!("Receiver dropped"); + } + // There is nothing to send to the server here + return HandleResult::Ok; + } + Method::Kill => { + if let [Value::Uuid(id)] = ¶ms[..1] { + state.live_queries.remove(id); + } + } + _ => {} + } + let method_str = match method { + Method::Health => PING_METHOD, + _ => method.as_str(), + }; + let message = { + let request = RouterRequest { + id: Some(Value::from(id)), + method: method_str.into(), + params: (!params.is_empty()).then(|| params.into()), + }; + trace!("Request {:?}", request); + let payload = serialize(&request, endpoint.supports_revision).unwrap(); + Message::Binary(payload) + }; + if let Method::Authenticate + | Method::Invalidate + | Method::Signin + | Method::Signup + | Method::Use = method + { + state.replay.insert(method, message.clone()); + } + match state.sink.send(message).await { + Ok(..) => { + state.last_activity = Instant::now(); + match state.routes.entry(id) { + Entry::Vacant(entry) => { + entry.insert((method, response)); + } + Entry::Occupied(..) => { + let error = Error::DuplicateRequestId(id); + if response.into_send_async(Err(error.into())).await.is_err() { + trace!("Receiver dropped"); + } + } + } + } + Err(error) => { + let error = Error::Ws(error.to_string()); + if response.into_send_async(Err(error.into())).await.is_err() { + trace!("Receiver dropped"); + } + return HandleResult::Disconnected; + } + } + HandleResult::Ok +} + +async fn router_handle_response( + response: Message, + state: &mut RouterState, + endpoint: &Endpoint, +) -> HandleResult { + match Response::try_from(&response, endpoint.supports_revision) { + Ok(option) => { + // We are only interested in responses that are not empty + if let Some(response) = option { + trace!("{response:?}"); + match response.id { + // If `id` is set this is a normal response + Some(id) => { + if let Ok(id) = id.coerce_to_i64() { + // We can only route responses with IDs + if let Some((method, sender)) = state.routes.remove(&id) { + if matches!(method, Method::Set) { + if let Some((key, value)) = state.var_stash.swap_remove(&id) { + state.vars.insert(key, value); + } + } + // Send the response back to the caller + let mut response = response.result; + if matches!(method, Method::Insert) { + // For insert, we need to flatten single responses in an array + if let Ok(Data::Other(Value::Array(value))) = &mut response { + if let [value] = &mut value.0[..] { + response = Ok(Data::Other(mem::take(value))); + } + } + } + let _res = sender.into_send_async(DbResponse::from(response)).await; + } + } + } + // If `id` is not set, this may be a live query notification + None => match response.result { + Ok(Data::Live(notification)) => { + let live_query_id = notification.id; + // Check if this live query is registered + if let Some(sender) = state.live_queries.get(&live_query_id) { + // Send the notification back to the caller or kill live query if the receiver is already dropped + if sender.send(notification).await.is_err() { + state.live_queries.remove(&live_query_id); + let kill = { + let request = RouterRequest { + id: None, + method: Method::Kill.as_str().into(), + params: Some(vec![Value::from(live_query_id)].into()), + }; + let value = serialize(&request, endpoint.supports_revision) + .unwrap(); + Message::Binary(value) + }; + if let Err(error) = state.sink.send(kill).await { + trace!( + "failed to send kill query to the server; {error:?}" + ); + return HandleResult::Disconnected; + } + } + } + } + Ok(..) => { /* Ignored responses like pings */ } + Err(error) => error!("{error:?}"), + }, + } + } + } + Err(error) => { + #[derive(Deserialize)] + #[revisioned(revision = 1)] + struct Response { + id: Option, + } + + // Let's try to find out the ID of the response that failed to deserialise + if let Message::Binary(binary) = response { + if let Ok(Response { + id, + }) = deserialize(&mut &binary[..], endpoint.supports_revision) + { + // Return an error if an ID was returned + if let Some(Ok(id)) = id.map(Value::coerce_to_i64) { + if let Some((_method, sender)) = state.routes.remove(&id) { + let _res = sender.into_send_async(Err(error)).await; + } + } + } else { + // Unfortunately, we don't know which response failed to deserialize + warn!("Failed to deserialise message; {error:?}"); + } + } + } + } + HandleResult::Ok +} + +async fn router_reconnect( + state: &mut RouterState, + events: &mut Events, + endpoint: &Endpoint, capacity: usize, - conn_tx: Sender>, - route_rx: Receiver>, ) { - spawn_local(async move { + loop { + trace!("Reconnecting..."); let connect = match endpoint.supports_revision { true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await, false => WsMeta::connect(&endpoint.url, None).await, }; - let (mut ws, mut socket) = match connect { - Ok(pair) => pair, + match connect { + Ok((mut meta, stream)) => { + let (new_sink, new_stream) = stream.split(); + state.sink = new_sink; + state.stream = new_stream; + *events = { + let result = match capacity { + 0 => meta.observe(ObserveConfig::default()).await, + capacity => meta.observe(Channel::Bounded(capacity).into()).await, + }; + match result { + Ok(events) => events, + Err(error) => { + trace!("{error}"); + time::sleep(Duration::from_secs(1)).await; + continue; + } + } + }; + for (_, message) in &state.replay { + if let Err(error) = state.sink.send(message.clone()).await { + trace!("{error}"); + time::sleep(Duration::from_secs(1)).await; + continue; + } + } + for (key, value) in &state.vars { + let request = RouterRequest { + id: None, + method: Method::Set.as_str().into(), + params: Some(vec![key.as_str().into(), value.clone()].into()), + }; + trace!("Request {:?}", request); + let serialize = serialize(&request, false).unwrap(); + if let Err(error) = state.sink.send(Message::Binary(serialize)).await { + trace!("{error}"); + time::sleep(Duration::from_secs(1)).await; + continue; + } + } + trace!("Reconnected successfully"); + break; + } + Err(error) => { + trace!("Failed to reconnect; {error}"); + time::sleep(Duration::from_secs(1)).await; + } + } + } +} + +pub(crate) async fn run_router( + endpoint: Endpoint, + capacity: usize, + conn_tx: Sender>, + route_rx: Receiver, +) { + let connect = match endpoint.supports_revision { + true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await, + false => WsMeta::connect(&endpoint.url, None).await, + }; + let (mut ws, socket) = match connect { + Ok(pair) => pair, + Err(error) => { + let _ = conn_tx.into_send_async(Err(error.into())).await; + return; + } + }; + + let mut events = { + let result = match capacity { + 0 => ws.observe(ObserveConfig::default()).await, + capacity => ws.observe(Channel::Bounded(capacity).into()).await, + }; + match result { + Ok(events) => events, Err(error) => { let _ = conn_tx.into_send_async(Err(error.into())).await; return; } - }; + } + }; - let mut events = { - let result = match capacity { - 0 => ws.observe(ObserveConfig::default()).await, - capacity => ws.observe(Channel::Bounded(capacity).into()).await, - }; - match result { - Ok(events) => events, - Err(error) => { - let _ = conn_tx.into_send_async(Err(error.into())).await; - return; + let _ = conn_tx.into_send_async(Ok(())).await; + + let ping = { + let mut request = BTreeMap::new(); + request.insert("method".to_owned(), PING_METHOD.into()); + let value = Value::from(request); + let value = serialize(&value, endpoint.supports_revision).unwrap(); + Message::Binary(value) + }; + + let (socket_sink, socket_stream) = socket.split(); + + let mut state = RouterState::new(socket_sink, socket_stream); + + let mut route_stream = route_rx.into_stream(); + + 'router: loop { + let mut interval = time::interval(PING_INTERVAL); + // don't bombard the server with pings if we miss some ticks + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + let mut pinger = IntervalStream::new(interval); + + state.last_activity = Instant::now(); + state.live_queries.clear(); + state.routes.clear(); + + loop { + futures::select! { + route = route_stream.next() => { + let Some(route) = route else { + match ws.close().await { + Ok(..) => trace!("Connection closed successfully"), + Err(error) => { + warn!("Failed to close database connection; {error}") + } + } + break 'router; + }; + + match router_handle_request(route, &mut state,&endpoint).await { + HandleResult::Ok => {}, + HandleResult::Disconnected => { + router_reconnect(&mut state, &mut events, &endpoint, capacity).await; + break + } + } } - } - }; + message = state.stream.next().fuse() => { + let Some(message) = message else { + // socket disconnected, + router_reconnect(&mut state, &mut events, &endpoint, capacity).await; + break + }; - let _ = conn_tx.into_send_async(Ok(())).await; - - let ping = { - let mut request = BTreeMap::new(); - request.insert("method".to_owned(), PING_METHOD.into()); - let value = Value::from(request); - let value = serialize(&value, endpoint.supports_revision).unwrap(); - Message::Binary(value) - }; - - let mut var_stash = IndexMap::new(); - let mut vars = IndexMap::new(); - let mut replay = IndexMap::new(); - - 'router: loop { - let (mut socket_sink, socket_stream) = socket.split(); - - let mut routes = match capacity { - 0 => HashMap::new(), - capacity => HashMap::with_capacity(capacity), - }; - let mut live_queries = HashMap::new(); - - let mut interval = time::interval(PING_INTERVAL); - // don't bombard the server with pings if we miss some ticks - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - - let pinger = IntervalStream::new(interval); - - let streams = ( - socket_stream.map(Either::Response), - route_rx.stream().map(Either::Request), - pinger.map(|_| Either::Ping), - events.map(Either::Event), - ); - - let mut merged = streams.merge(); - let mut last_activity = Instant::now(); - - while let Some(either) = merged.next().await { - match either { - Either::Request(Some(Route { - request, - response, - })) => { - let (id, method, param) = request; - let params = match param.query { - Some((query, bindings)) => { - vec![query.into(), bindings.into()] - } - None => param.other, - }; - match method { - Method::Set => { - if let [Value::Strand(key), value] = ¶ms[..2] { - var_stash.insert(id, (key.0.clone(), value.clone())); - } - } - Method::Unset => { - if let [Value::Strand(key)] = ¶ms[..1] { - vars.swap_remove(&key.0); - } - } - Method::Live => { - if let Some(sender) = param.notification_sender { - if let [Value::Uuid(id)] = ¶ms[..1] { - live_queries.insert(*id, sender); - } - } - if response - .into_send_async(Ok(DbResponse::Other(Value::None))) - .await - .is_err() - { - trace!("Receiver dropped"); - } - // There is nothing to send to the server here - continue; - } - Method::Kill => { - if let [Value::Uuid(id)] = ¶ms[..1] { - live_queries.remove(id); - } - } - _ => {} - } - let method_str = match method { - Method::Health => PING_METHOD, - _ => method.as_str(), - }; - let message = { - let mut request = BTreeMap::new(); - request.insert("id".to_owned(), Value::from(id)); - request.insert("method".to_owned(), method_str.into()); - if !params.is_empty() { - request.insert("params".to_owned(), params.into()); - } - let payload = Value::from(request); - trace!("Request {payload}"); - let payload = serialize(&payload, endpoint.supports_revision).unwrap(); - Message::Binary(payload) - }; - if let Method::Authenticate - | Method::Invalidate - | Method::Signin - | Method::Signup - | Method::Use = method - { - replay.insert(method, message.clone()); - } - match socket_sink.send(message).await { - Ok(..) => { - last_activity = Instant::now(); - match routes.entry(id) { - Entry::Vacant(entry) => { - entry.insert((method, response)); - } - Entry::Occupied(..) => { - let error = Error::DuplicateRequestId(id); - if response - .into_send_async(Err(error.into())) - .await - .is_err() - { - trace!("Receiver dropped"); - } - } - } - } - Err(error) => { - let error = Error::Ws(error.to_string()); - if response.into_send_async(Err(error.into())).await.is_err() { - trace!("Receiver dropped"); - } - break; - } + state.last_activity = Instant::now(); + match router_handle_response(message, &mut state,&endpoint).await { + HandleResult::Ok => {}, + HandleResult::Disconnected => { + router_reconnect(&mut state, &mut events, &endpoint, capacity).await; + break } } - Either::Response(message) => { - last_activity = Instant::now(); - match Response::try_from(&message, endpoint.supports_revision) { - Ok(option) => { - // We are only interested in responses that are not empty - if let Some(response) = option { - trace!("{response:?}"); - match response.id { - // If `id` is set this is a normal response - Some(id) => { - if let Ok(id) = id.coerce_to_i64() { - // We can only route responses with IDs - if let Some((method, sender)) = routes.remove(&id) { - if matches!(method, Method::Set) { - if let Some((key, value)) = - var_stash.swap_remove(&id) - { - vars.insert(key, value); - } - } - // Send the response back to the caller - let mut response = response.result; - if matches!(method, Method::Insert) { - // For insert, we need to flatten single responses in an array - if let Ok(Data::Other(Value::Array( - value, - ))) = &mut response - { - if let [value] = &mut value.0[..] { - response = Ok(Data::Other( - mem::take(value), - )); - } - } - } - let _res = sender - .into_send_async(DbResponse::from(response)) - .await; - } - } - } - // If `id` is not set, this may be a live query notification - None => match response.result { - Ok(Data::Live(notification)) => { - let live_query_id = notification.id; - // Check if this live query is registered - if let Some(sender) = - live_queries.get(&live_query_id) - { - // Send the notification back to the caller or kill live query if the receiver is already dropped - if sender.send(notification).await.is_err() { - live_queries.remove(&live_query_id); - let kill = { - let mut request = BTreeMap::new(); - request.insert( - "method".to_owned(), - Method::Kill.as_str().into(), - ); - request.insert( - "params".to_owned(), - vec![Value::from(live_query_id)] - .into(), - ); - let value = Value::from(request); - let value = serialize( - &value, - endpoint.supports_revision, - ) - .unwrap(); - Message::Binary(value) - }; - if let Err(error) = - socket_sink.send(kill).await - { - trace!("failed to send kill query to the server; {error:?}"); - break; - } - } - } - } - Ok(..) => { /* Ignored responses like pings */ } - Err(error) => error!("{error:?}"), - }, - } - } - } - Err(error) => { - #[derive(Deserialize)] - #[revisioned(revision = 1)] - struct Response { - id: Option, - } - - // Let's try to find out the ID of the response that failed to deserialise - if let Message::Binary(binary) = message { - if let Ok(Response { - id, - }) = deserialize(&mut &binary[..], endpoint.supports_revision) - { - // Return an error if an ID was returned - if let Some(Ok(id)) = id.map(Value::coerce_to_i64) { - if let Some((_method, sender)) = routes.remove(&id) { - let _res = sender.into_send_async(Err(error)).await; - } - } - } else { - // Unfortunately, we don't know which response failed to deserialize - warn!("Failed to deserialise message; {error:?}"); - } - } - } - } - } - Either::Event(event) => match event { + } + event = events.next().fuse() => { + let Some(event) = event else { + continue; + }; + match event { WsEvent::Error => { trace!("connection errored"); break; @@ -413,89 +479,25 @@ pub(crate) fn router( } WsEvent::Closed(..) => { trace!("connection closed"); + router_reconnect(&mut state, &mut events, &endpoint, capacity).await; break; } _ => {} - }, - Either::Ping => { - // only ping if we haven't talked to the server recently - if last_activity.elapsed() >= PING_INTERVAL { - trace!("Pinging the server"); - if let Err(error) = socket_sink.send(ping.clone()).await { - trace!("failed to ping the server; {error:?}"); - break; - } - } - } - // Close connection request received - Either::Request(None) => { - match ws.close().await { - Ok(..) => trace!("Connection closed successfully"), - Err(error) => { - warn!("Failed to close database connection; {error}") - } - } - break 'router; } } - } - - 'reconnect: loop { - trace!("Reconnecting..."); - let connect = match endpoint.supports_revision { - true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await, - false => WsMeta::connect(&endpoint.url, None).await, - }; - match connect { - Ok((mut meta, stream)) => { - socket = stream; - events = { - let result = match capacity { - 0 => meta.observe(ObserveConfig::default()).await, - capacity => meta.observe(Channel::Bounded(capacity).into()).await, - }; - match result { - Ok(events) => events, - Err(error) => { - trace!("{error}"); - time::sleep(Duration::from_secs(1)).await; - continue 'reconnect; - } - } - }; - for (_, message) in &replay { - if let Err(error) = socket.send(message.clone()).await { - trace!("{error}"); - time::sleep(Duration::from_secs(1)).await; - continue 'reconnect; - } + _ = pinger.next().fuse() => { + if state.last_activity.elapsed() >= PING_INTERVAL { + trace!("Pinging the server"); + if let Err(error) = state.sink.send(ping.clone()).await { + trace!("failed to ping the server; {error:?}"); + router_reconnect(&mut state, &mut events, &endpoint, capacity).await; + break; } - for (key, value) in &vars { - let mut request = BTreeMap::new(); - request.insert("method".to_owned(), Method::Set.as_str().into()); - request.insert( - "params".to_owned(), - vec![key.as_str().into(), value.clone()].into(), - ); - let payload = Value::from(request); - trace!("Request {payload}"); - if let Err(error) = socket.send(Message::Binary(payload.into())).await { - trace!("{error}"); - time::sleep(Duration::from_secs(1)).await; - continue 'reconnect; - } - } - trace!("Reconnected successfully"); - break; - } - Err(error) => { - trace!("Failed to reconnect; {error}"); - time::sleep(Duration::from_secs(1)).await; } } } } - }); + } } impl Response { diff --git a/lib/src/api/engine/tasks.rs b/lib/src/api/engine/tasks.rs index f5365506..cac42eb8 100644 --- a/lib/src/api/engine/tasks.rs +++ b/lib/src/api/engine/tasks.rs @@ -1,5 +1,4 @@ -use futures::{FutureExt, StreamExt}; -use futures_concurrency::stream::Merge; +use futures::StreamExt; use reblessive::TreeStack; #[cfg(target_arch = "wasm32")] use std::sync::atomic::{AtomicBool, Ordering}; @@ -93,24 +92,27 @@ fn init(opt: &EngineOptions, dbs: Arc) -> (FutureTask, oneshot::Sende let ret_status = completed_status.clone(); // We create a channel that can be streamed that will indicate termination - let (tx, rx) = oneshot::channel(); + let (tx, mut rx) = oneshot::channel(); let _fut = spawn_future(async move { let _lifecycle = crate::dbs::LoggingLifecycle::new("heartbeat task".to_string()); - let ticker = interval_ticker(tick_interval).await; - let streams = ( - ticker.map(|i| { - trace!("Node agent tick: {:?}", i); - Some(i) - }), - rx.into_stream().map(|_| None), - ); - let mut streams = streams.merge(); + let mut ticker = interval_ticker(tick_interval).await; - while let Some(Some(_)) = streams.next().await { - if let Err(e) = dbs.tick().await { - error!("Error running node agent tick: {}", e); - break; + loop { + tokio::select! { + v = ticker.next() => { + // ticker will never return None; + let i = v.unwrap(); + trace!("Node agent tick: {:?}", i); + if let Err(e) = dbs.tick().await { + error!("Error running node agent tick: {}", e); + break; + } + } + _ = &mut rx => { + // termination requested + break + } } } @@ -136,7 +138,7 @@ fn live_query_change_feed( let ret_status = completed_status.clone(); // We create a channel that can be streamed that will indicate termination - let (tx, rx) = oneshot::channel(); + let (tx, mut rx) = oneshot::channel(); let _fut = spawn_future(async move { let mut stack = TreeStack::new(); @@ -148,25 +150,29 @@ fn live_query_change_feed( completed_status.store(true, Ordering::Relaxed); return; } - let ticker = interval_ticker(tick_interval).await; - let streams = ( - ticker.map(|i| { - trace!("Live query agent tick: {:?}", i); - Some(i) - }), - rx.into_stream().map(|_| None), - ); - let mut streams = streams.merge(); + let mut ticker = interval_ticker(tick_interval).await; let opt = Options::default(); - while let Some(Some(_)) = streams.next().await { - if let Err(e) = - stack.enter(|stk| dbs.process_lq_notifications(stk, &opt)).finish().await - { - error!("Error running node agent tick: {}", e); - break; + loop { + tokio::select! { + v = ticker.next() => { + // ticker will never return None; + let i = v.unwrap(); + trace!("Live query agent tick: {:?}", i); + if let Err(e) = + stack.enter(|stk| dbs.process_lq_notifications(stk, &opt)).finish().await + { + error!("Error running node agent tick: {}", e); + break; + } + } + _ = &mut rx => { + // termination requested, + break + } } } + #[cfg(target_arch = "wasm32")] completed_status.store(true, Ordering::Relaxed); }); diff --git a/lib/src/api/method/tests/protocol.rs b/lib/src/api/method/tests/protocol.rs index dabdde90..663d1b27 100644 --- a/lib/src/api/method/tests/protocol.rs +++ b/lib/src/api/method/tests/protocol.rs @@ -96,13 +96,7 @@ impl Connection for Client { request: (0, self.method, param), response: sender, }; - router - .sender - .send_async(Some(route)) - .await - .as_ref() - .map_err(ToString::to_string) - .unwrap(); + router.sender.send_async(route).await.as_ref().map_err(ToString::to_string).unwrap(); Ok(receiver) }) } diff --git a/lib/src/api/method/tests/server.rs b/lib/src/api/method/tests/server.rs index e3d152a0..4f32783a 100644 --- a/lib/src/api/method/tests/server.rs +++ b/lib/src/api/method/tests/server.rs @@ -8,14 +8,14 @@ use crate::sql::Value; use flume::Receiver; use futures::StreamExt; -pub(super) fn mock(route_rx: Receiver>) { +pub(super) fn mock(route_rx: Receiver) { tokio::spawn(async move { let mut stream = route_rx.into_stream(); - while let Some(Some(Route { + while let Some(Route { request, response, - })) = stream.next().await + }) = stream.next().await { let (_, method, param) = request; let mut params = param.other; diff --git a/lib/tests/function.rs b/lib/tests/function.rs index 62cf0468..2718dab0 100644 --- a/lib/tests/function.rs +++ b/lib/tests/function.rs @@ -4648,7 +4648,7 @@ async fn function_type_is_bytes() -> Result<(), Error> { async fn function_type_is_collection() -> Result<(), Error> { let sql = r#" LET $collection = > { - type: 'GeometryCollection', + type: 'GeometryCollection', geometries: [{ type: 'MultiPoint', coordinates: [[10, 11.2], [10.5, 11.9]] }] }; RETURN type::is::collection($collection); @@ -4902,7 +4902,7 @@ async fn function_type_is_multipoint() -> Result<(), Error> { async fn function_type_is_multipolygon() -> Result<(), Error> { let sql = r#" LET $multipolygon = > { - type: 'MultiPolygon', + type: 'MultiPolygon', coordinates: [[[[10, 11.2], [10.5, 11.9], [10.8, 12], [10, 11.2]]], [[[9, 11.2], [10.5, 11.9], [10.3, 13], [9, 11.2]]]] }; RETURN type::is::multipolygon($multipolygon); @@ -5001,7 +5001,7 @@ async fn function_type_is_point() -> Result<(), Error> { async fn function_type_is_polygon() -> Result<(), Error> { let sql = r#" LET $polygon = > { - type: 'Polygon', + type: 'Polygon', coordinates: [ [ [-0.38314819, 51.37692386], @@ -5953,6 +5953,7 @@ pub async fn function_http_disabled() { "#, ) .await + .unwrap() .expect_errors(&[ "Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled", @@ -5960,7 +5961,8 @@ pub async fn function_http_disabled() { "Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled", "Remote HTTP request functions are not enabled", - ]); + ]) + .unwrap(); } // Tests for custom defined functions diff --git a/pkg/nix/drv/binary.nix b/pkg/nix/drv/binary.nix index f55a3fab..9209fb88 100644 --- a/pkg/nix/drv/binary.nix +++ b/pkg/nix/drv/binary.nix @@ -25,5 +25,4 @@ in craneLib.buildPackage (buildSpec // { inherit cargoArtifacts; inherit (util) version SURREAL_BUILD_METADATA; - RUSTFLAGS = "--cfg surrealdb_unstable"; }) diff --git a/src/cli/start.rs b/src/cli/start.rs index bd0c2fd4..76d26b01 100644 --- a/src/cli/start.rs +++ b/src/cli/start.rs @@ -196,7 +196,7 @@ pub async fn init( net::init(ct.clone()).await?; // Shutdown and stop closed tasks task_chans.into_iter().for_each(|chan| { - if let Err(_empty_tuple) = chan.send(()) { + if chan.send(()).is_err() { error!("Failed to send shutdown signal to task"); } });