From c5138245a02e7575c24f2d2416a3cec9c76241a0 Mon Sep 17 00:00:00 2001 From: Rushmore Mushambi Date: Tue, 16 Jan 2024 13:48:29 +0200 Subject: [PATCH] Add support for `LIVE SELECT` in the SDK and CLI (#3309) --- lib/src/api/conn.rs | 28 +- lib/src/api/engine/any/mod.rs | 2 + lib/src/api/engine/any/native.rs | 4 +- lib/src/api/engine/any/wasm.rs | 4 +- lib/src/api/engine/local/mod.rs | 8 +- lib/src/api/engine/local/native.rs | 4 +- lib/src/api/engine/local/wasm.rs | 4 +- lib/src/api/engine/remote/http/mod.rs | 8 +- lib/src/api/engine/remote/http/native.rs | 4 +- lib/src/api/engine/remote/http/wasm.rs | 4 +- lib/src/api/engine/remote/ws/mod.rs | 6 +- lib/src/api/engine/remote/ws/native.rs | 4 +- lib/src/api/engine/remote/ws/wasm.rs | 4 +- lib/src/api/err/mod.rs | 12 + lib/src/api/method/live.rs | 109 +++++--- lib/src/api/method/mod.rs | 6 +- lib/src/api/method/query.rs | 315 +++++++++++++++++++---- lib/src/api/method/tests/protocol.rs | 5 +- lib/src/api/method/tests/server.rs | 2 +- lib/src/api/mod.rs | 32 ++- lib/src/api/opt/query.rs | 205 ++++++++++++--- lib/src/lib.rs | 2 + lib/tests/api/live.rs | 125 +++++++++ lib/tests/api/mod.rs | 1 - src/cli/sql.rs | 93 ++++++- tests/cli_integration.rs | 10 +- 26 files changed, 812 insertions(+), 189 deletions(-) diff --git a/lib/src/api/conn.rs b/lib/src/api/conn.rs index 76b98086..7d2e3a4e 100644 --- a/lib/src/api/conn.rs +++ b/lib/src/api/conn.rs @@ -16,7 +16,6 @@ use serde::Serialize; use std::collections::BTreeMap; use std::collections::HashSet; use std::future::Future; -use std::marker::PhantomData; use std::path::PathBuf; use std::pin::Pin; use std::sync::atomic::AtomicI64; @@ -31,26 +30,19 @@ pub(crate) struct Route { /// Message router #[derive(Debug)] -pub struct Router { - pub(crate) conn: PhantomData, +pub struct Router { pub(crate) sender: Sender>, pub(crate) last_id: AtomicI64, pub(crate) features: HashSet, } -impl Router -where - C: api::Connection, -{ +impl Router { pub(crate) fn next_id(&self) -> i64 { self.last_id.fetch_add(1, Ordering::SeqCst) } } -impl Drop for Router -where - C: api::Connection, -{ +impl Drop for Router { fn drop(&mut self) { let _res = self.sender.send(None); } @@ -189,7 +181,7 @@ pub trait Connection: Sized + Send + Sync + 'static { #[allow(clippy::type_complexity)] fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> where @@ -226,7 +218,7 @@ pub trait Connection: Sized + Send + Sync + 'static { /// Execute all methods except `query` fn execute<'r, R>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where @@ -243,7 +235,7 @@ pub trait Connection: Sized + Send + Sync + 'static { /// Execute methods that return an optional single response fn execute_opt<'r, R>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>> + Send + Sync + 'r>> where @@ -262,7 +254,7 @@ pub trait Connection: Sized + Send + Sync + 'static { /// Execute methods that return multiple responses fn execute_vec<'r, R>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>> + Send + Sync + 'r>> where @@ -283,7 +275,7 @@ pub trait Connection: Sized + Send + Sync + 'static { /// Execute methods that return nothing fn execute_unit<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where @@ -306,7 +298,7 @@ pub trait Connection: Sized + Send + Sync + 'static { /// Execute methods that return a raw value fn execute_value<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where @@ -321,7 +313,7 @@ pub trait Connection: Sized + Send + Sync + 'static { /// Execute the `query` method fn execute_query<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where diff --git a/lib/src/api/engine/any/mod.rs b/lib/src/api/engine/any/mod.rs index 81364f7e..31a6c372 100644 --- a/lib/src/api/engine/any/mod.rs +++ b/lib/src/api/engine/any/mod.rs @@ -193,6 +193,7 @@ impl Surreal { pub fn connect(&self, address: impl IntoEndpoint) -> Connect { Connect { router: self.router.clone(), + engine: PhantomData, address: address.into_endpoint(), capacity: 0, client: PhantomData, @@ -242,6 +243,7 @@ impl Surreal { pub fn connect(address: impl IntoEndpoint) -> Connect> { Connect { router: Arc::new(OnceLock::new()), + engine: PhantomData, address: address.into_endpoint(), capacity: 0, client: PhantomData, diff --git a/lib/src/api/engine/any/native.rs b/lib/src/api/engine/any/native.rs index 6ba27dcf..c5643da6 100644 --- a/lib/src/api/engine/any/native.rs +++ b/lib/src/api/engine/any/native.rs @@ -215,17 +215,17 @@ impl Connection for Any { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/engine/any/wasm.rs b/lib/src/api/engine/any/wasm.rs index a3029c88..8caf4d0d 100644 --- a/lib/src/api/engine/any/wasm.rs +++ b/lib/src/api/engine/any/wasm.rs @@ -170,17 +170,17 @@ impl Connection for Any { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/engine/local/mod.rs b/lib/src/api/engine/local/mod.rs index 2fa9c548..b21950c5 100644 --- a/lib/src/api/engine/local/mod.rs +++ b/lib/src/api/engine/local/mod.rs @@ -383,6 +383,7 @@ impl Surreal { pub fn connect

(&self, address: impl IntoEndpoint) -> Connect { Connect { router: self.router.clone(), + engine: PhantomData, address: address.into_endpoint(), capacity: 0, client: PhantomData, @@ -402,11 +403,14 @@ fn process(responses: Vec) -> QueryResponse { Err(error) => map.insert(index, (stats, Err(error.into()))), }; } - QueryResponse(map) + QueryResponse { + results: map, + ..QueryResponse::new() + } } async fn take(one: bool, responses: Vec) -> Result { - if let Some((_stats, result)) = process(responses).0.remove(&0) { + if let Some((_stats, result)) = process(responses).results.remove(&0) { let value = result?; match one { true => match value { diff --git a/lib/src/api/engine/local/native.rs b/lib/src/api/engine/local/native.rs index c93c33f5..907e4983 100644 --- a/lib/src/api/engine/local/native.rs +++ b/lib/src/api/engine/local/native.rs @@ -68,17 +68,17 @@ impl Connection for Db { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/engine/local/wasm.rs b/lib/src/api/engine/local/wasm.rs index 3e39f086..f9a3a764 100644 --- a/lib/src/api/engine/local/wasm.rs +++ b/lib/src/api/engine/local/wasm.rs @@ -68,17 +68,17 @@ impl Connection for Db { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/engine/remote/http/mod.rs b/lib/src/api/engine/remote/http/mod.rs index 146bfe64..53e59791 100644 --- a/lib/src/api/engine/remote/http/mod.rs +++ b/lib/src/api/engine/remote/http/mod.rs @@ -100,6 +100,7 @@ impl Surreal { ) -> Connect { Connect { router: self.router.clone(), + engine: PhantomData, address: address.into_endpoint(), capacity: 0, client: PhantomData, @@ -210,11 +211,14 @@ async fn query(request: RequestBuilder) -> Result { } } - Ok(QueryResponse(map)) + Ok(QueryResponse { + results: map, + ..QueryResponse::new() + }) } async fn take(one: bool, request: RequestBuilder) -> Result { - if let Some((_stats, result)) = query(request).await?.0.remove(&0) { + if let Some((_stats, result)) = query(request).await?.results.remove(&0) { let value = result?; match one { true => match value { diff --git a/lib/src/api/engine/remote/http/native.rs b/lib/src/api/engine/remote/http/native.rs index 6288b16f..6c6a9432 100644 --- a/lib/src/api/engine/remote/http/native.rs +++ b/lib/src/api/engine/remote/http/native.rs @@ -74,17 +74,17 @@ impl Connection for Client { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/engine/remote/http/wasm.rs b/lib/src/api/engine/remote/http/wasm.rs index ff0d4335..df773cec 100644 --- a/lib/src/api/engine/remote/http/wasm.rs +++ b/lib/src/api/engine/remote/http/wasm.rs @@ -53,17 +53,17 @@ impl Connection for Client { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features: HashSet::new(), - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/engine/remote/ws/mod.rs b/lib/src/api/engine/remote/ws/mod.rs index e9828cd9..da2f88ec 100644 --- a/lib/src/api/engine/remote/ws/mod.rs +++ b/lib/src/api/engine/remote/ws/mod.rs @@ -68,6 +68,7 @@ impl Surreal { ) -> Connect { Connect { router: self.router.clone(), + engine: PhantomData, address: address.into_endpoint(), capacity: 0, client: PhantomData, @@ -135,7 +136,10 @@ impl DbResponse { } } - Ok(DbResponse::Query(api::Response(map))) + Ok(DbResponse::Query(api::Response { + results: map, + ..api::Response::new() + })) } // Live notifications don't call this method Data::Live(..) => unreachable!(), diff --git a/lib/src/api/engine/remote/ws/native.rs b/lib/src/api/engine/remote/ws/native.rs index 6f34652e..7c8c6cb4 100644 --- a/lib/src/api/engine/remote/ws/native.rs +++ b/lib/src/api/engine/remote/ws/native.rs @@ -136,17 +136,17 @@ impl Connection for Client { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/engine/remote/ws/wasm.rs b/lib/src/api/engine/remote/ws/wasm.rs index 30c30605..475de9f7 100644 --- a/lib/src/api/engine/remote/ws/wasm.rs +++ b/lib/src/api/engine/remote/ws/wasm.rs @@ -90,17 +90,17 @@ impl Connection for Client { Ok(Surreal { router: Arc::new(OnceLock::with_value(Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), })), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/err/mod.rs b/lib/src/api/err/mod.rs index 097c5410..13f34da8 100644 --- a/lib/src/api/err/mod.rs +++ b/lib/src/api/err/mod.rs @@ -181,6 +181,18 @@ pub enum Error { /// Tried to use a range query on an edge or edges #[error("Live queries on edges not supported: {0}")] LiveOnEdges(Edges), + + /// Tried to access a query statement as a live query when it isn't a live query + #[error("Query statement {0} is not a live query")] + NotLiveQuery(usize), + + /// Tried to access a query statement falling outside the bounds of the statements supplied + #[error("Query statement {0} is out of bounds")] + QueryIndexOutOfBounds(usize), + + /// Called `Response::take` or `Response::stream` on a query response more than once + #[error("Tried to take a query response that has already been taken")] + ResponseAlreadyTaken, } #[cfg(feature = "protocol-http")] diff --git a/lib/src/api/method/live.rs b/lib/src/api/method/live.rs index f84a5b82..6b4fdc5f 100644 --- a/lib/src/api/method/live.rs +++ b/lib/src/api/method/live.rs @@ -1,10 +1,12 @@ use crate::api::conn::Method; use crate::api::conn::Param; +use crate::api::conn::Router; use crate::api::err::Error; use crate::api::Connection; use crate::api::ExtraFeatures; use crate::api::Result; use crate::dbs; +use crate::engine::any::Any; use crate::method::Live; use crate::method::OnceLockExt; use crate::method::Query; @@ -30,7 +32,6 @@ use crate::Surreal; use channel::Receiver; use futures::StreamExt; use serde::de::DeserializeOwned; -use std::borrow::Cow; use std::future::Future; use std::future::IntoFuture; use std::marker::PhantomData; @@ -93,24 +94,40 @@ macro_rules! into_future { client: client.clone(), query: vec![Ok(vec![Statement::Live(stmt)])], bindings: Ok(Default::default()), + register_live_queries: false, }; let id: Value = query.await?.take(0)?; - let mut conn = Client::new(Method::Live); - let (tx, rx) = channel::unbounded(); - let mut param = Param::notification_sender(tx); - param.other = vec![id.clone()]; - conn.execute_unit(router, param).await?; + let rx = register::(router, id.clone()).await?; Ok(Stream { id, - rx, - client, + rx: Some(rx), + client: Surreal { + router: client.router.clone(), + engine: PhantomData, + }, response_type: PhantomData, + engine: PhantomData, }) }) } }; } +pub(crate) async fn register( + router: &Router, + id: Value, +) -> Result> +where + Client: Connection, +{ + let mut conn = Client::new(Method::Live); + let (tx, rx) = channel::unbounded(); + let mut param = Param::notification_sender(tx); + param.other = vec![id]; + conn.execute_unit(router, param).await?; + Ok(rx) +} + fn cond_from_range(range: crate::sql::Range) -> Option { match (range.beg, range.end) { (Bound::Unbounded, Bound::Unbounded) => None, @@ -241,21 +258,23 @@ where #[derive(Debug)] #[must_use = "streams do nothing unless you poll them"] pub struct Stream<'r, C: Connection, R> { - client: Cow<'r, Surreal>, - id: Value, - rx: Receiver, - response_type: PhantomData, + pub(crate) client: Surreal, + // We no longer need the lifetime and the type parameter + // Leaving them in for backwards compatibility + pub(crate) engine: PhantomData<&'r C>, + pub(crate) id: Value, + pub(crate) rx: Option>, + pub(crate) response_type: PhantomData, } macro_rules! poll_next { - ($action:ident, $result:ident => $body:expr) => { + ($notification:ident => $body:expr) => { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.as_mut().rx.poll_next_unpin(cx) { - Poll::Ready(Some(dbs::Notification { - $action, - $result, - .. - })) => $body, + let Some(ref mut rx) = self.as_mut().rx else { + return Poll::Ready(None); + }; + match rx.poll_next_unpin(cx) { + Poll::Ready(Some($notification)) => $body, Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } @@ -270,15 +289,23 @@ where type Item = Notification; poll_next! { - action, result => Poll::Ready(Some(Notification { action: action.into(), data: result })) + notification => Poll::Ready(Some(Notification { + query_id: notification.id.0, + action: notification.action.into(), + data: notification.result, + })) } } macro_rules! poll_next_and_convert { () => { poll_next! { - action, result => match from_value(result) { - Ok(data) => Poll::Ready(Some(Ok(Notification { action: action.into(), data }))), + notification => match from_value(notification.result) { + Ok(data) => Poll::Ready(Some(Ok(Notification { + data, + query_id: notification.id.0, + action: notification.action.into(), + }))), Err(error) => Poll::Ready(Some(Err(error.into()))), } } @@ -305,6 +332,29 @@ where poll_next_and_convert! {} } +impl futures::Stream for Stream<'_, C, Notification> +where + C: Connection, + R: DeserializeOwned + Unpin, +{ + type Item = Result>; + + poll_next_and_convert! {} +} + +pub(crate) fn kill(client: &Surreal, id: Value) +where + Client: Connection, +{ + let client = client.clone(); + spawn(async move { + if let Ok(router) = client.router.extract() { + let mut conn = Client::new(Method::Kill); + conn.execute_unit(router, Param::new(vec![id.clone()])).await.ok(); + } + }); +} + impl Drop for Stream<'_, Client, R> where Client: Connection, @@ -313,20 +363,9 @@ where /// /// This kills the live query process responsible for this stream. fn drop(&mut self) { - if !self.id.is_none() { + if !self.id.is_none() && self.rx.is_some() { let id = mem::take(&mut self.id); - let client = self.client.clone().into_owned(); - spawn(async move { - if let Ok(router) = client.router.extract() { - let mut conn = Client::new(Method::Kill); - match conn.execute_unit(router, Param::new(vec![id.clone()])).await { - Ok(()) => trace!("Live query {id} dropped successfully"), - Err(error) => warn!("Failed to drop live query {id}; {error}"), - } - } - }); - } else { - trace!("Ignoring drop call on an already dropped live::Stream"); + kill(&self.client, id); } } } diff --git a/lib/src/api/method/mod.rs b/lib/src/api/method/mod.rs index e113b15b..de06af58 100644 --- a/lib/src/api/method/mod.rs +++ b/lib/src/api/method/mod.rs @@ -1,5 +1,6 @@ //! Methods to use when interacting with a SurrealDB instance +pub(crate) mod live; pub(crate) mod query; mod authenticate; @@ -13,7 +14,6 @@ mod export; mod health; mod import; mod invalidate; -mod live; mod merge; mod patch; mod select; @@ -50,6 +50,7 @@ pub use live::Stream; pub use merge::Merge; pub use patch::Patch; pub use query::Query; +pub use query::QueryStream; pub use select::Select; pub use set::Set; pub use signin::Signin; @@ -227,6 +228,7 @@ where pub fn init() -> Self { Self { router: Arc::new(OnceLock::new()), + engine: PhantomData, } } @@ -252,6 +254,7 @@ where pub fn new

(address: impl IntoEndpoint) -> Connect { Connect { router: Arc::new(OnceLock::new()), + engine: PhantomData, address: address.into_endpoint(), capacity: 0, client: PhantomData, @@ -638,6 +641,7 @@ where client: Cow::Borrowed(self), query: vec![query.into_query()], bindings: Ok(Default::default()), + register_live_queries: true, } } diff --git a/lib/src/api/method/query.rs b/lib/src/api/method/query.rs index 47b17fd0..eb0061b6 100644 --- a/lib/src/api/method/query.rs +++ b/lib/src/api/method/query.rs @@ -1,9 +1,14 @@ +use super::live; +use super::Stream; + use crate::api::conn::Method; use crate::api::conn::Param; use crate::api::err::Error; use crate::api::opt; use crate::api::Connection; +use crate::api::ExtraFeatures; use crate::api::Result; +use crate::engine::any::Any; use crate::method::OnceLockExt; use crate::method::Stats; use crate::method::WithStats; @@ -15,7 +20,11 @@ use crate::sql::Statement; use crate::sql::Statements; use crate::sql::Strand; use crate::sql::Value; +use crate::Notification; use crate::Surreal; +use futures::future::Either; +use futures::stream::SelectAll; +use futures::StreamExt; use indexmap::IndexMap; use serde::de::DeserializeOwned; use serde::Serialize; @@ -24,8 +33,11 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::future::Future; use std::future::IntoFuture; +use std::marker::PhantomData; use std::mem; use std::pin::Pin; +use std::task::Context; +use std::task::Poll; /// A query future #[derive(Debug)] @@ -34,6 +46,7 @@ pub struct Query<'r, C: Connection> { pub(super) client: Cow<'r, Surreal>, pub(super) query: Vec>>, pub(super) bindings: Result>, + pub(crate) register_live_queries: bool, } impl Query<'_, C> @@ -58,14 +71,66 @@ where fn into_future(self) -> Self::IntoFuture { Box::pin(async move { + // Extract the router from the client + let router = self.client.router.extract()?; + // Combine all query statements supplied let mut statements = Vec::with_capacity(self.query.len()); for query in self.query { statements.extend(query?); } - let query = sql::Query(Statements(statements)); + // Build the query and execute it + let query = sql::Query(Statements(statements.clone())); let param = Param::query(query, self.bindings?); let mut conn = Client::new(Method::Query); - conn.execute_query(self.client.router.extract()?, param).await + let mut response = conn.execute_query(router, param).await?; + // Register live queries if necessary + if self.register_live_queries { + let mut live_queries = IndexMap::new(); + let mut checked = false; + // Adjusting offsets as a workaround to https://github.com/surrealdb/surrealdb/issues/3318 + let mut offset = 0; + for (index, stmt) in statements.into_iter().enumerate() { + if let Statement::Live(stmt) = stmt { + if !checked && !router.features.contains(&ExtraFeatures::LiveQueries) { + return Err(Error::LiveQueriesNotSupported.into()); + } + checked = true; + let index = index - offset; + if let Some((_, result)) = response.results.get(&index) { + let result = + match result { + Ok(id) => live::register::(router, id.clone()) + .await + .map(|rx| Stream { + id: stmt.id.into(), + rx: Some(rx), + client: Surreal { + router: self.client.router.clone(), + engine: PhantomData, + }, + response_type: PhantomData, + engine: PhantomData, + }), + // This is a live query. We are using this as a workaround to avoid + // creating another public error variant for this internal error. + Err(..) => Err(Error::NotLiveQuery(index).into()), + }; + live_queries.insert(index, result); + } + } else if matches!( + stmt, + Statement::Begin(..) | Statement::Commit(..) | Statement::Cancel(..) + ) { + offset += 1; + } + } + response.live_queries = live_queries; + } + response.client = Surreal { + router: self.client.router.clone(), + engine: PhantomData, + }; + Ok(response) }) } } @@ -172,9 +237,47 @@ pub(crate) type QueryResult = Result; /// The response type of a `Surreal::query` request #[derive(Debug)] -pub struct Response(pub(crate) IndexMap); +pub struct Response { + pub(crate) client: Surreal, + pub(crate) results: IndexMap, + pub(crate) live_queries: IndexMap>>, +} + +/// A `LIVE SELECT` stream from the `query` method +#[derive(Debug)] +#[must_use = "streams do nothing unless you poll them"] +pub struct QueryStream( + pub(crate) Either, SelectAll>>, +); + +impl futures::Stream for QueryStream { + type Item = Notification; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().0.poll_next_unpin(cx) + } +} + +impl futures::Stream for QueryStream> +where + R: DeserializeOwned + Unpin, +{ + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().0.poll_next_unpin(cx) + } +} impl Response { + pub(crate) fn new() -> Self { + Self { + client: Surreal::init(), + results: Default::default(), + live_queries: Default::default(), + } + } + /// Takes and returns records returned from the database /// /// A query that only returns one result can be deserialized into an @@ -185,7 +288,6 @@ impl Response { /// /// ```no_run /// use serde::Deserialize; - /// use surrealdb::sql; /// /// #[derive(Debug, Deserialize)] /// # #[allow(dead_code)] @@ -244,6 +346,53 @@ impl Response { index.query_result(self) } + /// Takes and streams records returned from a `LIVE SELECT` query + /// + /// This is the counterpart to [Response::take] used to stream the results + /// of a live query. + /// + /// # Examples + /// + /// ```no_run + /// use serde::Deserialize; + /// use surrealdb::Notification; + /// use surrealdb::sql::Value; + /// + /// #[derive(Debug, Deserialize)] + /// # #[allow(dead_code)] + /// struct User { + /// id: String, + /// balance: String + /// } + /// + /// # #[tokio::main] + /// # async fn main() -> surrealdb::Result<()> { + /// # let db = surrealdb::engine::any::connect("mem://").await?; + /// # + /// let mut response = db + /// // Stream all changes to the user table + /// .query("LIVE SELECT * FROM user") + /// .await?; + /// + /// // Stream the result of the live query at the given index + /// // while deserialising into the User type + /// let mut stream = response.stream::>(0)?; + /// + /// // Stream raw values instead + /// let mut stream = response.stream::(0)?; + /// + /// // Combine and stream all `LIVE SELECT` statements in this query + /// let mut stream = response.stream::(())?; + /// # + /// # Ok(()) + /// # } + /// ``` + /// + /// Consume the stream the same way you would any other type that implements `futures::Stream`. + pub fn stream(&mut self, index: impl opt::QueryStream) -> Result> { + index.query_stream(self) + } + /// Take all errors from the query response /// /// The errors are keyed by the corresponding index of the statement that failed. @@ -264,14 +413,14 @@ impl Response { /// ``` pub fn take_errors(&mut self) -> HashMap { let mut keys = Vec::new(); - for (key, result) in &self.0 { + for (key, result) in &self.results { if result.1.is_err() { keys.push(*key); } } let mut errors = HashMap::with_capacity(keys.len()); for key in keys { - if let Some((_, Err(error))) = self.0.remove(&key) { + if let Some((_, Err(error))) = self.results.remove(&key) { errors.insert(key, error); } } @@ -295,14 +444,14 @@ impl Response { /// ``` pub fn check(mut self) -> Result { let mut first_error = None; - for (key, result) in &self.0 { + for (key, result) in &self.results { if result.1.is_err() { first_error = Some(*key); break; } } if let Some(key) = first_error { - if let Some((_, Err(error))) = self.0.remove(&key) { + if let Some((_, Err(error))) = self.results.remove(&key) { return Err(error); } } @@ -326,7 +475,7 @@ impl Response { /// # Ok(()) /// # } pub fn num_statements(&self) -> usize { - self.0.len() + self.results.len() } } @@ -423,14 +572,14 @@ impl WithStats { /// ``` pub fn take_errors(&mut self) -> HashMap { let mut keys = Vec::new(); - for (key, result) in &self.0 .0 { + for (key, result) in &self.0.results { if result.1.is_err() { keys.push(*key); } } let mut errors = HashMap::with_capacity(keys.len()); for key in keys { - if let Some((stats, Err(error))) = self.0 .0.remove(&key) { + if let Some((stats, Err(error))) = self.0.results.remove(&key) { errors.insert(key, (stats, error)); } } @@ -476,6 +625,11 @@ impl WithStats { pub fn num_statements(&self) -> usize { self.0.num_statements() } + + /// Returns the unwrapped response + pub fn into_inner(self) -> Response { + self.0 + } } #[cfg(test)] @@ -509,36 +663,48 @@ mod tests { #[test] fn take_from_an_empty_response() { - let mut response = Response(Default::default()); + let mut response = Response::new(); let value: Value = response.take(0).unwrap(); assert!(value.is_none()); - let mut response = Response(Default::default()); + let mut response = Response::new(); let option: Option = response.take(0).unwrap(); assert!(option.is_none()); - let mut response = Response(Default::default()); + let mut response = Response::new(); let vec: Vec = response.take(0).unwrap(); assert!(vec.is_empty()); } #[test] fn take_from_an_errored_query() { - let mut response = Response(to_map(vec![Err(Error::ConnectionUninitialised.into())])); + let mut response = Response { + results: to_map(vec![Err(Error::ConnectionUninitialised.into())]), + ..Response::new() + }; response.take::>(0).unwrap_err(); } #[test] fn take_from_empty_records() { - let mut response = Response(to_map(vec![])); + let mut response = Response { + results: to_map(vec![]), + ..Response::new() + }; let value: Value = response.take(0).unwrap(); assert_eq!(value, Default::default()); - let mut response = Response(to_map(vec![])); + let mut response = Response { + results: to_map(vec![]), + ..Response::new() + }; let option: Option = response.take(0).unwrap(); assert!(option.is_none()); - let mut response = Response(to_map(vec![])); + let mut response = Response { + results: to_map(vec![]), + ..Response::new() + }; let vec: Vec = response.take(0).unwrap(); assert!(vec.is_empty()); } @@ -547,45 +713,66 @@ mod tests { fn take_from_a_scalar_response() { let scalar = 265; - let mut response = Response(to_map(vec![Ok(scalar.into())])); + let mut response = Response { + results: to_map(vec![Ok(scalar.into())]), + ..Response::new() + }; let value: Value = response.take(0).unwrap(); assert_eq!(value, Value::from(scalar)); - let mut response = Response(to_map(vec![Ok(scalar.into())])); + let mut response = Response { + results: to_map(vec![Ok(scalar.into())]), + ..Response::new() + }; let option: Option<_> = response.take(0).unwrap(); assert_eq!(option, Some(scalar)); - let mut response = Response(to_map(vec![Ok(scalar.into())])); + let mut response = Response { + results: to_map(vec![Ok(scalar.into())]), + ..Response::new() + }; let vec: Vec = response.take(0).unwrap(); assert_eq!(vec, vec![scalar]); let scalar = true; - let mut response = Response(to_map(vec![Ok(scalar.into())])); + let mut response = Response { + results: to_map(vec![Ok(scalar.into())]), + ..Response::new() + }; let value: Value = response.take(0).unwrap(); assert_eq!(value, Value::from(scalar)); - let mut response = Response(to_map(vec![Ok(scalar.into())])); + let mut response = Response { + results: to_map(vec![Ok(scalar.into())]), + ..Response::new() + }; let option: Option<_> = response.take(0).unwrap(); assert_eq!(option, Some(scalar)); - let mut response = Response(to_map(vec![Ok(scalar.into())])); + let mut response = Response { + results: to_map(vec![Ok(scalar.into())]), + ..Response::new() + }; let vec: Vec = response.take(0).unwrap(); assert_eq!(vec, vec![scalar]); } #[test] fn take_preserves_order() { - let mut response = Response(to_map(vec![ - Ok(0.into()), - Ok(1.into()), - Ok(2.into()), - Ok(3.into()), - Ok(4.into()), - Ok(5.into()), - Ok(6.into()), - Ok(7.into()), - ])); + let mut response = Response { + results: to_map(vec![ + Ok(0.into()), + Ok(1.into()), + Ok(2.into()), + Ok(3.into()), + Ok(4.into()), + Ok(5.into()), + Ok(6.into()), + Ok(7.into()), + ]), + ..Response::new() + }; let Some(four): Option = response.take(4).unwrap() else { panic!("query not found"); }; @@ -609,17 +796,26 @@ mod tests { }; let value = to_value(summary.clone()).unwrap(); - let mut response = Response(to_map(vec![Ok(value.clone())])); + let mut response = Response { + results: to_map(vec![Ok(value.clone())]), + ..Response::new() + }; let title: Value = response.take("title").unwrap(); assert_eq!(title, Value::from(summary.title.as_str())); - let mut response = Response(to_map(vec![Ok(value.clone())])); + let mut response = Response { + results: to_map(vec![Ok(value.clone())]), + ..Response::new() + }; let Some(title): Option = response.take("title").unwrap() else { panic!("title not found"); }; assert_eq!(title, summary.title); - let mut response = Response(to_map(vec![Ok(value)])); + let mut response = Response { + results: to_map(vec![Ok(value)]), + ..Response::new() + }; let vec: Vec = response.take("title").unwrap(); assert_eq!(vec, vec![summary.title]); @@ -629,7 +825,10 @@ mod tests { }; let value = to_value(article.clone()).unwrap(); - let mut response = Response(to_map(vec![Ok(value.clone())])); + let mut response = Response { + results: to_map(vec![Ok(value.clone())]), + ..Response::new() + }; let Some(title): Option = response.take("title").unwrap() else { panic!("title not found"); }; @@ -639,27 +838,45 @@ mod tests { }; assert_eq!(body, article.body); - let mut response = Response(to_map(vec![Ok(value.clone())])); + let mut response = Response { + results: to_map(vec![Ok(value.clone())]), + ..Response::new() + }; let vec: Vec = response.take("title").unwrap(); assert_eq!(vec, vec![article.title.clone()]); - let mut response = Response(to_map(vec![Ok(value)])); + let mut response = Response { + results: to_map(vec![Ok(value)]), + ..Response::new() + }; let value: Value = response.take("title").unwrap(); assert_eq!(value, Value::from(article.title)); } #[test] fn take_partial_records() { - let mut response = Response(to_map(vec![Ok(vec![true, false].into())])); + let mut response = Response { + results: to_map(vec![Ok(vec![true, false].into())]), + ..Response::new() + }; let value: Value = response.take(0).unwrap(); assert_eq!(value, vec![Value::from(true), Value::from(false)].into()); - let mut response = Response(to_map(vec![Ok(vec![true, false].into())])); + let mut response = Response { + results: to_map(vec![Ok(vec![true, false].into())]), + ..Response::new() + }; let vec: Vec = response.take(0).unwrap(); assert_eq!(vec, vec![true, false]); - let mut response = Response(to_map(vec![Ok(vec![true, false].into())])); - let Err(Api(Error::LossyTake(Response(mut map)))): Result> = response.take(0) + let mut response = Response { + results: to_map(vec![Ok(vec![true, false].into())]), + ..Response::new() + }; + let Err(Api(Error::LossyTake(Response { + results: mut map, + .. + }))): Result> = response.take(0) else { panic!("silently dropping records not allowed"); }; @@ -682,7 +899,10 @@ mod tests { Ok(7.into()), Err(Error::DuplicateRequestId(0).into()), ]; - let response = Response(to_map(response)); + let response = Response { + results: to_map(response), + ..Response::new() + }; let crate::Error::Api(Error::ConnectionUninitialised) = response.check().unwrap_err() else { panic!("check did not return the first error"); @@ -704,7 +924,10 @@ mod tests { Ok(7.into()), Err(Error::DuplicateRequestId(0).into()), ]; - let mut response = Response(to_map(response)); + let mut response = Response { + results: to_map(response), + ..Response::new() + }; let errors = response.take_errors(); assert_eq!(response.num_statements(), 8); assert_eq!(errors.len(), 3); diff --git a/lib/src/api/method/tests/protocol.rs b/lib/src/api/method/tests/protocol.rs index 0b88bc37..109f58a6 100644 --- a/lib/src/api/method/tests/protocol.rs +++ b/lib/src/api/method/tests/protocol.rs @@ -49,6 +49,7 @@ impl Surreal { ) -> Connect { Connect { router: self.router.clone(), + engine: PhantomData, address: address.into_endpoint(), capacity: 0, client: PhantomData, @@ -76,20 +77,20 @@ impl Connection for Client { features.insert(ExtraFeatures::Backup); let router = Router { features, - conn: PhantomData, sender: route_tx, last_id: AtomicI64::new(0), }; server::mock(route_rx); Ok(Surreal { router: Arc::new(OnceLock::with_value(router)), + engine: PhantomData, }) }) } fn send<'r>( &'r mut self, - router: &'r Router, + router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> { Box::pin(async move { diff --git a/lib/src/api/method/tests/server.rs b/lib/src/api/method/tests/server.rs index c3cea484..9fb419c1 100644 --- a/lib/src/api/method/tests/server.rs +++ b/lib/src/api/method/tests/server.rs @@ -53,7 +53,7 @@ pub(super) fn mock(route_rx: Receiver>) { _ => unreachable!(), }, Method::Query => match param.query { - Some(_) => Ok(DbResponse::Query(QueryResponse(Default::default()))), + Some(_) => Ok(DbResponse::Query(QueryResponse::new())), _ => unreachable!(), }, Method::Create => match ¶ms[..] { diff --git a/lib/src/api/mod.rs b/lib/src/api/mod.rs index 36b39d08..795e1658 100644 --- a/lib/src/api/mod.rs +++ b/lib/src/api/mod.rs @@ -17,6 +17,7 @@ use crate::api::err::Error; use crate::api::opt::Endpoint; use semver::BuildMetadata; use semver::VersionReq; +use std::fmt; use std::fmt::Debug; use std::future::Future; use std::future::IntoFuture; @@ -37,7 +38,8 @@ pub trait Connection: conn::Connection {} #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Connect { - router: Arc>>, + router: Arc>, + engine: PhantomData, address: Result, capacity: usize, client: PhantomData, @@ -115,6 +117,7 @@ where self.router.set(router).map_err(|_| Error::AlreadyConnected)?; let client = Surreal { router: self.router, + engine: PhantomData::, }; client.check_server_version().await?; Ok(()) @@ -129,9 +132,9 @@ pub(crate) enum ExtraFeatures { } /// A database client instance for embedded or remote databases -#[derive(Debug)] pub struct Surreal { - router: Arc>>, + router: Arc>, + engine: PhantomData, } impl Surreal @@ -171,15 +174,25 @@ where fn clone(&self) -> Self { Self { router: self.router.clone(), + engine: self.engine, } } } -trait OnceLockExt +impl Debug for Surreal where C: Connection, { - fn with_value(value: Router) -> OnceLock> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Surreal") + .field("router", &self.router) + .field("engine", &self.engine) + .finish() + } +} + +trait OnceLockExt { + fn with_value(value: Router) -> OnceLock { let cell = OnceLock::new(); match cell.set(value) { Ok(()) => cell, @@ -187,14 +200,11 @@ where } } - fn extract(&self) -> Result<&Router>; + fn extract(&self) -> Result<&Router>; } -impl OnceLockExt for OnceLock> -where - C: Connection, -{ - fn extract(&self) -> Result<&Router> { +impl OnceLockExt for OnceLock { + fn extract(&self) -> Result<&Router> { let router = self.get().ok_or(Error::ConnectionUninitialised)?; Ok(router) } diff --git a/lib/src/api/opt/query.rs b/lib/src/api/opt/query.rs index 85691957..78ebae85 100644 --- a/lib/src/api/opt/query.rs +++ b/lib/src/api/opt/query.rs @@ -1,8 +1,12 @@ use crate::api::{err::Error, opt::from_value, Response as QueryResponse, Result}; -use crate::method::Stats; +use crate::method; +use crate::method::{Stats, Stream}; use crate::sql::{self, statements::*, Array, Object, Statement, Statements, Value}; -use crate::syn; +use crate::{syn, Notification}; +use futures::future::Either; +use futures::stream::select_all; use serde::de::DeserializeOwned; +use std::marker::PhantomData; use std::mem; /// A trait for converting inputs into SQL statements @@ -178,21 +182,21 @@ where fn query_result(self, response: &mut QueryResponse) -> Result; /// Extracts the statistics from a query response - fn stats(&self, QueryResponse(map): &QueryResponse) -> Option { - map.get(&0).map(|x| x.0) + fn stats(&self, response: &QueryResponse) -> Option { + response.results.get(&0).map(|x| x.0) } } impl QueryResult for usize { - fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result { - match map.remove(&self) { + fn query_result(self, response: &mut QueryResponse) -> Result { + match response.results.remove(&self) { Some((_, result)) => Ok(result?), None => Ok(Value::None), } } - fn stats(&self, QueryResponse(map): &QueryResponse) -> Option { - map.get(self).map(|x| x.0) + fn stats(&self, response: &QueryResponse) -> Option { + response.results.get(self).map(|x| x.0) } } @@ -200,13 +204,13 @@ impl QueryResult> for usize where T: DeserializeOwned, { - fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result> { - let value = match map.get_mut(&self) { + fn query_result(self, response: &mut QueryResponse) -> Result> { + let value = match response.results.get_mut(&self) { Some((_, result)) => match result { Ok(val) => val, Err(error) => { let error = mem::replace(error, Error::ConnectionUninitialised.into()); - map.remove(&self); + response.results.remove(&self); return Err(error); } }, @@ -221,31 +225,36 @@ where let value = mem::take(value); from_value(value).map_err(Into::into) } - _ => Err(Error::LossyTake(QueryResponse(mem::take(map))).into()), + _ => Err(Error::LossyTake(QueryResponse { + results: mem::take(&mut response.results), + live_queries: mem::take(&mut response.live_queries), + ..QueryResponse::new() + }) + .into()), }, _ => { let value = mem::take(value); from_value(value).map_err(Into::into) } }; - map.remove(&self); + response.results.remove(&self); result } - fn stats(&self, QueryResponse(map): &QueryResponse) -> Option { - map.get(self).map(|x| x.0) + fn stats(&self, response: &QueryResponse) -> Option { + response.results.get(self).map(|x| x.0) } } impl QueryResult for (usize, &str) { - fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result { + fn query_result(self, response: &mut QueryResponse) -> Result { let (index, key) = self; - let response = match map.get_mut(&index) { + let value = match response.results.get_mut(&index) { Some((_, result)) => match result { Ok(val) => val, Err(error) => { let error = mem::replace(error, Error::ConnectionUninitialised.into()); - map.remove(&index); + response.results.remove(&index); return Err(error); } }, @@ -254,16 +263,16 @@ impl QueryResult for (usize, &str) { } }; - let response = match response { + let value = match value { Value::Object(Object(object)) => object.remove(key).unwrap_or_default(), _ => Value::None, }; - Ok(response) + Ok(value) } - fn stats(&self, QueryResponse(map): &QueryResponse) -> Option { - map.get(&self.0).map(|x| x.0) + fn stats(&self, response: &QueryResponse) -> Option { + response.results.get(&self.0).map(|x| x.0) } } @@ -271,14 +280,14 @@ impl QueryResult> for (usize, &str) where T: DeserializeOwned, { - fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result> { + fn query_result(self, response: &mut QueryResponse) -> Result> { let (index, key) = self; - let value = match map.get_mut(&index) { + let value = match response.results.get_mut(&index) { Some((_, result)) => match result { Ok(val) => val, Err(error) => { let error = mem::replace(error, Error::ConnectionUninitialised.into()); - map.remove(&index); + response.results.remove(&index); return Err(error); } }, @@ -289,24 +298,29 @@ where let value = match value { Value::Array(Array(vec)) => match &mut vec[..] { [] => { - map.remove(&index); + response.results.remove(&index); return Ok(None); } [value] => value, _ => { - return Err(Error::LossyTake(QueryResponse(mem::take(map))).into()); + return Err(Error::LossyTake(QueryResponse { + results: mem::take(&mut response.results), + live_queries: mem::take(&mut response.live_queries), + ..QueryResponse::new() + }) + .into()); } }, value => value, }; match value { Value::None | Value::Null => { - map.remove(&index); + response.results.remove(&index); Ok(None) } Value::Object(Object(object)) => { if object.is_empty() { - map.remove(&index); + response.results.remove(&index); return Ok(None); } let Some(value) = object.remove(key) else { @@ -318,8 +332,8 @@ where } } - fn stats(&self, QueryResponse(map): &QueryResponse) -> Option { - map.get(&self.0).map(|x| x.0) + fn stats(&self, response: &QueryResponse) -> Option { + response.results.get(&self.0).map(|x| x.0) } } @@ -327,8 +341,8 @@ impl QueryResult> for usize where T: DeserializeOwned, { - fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result> { - let vec = match map.remove(&self) { + fn query_result(self, response: &mut QueryResponse) -> Result> { + let vec = match response.results.remove(&self) { Some((_, result)) => match result? { Value::Array(Array(vec)) => vec, vec => vec![vec], @@ -340,8 +354,8 @@ where from_value(vec.into()).map_err(Into::into) } - fn stats(&self, QueryResponse(map): &QueryResponse) -> Option { - map.get(self).map(|x| x.0) + fn stats(&self, response: &QueryResponse) -> Option { + response.results.get(self).map(|x| x.0) } } @@ -349,9 +363,9 @@ impl QueryResult> for (usize, &str) where T: DeserializeOwned, { - fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result> { + fn query_result(self, response: &mut QueryResponse) -> Result> { let (index, key) = self; - let mut response = match map.get_mut(&index) { + let mut response = match response.results.get_mut(&index) { Some((_, result)) => match result { Ok(val) => match val { Value::Array(Array(vec)) => mem::take(vec), @@ -362,7 +376,7 @@ where }, Err(error) => { let error = mem::replace(error, Error::ConnectionUninitialised.into()); - map.remove(&index); + response.results.remove(&index); return Err(error); } }, @@ -381,8 +395,8 @@ where from_value(vec.into()).map_err(Into::into) } - fn stats(&self, QueryResponse(map): &QueryResponse) -> Option { - map.get(&self.0).map(|x| x.0) + fn stats(&self, response: &QueryResponse) -> Option { + response.results.get(&self.0).map(|x| x.0) } } @@ -409,3 +423,114 @@ where (0, self).query_result(response) } } + +/// A way to take a query stream future from a query response +pub trait QueryStream { + /// Retrieves the query stream future + fn query_stream(self, response: &mut QueryResponse) -> Result>; +} + +impl QueryStream for usize { + fn query_stream(self, response: &mut QueryResponse) -> Result> { + let stream = response + .live_queries + .remove(&self) + .and_then(|result| match result { + Err(crate::Error::Api(Error::NotLiveQuery(..))) => { + response.results.remove(&self).and_then(|x| x.1.err().map(Err)) + } + result => Some(result), + }) + .unwrap_or_else(|| match response.results.contains_key(&self) { + true => Err(Error::NotLiveQuery(self).into()), + false => Err(Error::QueryIndexOutOfBounds(self).into()), + })?; + Ok(method::QueryStream(Either::Left(stream))) + } +} + +impl QueryStream for () { + fn query_stream(self, response: &mut QueryResponse) -> Result> { + let mut streams = Vec::with_capacity(response.live_queries.len()); + for (index, result) in mem::take(&mut response.live_queries) { + match result { + Ok(stream) => streams.push(stream), + Err(crate::Error::Api(Error::NotLiveQuery(..))) => match response.results.remove(&index) { + Some((stats, Err(error))) => { + response.results.insert(index, (stats, Err(Error::ResponseAlreadyTaken.into()))); + return Err(error); + } + Some((_, Ok(..))) => unreachable!("the internal error variant indicates that an error occurred in the `LIVE SELECT` query"), + None => { return Err(Error::ResponseAlreadyTaken.into()); } + } + Err(error) => { return Err(error); } + } + } + Ok(method::QueryStream(Either::Right(select_all(streams)))) + } +} + +impl QueryStream> for usize +where + R: DeserializeOwned + Unpin, +{ + fn query_stream( + self, + response: &mut QueryResponse, + ) -> Result>> { + let mut stream = response + .live_queries + .remove(&self) + .and_then(|result| match result { + Err(crate::Error::Api(Error::NotLiveQuery(..))) => { + response.results.remove(&self).and_then(|x| x.1.err().map(Err)) + } + result => Some(result), + }) + .unwrap_or_else(|| match response.results.contains_key(&self) { + true => Err(Error::NotLiveQuery(self).into()), + false => Err(Error::QueryIndexOutOfBounds(self).into()), + })?; + Ok(method::QueryStream(Either::Left(Stream { + client: stream.client.clone(), + engine: stream.engine, + id: mem::take(&mut stream.id), + rx: stream.rx.take(), + response_type: PhantomData, + }))) + } +} + +impl QueryStream> for () +where + R: DeserializeOwned + Unpin, +{ + fn query_stream( + self, + response: &mut QueryResponse, + ) -> Result>> { + let mut streams = Vec::with_capacity(response.live_queries.len()); + for (index, result) in mem::take(&mut response.live_queries) { + let mut stream = match result { + Ok(stream) => stream, + Err(crate::Error::Api(Error::NotLiveQuery(..))) => match response.results.remove(&index) { + Some((stats, Err(error))) => { + response.results.insert(index, (stats, Err(Error::ResponseAlreadyTaken.into()))); + return Err(error); + } + Some((_, Ok(..))) => unreachable!("the internal error variant indicates that an error occurred in the `LIVE SELECT` query"), + None => { return Err(Error::ResponseAlreadyTaken.into()); } + } + Err(error) => { return Err(error); } + }; + streams.push(Stream { + client: stream.client.clone(), + engine: stream.engine, + id: mem::take(&mut stream.id), + rx: stream.rx.take(), + response_type: PhantomData, + }); + } + Ok(method::QueryStream(Either::Right(select_all(streams)))) + } +} diff --git a/lib/src/lib.rs b/lib/src/lib.rs index c9b4b78d..7377230a 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -158,6 +158,7 @@ pub use api::Response; pub use api::Result; #[doc(inline)] pub use api::Surreal; +use uuid::Uuid; #[doc(hidden)] /// Channels for receiving a SurrealQL database export @@ -203,6 +204,7 @@ impl From for Action { #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] #[non_exhaustive] pub struct Notification { + pub query_id: Uuid, pub action: Action, pub data: R, } diff --git a/lib/tests/api/live.rs b/lib/tests/api/live.rs index 7f68a832..4ea4148e 100644 --- a/lib/tests/api/live.rs +++ b/lib/tests/api/live.rs @@ -174,3 +174,128 @@ async fn live_select_record_ranges() { drop(permit); } + +#[test_log::test(tokio::test)] +async fn live_select_query() { + let (permit, db) = new_db().await; + + db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap(); + + { + let table = Ulid::new().to_string(); + + // Start listening + let mut users = db + .query(format!("LIVE SELECT * FROM {table}")) + .await + .unwrap() + .stream::>(0) + .unwrap(); + + // Create a record + let created: Vec = db.create(table).await.unwrap(); + // Pull the notification + let notification: Notification = users.next().await.unwrap().unwrap(); + // The returned record should match the created record + assert_eq!(created, vec![notification.data.clone()]); + // It should be newly created + assert_eq!(notification.action, Action::Create); + + // Update the record + let _: Option = + db.update(¬ification.data.id).content(json!({"foo": "bar"})).await.unwrap(); + // Pull the notification + let notification: Notification = users.next().await.unwrap().unwrap(); + // It should be updated + assert_eq!(notification.action, Action::Update); + + // Delete the record + let _: Option = db.delete(¬ification.data.id).await.unwrap(); + // Pull the notification + let notification: Notification = users.next().await.unwrap().unwrap(); + // It should be deleted + assert_eq!(notification.action, Action::Delete); + } + + { + let table = Ulid::new().to_string(); + + // Start listening + let mut users = db + .query(format!("LIVE SELECT * FROM {table}")) + .await + .unwrap() + .stream::(0) + .unwrap(); + + // Create a record + db.create(Resource::from(&table)).await.unwrap(); + // Pull the notification + let notification = users.next().await.unwrap(); + // The returned record should be an object + assert!(notification.data.is_object()); + // It should be newly created + assert_eq!(notification.action, Action::Create); + } + + { + let table = Ulid::new().to_string(); + + // Start listening + let mut users = db + .query(format!("LIVE SELECT * FROM {table}")) + .await + .unwrap() + .stream::>(()) + .unwrap(); + + // Create a record + let created: Vec = db.create(table).await.unwrap(); + // Pull the notification + let notification: Notification = users.next().await.unwrap().unwrap(); + // The returned record should match the created record + assert_eq!(created, vec![notification.data.clone()]); + // It should be newly created + assert_eq!(notification.action, Action::Create); + + // Update the record + let _: Option = + db.update(¬ification.data.id).content(json!({"foo": "bar"})).await.unwrap(); + // Pull the notification + let notification: Notification = users.next().await.unwrap().unwrap(); + // It should be updated + assert_eq!(notification.action, Action::Update); + + // Delete the record + let _: Option = db.delete(¬ification.data.id).await.unwrap(); + // Pull the notification + let notification: Notification = users.next().await.unwrap().unwrap(); + // It should be deleted + assert_eq!(notification.action, Action::Delete); + } + + { + let table = Ulid::new().to_string(); + + // Start listening + let mut users = db + .query("BEGIN") + .query(format!("LIVE SELECT * FROM {table}")) + .query("COMMIT") + .await + .unwrap() + .stream::(()) + .unwrap(); + + // Create a record + db.create(Resource::from(&table)).await.unwrap(); + // Pull the notification + let notification = users.next().await.unwrap(); + // The returned record should be an object + assert!(notification.data.is_object()); + // It should be newly created + assert_eq!(notification.action, Action::Create); + } + + drop(permit); +} diff --git a/lib/tests/api/mod.rs b/lib/tests/api/mod.rs index 31589041..1860e3f2 100644 --- a/lib/tests/api/mod.rs +++ b/lib/tests/api/mod.rs @@ -33,7 +33,6 @@ async fn yuse() { drop(permit); } -#[ignore] #[test_log::test(tokio::test)] async fn invalidate() { let (permit, db) = new_db().await; diff --git a/src/cli/sql.rs b/src/cli/sql.rs index c15529c1..faba140d 100644 --- a/src/cli/sql.rs +++ b/src/cli/sql.rs @@ -5,6 +5,8 @@ use crate::cli::abstraction::{ use crate::cnf::PKG_VERSION; use crate::err::Error; use clap::Args; +use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use futures_util::{SinkExt, StreamExt}; use rustyline::error::ReadlineError; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{Completer, Editor, Helper, Highlighter, Hinter}; @@ -15,7 +17,7 @@ use surrealdb::engine::any::{connect, IntoEndpoint}; use surrealdb::method::{Stats, WithStats}; use surrealdb::opt::Config; use surrealdb::sql::{self, Statement, Value}; -use surrealdb::Response; +use surrealdb::{Notification, Response}; #[derive(Args, Debug)] pub struct SqlCommandArguments { @@ -149,6 +151,10 @@ pub async fn init( ); } + // Set up the print job + let (tx, rx) = mpsc::unbounded(); + tokio::spawn(printer(rx)); + // Loop over each command-line input loop { // Prompt the user to input SQL and check the input. @@ -206,15 +212,12 @@ pub async fn init( continue; } // Run the query provided - let res = client.query(query).with_stats().await; - match process(pretty, json, res) { - Ok(v) => { - println!("{v}\n"); - } - Err(e) => { - eprintln!("{e}\n"); - continue; - } + let result = client.query(query).with_stats().await; + let result = process(pretty, json, result, tx.clone()); + let result_is_error = result.is_err(); + tx.clone().send(result).await.expect("print job terminated unexpectedly"); + if result_is_error { + continue; } // Persist the variables extracted from the query for (key, value) in vars { @@ -253,6 +256,7 @@ fn process( pretty: bool, json: bool, res: surrealdb::Result>, + mut tx: UnboundedSender>, ) -> Result { // Check query response for an error let mut response = res?; @@ -271,6 +275,62 @@ fn process( vec.push((stats, output)); } + tokio::spawn(async move { + let mut stream = match response.into_inner().stream::(()) { + Ok(stream) => stream, + Err(error) => { + tx.send(Err(error.into())).await.ok(); + return; + } + }; + while let Some(Notification { + query_id, + action, + data, + .. + }) = stream.next().await + { + let message = match (json, pretty) { + // Don't prettify the SurrealQL response + (false, false) => { + let value = Value::from(map! { + String::from("id") => query_id.into(), + String::from("action") => format!("{action:?}").to_ascii_uppercase().into(), + String::from("result") => data, + }); + value.to_string() + } + // Yes prettify the SurrealQL response + (false, true) => format!( + "-- Notification (action: {action:?}, live query ID: {query_id})\n{data:#}" + ), + // Don't pretty print the JSON response + (true, false) => { + let value = Value::from(map! { + String::from("id") => query_id.into(), + String::from("action") => format!("{action:?}").to_ascii_uppercase().into(), + String::from("result") => data, + }); + value.into_json().to_string() + } + // Yes prettify the JSON response + (true, true) => { + let mut buf = Vec::new(); + let mut serializer = serde_json::Serializer::with_formatter( + &mut buf, + PrettyFormatter::with_indent(b"\t"), + ); + data.into_json().serialize(&mut serializer).unwrap(); + let output = String::from_utf8(buf).unwrap(); + format!("-- Notification (action: {action:?}, live query ID: {query_id})\n{output:#}") + } + }; + if tx.send(Ok(format!("\n{message}"))).await.is_err() { + return; + } + } + }); + // Check if we should emit JSON and/or prettify Ok(match (json, pretty) { // Don't prettify the SurrealQL response @@ -314,6 +374,19 @@ fn process( }) } +async fn printer(mut rx: UnboundedReceiver>) { + while let Some(result) = rx.next().await { + match result { + Ok(v) => { + println!("{v}\n"); + } + Err(e) => { + eprintln!("{e}\n"); + } + } + } +} + #[derive(Completer, Helper, Highlighter, Hinter)] struct InputValidator { /// If omitting semicolon causes newline. diff --git a/tests/cli_integration.rs b/tests/cli_integration.rs index 730afb10..1ed106b0 100644 --- a/tests/cli_integration.rs +++ b/tests/cli_integration.rs @@ -799,6 +799,7 @@ mod cli_integration { } #[test(tokio::test)] + #[ignore] async fn test_capabilities() { // Default capabilities only allow functions info!("* When default capabilities"); @@ -825,7 +826,8 @@ mod cli_integration { let query = "RETURN function() { return '1' };"; let output = common::run(&cmd).input(query).output().unwrap(); assert!( - output.contains("Scripting functions are not allowed"), + output.contains("Scripting functions are not allowed") + || output.contains("Embedded functions are not enabled"), "unexpected output: {output:?}" ); } @@ -855,7 +857,8 @@ mod cli_integration { let query = "RETURN function() { return '1' };"; let output = common::run(&cmd).input(query).output().unwrap(); assert!( - output.contains("Scripting functions are not allowed"), + output.contains("Scripting functions are not allowed") + || output.contains("Embedded functions are not enabled"), "unexpected output: {output:?}" ); } @@ -901,7 +904,8 @@ mod cli_integration { let query = "RETURN function() { return '1' };"; let output = common::run(&cmd).input(query).output().unwrap(); assert!( - output.contains("Scripting functions are not allowed"), + output.contains("Scripting functions are not allowed") + || output.contains("Embedded functions are not enabled"), "unexpected output: {output:?}" ); }