Minor lib improvements (#4261)

This commit is contained in:
Mees Delzenne 2024-07-05 11:19:04 +02:00 committed by GitHub
parent a701230e9d
commit 07a88383fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1323 additions and 1146 deletions

14
Cargo.lock generated
View file

@ -2177,19 +2177,6 @@ dependencies = [
"futures-sink",
]
[[package]]
name = "futures-concurrency"
version = "7.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b590a729e1cbaf9ae3ec294143ea034d93cbb1de01c884d04bcd0af8b613d02"
dependencies = [
"bitvec",
"futures-core",
"pin-project",
"slab",
"smallvec",
]
[[package]]
name = "futures-core"
version = "0.3.30"
@ -5952,7 +5939,6 @@ dependencies = [
"flate2",
"flume",
"futures",
"futures-concurrency",
"geo 0.27.0",
"hashbrown 0.14.5",
"indexmap 2.2.6",

View file

@ -10,13 +10,11 @@ args = ["check", "--locked", "--workspace"]
[tasks.ci-check-wasm]
category = "CI - CHECK"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["check", "--locked", "--package", "surrealdb", "--features", "protocol-ws,protocol-http,kv-mem,kv-indxdb,http,jwks", "--target", "wasm32-unknown-unknown"]
[tasks.ci-clippy]
category = "CI - CHECK"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["clippy", "--all-targets", "--features", "storage-mem,storage-rocksdb,storage-tikv,storage-fdb,scripting,http,jwks", "--tests", "--benches", "--examples", "--bins", "--", "-D", "warnings"]
#
@ -26,31 +24,43 @@ args = ["clippy", "--all-targets", "--features", "storage-mem,storage-rocksdb,st
[tasks.ci-cli-integration]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "cli_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "cli_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,storage-surrealkv,http,scripting,jwks", "--workspace", "--test", "cli_integration", "--", "cli_integration"]
[tasks.ci-http-integration]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "http_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "http_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem,http-compression,jwks", "--workspace", "--test", "http_integration", "--", "http_integration"]
[tasks.ci-ws-integration]
category = "WS - INTEGRATION TESTS"
command = "cargo"
env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "ws_integration=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--no-default-features", "--features", "storage-mem", "--workspace", "--test", "ws_integration", "--", "ws_integration"]
[tasks.ci-ml-integration]
category = "ML - INTEGRATION TESTS"
command = "cargo"
env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable", RUST_LOG = { value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
env = { RUST_BACKTRACE = 1, RUST_LOG = { value = "cli_integration::common=debug", condition = { env_not_set = ["RUST_LOG"] } } }
args = ["test", "--locked", "--features", "storage-mem,ml", "--workspace", "--test", "ml_integration", "--", "ml_integration", "--nocapture"]
[tasks.ci-test-workspace]
category = "CI - INTEGRATION TESTS"
command = "cargo"
args = [
"test", "--locked", "--no-default-features", "--features", "storage-mem,scripting,http,jwks", "--workspace", "--",
"--skip", "api_integration",
"--skip", "cli_integration",
"--skip", "http_integration",
"--skip", "ws_integration",
"--skip", "database_upgrade"
]
[tasks.ci-workspace-coverage]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = [
"llvm-cov", "--html", "--locked", "--no-default-features", "--features", "storage-mem,scripting,http,jwks", "--workspace", "--",
"--skip", "api_integration",
@ -63,7 +73,6 @@ args = [
[tasks.test-workspace-coverage-complete]
category = "CI - INTEGRATION TESTS"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["llvm-cov", "--html", "--locked", "--no-default-features", "--features", "protocol-ws,protocol-http,kv-mem,kv-rocksdb", "--workspace"]
[tasks.ci-workspace-coverage-complete]
@ -83,25 +92,25 @@ run_task = { name = ["test-database-upgrade"], fork = true }
[tasks.test-kvs]
private = true
command = "cargo"
env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable" }
env = { RUST_BACKTRACE = 1 }
args = ["test", "--locked", "--package", "surrealdb", "--no-default-features", "--features", "${_TEST_FEATURES}", "--lib", "kvs"]
[tasks.test-api-integration]
private = true
command = "cargo"
env = { RUST_BACKTRACE = 1, RUSTFLAGS = "--cfg surrealdb_unstable" }
env = { RUST_BACKTRACE = 1 }
args = ["test", "--locked", "--package", "surrealdb", "--no-default-features", "--features", "${_TEST_FEATURES}", "--test", "api", "api_integration::${_TEST_API_ENGINE}"]
[tasks.ci-api-integration]
env = { RUST_BACKTRACE = 1, _START_SURREALDB_PATH = "memory", RUSTFLAGS = "--cfg surrealdb_unstable" }
env = { RUST_BACKTRACE = 1, _START_SURREALDB_PATH = "memory" }
private = true
run_task = { name = ["start-surrealdb", "test-api-integration", "stop-surrealdb"], fork = true }
[tasks.test-database-upgrade]
private = true
command = "cargo"
env = { RUST_BACKTRACE = 1, RUST_LOG = "info", RUSTFLAGS = "--cfg surrealdb_unstable" }
env = { RUST_BACKTRACE = 1, RUST_LOG = "info" }
args = ["test", "--locked", "--no-default-features", "--features", "${_TEST_FEATURES}", "--workspace", "--test", "database_upgrade", "--", "database_upgrade", "--show-output"]
@ -111,17 +120,17 @@ args = ["test", "--locked", "--no-default-features", "--features", "${_TEST_FEAT
[tasks.ci-api-integration-http]
category = "CI - INTEGRATION TESTS"
env = { _TEST_API_ENGINE = "http", _TEST_FEATURES = "protocol-http", RUSTFLAGS = "--cfg surrealdb_unstable" }
env = { _TEST_API_ENGINE = "http", _TEST_FEATURES = "protocol-http" }
run_task = "ci-api-integration"
[tasks.ci-api-integration-ws]
category = "CI - INTEGRATION TESTS"
env = { _TEST_API_ENGINE = "ws", _TEST_FEATURES = "protocol-ws", RUSTFLAGS = "--cfg surrealdb_unstable" }
env = { _TEST_API_ENGINE = "ws", _TEST_FEATURES = "protocol-ws" }
run_task = "ci-api-integration"
[tasks.ci-api-integration-any]
category = "CI - INTEGRATION TESTS"
env = { _TEST_API_ENGINE = "any", _TEST_FEATURES = "protocol-http", RUSTFLAGS = "--cfg surrealdb_unstable" }
env = { _TEST_API_ENGINE = "any", _TEST_FEATURES = "protocol-http" }
run_task = "ci-api-integration"
#
@ -256,7 +265,6 @@ ${HOME}/.tiup/bin/tiup clean --all
[tasks.build-surrealdb]
category = "CI - BUILD"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["build", "--locked", "--no-default-features", "--features", "storage-mem"]
#
@ -265,7 +273,6 @@ args = ["build", "--locked", "--no-default-features", "--features", "storage-mem
[tasks.ci-bench]
category = "CI - BENCHMARK"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["bench", "--quiet", "--package", "surrealdb", "--no-default-features", "--features", "kv-mem,scripting,http,jwks", "${@}"]
#
@ -277,13 +284,11 @@ BENCH_NUM_OPS = { value = "1000", condition = { env_not_set = ["BENCH_NUM_OPS"]
BENCH_DURATION = { value = "30", condition = { env_not_set = ["BENCH_DURATION"] } }
BENCH_SAMPLE_SIZE = { value = "10", condition = { env_not_set = ["BENCH_SAMPLE_SIZE"] } }
BENCH_FEATURES = { value = "protocol-ws,kv-mem,kv-rocksdb,kv-fdb-7_1,kv-surrealkv", condition = { env_not_set = ["BENCH_FEATURES"] } }
RUSTFLAGS = "--cfg surrealdb_unstable"
[tasks.bench-target]
private = true
category = "CI - BENCHMARK - SurrealDB Target"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["bench", "--package", "surrealdb", "--bench", "sdb", "--no-default-features", "--features", "${BENCH_FEATURES}", "${@}"]
[tasks.bench-lib-mem]

View file

@ -17,21 +17,19 @@ dependencies = ["cargo-upgrade", "cargo-update"]
[tasks.docs]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTDOCFLAGS = "--cfg surrealdb_unstable" }
args = ["doc", "--open", "--no-deps", "--package", "surrealdb", "--features", "rustls,native-tls,protocol-ws,protocol-http,kv-mem,kv-rocksdb,kv-tikv,http,scripting,jwks"]
# Test
[tasks.test]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable", RUSTDOCFLAGS = "--cfg surrealdb_unstable", RUST_MIN_STACK={ value = "4194304", condition = { env_not_set = ["RUST_MIN_STACK"] } } }
env = { RUST_MIN_STACK={ value = "4194304", condition = { env_not_set = ["RUST_MIN_STACK"] } } }
args = ["test", "--workspace", "--no-fail-fast"]
# Check
[tasks.cargo-check]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["check", "--workspace", "--features", "${DEV_FEATURES}"]
[tasks.cargo-fmt]
@ -50,7 +48,6 @@ script = """
[tasks.cargo-clippy]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["clippy", "--all-targets", "--all-features", "--", "-D", "warnings"]
[tasks.check]
@ -71,28 +68,24 @@ args = ["clean"]
[tasks.bench]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["bench", "--package", "surrealdb", "--no-default-features", "--features", "kv-mem,http,scripting,jwks", "--", "${@}"]
# Run
[tasks.run]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "${@}"]
# Serve
[tasks.serve]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "start", "--allow-all", "${@}"]
# SQL
[tasks.sql]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["run", "--no-default-features", "--features", "${DEV_FEATURES}", "--", "sql", "--pretty", "${@}"]
# Quick
@ -105,7 +98,6 @@ args = ["build", "${@}"]
[tasks.build]
category = "LOCAL USAGE"
command = "cargo"
env = { RUSTFLAGS = "--cfg surrealdb_unstable" }
args = ["build", "--release", "${@}"]
# Default

View file

@ -16,7 +16,9 @@ use crate::sql::range::Range;
use crate::sql::table::Table;
use crate::sql::thing::Thing;
use crate::sql::value::Value;
use reblessive::{tree::Stk, TreeStack};
use reblessive::tree::Stk;
#[cfg(not(target_arch = "wasm32"))]
use reblessive::TreeStack;
use std::mem;
#[derive(Clone)]

View file

@ -9,6 +9,7 @@ pub static ID: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("id")]);
pub static METHOD: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("method")]);
pub static PARAMS: Lazy<[Part; 1]> = Lazy::new(|| [Part::from("params")]);
#[derive(Debug)]
pub struct Request {
pub id: Option<Value>,
pub method: String,

View file

@ -79,7 +79,6 @@ chrono = { version = "0.4.31", features = ["serde"] }
dmp = "0.2.0"
flume = "0.11.0"
futures = "0.3.29"
futures-concurrency = "7.4.3"
geo = { version = "0.27.0", features = ["use-serde"] }
indexmap = { version = "2.1.0", features = ["serde"] }
native-tls = { version = "0.2.11", optional = true }

View file

@ -31,7 +31,7 @@ pub(crate) struct Route {
/// Message router
#[derive(Debug)]
pub struct Router {
pub(crate) sender: Sender<Option<Route>>,
pub(crate) sender: Sender<Route>,
pub(crate) last_id: AtomicI64,
pub(crate) features: HashSet<ExtraFeatures>,
}
@ -42,12 +42,6 @@ impl Router {
}
}
impl Drop for Router {
fn drop(&mut self) {
let _res = self.sender.send(None);
}
}
/// The query method
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]

View file

@ -68,7 +68,7 @@ impl Connection for Any {
{
features.insert(ExtraFeatures::Backup);
features.insert(ExtraFeatures::LiveQueries);
engine::local::native::router(address, conn_tx, route_rx);
tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??
}
@ -83,7 +83,7 @@ impl Connection for Any {
{
features.insert(ExtraFeatures::Backup);
features.insert(ExtraFeatures::LiveQueries);
engine::local::native::router(address, conn_tx, route_rx);
tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??
}
@ -98,7 +98,7 @@ impl Connection for Any {
{
features.insert(ExtraFeatures::Backup);
features.insert(ExtraFeatures::LiveQueries);
engine::local::native::router(address, conn_tx, route_rx);
tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??
}
@ -114,7 +114,7 @@ impl Connection for Any {
{
features.insert(ExtraFeatures::Backup);
features.insert(ExtraFeatures::LiveQueries);
engine::local::native::router(address, conn_tx, route_rx);
tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??
}
@ -129,7 +129,7 @@ impl Connection for Any {
{
features.insert(ExtraFeatures::Backup);
features.insert(ExtraFeatures::LiveQueries);
engine::local::native::router(address, conn_tx, route_rx);
tokio::spawn(engine::local::native::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??
}
@ -162,7 +162,9 @@ impl Connection for Any {
client.get(base_url.join(Method::Health.as_str())?),
)
.await?;
engine::remote::http::native::router(base_url, client, route_rx);
tokio::spawn(engine::remote::http::native::run_router(
base_url, client, route_rx,
));
}
#[cfg(not(feature = "protocol-http"))]
@ -195,14 +197,14 @@ impl Connection for Any {
maybe_connector.clone(),
)
.await?;
engine::remote::ws::native::router(
tokio::spawn(engine::remote::ws::native::run_router(
endpoint,
maybe_connector,
capacity,
config,
socket,
route_rx,
);
));
}
#[cfg(not(feature = "protocol-ws"))]
@ -237,7 +239,7 @@ impl Connection for Any {
request: (self.id, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}

View file

@ -18,12 +18,12 @@ use crate::opt::WaitFor;
use flume::Receiver;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::watch;
use wasm_bindgen_futures::spawn_local;
impl crate::api::Connection for Any {}
@ -54,7 +54,7 @@ impl Connection for Any {
#[cfg(feature = "kv-fdb")]
{
features.insert(ExtraFeatures::LiveQueries);
engine::local::wasm::router(address, conn_tx, route_rx);
spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
}
@ -68,7 +68,7 @@ impl Connection for Any {
#[cfg(feature = "kv-indxdb")]
{
features.insert(ExtraFeatures::LiveQueries);
engine::local::wasm::router(address, conn_tx, route_rx);
spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
}
@ -82,7 +82,7 @@ impl Connection for Any {
#[cfg(feature = "kv-mem")]
{
features.insert(ExtraFeatures::LiveQueries);
engine::local::wasm::router(address, conn_tx, route_rx);
spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
}
@ -96,7 +96,7 @@ impl Connection for Any {
#[cfg(feature = "kv-rocksdb")]
{
features.insert(ExtraFeatures::LiveQueries);
engine::local::wasm::router(address, conn_tx, route_rx);
spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
}
@ -111,7 +111,7 @@ impl Connection for Any {
#[cfg(feature = "kv-surrealkv")]
{
features.insert(ExtraFeatures::LiveQueries);
engine::local::wasm::router(address, conn_tx, route_rx);
spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
}
@ -126,7 +126,7 @@ impl Connection for Any {
#[cfg(feature = "kv-tikv")]
{
features.insert(ExtraFeatures::LiveQueries);
engine::local::wasm::router(address, conn_tx, route_rx);
spawn_local(engine::local::wasm::run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
}
@ -139,7 +139,9 @@ impl Connection for Any {
EndpointKind::Http | EndpointKind::Https => {
#[cfg(feature = "protocol-http")]
{
engine::remote::http::wasm::router(address, conn_tx, route_rx);
spawn_local(engine::remote::http::wasm::run_router(
address, conn_tx, route_rx,
));
}
#[cfg(not(feature = "protocol-http"))]
@ -155,7 +157,9 @@ impl Connection for Any {
features.insert(ExtraFeatures::LiveQueries);
let mut endpoint = address;
endpoint.url = endpoint.url.join(engine::remote::ws::PATH)?;
engine::remote::ws::wasm::router(endpoint, capacity, conn_tx, route_rx);
spawn_local(engine::remote::ws::wasm::run_router(
endpoint, capacity, conn_tx, route_rx,
));
conn_rx.into_recv_async().await??;
}
@ -192,7 +196,7 @@ impl Connection for Any {
request: (self.id, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}

View file

@ -19,10 +19,8 @@ use crate::opt::WaitFor;
use crate::options::EngineOptions;
use flume::Receiver;
use flume::Sender;
use futures::future::Either;
use futures::stream::poll_fn;
use futures::StreamExt;
use futures_concurrency::stream::Merge as _;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
@ -55,7 +53,7 @@ impl Connection for Db {
let (conn_tx, conn_rx) = flume::bounded(1);
router(address, conn_tx, route_rx);
tokio::spawn(run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
@ -85,144 +83,142 @@ impl Connection for Db {
request: (0, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}
}
pub(crate) fn router(
pub(crate) async fn run_router(
address: Endpoint,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Option<Route>>,
route_rx: Receiver<Route>,
) {
tokio::spawn(async move {
let configured_root = match address.config.auth {
Level::Root => Some(Root {
username: &address.config.username,
password: &address.config.password,
}),
_ => None,
};
let configured_root = match address.config.auth {
Level::Root => Some(Root {
username: &address.config.username,
password: &address.config.password,
}),
_ => None,
};
let endpoint = match EndpointKind::from(address.url.scheme()) {
EndpointKind::TiKv => address.url.as_str(),
_ => &address.path,
};
let endpoint = match EndpointKind::from(address.url.scheme()) {
EndpointKind::TiKv => address.url.as_str(),
_ => &address.path,
};
let kvs = match Datastore::new(endpoint).await {
Ok(kvs) => {
if let Err(error) = kvs.bootstrap().await {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
// If a root user is specified, setup the initial datastore credentials
if let Some(root) = configured_root {
if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await
{
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
}
let _ = conn_tx.into_send_async(Ok(())).await;
kvs.with_auth_enabled(configured_root.is_some())
}
Err(error) => {
let kvs = match Datastore::new(endpoint).await {
Ok(kvs) => {
if let Err(error) = kvs.bootstrap().await {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
};
let kvs = match address.config.capabilities.allows_live_query_notifications() {
true => kvs.with_notifications(),
false => kvs,
};
let kvs = kvs
.with_strict_mode(address.config.strict)
.with_query_timeout(address.config.query_timeout)
.with_transaction_timeout(address.config.transaction_timeout)
.with_capabilities(address.config.capabilities);
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
))]
let kvs = match address.config.temporary_directory {
Some(tmp_dir) => kvs.with_temporary_directory(tmp_dir),
_ => kvs,
};
let kvs = Arc::new(kvs);
let mut vars = BTreeMap::new();
let mut live_queries = HashMap::new();
let mut session = Session::default().with_rt(true);
let opt = {
let mut engine_options = EngineOptions::default();
engine_options.tick_interval = address
.config
.tick_interval
.unwrap_or(crate::api::engine::local::DEFAULT_TICK_INTERVAL);
engine_options
};
let (tasks, task_chans) = start_tasks(&opt, kvs.clone());
let mut notifications = kvs.notifications();
let notification_stream = poll_fn(move |cx| match &mut notifications {
Some(rx) => rx.poll_next_unpin(cx),
None => Poll::Ready(None),
});
let streams = (route_rx.stream().map(Either::Left), notification_stream.map(Either::Right));
let mut merged = streams.merge();
while let Some(either) = merged.next().await {
match either {
Either::Left(None) => break, // Received a shutdown signal
Either::Left(Some(route)) => {
match super::router(
route.request,
&kvs,
&mut session,
&mut vars,
&mut live_queries,
)
.await
{
Ok(value) => {
let _ = route.response.into_send_async(Ok(value)).await;
}
Err(error) => {
let _ = route.response.into_send_async(Err(error)).await;
}
}
}
Either::Right(notification) => {
let id = notification.id;
if let Some(sender) = live_queries.get(&id) {
if sender.send(notification).await.is_err() {
live_queries.remove(&id);
if let Err(error) =
super::kill_live_query(&kvs, id, &session, vars.clone()).await
{
warn!("Failed to kill live query '{id}'; {error}");
}
}
}
// If a root user is specified, setup the initial datastore credentials
if let Some(root) = configured_root {
if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
}
let _ = conn_tx.into_send_async(Ok(())).await;
kvs.with_auth_enabled(configured_root.is_some())
}
Err(error) => {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
};
// Stop maintenance tasks
for chan in task_chans {
if let Err(_empty_tuple) = chan.send(()) {
error!("Error sending shutdown signal to task");
}
}
tasks.resolve().await.unwrap();
let kvs = match address.config.capabilities.allows_live_query_notifications() {
true => kvs.with_notifications(),
false => kvs,
};
let kvs = kvs
.with_strict_mode(address.config.strict)
.with_query_timeout(address.config.query_timeout)
.with_transaction_timeout(address.config.transaction_timeout)
.with_capabilities(address.config.capabilities);
#[cfg(any(
feature = "kv-mem",
feature = "kv-surrealkv",
feature = "kv-rocksdb",
feature = "kv-fdb",
feature = "kv-tikv",
))]
let kvs = match address.config.temporary_directory {
Some(tmp_dir) => kvs.with_temporary_directory(tmp_dir),
_ => kvs,
};
let kvs = Arc::new(kvs);
let mut vars = BTreeMap::new();
let mut live_queries = HashMap::new();
let mut session = Session::default().with_rt(true);
let opt = {
let mut engine_options = EngineOptions::default();
engine_options.tick_interval = address
.config
.tick_interval
.unwrap_or(crate::api::engine::local::DEFAULT_TICK_INTERVAL);
engine_options
};
let (tasks, task_chans) = start_tasks(&opt, kvs.clone());
let mut notifications = kvs.notifications();
let mut notification_stream = poll_fn(move |cx| match &mut notifications {
Some(rx) => rx.poll_next_unpin(cx),
// return poll pending so that this future is never woken up again and therefore not
// constantly polled.
None => Poll::Pending,
});
let mut route_stream = route_rx.into_stream();
loop {
tokio::select! {
route = route_stream.next() => {
let Some(route) = route else {
break
};
match super::router(route.request, &kvs, &mut session, &mut vars, &mut live_queries)
.await
{
Ok(value) => {
let _ = route.response.into_send_async(Ok(value)).await;
}
Err(error) => {
let _ = route.response.into_send_async(Err(error)).await;
}
}
}
notification = notification_stream.next() => {
let Some(notification) = notification else {
// TODO: Maybe we should do something more then ignore a closed notifications
// channel?
continue
};
let id = notification.id;
if let Some(sender) = live_queries.get(&id) {
if sender.send(notification).await.is_err() {
live_queries.remove(&id);
if let Err(error) =
super::kill_live_query(&kvs, id, &session, vars.clone()).await
{
warn!("Failed to kill live query '{id}'; {error}");
}
}
}
}
}
}
// Stop maintenance tasks
for chan in task_chans {
if chan.send(()).is_err() {
error!("Error sending shutdown signal to task");
}
}
tasks.resolve().await.unwrap();
}

View file

@ -20,15 +20,13 @@ use crate::opt::WaitFor;
use crate::options::EngineOptions;
use flume::Receiver;
use flume::Sender;
use futures::future::Either;
use futures::stream::poll_fn;
use futures::FutureExt;
use futures::StreamExt;
use futures_concurrency::stream::Merge as _;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
@ -58,7 +56,7 @@ impl Connection for Db {
let (conn_tx, conn_rx) = flume::bounded(1);
router(address, conn_tx, route_rx);
spawn_local(run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
@ -87,120 +85,127 @@ impl Connection for Db {
request: (0, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}
}
pub(crate) fn router(
pub(crate) async fn run_router(
address: Endpoint,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Option<Route>>,
route_rx: Receiver<Route>,
) {
spawn_local(async move {
let configured_root = match address.config.auth {
Level::Root => Some(Root {
username: &address.config.username,
password: &address.config.password,
}),
_ => None,
};
let configured_root = match address.config.auth {
Level::Root => Some(Root {
username: &address.config.username,
password: &address.config.password,
}),
_ => None,
};
let kvs = match Datastore::new(&address.path).await {
Ok(kvs) => {
if let Err(error) = kvs.bootstrap().await {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
// If a root user is specified, setup the initial datastore credentials
if let Some(root) = configured_root {
if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await
{
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
}
let _ = conn_tx.into_send_async(Ok(())).await;
kvs.with_auth_enabled(configured_root.is_some())
}
Err(error) => {
let kvs = match Datastore::new(&address.path).await {
Ok(kvs) => {
if let Err(error) = kvs.bootstrap().await {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
};
let kvs = match address.config.capabilities.allows_live_query_notifications() {
true => kvs.with_notifications(),
false => kvs,
};
let kvs = kvs
.with_strict_mode(address.config.strict)
.with_query_timeout(address.config.query_timeout)
.with_transaction_timeout(address.config.transaction_timeout)
.with_capabilities(address.config.capabilities);
let kvs = Arc::new(kvs);
let mut vars = BTreeMap::new();
let mut live_queries = HashMap::new();
let mut session = Session::default().with_rt(true);
let mut opt = EngineOptions::default();
opt.tick_interval = address.config.tick_interval.unwrap_or(DEFAULT_TICK_INTERVAL);
let (_tasks, task_chans) = start_tasks(&opt, kvs.clone());
let mut notifications = kvs.notifications();
let notification_stream = poll_fn(move |cx| match &mut notifications {
Some(rx) => rx.poll_next_unpin(cx),
None => Poll::Ready(None),
});
let streams = (route_rx.stream().map(Either::Left), notification_stream.map(Either::Right));
let mut merged = streams.merge();
while let Some(either) = merged.next().await {
match either {
Either::Left(None) => break, // Received a shutdown signal
Either::Left(Some(route)) => {
match super::router(
route.request,
&kvs,
&mut session,
&mut vars,
&mut live_queries,
)
.await
{
Ok(value) => {
let _ = route.response.into_send_async(Ok(value)).await;
}
Err(error) => {
let _ = route.response.into_send_async(Err(error)).await;
}
}
}
Either::Right(notification) => {
let id = notification.id;
if let Some(sender) = live_queries.get(&id) {
if sender.send(notification).await.is_err() {
live_queries.remove(&id);
if let Err(error) =
super::kill_live_query(&kvs, id, &session, vars.clone()).await
{
warn!("Failed to kill live query '{id}'; {error}");
}
}
}
// If a root user is specified, setup the initial datastore credentials
if let Some(root) = configured_root {
if let Err(error) = kvs.setup_initial_creds(root.username, root.password).await {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
}
let _ = conn_tx.into_send_async(Ok(())).await;
kvs.with_auth_enabled(configured_root.is_some())
}
Err(error) => {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
};
// Stop maintenance tasks
for chan in task_chans {
if let Err(_empty_tuple) = chan.send(()) {
error!("Error sending shutdown signal to maintenance task");
}
}
let kvs = match address.config.capabilities.allows_live_query_notifications() {
true => kvs.with_notifications(),
false => kvs,
};
let kvs = kvs
.with_strict_mode(address.config.strict)
.with_query_timeout(address.config.query_timeout)
.with_transaction_timeout(address.config.transaction_timeout)
.with_capabilities(address.config.capabilities);
let kvs = Arc::new(kvs);
let mut vars = BTreeMap::new();
let mut live_queries = HashMap::new();
let mut session = Session::default().with_rt(true);
let mut opt = EngineOptions::default();
opt.tick_interval = address.config.tick_interval.unwrap_or(DEFAULT_TICK_INTERVAL);
let (_tasks, task_chans) = start_tasks(&opt, kvs.clone());
let mut notifications = kvs.notifications();
let mut notification_stream = poll_fn(move |cx| match &mut notifications {
Some(rx) => rx.poll_next_unpin(cx),
None => Poll::Pending,
});
let mut route_stream = route_rx.into_stream();
loop {
// use the less ergonomic futures::select as tokio::select is not available.
futures::select! {
route = route_stream.next().fuse() => {
let Some(route) = route else {
// termination requested
break
};
match super::router(
route.request,
&kvs,
&mut session,
&mut vars,
&mut live_queries,
)
.await
{
Ok(value) => {
let _ = route.response.into_send_async(Ok(value)).await;
}
Err(error) => {
let _ = route.response.into_send_async(Err(error)).await;
}
}
}
notification = notification_stream.next().fuse() => {
let Some(notification) = notification else {
// TODO: maybe do something else then ignore a disconnected notification
// channel.
continue;
};
let id = notification.id;
if let Some(sender) = live_queries.get(&id) {
if sender.send(notification).await.is_err() {
live_queries.remove(&id);
if let Err(error) =
super::kill_live_query(&kvs, id, &session, vars.clone()).await
{
warn!("Failed to kill live query '{id}'; {error}");
}
}
}
}
}
}
// Stop maintenance tasks
for chan in task_chans {
if chan.send(()).is_err() {
error!("Error sending shutdown signal to maintenance task");
}
}
}

View file

@ -67,7 +67,7 @@ impl Connection for Client {
capacity => flume::bounded(capacity),
};
router(base_url, client, route_rx);
tokio::spawn(run_router(base_url, client, route_rx));
let mut features = HashSet::new();
features.insert(ExtraFeatures::Backup);
@ -94,30 +94,22 @@ impl Connection for Client {
request: (0, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}
}
pub(crate) fn router(base_url: Url, client: reqwest::Client, route_rx: Receiver<Option<Route>>) {
tokio::spawn(async move {
let mut headers = HeaderMap::new();
let mut vars = IndexMap::new();
let mut auth = None;
let mut stream = route_rx.into_stream();
pub(crate) async fn run_router(base_url: Url, client: reqwest::Client, route_rx: Receiver<Route>) {
let mut headers = HeaderMap::new();
let mut vars = IndexMap::new();
let mut auth = None;
let mut stream = route_rx.into_stream();
while let Some(Some(route)) = stream.next().await {
let result = super::router(
route.request,
&base_url,
&client,
&mut headers,
&mut vars,
&mut auth,
)
.await;
let _ = route.response.into_send_async(result).await;
}
});
while let Some(route) = stream.next().await {
let result =
super::router(route.request, &base_url, &client, &mut headers, &mut vars, &mut auth)
.await;
let _ = route.response.into_send_async(result).await;
}
}

View file

@ -18,7 +18,6 @@ use reqwest::header::HeaderMap;
use reqwest::ClientBuilder;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
@ -48,7 +47,7 @@ impl Connection for Client {
let (conn_tx, conn_rx) = flume::bounded(1);
router(address, conn_tx, route_rx);
spawn_local(run_router(address, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
@ -75,7 +74,7 @@ impl Connection for Client {
request: (0, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}
@ -90,48 +89,39 @@ async fn client(base_url: &Url) -> Result<reqwest::Client> {
Ok(client)
}
pub(crate) fn router(
pub(crate) async fn run_router(
address: Endpoint,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Option<Route>>,
route_rx: Receiver<Route>,
) {
spawn_local(async move {
let base_url = address.url;
let base_url = address.url;
let client = match client(&base_url).await {
Ok(client) => {
let _ = conn_tx.into_send_async(Ok(())).await;
client
let client = match client(&base_url).await {
Ok(client) => {
let _ = conn_tx.into_send_async(Ok(())).await;
client
}
Err(error) => {
let _ = conn_tx.into_send_async(Err(error)).await;
return;
}
};
let mut headers = HeaderMap::new();
let mut vars = IndexMap::new();
let mut auth = None;
let mut stream = route_rx.into_stream();
while let Some(route) = stream.next().await {
match super::router(route.request, &base_url, &client, &mut headers, &mut vars, &mut auth)
.await
{
Ok(value) => {
let _ = route.response.into_send_async(Ok(value)).await;
}
Err(error) => {
let _ = conn_tx.into_send_async(Err(error)).await;
return;
}
};
let mut headers = HeaderMap::new();
let mut vars = IndexMap::new();
let mut auth = None;
let mut stream = route_rx.into_stream();
while let Some(Some(route)) = stream.next().await {
match super::router(
route.request,
&base_url,
&client,
&mut headers,
&mut vars,
&mut auth,
)
.await
{
Ok(value) => {
let _ = route.response.into_send_async(Ok(value)).await;
}
Err(error) => {
let _ = route.response.into_send_async(Err(error)).await;
}
let _ = route.response.into_send_async(Err(error)).await;
}
}
});
}
}

View file

@ -20,20 +20,160 @@ use crate::dbs::Status;
use crate::method::Stats;
use crate::opt::IntoEndpoint;
use crate::sql::Value;
use bincode::Options as _;
use flume::Sender;
use indexmap::IndexMap;
use revision::revisioned;
use revision::Revisioned;
use serde::de::DeserializeOwned;
use serde::ser::SerializeMap;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::io::Read;
use std::marker::PhantomData;
use std::time::Duration;
use surrealdb_core::dbs::Notification as CoreNotification;
use trice::Instant;
use uuid::Uuid;
pub(crate) const PATH: &str = "rpc";
const PING_INTERVAL: Duration = Duration::from_secs(5);
const PING_METHOD: &str = "ping";
const REVISION_HEADER: &str = "revision";
/// A struct which will be serialized as a map to behave like the previously used BTreeMap.
///
/// This struct serializes as if it is a surrealdb_core::sql::Value::Object.
#[derive(Debug)]
struct RouterRequest {
id: Option<Value>,
method: Value,
params: Option<Value>,
}
impl Serialize for RouterRequest {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
struct InnerRequest<'a>(&'a RouterRequest);
impl Serialize for InnerRequest<'_> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let size = 1 + self.0.id.is_some() as usize + self.0.params.is_some() as usize;
let mut map = serializer.serialize_map(Some(size))?;
if let Some(id) = self.0.id.as_ref() {
map.serialize_entry("id", id)?;
}
map.serialize_entry("method", &self.0.method)?;
if let Some(params) = self.0.params.as_ref() {
map.serialize_entry("params", params)?;
}
map.end()
}
}
serializer.serialize_newtype_variant("Value", 9, "Object", &InnerRequest(self))
}
}
impl Revisioned for RouterRequest {
fn revision() -> u16 {
1
}
fn serialize_revisioned<W: std::io::Write>(
&self,
w: &mut W,
) -> std::result::Result<(), revision::Error> {
// version
Revisioned::serialize_revisioned(&1u32, w)?;
// object variant
Revisioned::serialize_revisioned(&9u32, w)?;
// object wrapper version
Revisioned::serialize_revisioned(&1u32, w)?;
let size = 1 + self.id.is_some() as usize + self.params.is_some() as usize;
size.serialize_revisioned(w)?;
let serializer = bincode::options()
.with_no_limit()
.with_little_endian()
.with_varint_encoding()
.reject_trailing_bytes();
if let Some(x) = self.id.as_ref() {
serializer
.serialize_into(&mut *w, "id")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
serializer
.serialize_into(&mut *w, "method")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
self.method.serialize_revisioned(w)?;
if let Some(x) = self.params.as_ref() {
serializer
.serialize_into(&mut *w, "params")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
Ok(())
}
fn deserialize_revisioned<R: Read>(_: &mut R) -> std::result::Result<Self, revision::Error>
where
Self: Sized,
{
panic!("deliberately unimplemented");
}
}
struct RouterState<Sink, Stream, Msg> {
var_stash: IndexMap<i64, (String, Value)>,
/// Vars currently set by the set method,
vars: IndexMap<String, Value>,
/// Messages which aught to be replayed on a reconnect.
replay: IndexMap<Method, Msg>,
/// Pending live queries
live_queries: HashMap<Uuid, channel::Sender<CoreNotification>>,
routes: HashMap<i64, (Method, Sender<Result<DbResponse>>)>,
last_activity: Instant,
sink: Sink,
stream: Stream,
}
impl<Sink, Stream, Msg> RouterState<Sink, Stream, Msg> {
pub fn new(sink: Sink, stream: Stream) -> Self {
RouterState {
var_stash: IndexMap::new(),
vars: IndexMap::new(),
replay: IndexMap::new(),
live_queries: HashMap::new(),
routes: HashMap::new(),
last_activity: Instant::now(),
sink,
stream,
}
}
}
enum HandleResult {
/// Socket disconnected, should continue to reconnect
Disconnected,
/// Nothing wrong continue as normal.
Ok,
}
/// The WS scheme used to connect to `ws://` endpoints
#[derive(Debug)]
pub struct Ws;
@ -156,7 +296,10 @@ pub(crate) struct Response {
pub(crate) result: ServerResult,
}
fn serialize(value: &Value, revisioned: bool) -> Result<Vec<u8>> {
fn serialize<V>(value: &V, revisioned: bool) -> Result<Vec<u8>>
where
V: serde::Serialize + Revisioned,
{
if revisioned {
let mut buf = Vec::new();
value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?;
@ -177,3 +320,67 @@ where
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
}
#[cfg(test)]
mod test {
use std::io::Cursor;
use revision::Revisioned;
use surrealdb_core::sql::Value;
use super::RouterRequest;
fn assert_converts<S, D, I>(req: &RouterRequest, s: S, d: D)
where
S: FnOnce(&RouterRequest) -> I,
D: FnOnce(I) -> Value,
{
let ser = s(req);
let val = d(ser);
let Value::Object(obj) = val else {
panic!("not an object");
};
assert_eq!(obj.get("id").cloned(), req.id);
assert_eq!(obj.get("method").unwrap().clone(), req.method);
assert_eq!(obj.get("params").cloned(), req.params);
}
#[test]
fn router_request_value_conversion() {
let request = RouterRequest {
id: Some(Value::from(1234i64)),
method: Value::from("request"),
params: Some(vec![Value::from(1234i64), Value::from("request")].into()),
};
println!("test convert bincode");
assert_converts(
&request,
|i| bincode::serialize(i).unwrap(),
|b| bincode::deserialize(&b).unwrap(),
);
println!("test convert json");
assert_converts(
&request,
|i| serde_json::to_string(i).unwrap(),
|b| serde_json::from_str(&b).unwrap(),
);
println!("test convert revisioned");
assert_converts(
&request,
|i| {
let mut buf = Vec::new();
i.serialize_revisioned(&mut Cursor::new(&mut buf)).unwrap();
buf
},
|b| Value::deserialize_revisioned(&mut Cursor::new(b)).unwrap(),
);
println!("done");
}
}

View file

@ -1,5 +1,6 @@
use super::PATH;
use super::{deserialize, serialize};
use super::{HandleResult, RouterRequest};
use crate::api::conn::Connection;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
@ -23,16 +24,12 @@ use crate::engine::IntervalStream;
use crate::opt::WaitFor;
use crate::sql::Value;
use flume::Receiver;
use futures::stream::SplitSink;
use futures::stream::{SplitSink, SplitStream};
use futures::SinkExt;
use futures::StreamExt;
use futures_concurrency::stream::Merge as _;
use indexmap::IndexMap;
use revision::revisioned;
use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::future::Future;
use std::mem;
@ -55,19 +52,15 @@ use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use trice::Instant;
type WsResult<T> = std::result::Result<T, WsError>;
pub(crate) const MAX_MESSAGE_SIZE: usize = 64 << 20; // 64 MiB
pub(crate) const MAX_FRAME_SIZE: usize = 16 << 20; // 16 MiB
pub(crate) const WRITE_BUFFER_SIZE: usize = 128000; // tungstenite default
pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = WRITE_BUFFER_SIZE + MAX_MESSAGE_SIZE; // Recommended max according to tungstenite docs
pub(crate) const NAGLE_ALG: bool = false;
pub(crate) enum Either {
Request(Option<Route>),
Response(WsResult<Message>),
Ping,
}
type MessageSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type MessageStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
type RouterState = super::RouterState<MessageSink, MessageStream, Message>;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
impl From<Tls> for Connector {
@ -144,7 +137,7 @@ impl Connection for Client {
capacity => flume::bounded(capacity),
};
router(address, maybe_connector, capacity, config, socket, route_rx);
tokio::spawn(run_router(address, maybe_connector, capacity, config, socket, route_rx));
let mut features = HashSet::new();
features.insert(ExtraFeatures::LiveQueries);
@ -172,384 +165,392 @@ impl Connection for Client {
request: (self.id, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}
}
#[allow(clippy::too_many_lines)]
pub(crate) fn router(
endpoint: Endpoint,
maybe_connector: Option<Connector>,
capacity: usize,
config: WebSocketConfig,
mut socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
route_rx: Receiver<Option<Route>>,
) {
tokio::spawn(async move {
let ping = {
let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into());
let value = Value::from(request);
let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value)
};
let mut var_stash = IndexMap::new();
let mut vars = IndexMap::new();
let mut replay = IndexMap::new();
'router: loop {
let (socket_sink, socket_stream) = socket.split();
let mut socket_sink = Socket(Some(socket_sink));
if let Socket(Some(socket_sink)) = &mut socket_sink {
let mut routes = match capacity {
0 => HashMap::new(),
capacity => HashMap::with_capacity(capacity),
};
let mut live_queries = HashMap::new();
let mut interval = time::interval(PING_INTERVAL);
// don't bombard the server with pings if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let pinger = IntervalStream::new(interval);
let streams = (
socket_stream.map(Either::Response),
route_rx.stream().map(Either::Request),
pinger.map(|_| Either::Ping),
);
let mut merged = streams.merge();
let mut last_activity = Instant::now();
while let Some(either) = merged.next().await {
match either {
Either::Request(Some(Route {
request,
response,
})) => {
let (id, method, param) = request;
let params = match param.query {
Some((query, bindings)) => {
vec![query.into(), bindings.into()]
}
None => param.other,
};
match method {
Method::Set => {
if let [Value::Strand(key), value] = &params[..2] {
var_stash.insert(id, (key.0.clone(), value.clone()));
}
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
vars.swap_remove(&key.0);
}
}
Method::Live => {
if let Some(sender) = param.notification_sender {
if let [Value::Uuid(id)] = &params[..1] {
live_queries.insert(*id, sender);
}
}
if response
.into_send_async(Ok(DbResponse::Other(Value::None)))
.await
.is_err()
{
trace!("Receiver dropped");
}
// There is nothing to send to the server here
continue;
}
Method::Kill => {
if let [Value::Uuid(id)] = &params[..1] {
live_queries.remove(id);
}
}
_ => {}
}
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = {
let mut request = BTreeMap::new();
request.insert("id".to_owned(), Value::from(id));
request.insert("method".to_owned(), method_str.into());
if !params.is_empty() {
request.insert("params".to_owned(), params.into());
}
let payload = Value::from(request);
trace!("Request {payload}");
let payload =
serialize(&payload, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
replay.insert(method, message.clone());
}
match socket_sink.send(message).await {
Ok(..) => {
last_activity = Instant::now();
match routes.entry(id) {
Entry::Vacant(entry) => {
// Register query route
entry.insert((method, response));
}
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response
.into_send_async(Err(error.into()))
.await
.is_err()
{
trace!("Receiver dropped");
}
}
}
}
Err(error) => {
let error = Error::Ws(error.to_string());
if response.into_send_async(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
break;
}
}
}
Either::Response(result) => {
last_activity = Instant::now();
match result {
Ok(message) => {
match Response::try_from(&message, endpoint.supports_revision) {
Ok(option) => {
// We are only interested in responses that are not empty
if let Some(response) = option {
trace!("{response:?}");
match response.id {
// If `id` is set this is a normal response
Some(id) => {
if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs
if let Some((method, sender)) =
routes.remove(&id)
{
if matches!(method, Method::Set) {
if let Some((key, value)) =
var_stash.swap_remove(&id)
{
vars.insert(key, value);
}
}
// Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert)
{
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(
Value::Array(value),
)) = &mut response
{
if let [value] =
&mut value.0[..]
{
response =
Ok(Data::Other(
mem::take(
value,
),
));
}
}
}
let _res = sender
.into_send_async(
DbResponse::from(response),
)
.await;
}
}
}
// If `id` is not set, this may be a live query notification
None => match response.result {
Ok(Data::Live(notification)) => {
let live_query_id = notification.id;
// Check if this live query is registered
if let Some(sender) =
live_queries.get(&live_query_id)
{
// Send the notification back to the caller or kill live query if the receiver is already dropped
if sender
.send(notification)
.await
.is_err()
{
live_queries
.remove(&live_query_id);
let kill = {
let mut request =
BTreeMap::new();
request.insert(
"method".to_owned(),
Method::Kill
.as_str()
.into(),
);
request.insert(
"params".to_owned(),
vec![Value::from(
live_query_id,
)]
.into(),
);
let value =
Value::from(request);
let value = serialize(
&value,
endpoint
.supports_revision,
)
.unwrap();
Message::Binary(value)
};
if let Err(error) =
socket_sink.send(kill).await
{
trace!("failed to send kill query to the server; {error:?}");
break;
}
}
}
}
Ok(..) => { /* Ignored responses like pings */
}
Err(error) => error!("{error:?}"),
},
}
}
}
Err(error) => {
#[revisioned(revision = 1)]
#[derive(Deserialize)]
struct Response {
id: Option<Value>,
}
// Let's try to find out the ID of the response that failed to deserialise
if let Message::Binary(binary) = message {
if let Ok(Response {
id,
}) = deserialize(
&mut &binary[..],
endpoint.supports_revision,
) {
// Return an error if an ID was returned
if let Some(Ok(id)) =
id.map(Value::coerce_to_i64)
{
if let Some((_method, sender)) =
routes.remove(&id)
{
let _res = sender
.into_send_async(Err(error))
.await;
}
}
} else {
// Unfortunately, we don't know which response failed to deserialize
warn!(
"Failed to deserialise message; {error:?}"
);
}
}
}
}
}
Err(error) => {
match error {
WsError::ConnectionClosed => {
trace!("Connection successfully closed on the server");
}
error => {
trace!("{error}");
}
}
break;
}
}
}
Either::Ping => {
// only ping if we haven't talked to the server recently
if last_activity.elapsed() >= PING_INTERVAL {
trace!("Pinging the server");
if let Err(error) = socket_sink.send(ping.clone()).await {
trace!("failed to ping the server; {error:?}");
break;
}
}
}
// Close connection request received
Either::Request(None) => {
match socket_sink.send(Message::Close(None)).await {
Ok(..) => trace!("Connection closed successfully"),
Err(error) => {
warn!("Failed to close database connection; {error}")
}
}
break 'router;
}
}
async fn router_handle_route(
Route {
request,
response,
}: Route,
state: &mut RouterState,
endpoint: &Endpoint,
) -> HandleResult {
let (id, method, param) = request;
let params = match param.query {
Some((query, bindings)) => {
vec![query.into(), bindings.into()]
}
None => param.other,
};
match method {
Method::Set => {
if let [Value::Strand(key), value] = &params[..2] {
state.var_stash.insert(id, (key.0.clone(), value.clone()));
}
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
state.vars.swap_remove(&key.0);
}
}
Method::Live => {
if let Some(sender) = param.notification_sender {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.insert(id.0, sender);
}
}
if response.clone().into_send_async(Ok(DbResponse::Other(Value::None))).await.is_err() {
trace!("Receiver dropped");
}
// There is nothing to send to the server here
}
Method::Kill => {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.remove(id);
}
}
_ => {}
}
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = {
let request = RouterRequest {
id: Some(Value::from(id)),
method: method_str.into(),
params: (!params.is_empty()).then(|| params.into()),
};
'reconnect: loop {
trace!("Reconnecting...");
match connect(&endpoint, Some(config), maybe_connector.clone()).await {
Ok(s) => {
socket = s;
for (_, message) in &replay {
if let Err(error) = socket.send(message.clone()).await {
trace!("{error}");
time::sleep(time::Duration::from_secs(1)).await;
continue 'reconnect;
}
}
for (key, value) in &vars {
let mut request = BTreeMap::new();
request.insert("method".to_owned(), Method::Set.as_str().into());
request.insert(
"params".to_owned(),
vec![key.as_str().into(), value.clone()].into(),
);
let payload = Value::from(request);
trace!("Request {payload}");
if let Err(error) = socket.send(Message::Binary(payload.into())).await {
trace!("{error}");
time::sleep(time::Duration::from_secs(1)).await;
continue 'reconnect;
}
}
trace!("Reconnected successfully");
break;
}
Err(error) => {
trace!("Failed to reconnect; {error}");
time::sleep(time::Duration::from_secs(1)).await;
trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
state.replay.insert(method, message.clone());
}
match state.sink.send(message).await {
Ok(_) => {
state.last_activity = Instant::now();
match state.routes.entry(id) {
Entry::Vacant(entry) => {
// Register query route
entry.insert((method, response));
}
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response.into_send_async(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
}
}
}
});
Err(error) => {
let error = Error::Ws(error.to_string());
if response.into_send_async(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
return HandleResult::Disconnected;
}
}
HandleResult::Ok
}
async fn router_handle_response(
response: Message,
state: &mut RouterState,
endpoint: &Endpoint,
) -> HandleResult {
match Response::try_from(&response, endpoint.supports_revision) {
Ok(option) => {
// We are only interested in responses that are not empty
if let Some(response) = option {
trace!("{response:?}");
match response.id {
// If `id` is set this is a normal response
Some(id) => {
if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs
if let Some((method, sender)) = state.routes.remove(&id) {
if matches!(method, Method::Set) {
if let Some((key, value)) = state.var_stash.swap_remove(&id) {
state.vars.insert(key, value);
}
}
// Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert) {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(value))) = &mut response {
if let [value] = &mut value.0[..] {
response = Ok(Data::Other(mem::take(value)));
}
}
}
let _res = sender.into_send_async(DbResponse::from(response)).await;
}
}
}
// If `id` is not set, this may be a live query notification
None => {
match response.result {
Ok(Data::Live(notification)) => {
let live_query_id = notification.id;
// Check if this live query is registered
if let Some(sender) = state.live_queries.get(&live_query_id) {
// Send the notification back to the caller or kill live query if the receiver is already dropped
if sender.send(notification).await.is_err() {
state.live_queries.remove(&live_query_id);
let kill = {
let request = RouterRequest {
id: None,
method: Method::Kill.as_str().into(),
params: Some(
vec![Value::from(live_query_id)].into(),
),
};
let value =
serialize(&request, endpoint.supports_revision)
.unwrap();
Message::Binary(value)
};
if let Err(error) = state.sink.send(kill).await {
trace!("failed to send kill query to the server; {error:?}");
return HandleResult::Disconnected;
}
}
}
}
Ok(..) => { /* Ignored responses like pings */ }
Err(error) => error!("{error:?}"),
}
}
}
}
}
Err(error) => {
#[revisioned(revision = 1)]
#[derive(Deserialize)]
struct Response {
id: Option<Value>,
}
// Let's try to find out the ID of the response that failed to deserialise
if let Message::Binary(binary) = response {
if let Ok(Response {
id,
}) = deserialize(&mut &binary[..], endpoint.supports_revision)
{
// Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
if let Some((_method, sender)) = state.routes.remove(&id) {
let _res = sender.into_send_async(Err(error)).await;
}
}
} else {
// Unfortunately, we don't know which response failed to deserialize
warn!("Failed to deserialise message; {error:?}");
}
}
}
}
HandleResult::Ok
}
async fn router_reconnect(
maybe_connector: &Option<Connector>,
config: &WebSocketConfig,
state: &mut RouterState,
endpoint: &Endpoint,
) {
loop {
trace!("Reconnecting...");
match connect(endpoint, Some(*config), maybe_connector.clone()).await {
Ok(s) => {
let (new_sink, new_stream) = s.split();
state.sink = new_sink;
state.stream = new_stream;
for (_, message) in &state.replay {
if let Err(error) = state.sink.send(message.clone()).await {
trace!("{error}");
time::sleep(time::Duration::from_secs(1)).await;
continue;
}
}
for (key, value) in &state.vars {
let request = RouterRequest {
id: None,
method: Method::Set.as_str().into(),
params: Some(vec![key.as_str().into(), value.clone()].into()),
};
trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap();
if let Err(error) = state.sink.send(Message::Binary(payload)).await {
trace!("{error}");
time::sleep(time::Duration::from_secs(1)).await;
continue;
}
}
trace!("Reconnected successfully");
break;
}
Err(error) => {
trace!("Failed to reconnect; {error}");
time::sleep(time::Duration::from_secs(1)).await;
}
}
}
}
pub(crate) async fn run_router(
endpoint: Endpoint,
maybe_connector: Option<Connector>,
_capacity: usize,
config: WebSocketConfig,
socket: WebSocketStream<MaybeTlsStream<TcpStream>>,
route_rx: Receiver<Route>,
) {
let ping = {
let request = RouterRequest {
id: None,
method: PING_METHOD.into(),
params: None,
};
let value = serialize(&request, endpoint.supports_revision).unwrap();
Message::Binary(value)
};
let (socket_sink, socket_stream) = socket.split();
let mut state = RouterState::new(socket_sink, socket_stream);
let mut route_stream = route_rx.into_stream();
'router: loop {
let mut interval = time::interval(PING_INTERVAL);
// don't bombard the server with pings if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let mut pinger = IntervalStream::new(interval);
// Turn into a stream instead of calling recv_async
// The stream seems to be able to keep some state which would otherwise need to be
// recreated with each next.
state.last_activity = Instant::now();
state.live_queries.clear();
state.routes.clear();
loop {
tokio::select! {
route = route_stream.next() => {
// handle incoming route
let Some(response) = route else {
// route returned none, frontend dropped the channel, meaning the router
// should quit.
match state.sink.send(Message::Close(None)).await {
Ok(..) => trace!("Connection closed successfully"),
Err(error) => {
warn!("Failed to close database connection; {error}")
}
}
break 'router;
};
match router_handle_route(response, &mut state, &endpoint).await {
HandleResult::Ok => {},
HandleResult::Disconnected => {
router_reconnect(
&maybe_connector,
&config,
&mut state,
&endpoint,
)
.await;
continue 'router;
}
}
}
result = state.stream.next() => {
// Handle result from database.
let Some(result) = result else {
// stream returned none meaning the connection dropped, try to reconnect.
router_reconnect(
&maybe_connector,
&config,
&mut state,
&endpoint,
)
.await;
continue 'router;
};
state.last_activity = Instant::now();
match result {
Ok(message) => {
match router_handle_response(message, &mut state, &endpoint).await {
HandleResult::Ok => continue,
HandleResult::Disconnected => {
router_reconnect(
&maybe_connector,
&config,
&mut state,
&endpoint,
)
.await;
continue 'router;
}
}
}
Err(error) => {
match error {
WsError::ConnectionClosed => {
trace!("Connection successfully closed on the server");
}
error => {
trace!("{error}");
}
}
router_reconnect(
&maybe_connector,
&config,
&mut state,
&endpoint,
)
.await;
continue 'router;
}
}
}
_ = pinger.next() => {
// only ping if we haven't talked to the server recently
if state.last_activity.elapsed() >= PING_INTERVAL {
trace!("Pinging the server");
if let Err(error) = state.sink.send(ping.clone()).await {
trace!("failed to ping the server; {error:?}");
router_reconnect(
&maybe_connector,
&config,
&mut state,
&endpoint,
)
.await;
continue 'router;
}
}
}
}
}
}
}
impl Response {
@ -588,8 +589,6 @@ impl Response {
}
}
pub struct Socket(Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>);
#[cfg(test)]
mod tests {
use super::serialize;

View file

@ -1,5 +1,5 @@
use super::PATH;
use super::{deserialize, serialize};
use super::{HandleResult, PATH};
use crate::api::conn::Connection;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
@ -16,27 +16,26 @@ use crate::api::ExtraFeatures;
use crate::api::OnceLockExt;
use crate::api::Result;
use crate::api::Surreal;
use crate::engine::remote::ws::Data;
use crate::engine::remote::ws::{Data, RouterRequest};
use crate::engine::IntervalStream;
use crate::opt::WaitFor;
use crate::sql::Value;
use flume::Receiver;
use flume::Sender;
use futures::stream::{SplitSink, SplitStream};
use futures::FutureExt;
use futures::SinkExt;
use futures::StreamExt;
use futures_concurrency::stream::Merge as _;
use indexmap::IndexMap;
use pharos::Channel;
use pharos::Events;
use pharos::Observable;
use pharos::ObserveConfig;
use revision::revisioned;
use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::mem;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
@ -48,16 +47,13 @@ use trice::Instant;
use wasm_bindgen_futures::spawn_local;
use wasmtimer::tokio as time;
use wasmtimer::tokio::MissedTickBehavior;
use ws_stream_wasm::WsEvent;
use ws_stream_wasm::WsMessage as Message;
use ws_stream_wasm::WsMeta;
use ws_stream_wasm::{WsEvent, WsStream};
pub(crate) enum Either {
Request(Option<Route>),
Response(Message),
Event(WsEvent),
Ping,
}
type MessageStream = SplitStream<WsStream>;
type MessageSink = SplitSink<WsStream, Message>;
type RouterState = super::RouterState<MessageSink, MessageStream, Message>;
impl crate::api::Connection for Client {}
@ -83,7 +79,7 @@ impl Connection for Client {
let (conn_tx, conn_rx) = flume::bounded(1);
router(address, capacity, conn_tx, route_rx);
spawn_local(run_router(address, capacity, conn_tx, route_rx));
conn_rx.into_recv_async().await??;
@ -113,297 +109,367 @@ impl Connection for Client {
request: (self.id, self.method, param),
response: sender,
};
router.sender.send_async(Some(route)).await?;
router.sender.send_async(route).await?;
Ok(receiver)
})
}
}
pub(crate) fn router(
endpoint: Endpoint,
async fn router_handle_request(
Route {
request,
response,
}: Route,
state: &mut RouterState,
endpoint: &Endpoint,
) -> HandleResult {
let (id, method, param) = request;
let params = match param.query {
Some((query, bindings)) => {
vec![query.into(), bindings.into()]
}
None => param.other,
};
match method {
Method::Set => {
if let [Value::Strand(key), value] = &params[..2] {
state.var_stash.insert(id, (key.0.clone(), value.clone()));
}
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
state.vars.swap_remove(&key.0);
}
}
Method::Live => {
if let Some(sender) = param.notification_sender {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.insert(id.0, sender);
}
}
if response.into_send_async(Ok(DbResponse::Other(Value::None))).await.is_err() {
trace!("Receiver dropped");
}
// There is nothing to send to the server here
return HandleResult::Ok;
}
Method::Kill => {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.remove(id);
}
}
_ => {}
}
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = {
let request = RouterRequest {
id: Some(Value::from(id)),
method: method_str.into(),
params: (!params.is_empty()).then(|| params.into()),
};
trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
state.replay.insert(method, message.clone());
}
match state.sink.send(message).await {
Ok(..) => {
state.last_activity = Instant::now();
match state.routes.entry(id) {
Entry::Vacant(entry) => {
entry.insert((method, response));
}
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response.into_send_async(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
}
}
}
Err(error) => {
let error = Error::Ws(error.to_string());
if response.into_send_async(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
return HandleResult::Disconnected;
}
}
HandleResult::Ok
}
async fn router_handle_response(
response: Message,
state: &mut RouterState,
endpoint: &Endpoint,
) -> HandleResult {
match Response::try_from(&response, endpoint.supports_revision) {
Ok(option) => {
// We are only interested in responses that are not empty
if let Some(response) = option {
trace!("{response:?}");
match response.id {
// If `id` is set this is a normal response
Some(id) => {
if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs
if let Some((method, sender)) = state.routes.remove(&id) {
if matches!(method, Method::Set) {
if let Some((key, value)) = state.var_stash.swap_remove(&id) {
state.vars.insert(key, value);
}
}
// Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert) {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(value))) = &mut response {
if let [value] = &mut value.0[..] {
response = Ok(Data::Other(mem::take(value)));
}
}
}
let _res = sender.into_send_async(DbResponse::from(response)).await;
}
}
}
// If `id` is not set, this may be a live query notification
None => match response.result {
Ok(Data::Live(notification)) => {
let live_query_id = notification.id;
// Check if this live query is registered
if let Some(sender) = state.live_queries.get(&live_query_id) {
// Send the notification back to the caller or kill live query if the receiver is already dropped
if sender.send(notification).await.is_err() {
state.live_queries.remove(&live_query_id);
let kill = {
let request = RouterRequest {
id: None,
method: Method::Kill.as_str().into(),
params: Some(vec![Value::from(live_query_id)].into()),
};
let value = serialize(&request, endpoint.supports_revision)
.unwrap();
Message::Binary(value)
};
if let Err(error) = state.sink.send(kill).await {
trace!(
"failed to send kill query to the server; {error:?}"
);
return HandleResult::Disconnected;
}
}
}
}
Ok(..) => { /* Ignored responses like pings */ }
Err(error) => error!("{error:?}"),
},
}
}
}
Err(error) => {
#[derive(Deserialize)]
#[revisioned(revision = 1)]
struct Response {
id: Option<Value>,
}
// Let's try to find out the ID of the response that failed to deserialise
if let Message::Binary(binary) = response {
if let Ok(Response {
id,
}) = deserialize(&mut &binary[..], endpoint.supports_revision)
{
// Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
if let Some((_method, sender)) = state.routes.remove(&id) {
let _res = sender.into_send_async(Err(error)).await;
}
}
} else {
// Unfortunately, we don't know which response failed to deserialize
warn!("Failed to deserialise message; {error:?}");
}
}
}
}
HandleResult::Ok
}
async fn router_reconnect(
state: &mut RouterState,
events: &mut Events<WsEvent>,
endpoint: &Endpoint,
capacity: usize,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Option<Route>>,
) {
spawn_local(async move {
loop {
trace!("Reconnecting...");
let connect = match endpoint.supports_revision {
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
false => WsMeta::connect(&endpoint.url, None).await,
};
let (mut ws, mut socket) = match connect {
Ok(pair) => pair,
match connect {
Ok((mut meta, stream)) => {
let (new_sink, new_stream) = stream.split();
state.sink = new_sink;
state.stream = new_stream;
*events = {
let result = match capacity {
0 => meta.observe(ObserveConfig::default()).await,
capacity => meta.observe(Channel::Bounded(capacity).into()).await,
};
match result {
Ok(events) => events,
Err(error) => {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue;
}
}
};
for (_, message) in &state.replay {
if let Err(error) = state.sink.send(message.clone()).await {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue;
}
}
for (key, value) in &state.vars {
let request = RouterRequest {
id: None,
method: Method::Set.as_str().into(),
params: Some(vec![key.as_str().into(), value.clone()].into()),
};
trace!("Request {:?}", request);
let serialize = serialize(&request, false).unwrap();
if let Err(error) = state.sink.send(Message::Binary(serialize)).await {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue;
}
}
trace!("Reconnected successfully");
break;
}
Err(error) => {
trace!("Failed to reconnect; {error}");
time::sleep(Duration::from_secs(1)).await;
}
}
}
}
pub(crate) async fn run_router(
endpoint: Endpoint,
capacity: usize,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Route>,
) {
let connect = match endpoint.supports_revision {
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
false => WsMeta::connect(&endpoint.url, None).await,
};
let (mut ws, socket) = match connect {
Ok(pair) => pair,
Err(error) => {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
};
let mut events = {
let result = match capacity {
0 => ws.observe(ObserveConfig::default()).await,
capacity => ws.observe(Channel::Bounded(capacity).into()).await,
};
match result {
Ok(events) => events,
Err(error) => {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
}
};
}
};
let mut events = {
let result = match capacity {
0 => ws.observe(ObserveConfig::default()).await,
capacity => ws.observe(Channel::Bounded(capacity).into()).await,
};
match result {
Ok(events) => events,
Err(error) => {
let _ = conn_tx.into_send_async(Err(error.into())).await;
return;
let _ = conn_tx.into_send_async(Ok(())).await;
let ping = {
let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into());
let value = Value::from(request);
let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value)
};
let (socket_sink, socket_stream) = socket.split();
let mut state = RouterState::new(socket_sink, socket_stream);
let mut route_stream = route_rx.into_stream();
'router: loop {
let mut interval = time::interval(PING_INTERVAL);
// don't bombard the server with pings if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let mut pinger = IntervalStream::new(interval);
state.last_activity = Instant::now();
state.live_queries.clear();
state.routes.clear();
loop {
futures::select! {
route = route_stream.next() => {
let Some(route) = route else {
match ws.close().await {
Ok(..) => trace!("Connection closed successfully"),
Err(error) => {
warn!("Failed to close database connection; {error}")
}
}
break 'router;
};
match router_handle_request(route, &mut state,&endpoint).await {
HandleResult::Ok => {},
HandleResult::Disconnected => {
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break
}
}
}
}
};
message = state.stream.next().fuse() => {
let Some(message) = message else {
// socket disconnected,
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break
};
let _ = conn_tx.into_send_async(Ok(())).await;
let ping = {
let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into());
let value = Value::from(request);
let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value)
};
let mut var_stash = IndexMap::new();
let mut vars = IndexMap::new();
let mut replay = IndexMap::new();
'router: loop {
let (mut socket_sink, socket_stream) = socket.split();
let mut routes = match capacity {
0 => HashMap::new(),
capacity => HashMap::with_capacity(capacity),
};
let mut live_queries = HashMap::new();
let mut interval = time::interval(PING_INTERVAL);
// don't bombard the server with pings if we miss some ticks
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let pinger = IntervalStream::new(interval);
let streams = (
socket_stream.map(Either::Response),
route_rx.stream().map(Either::Request),
pinger.map(|_| Either::Ping),
events.map(Either::Event),
);
let mut merged = streams.merge();
let mut last_activity = Instant::now();
while let Some(either) = merged.next().await {
match either {
Either::Request(Some(Route {
request,
response,
})) => {
let (id, method, param) = request;
let params = match param.query {
Some((query, bindings)) => {
vec![query.into(), bindings.into()]
}
None => param.other,
};
match method {
Method::Set => {
if let [Value::Strand(key), value] = &params[..2] {
var_stash.insert(id, (key.0.clone(), value.clone()));
}
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
vars.swap_remove(&key.0);
}
}
Method::Live => {
if let Some(sender) = param.notification_sender {
if let [Value::Uuid(id)] = &params[..1] {
live_queries.insert(*id, sender);
}
}
if response
.into_send_async(Ok(DbResponse::Other(Value::None)))
.await
.is_err()
{
trace!("Receiver dropped");
}
// There is nothing to send to the server here
continue;
}
Method::Kill => {
if let [Value::Uuid(id)] = &params[..1] {
live_queries.remove(id);
}
}
_ => {}
}
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = {
let mut request = BTreeMap::new();
request.insert("id".to_owned(), Value::from(id));
request.insert("method".to_owned(), method_str.into());
if !params.is_empty() {
request.insert("params".to_owned(), params.into());
}
let payload = Value::from(request);
trace!("Request {payload}");
let payload = serialize(&payload, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
replay.insert(method, message.clone());
}
match socket_sink.send(message).await {
Ok(..) => {
last_activity = Instant::now();
match routes.entry(id) {
Entry::Vacant(entry) => {
entry.insert((method, response));
}
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response
.into_send_async(Err(error.into()))
.await
.is_err()
{
trace!("Receiver dropped");
}
}
}
}
Err(error) => {
let error = Error::Ws(error.to_string());
if response.into_send_async(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
break;
}
state.last_activity = Instant::now();
match router_handle_response(message, &mut state,&endpoint).await {
HandleResult::Ok => {},
HandleResult::Disconnected => {
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break
}
}
Either::Response(message) => {
last_activity = Instant::now();
match Response::try_from(&message, endpoint.supports_revision) {
Ok(option) => {
// We are only interested in responses that are not empty
if let Some(response) = option {
trace!("{response:?}");
match response.id {
// If `id` is set this is a normal response
Some(id) => {
if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs
if let Some((method, sender)) = routes.remove(&id) {
if matches!(method, Method::Set) {
if let Some((key, value)) =
var_stash.swap_remove(&id)
{
vars.insert(key, value);
}
}
// Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert) {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(
value,
))) = &mut response
{
if let [value] = &mut value.0[..] {
response = Ok(Data::Other(
mem::take(value),
));
}
}
}
let _res = sender
.into_send_async(DbResponse::from(response))
.await;
}
}
}
// If `id` is not set, this may be a live query notification
None => match response.result {
Ok(Data::Live(notification)) => {
let live_query_id = notification.id;
// Check if this live query is registered
if let Some(sender) =
live_queries.get(&live_query_id)
{
// Send the notification back to the caller or kill live query if the receiver is already dropped
if sender.send(notification).await.is_err() {
live_queries.remove(&live_query_id);
let kill = {
let mut request = BTreeMap::new();
request.insert(
"method".to_owned(),
Method::Kill.as_str().into(),
);
request.insert(
"params".to_owned(),
vec![Value::from(live_query_id)]
.into(),
);
let value = Value::from(request);
let value = serialize(
&value,
endpoint.supports_revision,
)
.unwrap();
Message::Binary(value)
};
if let Err(error) =
socket_sink.send(kill).await
{
trace!("failed to send kill query to the server; {error:?}");
break;
}
}
}
}
Ok(..) => { /* Ignored responses like pings */ }
Err(error) => error!("{error:?}"),
},
}
}
}
Err(error) => {
#[derive(Deserialize)]
#[revisioned(revision = 1)]
struct Response {
id: Option<Value>,
}
// Let's try to find out the ID of the response that failed to deserialise
if let Message::Binary(binary) = message {
if let Ok(Response {
id,
}) = deserialize(&mut &binary[..], endpoint.supports_revision)
{
// Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
if let Some((_method, sender)) = routes.remove(&id) {
let _res = sender.into_send_async(Err(error)).await;
}
}
} else {
// Unfortunately, we don't know which response failed to deserialize
warn!("Failed to deserialise message; {error:?}");
}
}
}
}
}
Either::Event(event) => match event {
}
event = events.next().fuse() => {
let Some(event) = event else {
continue;
};
match event {
WsEvent::Error => {
trace!("connection errored");
break;
@ -413,89 +479,25 @@ pub(crate) fn router(
}
WsEvent::Closed(..) => {
trace!("connection closed");
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break;
}
_ => {}
},
Either::Ping => {
// only ping if we haven't talked to the server recently
if last_activity.elapsed() >= PING_INTERVAL {
trace!("Pinging the server");
if let Err(error) = socket_sink.send(ping.clone()).await {
trace!("failed to ping the server; {error:?}");
break;
}
}
}
// Close connection request received
Either::Request(None) => {
match ws.close().await {
Ok(..) => trace!("Connection closed successfully"),
Err(error) => {
warn!("Failed to close database connection; {error}")
}
}
break 'router;
}
}
}
'reconnect: loop {
trace!("Reconnecting...");
let connect = match endpoint.supports_revision {
true => WsMeta::connect(&endpoint.url, vec![super::REVISION_HEADER]).await,
false => WsMeta::connect(&endpoint.url, None).await,
};
match connect {
Ok((mut meta, stream)) => {
socket = stream;
events = {
let result = match capacity {
0 => meta.observe(ObserveConfig::default()).await,
capacity => meta.observe(Channel::Bounded(capacity).into()).await,
};
match result {
Ok(events) => events,
Err(error) => {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue 'reconnect;
}
}
};
for (_, message) in &replay {
if let Err(error) = socket.send(message.clone()).await {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue 'reconnect;
}
_ = pinger.next().fuse() => {
if state.last_activity.elapsed() >= PING_INTERVAL {
trace!("Pinging the server");
if let Err(error) = state.sink.send(ping.clone()).await {
trace!("failed to ping the server; {error:?}");
router_reconnect(&mut state, &mut events, &endpoint, capacity).await;
break;
}
for (key, value) in &vars {
let mut request = BTreeMap::new();
request.insert("method".to_owned(), Method::Set.as_str().into());
request.insert(
"params".to_owned(),
vec![key.as_str().into(), value.clone()].into(),
);
let payload = Value::from(request);
trace!("Request {payload}");
if let Err(error) = socket.send(Message::Binary(payload.into())).await {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue 'reconnect;
}
}
trace!("Reconnected successfully");
break;
}
Err(error) => {
trace!("Failed to reconnect; {error}");
time::sleep(Duration::from_secs(1)).await;
}
}
}
}
});
}
}
impl Response {

View file

@ -1,5 +1,4 @@
use futures::{FutureExt, StreamExt};
use futures_concurrency::stream::Merge;
use futures::StreamExt;
use reblessive::TreeStack;
#[cfg(target_arch = "wasm32")]
use std::sync::atomic::{AtomicBool, Ordering};
@ -93,24 +92,27 @@ fn init(opt: &EngineOptions, dbs: Arc<Datastore>) -> (FutureTask, oneshot::Sende
let ret_status = completed_status.clone();
// We create a channel that can be streamed that will indicate termination
let (tx, rx) = oneshot::channel();
let (tx, mut rx) = oneshot::channel();
let _fut = spawn_future(async move {
let _lifecycle = crate::dbs::LoggingLifecycle::new("heartbeat task".to_string());
let ticker = interval_ticker(tick_interval).await;
let streams = (
ticker.map(|i| {
trace!("Node agent tick: {:?}", i);
Some(i)
}),
rx.into_stream().map(|_| None),
);
let mut streams = streams.merge();
let mut ticker = interval_ticker(tick_interval).await;
while let Some(Some(_)) = streams.next().await {
if let Err(e) = dbs.tick().await {
error!("Error running node agent tick: {}", e);
break;
loop {
tokio::select! {
v = ticker.next() => {
// ticker will never return None;
let i = v.unwrap();
trace!("Node agent tick: {:?}", i);
if let Err(e) = dbs.tick().await {
error!("Error running node agent tick: {}", e);
break;
}
}
_ = &mut rx => {
// termination requested
break
}
}
}
@ -136,7 +138,7 @@ fn live_query_change_feed(
let ret_status = completed_status.clone();
// We create a channel that can be streamed that will indicate termination
let (tx, rx) = oneshot::channel();
let (tx, mut rx) = oneshot::channel();
let _fut = spawn_future(async move {
let mut stack = TreeStack::new();
@ -148,25 +150,29 @@ fn live_query_change_feed(
completed_status.store(true, Ordering::Relaxed);
return;
}
let ticker = interval_ticker(tick_interval).await;
let streams = (
ticker.map(|i| {
trace!("Live query agent tick: {:?}", i);
Some(i)
}),
rx.into_stream().map(|_| None),
);
let mut streams = streams.merge();
let mut ticker = interval_ticker(tick_interval).await;
let opt = Options::default();
while let Some(Some(_)) = streams.next().await {
if let Err(e) =
stack.enter(|stk| dbs.process_lq_notifications(stk, &opt)).finish().await
{
error!("Error running node agent tick: {}", e);
break;
loop {
tokio::select! {
v = ticker.next() => {
// ticker will never return None;
let i = v.unwrap();
trace!("Live query agent tick: {:?}", i);
if let Err(e) =
stack.enter(|stk| dbs.process_lq_notifications(stk, &opt)).finish().await
{
error!("Error running node agent tick: {}", e);
break;
}
}
_ = &mut rx => {
// termination requested,
break
}
}
}
#[cfg(target_arch = "wasm32")]
completed_status.store(true, Ordering::Relaxed);
});

View file

@ -96,13 +96,7 @@ impl Connection for Client {
request: (0, self.method, param),
response: sender,
};
router
.sender
.send_async(Some(route))
.await
.as_ref()
.map_err(ToString::to_string)
.unwrap();
router.sender.send_async(route).await.as_ref().map_err(ToString::to_string).unwrap();
Ok(receiver)
})
}

View file

@ -8,14 +8,14 @@ use crate::sql::Value;
use flume::Receiver;
use futures::StreamExt;
pub(super) fn mock(route_rx: Receiver<Option<Route>>) {
pub(super) fn mock(route_rx: Receiver<Route>) {
tokio::spawn(async move {
let mut stream = route_rx.into_stream();
while let Some(Some(Route {
while let Some(Route {
request,
response,
})) = stream.next().await
}) = stream.next().await
{
let (_, method, param) = request;
let mut params = param.other;

View file

@ -4648,7 +4648,7 @@ async fn function_type_is_bytes() -> Result<(), Error> {
async fn function_type_is_collection() -> Result<(), Error> {
let sql = r#"
LET $collection = <geometry<collection>> {
type: 'GeometryCollection',
type: 'GeometryCollection',
geometries: [{ type: 'MultiPoint', coordinates: [[10, 11.2], [10.5, 11.9]] }]
};
RETURN type::is::collection($collection);
@ -4902,7 +4902,7 @@ async fn function_type_is_multipoint() -> Result<(), Error> {
async fn function_type_is_multipolygon() -> Result<(), Error> {
let sql = r#"
LET $multipolygon = <geometry<multipolygon>> {
type: 'MultiPolygon',
type: 'MultiPolygon',
coordinates: [[[[10, 11.2], [10.5, 11.9], [10.8, 12], [10, 11.2]]], [[[9, 11.2], [10.5, 11.9], [10.3, 13], [9, 11.2]]]]
};
RETURN type::is::multipolygon($multipolygon);
@ -5001,7 +5001,7 @@ async fn function_type_is_point() -> Result<(), Error> {
async fn function_type_is_polygon() -> Result<(), Error> {
let sql = r#"
LET $polygon = <geometry<polygon>> {
type: 'Polygon',
type: 'Polygon',
coordinates: [
[
[-0.38314819, 51.37692386],
@ -5953,6 +5953,7 @@ pub async fn function_http_disabled() {
"#,
)
.await
.unwrap()
.expect_errors(&[
"Remote HTTP request functions are not enabled",
"Remote HTTP request functions are not enabled",
@ -5960,7 +5961,8 @@ pub async fn function_http_disabled() {
"Remote HTTP request functions are not enabled",
"Remote HTTP request functions are not enabled",
"Remote HTTP request functions are not enabled",
]);
])
.unwrap();
}
// Tests for custom defined functions

View file

@ -25,5 +25,4 @@ in craneLib.buildPackage (buildSpec // {
inherit cargoArtifacts;
inherit (util) version SURREAL_BUILD_METADATA;
RUSTFLAGS = "--cfg surrealdb_unstable";
})

View file

@ -196,7 +196,7 @@ pub async fn init(
net::init(ct.clone()).await?;
// Shutdown and stop closed tasks
task_chans.into_iter().for_each(|chan| {
if let Err(_empty_tuple) = chan.send(()) {
if chan.send(()).is_err() {
error!("Failed to send shutdown signal to task");
}
});