Add Surreal::wait_for ()

This commit is contained in:
Rushmore Mushambi 2024-02-29 14:09:01 +02:00 committed by GitHub
parent 886f46e9bc
commit 9d2fe88717
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 129 additions and 5 deletions

View file

@ -144,6 +144,7 @@ use crate::opt::path_to_string;
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::watch;
use url::Url;
/// A trait for converting inputs to a server address object
@ -240,6 +241,7 @@ impl Surreal<Any> {
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}
}
@ -296,6 +298,7 @@ pub fn connect(address: impl IntoEndpoint) -> Connect<Any, Surreal<Any>> {
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: Arc::new(watch::channel(None)),
response_type: PhantomData,
}
}

View file

@ -21,6 +21,7 @@ use crate::api::Result;
use crate::api::Surreal;
#[allow(unused_imports)]
use crate::error::Db as DbError;
use crate::opt::WaitFor;
use flume::Receiver;
#[cfg(feature = "protocol-http")]
use reqwest::ClientBuilder;
@ -31,6 +32,7 @@ use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::watch;
#[cfg(feature = "protocol-ws")]
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
#[cfg(feature = "protocol-ws")]
@ -235,6 +237,7 @@ impl Connection for Any {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -14,6 +14,7 @@ use crate::api::OnceLockExt;
use crate::api::Result;
use crate::api::Surreal;
use crate::error::Db as DbError;
use crate::opt::WaitFor;
use flume::Receiver;
use std::collections::HashSet;
use std::future::Future;
@ -22,6 +23,7 @@ use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::watch;
impl crate::api::Connection for Any {}
@ -188,6 +190,7 @@ impl Connection for Any {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -392,6 +392,7 @@ impl Surreal<Db> {
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}
}

View file

@ -17,6 +17,7 @@ use crate::fflags::FFLAGS;
use crate::iam::Level;
use crate::kvs::Datastore;
use crate::opt::auth::Root;
use crate::opt::WaitFor;
use flume::Receiver;
use flume::Sender;
use futures::future::Either;
@ -35,6 +36,7 @@ use std::sync::OnceLock;
use std::task::Poll;
use std::time::Duration;
use surrealdb_core::dbs::Options;
use tokio::sync::watch;
use tokio::time;
use tokio::time::MissedTickBehavior;
@ -73,6 +75,7 @@ impl Connection for Db {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -17,6 +17,7 @@ use crate::fflags::FFLAGS;
use crate::iam::Level;
use crate::kvs::Datastore;
use crate::opt::auth::Root;
use crate::opt::WaitFor;
use flume::Receiver;
use flume::Sender;
use futures::future::Either;
@ -35,6 +36,7 @@ use std::sync::OnceLock;
use std::task::Poll;
use std::time::Duration;
use surrealdb_core::dbs::Options;
use tokio::sync::watch;
use wasm_bindgen_futures::spawn_local;
use wasmtimer::tokio as time;
use wasmtimer::tokio::MissedTickBehavior;
@ -73,6 +75,7 @@ impl Connection for Db {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -104,6 +104,7 @@ impl Surreal<Client> {
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}
}

View file

@ -12,6 +12,7 @@ use crate::api::ExtraFeatures;
use crate::api::OnceLockExt;
use crate::api::Result;
use crate::api::Surreal;
use crate::opt::WaitFor;
use flume::Receiver;
use futures::StreamExt;
use indexmap::IndexMap;
@ -24,6 +25,7 @@ use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::watch;
use url::Url;
impl crate::api::Connection for Client {}
@ -77,6 +79,7 @@ impl Connection for Client {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -9,6 +9,7 @@ use crate::api::opt::Endpoint;
use crate::api::OnceLockExt;
use crate::api::Result;
use crate::api::Surreal;
use crate::opt::WaitFor;
use flume::Receiver;
use flume::Sender;
use futures::StreamExt;
@ -22,6 +23,7 @@ use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::watch;
use url::Url;
use wasm_bindgen_futures::spawn_local;
@ -56,6 +58,7 @@ impl Connection for Client {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -78,6 +78,7 @@ impl Surreal<Client> {
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}
}

View file

@ -20,6 +20,7 @@ use crate::api::Result;
use crate::api::Surreal;
use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream;
use crate::opt::WaitFor;
use crate::sql::Strand;
use crate::sql::Value;
use flume::Receiver;
@ -41,6 +42,7 @@ use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::net::TcpStream;
use tokio::sync::watch;
use tokio::time;
use tokio::time::MissedTickBehavior;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
@ -154,6 +156,7 @@ impl Connection for Client {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -18,6 +18,7 @@ use crate::api::Result;
use crate::api::Surreal;
use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream;
use crate::opt::WaitFor;
use crate::sql::Strand;
use crate::sql::Value;
use flume::Receiver;
@ -42,6 +43,7 @@ use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use tokio::sync::watch;
use trice::Instant;
use wasm_bindgen_futures::spawn_local;
use wasmtimer::tokio as time;
@ -94,6 +96,7 @@ impl Connection for Client {
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
})

View file

@ -97,6 +97,7 @@ macro_rules! into_future {
rx: Some(rx),
client: Surreal {
router: client.router.clone(),
waiter: client.waiter.clone(),
engine: PhantomData,
},
response_type: PhantomData,

View file

@ -55,6 +55,7 @@ pub use select::Select;
pub use set::Set;
pub use signin::Signin;
pub use signup::Signup;
use tokio::sync::watch;
pub use unset::Unset;
pub use update::Update;
pub use use_db::UseDb;
@ -72,6 +73,7 @@ use crate::api::Connection;
use crate::api::OnceLockExt;
use crate::api::Surreal;
use crate::opt::IntoExportDestination;
use crate::opt::WaitFor;
use crate::sql::to_value;
use crate::sql::Value;
use serde::Serialize;
@ -228,6 +230,7 @@ where
pub fn init() -> Self {
Self {
router: Arc::new(OnceLock::new()),
waiter: Arc::new(watch::channel(None)),
engine: PhantomData,
}
}
@ -258,6 +261,7 @@ where
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: Arc::new(watch::channel(None)),
response_type: PhantomData,
}
}
@ -976,6 +980,21 @@ where
}
}
/// Wait for the selected event to happen before proceeding
pub async fn wait_for(&self, event: WaitFor) {
let mut rx = self.waiter.0.subscribe();
rx.wait_for(|current| match current {
// The connection hasn't been initialised yet.
None => false,
// The connection has been initialised. Only the connection even matches.
Some(WaitFor::Connection) => matches!(event, WaitFor::Connection),
// The database has been selected. Connection and database events both match.
Some(WaitFor::Database) => matches!(event, WaitFor::Connection | WaitFor::Database),
})
.await
.ok();
}
/// Dumps the database contents to a file
///
/// # Support

View file

@ -106,6 +106,7 @@ where
rx: Some(rx),
client: Surreal {
router: self.client.router.clone(),
waiter: self.client.waiter.clone(),
engine: PhantomData,
},
response_type: PhantomData,
@ -128,6 +129,7 @@ where
}
response.client = Surreal {
router: self.client.router.clone(),
waiter: self.client.waiter.clone(),
engine: PhantomData,
};
Ok(response)

View file

@ -20,6 +20,7 @@ use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
use tokio::sync::watch;
use url::Url;
#[derive(Debug)]
@ -49,6 +50,7 @@ impl Surreal<Client> {
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}
}
@ -79,6 +81,7 @@ impl Connection for Client {
server::mock(route_rx);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(router)),
waiter: Arc::new(watch::channel(None)),
engine: PhantomData,
})
})

View file

@ -3,6 +3,7 @@ use crate::api::conn::Param;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::opt::WaitFor;
use crate::sql::Value;
use crate::Surreal;
use std::borrow::Cow;
@ -45,7 +46,9 @@ where
self.client.router.extract()?,
Param::new(vec![self.ns, self.db.into()]),
)
.await
.await?;
self.client.waiter.0.send(Some(WaitFor::Database)).ok();
Ok(())
})
}
}

View file

@ -11,6 +11,7 @@ mod conn;
pub use method::query::Response;
use semver::Version;
use tokio::sync::watch;
use crate::api::conn::DbResponse;
use crate::api::conn::Router;
@ -28,10 +29,14 @@ use std::sync::Arc;
use std::sync::OnceLock;
use self::opt::EndpointKind;
use self::opt::WaitFor;
/// A specialized `Result` type
pub type Result<T> = std::result::Result<T, crate::Error>;
// Channel for waiters
type Waiter = (watch::Sender<Option<WaitFor>>, watch::Receiver<Option<WaitFor>>);
const SUPPORTED_VERSIONS: (&str, &str) = (">=1.0.0, <2.0.0", "20230701.55918b7c");
const REVISION_SUPPORTED_SERVER_VERSION: Version = Version::new(1, 2, 0);
@ -47,6 +52,7 @@ pub struct Connect<C: Connection, Response> {
address: Result<Endpoint>,
capacity: usize,
client: PhantomData<C>,
waiter: Arc<Waiter>,
response_type: PhantomData<Response>,
}
@ -109,6 +115,8 @@ where
client = Client::connect(endpoint, self.capacity).await?;
}
}
// Both ends of the channel are still alive at this point
client.waiter.0.send(Some(WaitFor::Connection)).ok();
Ok(client)
})
}
@ -129,10 +137,7 @@ where
}
let mut endpoint = self.address?;
let endpoint_kind = EndpointKind::from(endpoint.url.scheme());
let mut client = Surreal {
router: Client::connect(endpoint.clone(), self.capacity).await?.router,
engine: PhantomData::<Client>,
};
let mut client = Client::connect(endpoint.clone(), self.capacity).await?;
if endpoint_kind.is_remote() {
let mut version = client.version().await?;
// we would like to be able to connect to pre-releases too
@ -148,6 +153,8 @@ where
Arc::into_inner(client.router).expect("new connection to have no references");
let router = cell.into_inner().expect("router to be set");
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
// Both ends of the channel are still alive at this point
self.waiter.0.send(Some(WaitFor::Connection)).ok();
Ok(())
})
}
@ -162,6 +169,7 @@ pub(crate) enum ExtraFeatures {
/// A database client instance for embedded or remote databases
pub struct Surreal<C: Connection> {
router: Arc<OnceLock<Router>>,
waiter: Arc<Waiter>,
engine: PhantomData<C>,
}
@ -199,6 +207,7 @@ where
fn clone(&self) -> Self {
Self {
router: self.router.clone(),
waiter: self.waiter.clone(),
engine: self.engine,
}
}

View file

@ -138,3 +138,13 @@ impl PatchOp {
}))
}
}
/// Makes the client wait for a certain event or call to happen before continuing
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[non_exhaustive]
pub enum WaitFor {
/// Waits for the connection to succeed
Connection,
/// Waits for the desired database to be selected
Database,
}

View file

@ -73,6 +73,9 @@ mod api_integration {
#[cfg(feature = "protocol-ws")]
mod ws {
use super::*;
use futures::poll;
use std::pin::pin;
use std::task::Poll;
use surrealdb::engine::remote::ws::Client;
use surrealdb::engine::remote::ws::Ws;
@ -95,6 +98,50 @@ mod api_integration {
drop(permit);
}
#[test_log::test(tokio::test)]
async fn wait_for() {
use surrealdb::opt::WaitFor::{Connection, Database};
let permit = PERMITS.acquire().await.unwrap();
// Create an unconnected client
// At this point wait_for should continue to wait for both the connection and database selection.
let db: Surreal<ws::Client> = Surreal::init();
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Pending);
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
// Connect to the server
// The connection event should fire and allow wait_for to return immediately when waiting for a connection.
// When waiting for a database to be selected, it should continue waiting.
db.connect::<Ws>("127.0.0.1:8000").await.unwrap();
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
// Sign into the server
// At this point the connection has already been established but the database hasn't been selected yet.
db.signin(Root {
username: ROOT_USER,
password: ROOT_PASS,
})
.await
.unwrap();
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
// Selecting a namespace shouldn't fire the database selection event.
db.use_ns("namespace").await.unwrap();
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Pending);
// Select the database to use
// Both the connection and database events have fired, wait_for should return immediately for both.
db.use_db("database").await.unwrap();
assert_eq!(poll!(pin!(db.wait_for(Connection))), Poll::Ready(()));
assert_eq!(poll!(pin!(db.wait_for(Database))), Poll::Ready(()));
drop(permit);
}
include!("api/mod.rs");
include!("api/live.rs");
}