Add Surreal::wait_for
(#3581)
This commit is contained in:
parent
886f46e9bc
commit
9d2fe88717
20 changed files with 129 additions and 5 deletions
|
@ -144,6 +144,7 @@ use crate::opt::path_to_string;
|
|||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::sync::watch;
|
||||
use url::Url;
|
||||
|
||||
/// A trait for converting inputs to a server address object
|
||||
|
@ -240,6 +241,7 @@ impl Surreal<Any> {
|
|||
address: address.into_endpoint(),
|
||||
capacity: 0,
|
||||
client: PhantomData,
|
||||
waiter: self.waiter.clone(),
|
||||
response_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -296,6 +298,7 @@ pub fn connect(address: impl IntoEndpoint) -> Connect<Any, Surreal<Any>> {
|
|||
address: address.into_endpoint(),
|
||||
capacity: 0,
|
||||
client: PhantomData,
|
||||
waiter: Arc::new(watch::channel(None)),
|
||||
response_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ use crate::api::Result;
|
|||
use crate::api::Surreal;
|
||||
#[allow(unused_imports)]
|
||||
use crate::error::Db as DbError;
|
||||
use crate::opt::WaitFor;
|
||||
use flume::Receiver;
|
||||
#[cfg(feature = "protocol-http")]
|
||||
use reqwest::ClientBuilder;
|
||||
|
@ -31,6 +32,7 @@ use std::pin::Pin;
|
|||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::sync::watch;
|
||||
#[cfg(feature = "protocol-ws")]
|
||||
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
|
||||
#[cfg(feature = "protocol-ws")]
|
||||
|
@ -235,6 +237,7 @@ impl Connection for Any {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -14,6 +14,7 @@ use crate::api::OnceLockExt;
|
|||
use crate::api::Result;
|
||||
use crate::api::Surreal;
|
||||
use crate::error::Db as DbError;
|
||||
use crate::opt::WaitFor;
|
||||
use flume::Receiver;
|
||||
use std::collections::HashSet;
|
||||
use std::future::Future;
|
||||
|
@ -22,6 +23,7 @@ use std::pin::Pin;
|
|||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::sync::watch;
|
||||
|
||||
impl crate::api::Connection for Any {}
|
||||
|
||||
|
@ -188,6 +190,7 @@ impl Connection for Any {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -392,6 +392,7 @@ impl Surreal<Db> {
|
|||
address: address.into_endpoint(),
|
||||
capacity: 0,
|
||||
client: PhantomData,
|
||||
waiter: self.waiter.clone(),
|
||||
response_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ use crate::fflags::FFLAGS;
|
|||
use crate::iam::Level;
|
||||
use crate::kvs::Datastore;
|
||||
use crate::opt::auth::Root;
|
||||
use crate::opt::WaitFor;
|
||||
use flume::Receiver;
|
||||
use flume::Sender;
|
||||
use futures::future::Either;
|
||||
|
@ -35,6 +36,7 @@ use std::sync::OnceLock;
|
|||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
use surrealdb_core::dbs::Options;
|
||||
use tokio::sync::watch;
|
||||
use tokio::time;
|
||||
use tokio::time::MissedTickBehavior;
|
||||
|
||||
|
@ -73,6 +75,7 @@ impl Connection for Db {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -17,6 +17,7 @@ use crate::fflags::FFLAGS;
|
|||
use crate::iam::Level;
|
||||
use crate::kvs::Datastore;
|
||||
use crate::opt::auth::Root;
|
||||
use crate::opt::WaitFor;
|
||||
use flume::Receiver;
|
||||
use flume::Sender;
|
||||
use futures::future::Either;
|
||||
|
@ -35,6 +36,7 @@ use std::sync::OnceLock;
|
|||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
use surrealdb_core::dbs::Options;
|
||||
use tokio::sync::watch;
|
||||
use wasm_bindgen_futures::spawn_local;
|
||||
use wasmtimer::tokio as time;
|
||||
use wasmtimer::tokio::MissedTickBehavior;
|
||||
|
@ -73,6 +75,7 @@ impl Connection for Db {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -104,6 +104,7 @@ impl Surreal<Client> {
|
|||
address: address.into_endpoint(),
|
||||
capacity: 0,
|
||||
client: PhantomData,
|
||||
waiter: self.waiter.clone(),
|
||||
response_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ use crate::api::ExtraFeatures;
|
|||
use crate::api::OnceLockExt;
|
||||
use crate::api::Result;
|
||||
use crate::api::Surreal;
|
||||
use crate::opt::WaitFor;
|
||||
use flume::Receiver;
|
||||
use futures::StreamExt;
|
||||
use indexmap::IndexMap;
|
||||
|
@ -24,6 +25,7 @@ use std::pin::Pin;
|
|||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::sync::watch;
|
||||
use url::Url;
|
||||
|
||||
impl crate::api::Connection for Client {}
|
||||
|
@ -77,6 +79,7 @@ impl Connection for Client {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -9,6 +9,7 @@ use crate::api::opt::Endpoint;
|
|||
use crate::api::OnceLockExt;
|
||||
use crate::api::Result;
|
||||
use crate::api::Surreal;
|
||||
use crate::opt::WaitFor;
|
||||
use flume::Receiver;
|
||||
use flume::Sender;
|
||||
use futures::StreamExt;
|
||||
|
@ -22,6 +23,7 @@ use std::pin::Pin;
|
|||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::sync::watch;
|
||||
use url::Url;
|
||||
use wasm_bindgen_futures::spawn_local;
|
||||
|
||||
|
@ -56,6 +58,7 @@ impl Connection for Client {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -78,6 +78,7 @@ impl Surreal<Client> {
|
|||
address: address.into_endpoint(),
|
||||
capacity: 0,
|
||||
client: PhantomData,
|
||||
waiter: self.waiter.clone(),
|
||||
response_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ use crate::api::Result;
|
|||
use crate::api::Surreal;
|
||||
use crate::engine::remote::ws::Data;
|
||||
use crate::engine::IntervalStream;
|
||||
use crate::opt::WaitFor;
|
||||
use crate::sql::Strand;
|
||||
use crate::sql::Value;
|
||||
use flume::Receiver;
|
||||
|
@ -41,6 +42,7 @@ use std::sync::atomic::AtomicI64;
|
|||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::watch;
|
||||
use tokio::time;
|
||||
use tokio::time::MissedTickBehavior;
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
|
@ -154,6 +156,7 @@ impl Connection for Client {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -18,6 +18,7 @@ use crate::api::Result;
|
|||
use crate::api::Surreal;
|
||||
use crate::engine::remote::ws::Data;
|
||||
use crate::engine::IntervalStream;
|
||||
use crate::opt::WaitFor;
|
||||
use crate::sql::Strand;
|
||||
use crate::sql::Value;
|
||||
use flume::Receiver;
|
||||
|
@ -42,6 +43,7 @@ use std::sync::atomic::AtomicI64;
|
|||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::watch;
|
||||
use trice::Instant;
|
||||
use wasm_bindgen_futures::spawn_local;
|
||||
use wasmtimer::tokio as time;
|
||||
|
@ -94,6 +96,7 @@ impl Connection for Client {
|
|||
sender: route_tx,
|
||||
last_id: AtomicI64::new(0),
|
||||
})),
|
||||
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -97,6 +97,7 @@ macro_rules! into_future {
|
|||
rx: Some(rx),
|
||||
client: Surreal {
|
||||
router: client.router.clone(),
|
||||
waiter: client.waiter.clone(),
|
||||
engine: PhantomData,
|
||||
},
|
||||
response_type: PhantomData,
|
||||
|
|
|
@ -55,6 +55,7 @@ pub use select::Select;
|
|||
pub use set::Set;
|
||||
pub use signin::Signin;
|
||||
pub use signup::Signup;
|
||||
use tokio::sync::watch;
|
||||
pub use unset::Unset;
|
||||
pub use update::Update;
|
||||
pub use use_db::UseDb;
|
||||
|
@ -72,6 +73,7 @@ use crate::api::Connection;
|
|||
use crate::api::OnceLockExt;
|
||||
use crate::api::Surreal;
|
||||
use crate::opt::IntoExportDestination;
|
||||
use crate::opt::WaitFor;
|
||||
use crate::sql::to_value;
|
||||
use crate::sql::Value;
|
||||
use serde::Serialize;
|
||||
|
@ -228,6 +230,7 @@ where
|
|||
pub fn init() -> Self {
|
||||
Self {
|
||||
router: Arc::new(OnceLock::new()),
|
||||
waiter: Arc::new(watch::channel(None)),
|
||||
engine: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -258,6 +261,7 @@ where
|
|||
address: address.into_endpoint(),
|
||||
capacity: 0,
|
||||
client: PhantomData,
|
||||
waiter: Arc::new(watch::channel(None)),
|
||||
response_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -976,6 +980,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Wait for the selected event to happen before proceeding
|
||||
pub async fn wait_for(&self, event: WaitFor) {
|
||||
let mut rx = self.waiter.0.subscribe();
|
||||
rx.wait_for(|current| match current {
|
||||
// The connection hasn't been initialised yet.
|
||||
None => false,
|
||||
// The connection has been initialised. Only the connection even matches.
|
||||
Some(WaitFor::Connection) => matches!(event, WaitFor::Connection),
|
||||
// The database has been selected. Connection and database events both match.
|
||||
Some(WaitFor::Database) => matches!(event, WaitFor::Connection | WaitFor::Database),
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
/// Dumps the database contents to a file
|
||||
///
|
||||
/// # Support
|
||||
|
|
|
@ -106,6 +106,7 @@ where
|
|||
rx: Some(rx),
|
||||
client: Surreal {
|
||||
router: self.client.router.clone(),
|
||||
waiter: self.client.waiter.clone(),
|
||||
engine: PhantomData,
|
||||
},
|
||||
response_type: PhantomData,
|
||||
|
@ -128,6 +129,7 @@ where
|
|||
}
|
||||
response.client = Surreal {
|
||||
router: self.client.router.clone(),
|
||||
waiter: self.client.waiter.clone(),
|
||||
engine: PhantomData,
|
||||
};
|
||||
Ok(response)
|
||||
|
|
|
@ -20,6 +20,7 @@ use std::pin::Pin;
|
|||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tokio::sync::watch;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -49,6 +50,7 @@ impl Surreal<Client> {
|
|||
address: address.into_endpoint(),
|
||||
capacity: 0,
|
||||
client: PhantomData,
|
||||
waiter: self.waiter.clone(),
|
||||
response_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -79,6 +81,7 @@ impl Connection for Client {
|
|||
server::mock(route_rx);
|
||||
Ok(Surreal {
|
||||
router: Arc::new(OnceLock::with_value(router)),
|
||||
waiter: Arc::new(watch::channel(None)),
|
||||
engine: PhantomData,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate::api::conn::Param;
|
|||
use crate::api::Connection;
|
||||
use crate::api::Result;
|
||||
use crate::method::OnceLockExt;
|
||||
use crate::opt::WaitFor;
|
||||
use crate::sql::Value;
|
||||
use crate::Surreal;
|
||||
use std::borrow::Cow;
|
||||
|
@ -45,7 +46,9 @@ where
|
|||
self.client.router.extract()?,
|
||||
Param::new(vec![self.ns, self.db.into()]),
|
||||
)
|
||||
.await
|
||||
.await?;
|
||||
self.client.waiter.0.send(Some(WaitFor::Database)).ok();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ mod conn;
|
|||
|
||||
pub use method::query::Response;
|
||||
use semver::Version;
|
||||
use tokio::sync::watch;
|
||||
|
||||
use crate::api::conn::DbResponse;
|
||||
use crate::api::conn::Router;
|
||||
|
@ -28,10 +29,14 @@ use std::sync::Arc;
|
|||
use std::sync::OnceLock;
|
||||
|
||||
use self::opt::EndpointKind;
|
||||
use self::opt::WaitFor;
|
||||
|
||||
/// A specialized `Result` type
|
||||
pub type Result<T> = std::result::Result<T, crate::Error>;
|
||||
|
||||
// Channel for waiters
|
||||
type Waiter = (watch::Sender<Option<WaitFor>>, watch::Receiver<Option<WaitFor>>);
|
||||
|
||||
const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <2.0.0", "20230701.55918b7c");
|
||||
const REVISION_SUPPORTED_SERVER_VERSION: Version = Version::new(1, 2, 0);
|
||||
|
||||
|
@ -47,6 +52,7 @@ pub struct Connect<C: Connection, Response> {
|
|||
address: Result<Endpoint>,
|
||||
capacity: usize,
|
||||
client: PhantomData<C>,
|
||||
waiter: Arc<Waiter>,
|
||||
response_type: PhantomData<Response>,
|
||||
}
|
||||
|
||||
|
@ -109,6 +115,8 @@ where
|
|||
client = Client::connect(endpoint, self.capacity).await?;
|
||||
}
|
||||
}
|
||||
// Both ends of the channel are still alive at this point
|
||||
client.waiter.0.send(Some(WaitFor::Connection)).ok();
|
||||
Ok(client)
|
||||
})
|
||||
}
|
||||
|
@ -129,10 +137,7 @@ where
|
|||
}
|
||||
let mut endpoint = self.address?;
|
||||
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
|
||||
let mut client = Surreal {
|
||||
router: Client::connect(endpoint.clone(), self.capacity).await?.router,
|
||||
engine: PhantomData::<Client>,
|
||||
};
|
||||
let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
|
||||
if endpoint_kind.is_remote() {
|
||||
let mut version = client.version().await?;
|
||||
// we would like to be able to connect to pre-releases too
|
||||
|
@ -148,6 +153,8 @@ where
|
|||
Arc::into_inner(client.router).expect("new connection to have no references");
|
||||
let router = cell.into_inner().expect("router to be set");
|
||||
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
|
||||
// Both ends of the channel are still alive at this point
|
||||
self.waiter.0.send(Some(WaitFor::Connection)).ok();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
@ -162,6 +169,7 @@ pub(crate) enum ExtraFeatures {
|
|||
/// A database client instance for embedded or remote databases
|
||||
pub struct Surreal<C: Connection> {
|
||||
router: Arc<OnceLock<Router>>,
|
||||
waiter: Arc<Waiter>,
|
||||
engine: PhantomData<C>,
|
||||
}
|
||||
|
||||
|
@ -199,6 +207,7 @@ where
|
|||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
router: self.router.clone(),
|
||||
waiter: self.waiter.clone(),
|
||||
engine: self.engine,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,3 +138,13 @@ impl PatchOp {
|
|||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Makes the client wait for a certain event or call to happen before continuing
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[non_exhaustive]
|
||||
pub enum WaitFor {
|
||||
/// Waits for the connection to succeed
|
||||
Connection,
|
||||
/// Waits for the desired database to be selected
|
||||
Database,
|
||||
}
|
||||
|
|
|
@ -73,6 +73,9 @@ mod api_integration {
|
|||
#[cfg(feature = "protocol-ws")]
|
||||
mod ws {
|
||||
use super::*;
|
||||
use futures::poll;
|
||||
use std::pin::pin;
|
||||
use std::task::Poll;
|
||||
use surrealdb::engine::remote::ws::Client;
|
||||
use surrealdb::engine::remote::ws::Ws;
|
||||
|
||||
|
@ -95,6 +98,50 @@ mod api_integration {
|
|||
drop(permit);
|
||||
}
|
||||
|
||||
#[test_log::test(tokio::test)]
|
||||
async fn wait_for() {
|
||||
use surrealdb::opt::WaitFor::{Connection, Database};
|
||||
|
||||
let permit = PERMITS.acquire().await.unwrap();
|
||||
|
||||
// Create an unconnected client
|
||||
// At this point wait_for should continue to wait for both the connection and database selection.
|
||||
let db: Surreal<ws::Client> = Surreal::init();
|
||||
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Pending);
|
||||
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
|
||||
|
||||
// Connect to the server
|
||||
// The connection event should fire and allow wait_for to return immediately when waiting for a connection.
|
||||
// When waiting for a database to be selected, it should continue waiting.
|
||||
db.connect::<Ws>("127.0.0.1:8000").await.unwrap();
|
||||
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
|
||||
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
|
||||
|
||||
// Sign into the server
|
||||
// At this point the connection has already been established but the database hasn't been selected yet.
|
||||
db.signin(Root {
|
||||
username: ROOT_USER,
|
||||
password: ROOT_PASS,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
|
||||
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
|
||||
|
||||
// Selecting a namespace shouldn't fire the database selection event.
|
||||
db.use_ns("namespace").await.unwrap();
|
||||
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
|
||||
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
|
||||
|
||||
// Select the database to use
|
||||
// Both the connection and database events have fired, wait_for should return immediately for both.
|
||||
db.use_db("database").await.unwrap();
|
||||
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
|
||||
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Ready(()));
|
||||
|
||||
drop(permit);
|
||||
}
|
||||
|
||||
include!("api/mod.rs");
|
||||
include!("api/live.rs");
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue