From 28bd007f729719609baf38c30a1dc07607e6803d Mon Sep 17 00:00:00 2001 From: Finn Bear <finnbearlabs@gmail.com> Date: Tue, 18 Apr 2023 14:32:29 -0700 Subject: [PATCH] Limit the number of concurrent futures run when fetching remote records (#1824) --- Cargo.lock | 1 + lib/Cargo.toml | 1 + lib/src/exe/mod.rs | 23 +--- lib/src/exe/spawn.rs | 19 ++++ lib/src/exe/try_join_all_buffered.rs | 154 +++++++++++++++++++++++++++ lib/src/sql/value/del.rs | 6 +- lib/src/sql/value/get.rs | 6 +- lib/src/sql/value/set.rs | 6 +- lib/tests/future.rs | 59 ++++++++++ 9 files changed, 248 insertions(+), 27 deletions(-) create mode 100644 lib/src/exe/spawn.rs create mode 100644 lib/src/exe/try_join_all_buffered.rs diff --git a/Cargo.lock b/Cargo.lock index 2c35adb9..44783013 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3812,6 +3812,7 @@ dependencies = [ "once_cell", "pbkdf2", "pharos", + "pin-project-lite", "rand 0.8.5", "regex", "reqwest", diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 275f29b0..81e54545 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -80,6 +80,7 @@ native-tls = { version = "0.2.11", optional = true } nom = { version = "7.1.3", features = ["alloc"] } once_cell = "1.17.1" pbkdf2 = { version = "0.12.1", features = ["simple"] } +pin-project-lite = "0.2.9" rand = "0.8.5" regex = "1.7.3" reqwest = { version = "0.11.16", default-features = false, features = ["json", "stream"], optional = true } diff --git a/lib/src/exe/mod.rs b/lib/src/exe/mod.rs index 3b5e19a1..816901f9 100644 --- a/lib/src/exe/mod.rs +++ b/lib/src/exe/mod.rs @@ -1,19 +1,6 @@ -#![cfg(not(target_arch = "wasm32"))] +#[cfg(not(target_arch = "wasm32"))] +pub use spawn::spawn; +pub use try_join_all_buffered::try_join_all_buffered; -use executor::{Executor, Task}; -use once_cell::sync::Lazy; -use std::future::Future; -use std::panic::catch_unwind; - -pub fn spawn<T: Send + 'static>(future: impl Future<Output = T> + Send + 'static) -> Task<T> { - static GLOBAL: Lazy<Executor<'_>> = Lazy::new(|| { - std::thread::spawn(|| { - catch_unwind(|| { - futures::executor::block_on(GLOBAL.run(futures::future::pending::<()>())) - }) - .ok(); - }); - Executor::new() - }); - GLOBAL.spawn(future) -} +mod spawn; +mod try_join_all_buffered; diff --git a/lib/src/exe/spawn.rs b/lib/src/exe/spawn.rs new file mode 100644 index 00000000..3b5e19a1 --- /dev/null +++ b/lib/src/exe/spawn.rs @@ -0,0 +1,19 @@ +#![cfg(not(target_arch = "wasm32"))] + +use executor::{Executor, Task}; +use once_cell::sync::Lazy; +use std::future::Future; +use std::panic::catch_unwind; + +pub fn spawn<T: Send + 'static>(future: impl Future<Output = T> + Send + 'static) -> Task<T> { + static GLOBAL: Lazy<Executor<'_>> = Lazy::new(|| { + std::thread::spawn(|| { + catch_unwind(|| { + futures::executor::block_on(GLOBAL.run(futures::future::pending::<()>())) + }) + .ok(); + }); + Executor::new() + }); + GLOBAL.spawn(future) +} diff --git a/lib/src/exe/try_join_all_buffered.rs b/lib/src/exe/try_join_all_buffered.rs new file mode 100644 index 00000000..cabdffbe --- /dev/null +++ b/lib/src/exe/try_join_all_buffered.rs @@ -0,0 +1,154 @@ +use futures::{ + future::IntoFuture, ready, stream::FuturesOrdered, TryFuture, TryFutureExt, TryStream, +}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`try_join_all_buffered`] function. + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct TryJoinAllBuffered<F, I> + where + F: TryFuture, + I: Iterator<Item = F>, + { + input: I, + #[pin] + active: FuturesOrdered<IntoFuture<F>>, + output: Vec<F::Ok>, + } +} + +/// Creates a future which represents either an in-order collection of the +/// results of the futures given or a (fail-fast) error. +/// +/// Only a limited number of futures are driven at a time. +pub fn try_join_all_buffered<I>(iter: I) -> TryJoinAllBuffered<I::Item, I::IntoIter> +where + I: IntoIterator, + I::Item: TryFuture, +{ + #[cfg(target_arch = "wasm32")] + const LIMIT: usize = 1; + + #[cfg(not(target_arch = "wasm32"))] + const LIMIT: usize = crate::cnf::MAX_CONCURRENT_TASKS; + + let mut input = iter.into_iter(); + let (lo, hi) = input.size_hint(); + let initial_capacity = hi.unwrap_or(lo); + let mut active = FuturesOrdered::new(); + + while active.len() < LIMIT { + if let Some(next) = input.next() { + active.push_back(TryFutureExt::into_future(next)); + } else { + break; + } + } + + TryJoinAllBuffered { + input, + active, + output: Vec::with_capacity(initial_capacity), + } +} + +impl<F, I> Future for TryJoinAllBuffered<F, I> +where + F: TryFuture, + I: Iterator<Item = F>, +{ + type Output = Result<Vec<F::Ok>, F::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut this = self.project(); + Poll::Ready(Ok(loop { + match ready!(this.active.as_mut().try_poll_next(cx)?) { + Some(x) => { + if let Some(next) = this.input.next() { + this.active.push_back(TryFutureExt::into_future(next)); + } + this.output.push(x) + } + None => break mem::take(this.output), + } + })) + } +} + +#[cfg(test)] +mod tests { + use super::try_join_all_buffered; + use futures::ready; + use pin_project_lite::pin_project; + use rand::{thread_rng, Rng}; + use std::{ + future::Future, + task::Poll, + time::{Duration, Instant}, + }; + use tokio::time::{sleep, Sleep}; + + pin_project! { + struct BenchFuture { + #[pin] + sleep: Sleep, + } + } + + impl Future for BenchFuture { + type Output = Result<usize, &'static str>; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Self::Output> { + let me = self.project(); + ready!(me.sleep.poll(cx)); + Poll::Ready(if true { + Ok(42) + } else { + Err("no good") + }) + } + } + + /// Returns average # of seconds. + async fn benchmark_try_join_all<F: Future<Output = Result<Vec<usize>, &'static str>>>( + try_join_all: fn(Vec<BenchFuture>) -> F, + count: usize, + ) -> f32 { + let mut rng = thread_rng(); + let mut total = Duration::ZERO; + let samples = (250 / count.max(1)).max(10); + for _ in 0..samples { + let futures = Vec::from_iter((0..count).map(|_| BenchFuture { + sleep: sleep(Duration::from_millis(rng.gen_range(0..5))), + })); + let start = Instant::now(); + try_join_all(futures).await.unwrap(); + total += start.elapsed(); + } + total.as_secs_f32() / samples as f32 + } + + #[tokio::test] + async fn comparison() { + for i in (0..10).chain((20..100).step_by(20)).chain((500..10000).step_by(500)) { + let unbuffered = benchmark_try_join_all(futures::future::try_join_all, i).await; + let buffered = benchmark_try_join_all(try_join_all_buffered, i).await; + println!( + "with {i:<4} futs, buf. exe. takes {buffered:.4}s = {:>5.1}% the time", + 100.0 * buffered / unbuffered + ); + + if i > 7000 { + assert!(buffered < unbuffered, "buf: {buffered:.5}s unbuf: {unbuffered:.5}s"); + } + } + } +} diff --git a/lib/src/sql/value/del.rs b/lib/src/sql/value/del.rs index 8e3949cd..3743d4d9 100644 --- a/lib/src/sql/value/del.rs +++ b/lib/src/sql/value/del.rs @@ -2,12 +2,12 @@ use crate::ctx::Context; use crate::dbs::Options; use crate::dbs::Transaction; use crate::err::Error; +use crate::exe::try_join_all_buffered; use crate::sql::array::Abolish; use crate::sql::part::Next; use crate::sql::part::Part; use crate::sql::value::Value; use async_recursion::async_recursion; -use futures::future::try_join_all; use std::collections::HashSet; impl Value { @@ -47,7 +47,7 @@ impl Value { _ => { let path = path.next(); let futs = v.iter_mut().map(|v| v.del(ctx, opt, txn, path)); - try_join_all(futs).await?; + try_join_all_buffered(futs).await?; Ok(()) } }, @@ -114,7 +114,7 @@ impl Value { }, _ => { let futs = v.iter_mut().map(|v| v.del(ctx, opt, txn, path)); - try_join_all(futs).await?; + try_join_all_buffered(futs).await?; Ok(()) } }, diff --git a/lib/src/sql/value/get.rs b/lib/src/sql/value/get.rs index ac60af1a..7a3f145d 100644 --- a/lib/src/sql/value/get.rs +++ b/lib/src/sql/value/get.rs @@ -2,6 +2,7 @@ use crate::ctx::Context; use crate::dbs::Options; use crate::dbs::Transaction; use crate::err::Error; +use crate::exe::try_join_all_buffered; use crate::sql::edges::Edges; use crate::sql::field::{Field, Fields}; use crate::sql::id::Id; @@ -12,7 +13,6 @@ use crate::sql::statements::select::SelectStatement; use crate::sql::thing::Thing; use crate::sql::value::{Value, Values}; use async_recursion::async_recursion; -use futures::future::try_join_all; impl Value { #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] @@ -68,7 +68,7 @@ impl Value { Part::All => { let path = path.next(); let futs = v.iter().map(|v| v.get(ctx, opt, txn, path)); - try_join_all(futs).await.map(Into::into) + try_join_all_buffered(futs).await.map(Into::into) } Part::First => match v.first() { Some(v) => v.get(ctx, opt, txn, path.next()).await, @@ -94,7 +94,7 @@ impl Value { } _ => { let futs = v.iter().map(|v| v.get(ctx, opt, txn, path)); - try_join_all(futs).await.map(Into::into) + try_join_all_buffered(futs).await.map(Into::into) } }, // Current path part is an edges diff --git a/lib/src/sql/value/set.rs b/lib/src/sql/value/set.rs index 586c2229..8d9fd49c 100644 --- a/lib/src/sql/value/set.rs +++ b/lib/src/sql/value/set.rs @@ -2,11 +2,11 @@ use crate::ctx::Context; use crate::dbs::Options; use crate::dbs::Transaction; use crate::err::Error; +use crate::exe::try_join_all_buffered; use crate::sql::part::Next; use crate::sql::part::Part; use crate::sql::value::Value; use async_recursion::async_recursion; -use futures::future::try_join_all; impl Value { #[cfg_attr(not(target_arch = "wasm32"), async_recursion)] @@ -49,7 +49,7 @@ impl Value { Part::All => { let path = path.next(); let futs = v.iter_mut().map(|v| v.set(ctx, opt, txn, path, val.clone())); - try_join_all(futs).await?; + try_join_all_buffered(futs).await?; Ok(()) } Part::First => match v.first_mut() { @@ -75,7 +75,7 @@ impl Value { } _ => { let futs = v.iter_mut().map(|v| v.set(ctx, opt, txn, path, val.clone())); - try_join_all(futs).await?; + try_join_all_buffered(futs).await?; Ok(()) } }, diff --git a/lib/tests/future.rs b/lib/tests/future.rs index bd767de4..8eb9df91 100644 --- a/lib/tests/future.rs +++ b/lib/tests/future.rs @@ -65,3 +65,62 @@ async fn future_function_arguments() -> Result<(), Error> { // Ok(()) } + +#[tokio::test] +async fn concurrency() -> Result<(), Error> { + // cargo test --package surrealdb --test future --features kv-mem --release -- concurrency --nocapture + + const MILLIS: usize = 50; + + // If all futures complete in less than double `MILLIS`, then they must have executed + // concurrently. Otherwise, some executed sequentially. + const TIMEOUT: usize = MILLIS * 19 / 10; + + /// Returns a query that will execute `count` futures that each wait for `millis` + fn query(count: usize, millis: usize) -> String { + // TODO: Find a simpler way to trigger the concurrent future case. + format!( + "SELECT foo FROM [[{}]] TIMEOUT {TIMEOUT}ms;", + (0..count) + .map(|i| format!("<future>{{[sleep({millis}ms), {{foo: {i}}}]}}")) + .collect::<Vec<_>>() + .join(", ") + ) + } + + /// Returns `true` iif `limit` futures are concurrently executed. + async fn test_limit(limit: usize) -> Result<bool, Error> { + let sql = query(limit, MILLIS); + let dbs = Datastore::new("memory").await?; + let ses = Session::for_kv().with_ns("test").with_db("test"); + let res = dbs.execute(&sql, &ses, None, false).await; + + if matches!(res, Err(Error::QueryTimedout)) { + Ok(false) + } else { + let res = res?; + assert_eq!(res.len(), 1); + + let res = res.into_iter().next().unwrap(); + + let elapsed = res.time.as_millis() as usize; + + Ok(elapsed < TIMEOUT) + } + } + + // Diagnostics. + /* + for i in (1..=80).step_by(8) { + println!("{i} futures => {}", test_limit(i).await?); + } + */ + + assert!(test_limit(3).await?); + + // Too slow to *parse* query in debug mode. + #[cfg(not(debug_assertions))] + assert!(!test_limit(64 /* surrealdb::cnf::MAX_CONCURRENT_TASKS */ + 1).await?); + + Ok(()) +}