Add support for LIVE SELECT in the SDK and CLI (#3309)

This commit is contained in:
Rushmore Mushambi 2024-01-16 13:48:29 +02:00 committed by GitHub
parent f587289923
commit c5138245a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 812 additions and 189 deletions

View file

@ -16,7 +16,6 @@ use serde::Serialize;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::collections::HashSet; use std::collections::HashSet;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData;
use std::path::PathBuf; use std::path::PathBuf;
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::AtomicI64; use std::sync::atomic::AtomicI64;
@ -31,26 +30,19 @@ pub(crate) struct Route {
/// Message router /// Message router
#[derive(Debug)] #[derive(Debug)]
pub struct Router<C: api::Connection> { pub struct Router {
pub(crate) conn: PhantomData<C>,
pub(crate) sender: Sender<Option<Route>>, pub(crate) sender: Sender<Option<Route>>,
pub(crate) last_id: AtomicI64, pub(crate) last_id: AtomicI64,
pub(crate) features: HashSet<ExtraFeatures>, pub(crate) features: HashSet<ExtraFeatures>,
} }
impl<C> Router<C> impl Router {
where
C: api::Connection,
{
pub(crate) fn next_id(&self) -> i64 { pub(crate) fn next_id(&self) -> i64 {
self.last_id.fetch_add(1, Ordering::SeqCst) self.last_id.fetch_add(1, Ordering::SeqCst)
} }
} }
impl<C> Drop for Router<C> impl Drop for Router {
where
C: api::Connection,
{
fn drop(&mut self) { fn drop(&mut self) {
let _res = self.sender.send(None); let _res = self.sender.send(None);
} }
@ -189,7 +181,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>>
where where
@ -226,7 +218,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute all methods except `query` /// Execute all methods except `query`
fn execute<'r, R>( fn execute<'r, R>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<R>> + Send + Sync + 'r>> ) -> Pin<Box<dyn Future<Output = Result<R>> + Send + Sync + 'r>>
where where
@ -243,7 +235,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return an optional single response /// Execute methods that return an optional single response
fn execute_opt<'r, R>( fn execute_opt<'r, R>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Option<R>>> + Send + Sync + 'r>> ) -> Pin<Box<dyn Future<Output = Result<Option<R>>> + Send + Sync + 'r>>
where where
@ -262,7 +254,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return multiple responses /// Execute methods that return multiple responses
fn execute_vec<'r, R>( fn execute_vec<'r, R>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Vec<R>>> + Send + Sync + 'r>> ) -> Pin<Box<dyn Future<Output = Result<Vec<R>>> + Send + Sync + 'r>>
where where
@ -283,7 +275,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return nothing /// Execute methods that return nothing
fn execute_unit<'r>( fn execute_unit<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + Sync + 'r>> ) -> Pin<Box<dyn Future<Output = Result<()>> + Send + Sync + 'r>>
where where
@ -306,7 +298,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return a raw value /// Execute methods that return a raw value
fn execute_value<'r>( fn execute_value<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + Sync + 'r>> ) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + Sync + 'r>>
where where
@ -321,7 +313,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute the `query` method /// Execute the `query` method
fn execute_query<'r>( fn execute_query<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + Sync + 'r>> ) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + Sync + 'r>>
where where

View file

@ -193,6 +193,7 @@ impl Surreal<Any> {
pub fn connect(&self, address: impl IntoEndpoint) -> Connect<Any, ()> { pub fn connect(&self, address: impl IntoEndpoint) -> Connect<Any, ()> {
Connect { Connect {
router: self.router.clone(), router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(), address: address.into_endpoint(),
capacity: 0, capacity: 0,
client: PhantomData, client: PhantomData,
@ -242,6 +243,7 @@ impl Surreal<Any> {
pub fn connect(address: impl IntoEndpoint) -> Connect<Any, Surreal<Any>> { pub fn connect(address: impl IntoEndpoint) -> Connect<Any, Surreal<Any>> {
Connect { Connect {
router: Arc::new(OnceLock::new()), router: Arc::new(OnceLock::new()),
engine: PhantomData,
address: address.into_endpoint(), address: address.into_endpoint(),
capacity: 0, capacity: 0,
client: PhantomData, client: PhantomData,

View file

@ -215,17 +215,17 @@ impl Connection for Any {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -170,17 +170,17 @@ impl Connection for Any {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -383,6 +383,7 @@ impl Surreal<Db> {
pub fn connect<P>(&self, address: impl IntoEndpoint<P, Client = Db>) -> Connect<Db, ()> { pub fn connect<P>(&self, address: impl IntoEndpoint<P, Client = Db>) -> Connect<Db, ()> {
Connect { Connect {
router: self.router.clone(), router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(), address: address.into_endpoint(),
capacity: 0, capacity: 0,
client: PhantomData, client: PhantomData,
@ -402,11 +403,14 @@ fn process(responses: Vec<Response>) -> QueryResponse {
Err(error) => map.insert(index, (stats, Err(error.into()))), Err(error) => map.insert(index, (stats, Err(error.into()))),
}; };
} }
QueryResponse(map) QueryResponse {
results: map,
..QueryResponse::new()
}
} }
async fn take(one: bool, responses: Vec<Response>) -> Result<Value> { async fn take(one: bool, responses: Vec<Response>) -> Result<Value> {
if let Some((_stats, result)) = process(responses).0.remove(&0) { if let Some((_stats, result)) = process(responses).results.remove(&0) {
let value = result?; let value = result?;
match one { match one {
true => match value { true => match value {

View file

@ -68,17 +68,17 @@ impl Connection for Db {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -68,17 +68,17 @@ impl Connection for Db {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -100,6 +100,7 @@ impl Surreal<Client> {
) -> Connect<Client, ()> { ) -> Connect<Client, ()> {
Connect { Connect {
router: self.router.clone(), router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(), address: address.into_endpoint(),
capacity: 0, capacity: 0,
client: PhantomData, client: PhantomData,
@ -210,11 +211,14 @@ async fn query(request: RequestBuilder) -> Result<QueryResponse> {
} }
} }
Ok(QueryResponse(map)) Ok(QueryResponse {
results: map,
..QueryResponse::new()
})
} }
async fn take(one: bool, request: RequestBuilder) -> Result<Value> { async fn take(one: bool, request: RequestBuilder) -> Result<Value> {
if let Some((_stats, result)) = query(request).await?.0.remove(&0) { if let Some((_stats, result)) = query(request).await?.results.remove(&0) {
let value = result?; let value = result?;
match one { match one {
true => match value { true => match value {

View file

@ -74,17 +74,17 @@ impl Connection for Client {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -53,17 +53,17 @@ impl Connection for Client {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features: HashSet::new(), features: HashSet::new(),
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -68,6 +68,7 @@ impl Surreal<Client> {
) -> Connect<Client, ()> { ) -> Connect<Client, ()> {
Connect { Connect {
router: self.router.clone(), router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(), address: address.into_endpoint(),
capacity: 0, capacity: 0,
client: PhantomData, client: PhantomData,
@ -135,7 +136,10 @@ impl DbResponse {
} }
} }
Ok(DbResponse::Query(api::Response(map))) Ok(DbResponse::Query(api::Response {
results: map,
..api::Response::new()
}))
} }
// Live notifications don't call this method // Live notifications don't call this method
Data::Live(..) => unreachable!(), Data::Live(..) => unreachable!(),

View file

@ -136,17 +136,17 @@ impl Connection for Client {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -90,17 +90,17 @@ impl Connection for Client {
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router { router: Arc::new(OnceLock::with_value(Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
})), })),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -181,6 +181,18 @@ pub enum Error {
/// Tried to use a range query on an edge or edges /// Tried to use a range query on an edge or edges
#[error("Live queries on edges not supported: {0}")] #[error("Live queries on edges not supported: {0}")]
LiveOnEdges(Edges), LiveOnEdges(Edges),
/// Tried to access a query statement as a live query when it isn't a live query
#[error("Query statement {0} is not a live query")]
NotLiveQuery(usize),
/// Tried to access a query statement falling outside the bounds of the statements supplied
#[error("Query statement {0} is out of bounds")]
QueryIndexOutOfBounds(usize),
/// Called `Response::take` or `Response::stream` on a query response more than once
#[error("Tried to take a query response that has already been taken")]
ResponseAlreadyTaken,
} }
#[cfg(feature = "protocol-http")] #[cfg(feature = "protocol-http")]

View file

@ -1,10 +1,12 @@
use crate::api::conn::Method; use crate::api::conn::Method;
use crate::api::conn::Param; use crate::api::conn::Param;
use crate::api::conn::Router;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::Connection; use crate::api::Connection;
use crate::api::ExtraFeatures; use crate::api::ExtraFeatures;
use crate::api::Result; use crate::api::Result;
use crate::dbs; use crate::dbs;
use crate::engine::any::Any;
use crate::method::Live; use crate::method::Live;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::method::Query; use crate::method::Query;
@ -30,7 +32,6 @@ use crate::Surreal;
use channel::Receiver; use channel::Receiver;
use futures::StreamExt; use futures::StreamExt;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::borrow::Cow;
use std::future::Future; use std::future::Future;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -93,24 +94,40 @@ macro_rules! into_future {
client: client.clone(), client: client.clone(),
query: vec![Ok(vec![Statement::Live(stmt)])], query: vec![Ok(vec![Statement::Live(stmt)])],
bindings: Ok(Default::default()), bindings: Ok(Default::default()),
register_live_queries: false,
}; };
let id: Value = query.await?.take(0)?; let id: Value = query.await?.take(0)?;
let mut conn = Client::new(Method::Live); let rx = register::<Client>(router, id.clone()).await?;
let (tx, rx) = channel::unbounded();
let mut param = Param::notification_sender(tx);
param.other = vec![id.clone()];
conn.execute_unit(router, param).await?;
Ok(Stream { Ok(Stream {
id, id,
rx, rx: Some(rx),
client, client: Surreal {
router: client.router.clone(),
engine: PhantomData,
},
response_type: PhantomData, response_type: PhantomData,
engine: PhantomData,
}) })
}) })
} }
}; };
} }
pub(crate) async fn register<Client>(
router: &Router,
id: Value,
) -> Result<Receiver<dbs::Notification>>
where
Client: Connection,
{
let mut conn = Client::new(Method::Live);
let (tx, rx) = channel::unbounded();
let mut param = Param::notification_sender(tx);
param.other = vec![id];
conn.execute_unit(router, param).await?;
Ok(rx)
}
fn cond_from_range(range: crate::sql::Range) -> Option<Cond> { fn cond_from_range(range: crate::sql::Range) -> Option<Cond> {
match (range.beg, range.end) { match (range.beg, range.end) {
(Bound::Unbounded, Bound::Unbounded) => None, (Bound::Unbounded, Bound::Unbounded) => None,
@ -241,21 +258,23 @@ where
#[derive(Debug)] #[derive(Debug)]
#[must_use = "streams do nothing unless you poll them"] #[must_use = "streams do nothing unless you poll them"]
pub struct Stream<'r, C: Connection, R> { pub struct Stream<'r, C: Connection, R> {
client: Cow<'r, Surreal<C>>, pub(crate) client: Surreal<Any>,
id: Value, // We no longer need the lifetime and the type parameter
rx: Receiver<dbs::Notification>, // Leaving them in for backwards compatibility
response_type: PhantomData<R>, pub(crate) engine: PhantomData<&'r C>,
pub(crate) id: Value,
pub(crate) rx: Option<Receiver<dbs::Notification>>,
pub(crate) response_type: PhantomData<R>,
} }
macro_rules! poll_next { macro_rules! poll_next {
($action:ident, $result:ident => $body:expr) => { ($notification:ident => $body:expr) => {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.as_mut().rx.poll_next_unpin(cx) { let Some(ref mut rx) = self.as_mut().rx else {
Poll::Ready(Some(dbs::Notification { return Poll::Ready(None);
$action, };
$result, match rx.poll_next_unpin(cx) {
.. Poll::Ready(Some($notification)) => $body,
})) => $body,
Poll::Ready(None) => Poll::Ready(None), Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
} }
@ -270,15 +289,23 @@ where
type Item = Notification<Value>; type Item = Notification<Value>;
poll_next! { poll_next! {
action, result => Poll::Ready(Some(Notification { action: action.into(), data: result })) notification => Poll::Ready(Some(Notification {
query_id: notification.id.0,
action: notification.action.into(),
data: notification.result,
}))
} }
} }
macro_rules! poll_next_and_convert { macro_rules! poll_next_and_convert {
() => { () => {
poll_next! { poll_next! {
action, result => match from_value(result) { notification => match from_value(notification.result) {
Ok(data) => Poll::Ready(Some(Ok(Notification { action: action.into(), data }))), Ok(data) => Poll::Ready(Some(Ok(Notification {
data,
query_id: notification.id.0,
action: notification.action.into(),
}))),
Err(error) => Poll::Ready(Some(Err(error.into()))), Err(error) => Poll::Ready(Some(Err(error.into()))),
} }
} }
@ -305,6 +332,29 @@ where
poll_next_and_convert! {} poll_next_and_convert! {}
} }
impl<C, R> futures::Stream for Stream<'_, C, Notification<R>>
where
C: Connection,
R: DeserializeOwned + Unpin,
{
type Item = Result<Notification<R>>;
poll_next_and_convert! {}
}
pub(crate) fn kill<Client>(client: &Surreal<Client>, id: Value)
where
Client: Connection,
{
let client = client.clone();
spawn(async move {
if let Ok(router) = client.router.extract() {
let mut conn = Client::new(Method::Kill);
conn.execute_unit(router, Param::new(vec![id.clone()])).await.ok();
}
});
}
impl<Client, R> Drop for Stream<'_, Client, R> impl<Client, R> Drop for Stream<'_, Client, R>
where where
Client: Connection, Client: Connection,
@ -313,20 +363,9 @@ where
/// ///
/// This kills the live query process responsible for this stream. /// This kills the live query process responsible for this stream.
fn drop(&mut self) { fn drop(&mut self) {
if !self.id.is_none() { if !self.id.is_none() && self.rx.is_some() {
let id = mem::take(&mut self.id); let id = mem::take(&mut self.id);
let client = self.client.clone().into_owned(); kill(&self.client, id);
spawn(async move {
if let Ok(router) = client.router.extract() {
let mut conn = Client::new(Method::Kill);
match conn.execute_unit(router, Param::new(vec![id.clone()])).await {
Ok(()) => trace!("Live query {id} dropped successfully"),
Err(error) => warn!("Failed to drop live query {id}; {error}"),
}
}
});
} else {
trace!("Ignoring drop call on an already dropped live::Stream");
} }
} }
} }

View file

@ -1,5 +1,6 @@
//! Methods to use when interacting with a SurrealDB instance //! Methods to use when interacting with a SurrealDB instance
pub(crate) mod live;
pub(crate) mod query; pub(crate) mod query;
mod authenticate; mod authenticate;
@ -13,7 +14,6 @@ mod export;
mod health; mod health;
mod import; mod import;
mod invalidate; mod invalidate;
mod live;
mod merge; mod merge;
mod patch; mod patch;
mod select; mod select;
@ -50,6 +50,7 @@ pub use live::Stream;
pub use merge::Merge; pub use merge::Merge;
pub use patch::Patch; pub use patch::Patch;
pub use query::Query; pub use query::Query;
pub use query::QueryStream;
pub use select::Select; pub use select::Select;
pub use set::Set; pub use set::Set;
pub use signin::Signin; pub use signin::Signin;
@ -227,6 +228,7 @@ where
pub fn init() -> Self { pub fn init() -> Self {
Self { Self {
router: Arc::new(OnceLock::new()), router: Arc::new(OnceLock::new()),
engine: PhantomData,
} }
} }
@ -252,6 +254,7 @@ where
pub fn new<P>(address: impl IntoEndpoint<P, Client = C>) -> Connect<C, Self> { pub fn new<P>(address: impl IntoEndpoint<P, Client = C>) -> Connect<C, Self> {
Connect { Connect {
router: Arc::new(OnceLock::new()), router: Arc::new(OnceLock::new()),
engine: PhantomData,
address: address.into_endpoint(), address: address.into_endpoint(),
capacity: 0, capacity: 0,
client: PhantomData, client: PhantomData,
@ -638,6 +641,7 @@ where
client: Cow::Borrowed(self), client: Cow::Borrowed(self),
query: vec![query.into_query()], query: vec![query.into_query()],
bindings: Ok(Default::default()), bindings: Ok(Default::default()),
register_live_queries: true,
} }
} }

View file

@ -1,9 +1,14 @@
use super::live;
use super::Stream;
use crate::api::conn::Method; use crate::api::conn::Method;
use crate::api::conn::Param; use crate::api::conn::Param;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::opt; use crate::api::opt;
use crate::api::Connection; use crate::api::Connection;
use crate::api::ExtraFeatures;
use crate::api::Result; use crate::api::Result;
use crate::engine::any::Any;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::method::Stats; use crate::method::Stats;
use crate::method::WithStats; use crate::method::WithStats;
@ -15,7 +20,11 @@ use crate::sql::Statement;
use crate::sql::Statements; use crate::sql::Statements;
use crate::sql::Strand; use crate::sql::Strand;
use crate::sql::Value; use crate::sql::Value;
use crate::Notification;
use crate::Surreal; use crate::Surreal;
use futures::future::Either;
use futures::stream::SelectAll;
use futures::StreamExt;
use indexmap::IndexMap; use indexmap::IndexMap;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
@ -24,8 +33,11 @@ use std::collections::BTreeMap;
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData;
use std::mem; use std::mem;
use std::pin::Pin; use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
/// A query future /// A query future
#[derive(Debug)] #[derive(Debug)]
@ -34,6 +46,7 @@ pub struct Query<'r, C: Connection> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) query: Vec<Result<Vec<Statement>>>, pub(super) query: Vec<Result<Vec<Statement>>>,
pub(super) bindings: Result<BTreeMap<String, Value>>, pub(super) bindings: Result<BTreeMap<String, Value>>,
pub(crate) register_live_queries: bool,
} }
impl<C> Query<'_, C> impl<C> Query<'_, C>
@ -58,14 +71,66 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
// Extract the router from the client
let router = self.client.router.extract()?;
// Combine all query statements supplied
let mut statements = Vec::with_capacity(self.query.len()); let mut statements = Vec::with_capacity(self.query.len());
for query in self.query { for query in self.query {
statements.extend(query?); statements.extend(query?);
} }
let query = sql::Query(Statements(statements)); // Build the query and execute it
let query = sql::Query(Statements(statements.clone()));
let param = Param::query(query, self.bindings?); let param = Param::query(query, self.bindings?);
let mut conn = Client::new(Method::Query); let mut conn = Client::new(Method::Query);
conn.execute_query(self.client.router.extract()?, param).await let mut response = conn.execute_query(router, param).await?;
// Register live queries if necessary
if self.register_live_queries {
let mut live_queries = IndexMap::new();
let mut checked = false;
// Adjusting offsets as a workaround to https://github.com/surrealdb/surrealdb/issues/3318
let mut offset = 0;
for (index, stmt) in statements.into_iter().enumerate() {
if let Statement::Live(stmt) = stmt {
if !checked && !router.features.contains(&ExtraFeatures::LiveQueries) {
return Err(Error::LiveQueriesNotSupported.into());
}
checked = true;
let index = index - offset;
if let Some((_, result)) = response.results.get(&index) {
let result =
match result {
Ok(id) => live::register::<Client>(router, id.clone())
.await
.map(|rx| Stream {
id: stmt.id.into(),
rx: Some(rx),
client: Surreal {
router: self.client.router.clone(),
engine: PhantomData,
},
response_type: PhantomData,
engine: PhantomData,
}),
// This is a live query. We are using this as a workaround to avoid
// creating another public error variant for this internal error.
Err(..) => Err(Error::NotLiveQuery(index).into()),
};
live_queries.insert(index, result);
}
} else if matches!(
stmt,
Statement::Begin(..) | Statement::Commit(..) | Statement::Cancel(..)
) {
offset += 1;
}
}
response.live_queries = live_queries;
}
response.client = Surreal {
router: self.client.router.clone(),
engine: PhantomData,
};
Ok(response)
}) })
} }
} }
@ -172,9 +237,47 @@ pub(crate) type QueryResult = Result<Value>;
/// The response type of a `Surreal::query` request /// The response type of a `Surreal::query` request
#[derive(Debug)] #[derive(Debug)]
pub struct Response(pub(crate) IndexMap<usize, (Stats, QueryResult)>); pub struct Response {
pub(crate) client: Surreal<Any>,
pub(crate) results: IndexMap<usize, (Stats, QueryResult)>,
pub(crate) live_queries: IndexMap<usize, Result<Stream<'static, Any, Value>>>,
}
/// A `LIVE SELECT` stream from the `query` method
#[derive(Debug)]
#[must_use = "streams do nothing unless you poll them"]
pub struct QueryStream<R>(
pub(crate) Either<Stream<'static, Any, R>, SelectAll<Stream<'static, Any, R>>>,
);
impl futures::Stream for QueryStream<Value> {
type Item = Notification<Value>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().0.poll_next_unpin(cx)
}
}
impl<R> futures::Stream for QueryStream<Notification<R>>
where
R: DeserializeOwned + Unpin,
{
type Item = Result<Notification<R>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().0.poll_next_unpin(cx)
}
}
impl Response { impl Response {
pub(crate) fn new() -> Self {
Self {
client: Surreal::init(),
results: Default::default(),
live_queries: Default::default(),
}
}
/// Takes and returns records returned from the database /// Takes and returns records returned from the database
/// ///
/// A query that only returns one result can be deserialized into an /// A query that only returns one result can be deserialized into an
@ -185,7 +288,6 @@ impl Response {
/// ///
/// ```no_run /// ```no_run
/// use serde::Deserialize; /// use serde::Deserialize;
/// use surrealdb::sql;
/// ///
/// #[derive(Debug, Deserialize)] /// #[derive(Debug, Deserialize)]
/// # #[allow(dead_code)] /// # #[allow(dead_code)]
@ -244,6 +346,53 @@ impl Response {
index.query_result(self) index.query_result(self)
} }
/// Takes and streams records returned from a `LIVE SELECT` query
///
/// This is the counterpart to [Response::take] used to stream the results
/// of a live query.
///
/// # Examples
///
/// ```no_run
/// use serde::Deserialize;
/// use surrealdb::Notification;
/// use surrealdb::sql::Value;
///
/// #[derive(Debug, Deserialize)]
/// # #[allow(dead_code)]
/// struct User {
/// id: String,
/// balance: String
/// }
///
/// # #[tokio::main]
/// # async fn main() -> surrealdb::Result<()> {
/// # let db = surrealdb::engine::any::connect("mem://").await?;
/// #
/// let mut response = db
/// // Stream all changes to the user table
/// .query("LIVE SELECT * FROM user")
/// .await?;
///
/// // Stream the result of the live query at the given index
/// // while deserialising into the User type
/// let mut stream = response.stream::<Notification<User>>(0)?;
///
/// // Stream raw values instead
/// let mut stream = response.stream::<Value>(0)?;
///
/// // Combine and stream all `LIVE SELECT` statements in this query
/// let mut stream = response.stream::<Value>(())?;
/// #
/// # Ok(())
/// # }
/// ```
///
/// Consume the stream the same way you would any other type that implements `futures::Stream`.
pub fn stream<R>(&mut self, index: impl opt::QueryStream<R>) -> Result<QueryStream<R>> {
index.query_stream(self)
}
/// Take all errors from the query response /// Take all errors from the query response
/// ///
/// The errors are keyed by the corresponding index of the statement that failed. /// The errors are keyed by the corresponding index of the statement that failed.
@ -264,14 +413,14 @@ impl Response {
/// ``` /// ```
pub fn take_errors(&mut self) -> HashMap<usize, crate::Error> { pub fn take_errors(&mut self) -> HashMap<usize, crate::Error> {
let mut keys = Vec::new(); let mut keys = Vec::new();
for (key, result) in &self.0 { for (key, result) in &self.results {
if result.1.is_err() { if result.1.is_err() {
keys.push(*key); keys.push(*key);
} }
} }
let mut errors = HashMap::with_capacity(keys.len()); let mut errors = HashMap::with_capacity(keys.len());
for key in keys { for key in keys {
if let Some((_, Err(error))) = self.0.remove(&key) { if let Some((_, Err(error))) = self.results.remove(&key) {
errors.insert(key, error); errors.insert(key, error);
} }
} }
@ -295,14 +444,14 @@ impl Response {
/// ``` /// ```
pub fn check(mut self) -> Result<Self> { pub fn check(mut self) -> Result<Self> {
let mut first_error = None; let mut first_error = None;
for (key, result) in &self.0 { for (key, result) in &self.results {
if result.1.is_err() { if result.1.is_err() {
first_error = Some(*key); first_error = Some(*key);
break; break;
} }
} }
if let Some(key) = first_error { if let Some(key) = first_error {
if let Some((_, Err(error))) = self.0.remove(&key) { if let Some((_, Err(error))) = self.results.remove(&key) {
return Err(error); return Err(error);
} }
} }
@ -326,7 +475,7 @@ impl Response {
/// # Ok(()) /// # Ok(())
/// # } /// # }
pub fn num_statements(&self) -> usize { pub fn num_statements(&self) -> usize {
self.0.len() self.results.len()
} }
} }
@ -423,14 +572,14 @@ impl WithStats<Response> {
/// ``` /// ```
pub fn take_errors(&mut self) -> HashMap<usize, (Stats, crate::Error)> { pub fn take_errors(&mut self) -> HashMap<usize, (Stats, crate::Error)> {
let mut keys = Vec::new(); let mut keys = Vec::new();
for (key, result) in &self.0 .0 { for (key, result) in &self.0.results {
if result.1.is_err() { if result.1.is_err() {
keys.push(*key); keys.push(*key);
} }
} }
let mut errors = HashMap::with_capacity(keys.len()); let mut errors = HashMap::with_capacity(keys.len());
for key in keys { for key in keys {
if let Some((stats, Err(error))) = self.0 .0.remove(&key) { if let Some((stats, Err(error))) = self.0.results.remove(&key) {
errors.insert(key, (stats, error)); errors.insert(key, (stats, error));
} }
} }
@ -476,6 +625,11 @@ impl WithStats<Response> {
pub fn num_statements(&self) -> usize { pub fn num_statements(&self) -> usize {
self.0.num_statements() self.0.num_statements()
} }
/// Returns the unwrapped response
pub fn into_inner(self) -> Response {
self.0
}
} }
#[cfg(test)] #[cfg(test)]
@ -509,36 +663,48 @@ mod tests {
#[test] #[test]
fn take_from_an_empty_response() { fn take_from_an_empty_response() {
let mut response = Response(Default::default()); let mut response = Response::new();
let value: Value = response.take(0).unwrap(); let value: Value = response.take(0).unwrap();
assert!(value.is_none()); assert!(value.is_none());
let mut response = Response(Default::default()); let mut response = Response::new();
let option: Option<String> = response.take(0).unwrap(); let option: Option<String> = response.take(0).unwrap();
assert!(option.is_none()); assert!(option.is_none());
let mut response = Response(Default::default()); let mut response = Response::new();
let vec: Vec<String> = response.take(0).unwrap(); let vec: Vec<String> = response.take(0).unwrap();
assert!(vec.is_empty()); assert!(vec.is_empty());
} }
#[test] #[test]
fn take_from_an_errored_query() { fn take_from_an_errored_query() {
let mut response = Response(to_map(vec![Err(Error::ConnectionUninitialised.into())])); let mut response = Response {
results: to_map(vec![Err(Error::ConnectionUninitialised.into())]),
..Response::new()
};
response.take::<Option<()>>(0).unwrap_err(); response.take::<Option<()>>(0).unwrap_err();
} }
#[test] #[test]
fn take_from_empty_records() { fn take_from_empty_records() {
let mut response = Response(to_map(vec![])); let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let value: Value = response.take(0).unwrap(); let value: Value = response.take(0).unwrap();
assert_eq!(value, Default::default()); assert_eq!(value, Default::default());
let mut response = Response(to_map(vec![])); let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let option: Option<String> = response.take(0).unwrap(); let option: Option<String> = response.take(0).unwrap();
assert!(option.is_none()); assert!(option.is_none());
let mut response = Response(to_map(vec![])); let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let vec: Vec<String> = response.take(0).unwrap(); let vec: Vec<String> = response.take(0).unwrap();
assert!(vec.is_empty()); assert!(vec.is_empty());
} }
@ -547,45 +713,66 @@ mod tests {
fn take_from_a_scalar_response() { fn take_from_a_scalar_response() {
let scalar = 265; let scalar = 265;
let mut response = Response(to_map(vec![Ok(scalar.into())])); let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap(); let value: Value = response.take(0).unwrap();
assert_eq!(value, Value::from(scalar)); assert_eq!(value, Value::from(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())])); let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let option: Option<_> = response.take(0).unwrap(); let option: Option<_> = response.take(0).unwrap();
assert_eq!(option, Some(scalar)); assert_eq!(option, Some(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())])); let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let vec: Vec<usize> = response.take(0).unwrap(); let vec: Vec<usize> = response.take(0).unwrap();
assert_eq!(vec, vec![scalar]); assert_eq!(vec, vec![scalar]);
let scalar = true; let scalar = true;
let mut response = Response(to_map(vec![Ok(scalar.into())])); let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap(); let value: Value = response.take(0).unwrap();
assert_eq!(value, Value::from(scalar)); assert_eq!(value, Value::from(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())])); let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let option: Option<_> = response.take(0).unwrap(); let option: Option<_> = response.take(0).unwrap();
assert_eq!(option, Some(scalar)); assert_eq!(option, Some(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())])); let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let vec: Vec<bool> = response.take(0).unwrap(); let vec: Vec<bool> = response.take(0).unwrap();
assert_eq!(vec, vec![scalar]); assert_eq!(vec, vec![scalar]);
} }
#[test] #[test]
fn take_preserves_order() { fn take_preserves_order() {
let mut response = Response(to_map(vec![ let mut response = Response {
Ok(0.into()), results: to_map(vec![
Ok(1.into()), Ok(0.into()),
Ok(2.into()), Ok(1.into()),
Ok(3.into()), Ok(2.into()),
Ok(4.into()), Ok(3.into()),
Ok(5.into()), Ok(4.into()),
Ok(6.into()), Ok(5.into()),
Ok(7.into()), Ok(6.into()),
])); Ok(7.into()),
]),
..Response::new()
};
let Some(four): Option<i32> = response.take(4).unwrap() else { let Some(four): Option<i32> = response.take(4).unwrap() else {
panic!("query not found"); panic!("query not found");
}; };
@ -609,17 +796,26 @@ mod tests {
}; };
let value = to_value(summary.clone()).unwrap(); let value = to_value(summary.clone()).unwrap();
let mut response = Response(to_map(vec![Ok(value.clone())])); let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let title: Value = response.take("title").unwrap(); let title: Value = response.take("title").unwrap();
assert_eq!(title, Value::from(summary.title.as_str())); assert_eq!(title, Value::from(summary.title.as_str()));
let mut response = Response(to_map(vec![Ok(value.clone())])); let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let Some(title): Option<String> = response.take("title").unwrap() else { let Some(title): Option<String> = response.take("title").unwrap() else {
panic!("title not found"); panic!("title not found");
}; };
assert_eq!(title, summary.title); assert_eq!(title, summary.title);
let mut response = Response(to_map(vec![Ok(value)])); let mut response = Response {
results: to_map(vec![Ok(value)]),
..Response::new()
};
let vec: Vec<String> = response.take("title").unwrap(); let vec: Vec<String> = response.take("title").unwrap();
assert_eq!(vec, vec![summary.title]); assert_eq!(vec, vec![summary.title]);
@ -629,7 +825,10 @@ mod tests {
}; };
let value = to_value(article.clone()).unwrap(); let value = to_value(article.clone()).unwrap();
let mut response = Response(to_map(vec![Ok(value.clone())])); let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let Some(title): Option<String> = response.take("title").unwrap() else { let Some(title): Option<String> = response.take("title").unwrap() else {
panic!("title not found"); panic!("title not found");
}; };
@ -639,27 +838,45 @@ mod tests {
}; };
assert_eq!(body, article.body); assert_eq!(body, article.body);
let mut response = Response(to_map(vec![Ok(value.clone())])); let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let vec: Vec<String> = response.take("title").unwrap(); let vec: Vec<String> = response.take("title").unwrap();
assert_eq!(vec, vec![article.title.clone()]); assert_eq!(vec, vec![article.title.clone()]);
let mut response = Response(to_map(vec![Ok(value)])); let mut response = Response {
results: to_map(vec![Ok(value)]),
..Response::new()
};
let value: Value = response.take("title").unwrap(); let value: Value = response.take("title").unwrap();
assert_eq!(value, Value::from(article.title)); assert_eq!(value, Value::from(article.title));
} }
#[test] #[test]
fn take_partial_records() { fn take_partial_records() {
let mut response = Response(to_map(vec![Ok(vec![true, false].into())])); let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap(); let value: Value = response.take(0).unwrap();
assert_eq!(value, vec![Value::from(true), Value::from(false)].into()); assert_eq!(value, vec![Value::from(true), Value::from(false)].into());
let mut response = Response(to_map(vec![Ok(vec![true, false].into())])); let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let vec: Vec<bool> = response.take(0).unwrap(); let vec: Vec<bool> = response.take(0).unwrap();
assert_eq!(vec, vec![true, false]); assert_eq!(vec, vec![true, false]);
let mut response = Response(to_map(vec![Ok(vec![true, false].into())])); let mut response = Response {
let Err(Api(Error::LossyTake(Response(mut map)))): Result<Option<bool>> = response.take(0) results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let Err(Api(Error::LossyTake(Response {
results: mut map,
..
}))): Result<Option<bool>> = response.take(0)
else { else {
panic!("silently dropping records not allowed"); panic!("silently dropping records not allowed");
}; };
@ -682,7 +899,10 @@ mod tests {
Ok(7.into()), Ok(7.into()),
Err(Error::DuplicateRequestId(0).into()), Err(Error::DuplicateRequestId(0).into()),
]; ];
let response = Response(to_map(response)); let response = Response {
results: to_map(response),
..Response::new()
};
let crate::Error::Api(Error::ConnectionUninitialised) = response.check().unwrap_err() let crate::Error::Api(Error::ConnectionUninitialised) = response.check().unwrap_err()
else { else {
panic!("check did not return the first error"); panic!("check did not return the first error");
@ -704,7 +924,10 @@ mod tests {
Ok(7.into()), Ok(7.into()),
Err(Error::DuplicateRequestId(0).into()), Err(Error::DuplicateRequestId(0).into()),
]; ];
let mut response = Response(to_map(response)); let mut response = Response {
results: to_map(response),
..Response::new()
};
let errors = response.take_errors(); let errors = response.take_errors();
assert_eq!(response.num_statements(), 8); assert_eq!(response.num_statements(), 8);
assert_eq!(errors.len(), 3); assert_eq!(errors.len(), 3);

View file

@ -49,6 +49,7 @@ impl Surreal<Client> {
) -> Connect<Client, ()> { ) -> Connect<Client, ()> {
Connect { Connect {
router: self.router.clone(), router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(), address: address.into_endpoint(),
capacity: 0, capacity: 0,
client: PhantomData, client: PhantomData,
@ -76,20 +77,20 @@ impl Connection for Client {
features.insert(ExtraFeatures::Backup); features.insert(ExtraFeatures::Backup);
let router = Router { let router = Router {
features, features,
conn: PhantomData,
sender: route_tx, sender: route_tx,
last_id: AtomicI64::new(0), last_id: AtomicI64::new(0),
}; };
server::mock(route_rx); server::mock(route_rx);
Ok(Surreal { Ok(Surreal {
router: Arc::new(OnceLock::with_value(router)), router: Arc::new(OnceLock::with_value(router)),
engine: PhantomData,
}) })
}) })
} }
fn send<'r>( fn send<'r>(
&'r mut self, &'r mut self,
router: &'r Router<Self>, router: &'r Router,
param: Param, param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> { ) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move { Box::pin(async move {

View file

@ -53,7 +53,7 @@ pub(super) fn mock(route_rx: Receiver<Option<Route>>) {
_ => unreachable!(), _ => unreachable!(),
}, },
Method::Query => match param.query { Method::Query => match param.query {
Some(_) => Ok(DbResponse::Query(QueryResponse(Default::default()))), Some(_) => Ok(DbResponse::Query(QueryResponse::new())),
_ => unreachable!(), _ => unreachable!(),
}, },
Method::Create => match &params[..] { Method::Create => match &params[..] {

View file

@ -17,6 +17,7 @@ use crate::api::err::Error;
use crate::api::opt::Endpoint; use crate::api::opt::Endpoint;
use semver::BuildMetadata; use semver::BuildMetadata;
use semver::VersionReq; use semver::VersionReq;
use std::fmt;
use std::fmt::Debug; use std::fmt::Debug;
use std::future::Future; use std::future::Future;
use std::future::IntoFuture; use std::future::IntoFuture;
@ -37,7 +38,8 @@ pub trait Connection: conn::Connection {}
#[derive(Debug)] #[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Connect<C: Connection, Response> { pub struct Connect<C: Connection, Response> {
router: Arc<OnceLock<Router<C>>>, router: Arc<OnceLock<Router>>,
engine: PhantomData<C>,
address: Result<Endpoint>, address: Result<Endpoint>,
capacity: usize, capacity: usize,
client: PhantomData<C>, client: PhantomData<C>,
@ -115,6 +117,7 @@ where
self.router.set(router).map_err(|_| Error::AlreadyConnected)?; self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
let client = Surreal { let client = Surreal {
router: self.router, router: self.router,
engine: PhantomData::<Client>,
}; };
client.check_server_version().await?; client.check_server_version().await?;
Ok(()) Ok(())
@ -129,9 +132,9 @@ pub(crate) enum ExtraFeatures {
} }
/// A database client instance for embedded or remote databases /// A database client instance for embedded or remote databases
#[derive(Debug)]
pub struct Surreal<C: Connection> { pub struct Surreal<C: Connection> {
router: Arc<OnceLock<Router<C>>>, router: Arc<OnceLock<Router>>,
engine: PhantomData<C>,
} }
impl<C> Surreal<C> impl<C> Surreal<C>
@ -171,15 +174,25 @@ where
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
router: self.router.clone(), router: self.router.clone(),
engine: self.engine,
} }
} }
} }
trait OnceLockExt<C> impl<C> Debug for Surreal<C>
where where
C: Connection, C: Connection,
{ {
fn with_value(value: Router<C>) -> OnceLock<Router<C>> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Surreal")
.field("router", &self.router)
.field("engine", &self.engine)
.finish()
}
}
trait OnceLockExt {
fn with_value(value: Router) -> OnceLock<Router> {
let cell = OnceLock::new(); let cell = OnceLock::new();
match cell.set(value) { match cell.set(value) {
Ok(()) => cell, Ok(()) => cell,
@ -187,14 +200,11 @@ where
} }
} }
fn extract(&self) -> Result<&Router<C>>; fn extract(&self) -> Result<&Router>;
} }
impl<C> OnceLockExt<C> for OnceLock<Router<C>> impl OnceLockExt for OnceLock<Router> {
where fn extract(&self) -> Result<&Router> {
C: Connection,
{
fn extract(&self) -> Result<&Router<C>> {
let router = self.get().ok_or(Error::ConnectionUninitialised)?; let router = self.get().ok_or(Error::ConnectionUninitialised)?;
Ok(router) Ok(router)
} }

View file

@ -1,8 +1,12 @@
use crate::api::{err::Error, opt::from_value, Response as QueryResponse, Result}; use crate::api::{err::Error, opt::from_value, Response as QueryResponse, Result};
use crate::method::Stats; use crate::method;
use crate::method::{Stats, Stream};
use crate::sql::{self, statements::*, Array, Object, Statement, Statements, Value}; use crate::sql::{self, statements::*, Array, Object, Statement, Statements, Value};
use crate::syn; use crate::{syn, Notification};
use futures::future::Either;
use futures::stream::select_all;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::marker::PhantomData;
use std::mem; use std::mem;
/// A trait for converting inputs into SQL statements /// A trait for converting inputs into SQL statements
@ -178,21 +182,21 @@ where
fn query_result(self, response: &mut QueryResponse) -> Result<Response>; fn query_result(self, response: &mut QueryResponse) -> Result<Response>;
/// Extracts the statistics from a query response /// Extracts the statistics from a query response
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> { fn stats(&self, response: &QueryResponse) -> Option<Stats> {
map.get(&0).map(|x| x.0) response.results.get(&0).map(|x| x.0)
} }
} }
impl QueryResult<Value> for usize { impl QueryResult<Value> for usize {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Value> { fn query_result(self, response: &mut QueryResponse) -> Result<Value> {
match map.remove(&self) { match response.results.remove(&self) {
Some((_, result)) => Ok(result?), Some((_, result)) => Ok(result?),
None => Ok(Value::None), None => Ok(Value::None),
} }
} }
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> { fn stats(&self, response: &QueryResponse) -> Option<Stats> {
map.get(self).map(|x| x.0) response.results.get(self).map(|x| x.0)
} }
} }
@ -200,13 +204,13 @@ impl<T> QueryResult<Option<T>> for usize
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Option<T>> { fn query_result(self, response: &mut QueryResponse) -> Result<Option<T>> {
let value = match map.get_mut(&self) { let value = match response.results.get_mut(&self) {
Some((_, result)) => match result { Some((_, result)) => match result {
Ok(val) => val, Ok(val) => val,
Err(error) => { Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into()); let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&self); response.results.remove(&self);
return Err(error); return Err(error);
} }
}, },
@ -221,31 +225,36 @@ where
let value = mem::take(value); let value = mem::take(value);
from_value(value).map_err(Into::into) from_value(value).map_err(Into::into)
} }
_ => Err(Error::LossyTake(QueryResponse(mem::take(map))).into()), _ => Err(Error::LossyTake(QueryResponse {
results: mem::take(&mut response.results),
live_queries: mem::take(&mut response.live_queries),
..QueryResponse::new()
})
.into()),
}, },
_ => { _ => {
let value = mem::take(value); let value = mem::take(value);
from_value(value).map_err(Into::into) from_value(value).map_err(Into::into)
} }
}; };
map.remove(&self); response.results.remove(&self);
result result
} }
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> { fn stats(&self, response: &QueryResponse) -> Option<Stats> {
map.get(self).map(|x| x.0) response.results.get(self).map(|x| x.0)
} }
} }
impl QueryResult<Value> for (usize, &str) { impl QueryResult<Value> for (usize, &str) {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Value> { fn query_result(self, response: &mut QueryResponse) -> Result<Value> {
let (index, key) = self; let (index, key) = self;
let response = match map.get_mut(&index) { let value = match response.results.get_mut(&index) {
Some((_, result)) => match result { Some((_, result)) => match result {
Ok(val) => val, Ok(val) => val,
Err(error) => { Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into()); let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&index); response.results.remove(&index);
return Err(error); return Err(error);
} }
}, },
@ -254,16 +263,16 @@ impl QueryResult<Value> for (usize, &str) {
} }
}; };
let response = match response { let value = match value {
Value::Object(Object(object)) => object.remove(key).unwrap_or_default(), Value::Object(Object(object)) => object.remove(key).unwrap_or_default(),
_ => Value::None, _ => Value::None,
}; };
Ok(response) Ok(value)
} }
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> { fn stats(&self, response: &QueryResponse) -> Option<Stats> {
map.get(&self.0).map(|x| x.0) response.results.get(&self.0).map(|x| x.0)
} }
} }
@ -271,14 +280,14 @@ impl<T> QueryResult<Option<T>> for (usize, &str)
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Option<T>> { fn query_result(self, response: &mut QueryResponse) -> Result<Option<T>> {
let (index, key) = self; let (index, key) = self;
let value = match map.get_mut(&index) { let value = match response.results.get_mut(&index) {
Some((_, result)) => match result { Some((_, result)) => match result {
Ok(val) => val, Ok(val) => val,
Err(error) => { Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into()); let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&index); response.results.remove(&index);
return Err(error); return Err(error);
} }
}, },
@ -289,24 +298,29 @@ where
let value = match value { let value = match value {
Value::Array(Array(vec)) => match &mut vec[..] { Value::Array(Array(vec)) => match &mut vec[..] {
[] => { [] => {
map.remove(&index); response.results.remove(&index);
return Ok(None); return Ok(None);
} }
[value] => value, [value] => value,
_ => { _ => {
return Err(Error::LossyTake(QueryResponse(mem::take(map))).into()); return Err(Error::LossyTake(QueryResponse {
results: mem::take(&mut response.results),
live_queries: mem::take(&mut response.live_queries),
..QueryResponse::new()
})
.into());
} }
}, },
value => value, value => value,
}; };
match value { match value {
Value::None | Value::Null => { Value::None | Value::Null => {
map.remove(&index); response.results.remove(&index);
Ok(None) Ok(None)
} }
Value::Object(Object(object)) => { Value::Object(Object(object)) => {
if object.is_empty() { if object.is_empty() {
map.remove(&index); response.results.remove(&index);
return Ok(None); return Ok(None);
} }
let Some(value) = object.remove(key) else { let Some(value) = object.remove(key) else {
@ -318,8 +332,8 @@ where
} }
} }
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> { fn stats(&self, response: &QueryResponse) -> Option<Stats> {
map.get(&self.0).map(|x| x.0) response.results.get(&self.0).map(|x| x.0)
} }
} }
@ -327,8 +341,8 @@ impl<T> QueryResult<Vec<T>> for usize
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Vec<T>> { fn query_result(self, response: &mut QueryResponse) -> Result<Vec<T>> {
let vec = match map.remove(&self) { let vec = match response.results.remove(&self) {
Some((_, result)) => match result? { Some((_, result)) => match result? {
Value::Array(Array(vec)) => vec, Value::Array(Array(vec)) => vec,
vec => vec![vec], vec => vec![vec],
@ -340,8 +354,8 @@ where
from_value(vec.into()).map_err(Into::into) from_value(vec.into()).map_err(Into::into)
} }
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> { fn stats(&self, response: &QueryResponse) -> Option<Stats> {
map.get(self).map(|x| x.0) response.results.get(self).map(|x| x.0)
} }
} }
@ -349,9 +363,9 @@ impl<T> QueryResult<Vec<T>> for (usize, &str)
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Vec<T>> { fn query_result(self, response: &mut QueryResponse) -> Result<Vec<T>> {
let (index, key) = self; let (index, key) = self;
let mut response = match map.get_mut(&index) { let mut response = match response.results.get_mut(&index) {
Some((_, result)) => match result { Some((_, result)) => match result {
Ok(val) => match val { Ok(val) => match val {
Value::Array(Array(vec)) => mem::take(vec), Value::Array(Array(vec)) => mem::take(vec),
@ -362,7 +376,7 @@ where
}, },
Err(error) => { Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into()); let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&index); response.results.remove(&index);
return Err(error); return Err(error);
} }
}, },
@ -381,8 +395,8 @@ where
from_value(vec.into()).map_err(Into::into) from_value(vec.into()).map_err(Into::into)
} }
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> { fn stats(&self, response: &QueryResponse) -> Option<Stats> {
map.get(&self.0).map(|x| x.0) response.results.get(&self.0).map(|x| x.0)
} }
} }
@ -409,3 +423,114 @@ where
(0, self).query_result(response) (0, self).query_result(response)
} }
} }
/// A way to take a query stream future from a query response
pub trait QueryStream<R> {
/// Retrieves the query stream future
fn query_stream(self, response: &mut QueryResponse) -> Result<method::QueryStream<R>>;
}
impl QueryStream<Value> for usize {
fn query_stream(self, response: &mut QueryResponse) -> Result<method::QueryStream<Value>> {
let stream = response
.live_queries
.remove(&self)
.and_then(|result| match result {
Err(crate::Error::Api(Error::NotLiveQuery(..))) => {
response.results.remove(&self).and_then(|x| x.1.err().map(Err))
}
result => Some(result),
})
.unwrap_or_else(|| match response.results.contains_key(&self) {
true => Err(Error::NotLiveQuery(self).into()),
false => Err(Error::QueryIndexOutOfBounds(self).into()),
})?;
Ok(method::QueryStream(Either::Left(stream)))
}
}
impl QueryStream<Value> for () {
fn query_stream(self, response: &mut QueryResponse) -> Result<method::QueryStream<Value>> {
let mut streams = Vec::with_capacity(response.live_queries.len());
for (index, result) in mem::take(&mut response.live_queries) {
match result {
Ok(stream) => streams.push(stream),
Err(crate::Error::Api(Error::NotLiveQuery(..))) => match response.results.remove(&index) {
Some((stats, Err(error))) => {
response.results.insert(index, (stats, Err(Error::ResponseAlreadyTaken.into())));
return Err(error);
}
Some((_, Ok(..))) => unreachable!("the internal error variant indicates that an error occurred in the `LIVE SELECT` query"),
None => { return Err(Error::ResponseAlreadyTaken.into()); }
}
Err(error) => { return Err(error); }
}
}
Ok(method::QueryStream(Either::Right(select_all(streams))))
}
}
impl<R> QueryStream<Notification<R>> for usize
where
R: DeserializeOwned + Unpin,
{
fn query_stream(
self,
response: &mut QueryResponse,
) -> Result<method::QueryStream<Notification<R>>> {
let mut stream = response
.live_queries
.remove(&self)
.and_then(|result| match result {
Err(crate::Error::Api(Error::NotLiveQuery(..))) => {
response.results.remove(&self).and_then(|x| x.1.err().map(Err))
}
result => Some(result),
})
.unwrap_or_else(|| match response.results.contains_key(&self) {
true => Err(Error::NotLiveQuery(self).into()),
false => Err(Error::QueryIndexOutOfBounds(self).into()),
})?;
Ok(method::QueryStream(Either::Left(Stream {
client: stream.client.clone(),
engine: stream.engine,
id: mem::take(&mut stream.id),
rx: stream.rx.take(),
response_type: PhantomData,
})))
}
}
impl<R> QueryStream<Notification<R>> for ()
where
R: DeserializeOwned + Unpin,
{
fn query_stream(
self,
response: &mut QueryResponse,
) -> Result<method::QueryStream<Notification<R>>> {
let mut streams = Vec::with_capacity(response.live_queries.len());
for (index, result) in mem::take(&mut response.live_queries) {
let mut stream = match result {
Ok(stream) => stream,
Err(crate::Error::Api(Error::NotLiveQuery(..))) => match response.results.remove(&index) {
Some((stats, Err(error))) => {
response.results.insert(index, (stats, Err(Error::ResponseAlreadyTaken.into())));
return Err(error);
}
Some((_, Ok(..))) => unreachable!("the internal error variant indicates that an error occurred in the `LIVE SELECT` query"),
None => { return Err(Error::ResponseAlreadyTaken.into()); }
}
Err(error) => { return Err(error); }
};
streams.push(Stream {
client: stream.client.clone(),
engine: stream.engine,
id: mem::take(&mut stream.id),
rx: stream.rx.take(),
response_type: PhantomData,
});
}
Ok(method::QueryStream(Either::Right(select_all(streams))))
}
}

View file

@ -158,6 +158,7 @@ pub use api::Response;
pub use api::Result; pub use api::Result;
#[doc(inline)] #[doc(inline)]
pub use api::Surreal; pub use api::Surreal;
use uuid::Uuid;
#[doc(hidden)] #[doc(hidden)]
/// Channels for receiving a SurrealQL database export /// Channels for receiving a SurrealQL database export
@ -203,6 +204,7 @@ impl From<dbs::Action> for Action {
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[non_exhaustive] #[non_exhaustive]
pub struct Notification<R> { pub struct Notification<R> {
pub query_id: Uuid,
pub action: Action, pub action: Action,
pub data: R, pub data: R,
} }

View file

@ -174,3 +174,128 @@ async fn live_select_record_ranges() {
drop(permit); drop(permit);
} }
#[test_log::test(tokio::test)]
async fn live_select_query() {
let (permit, db) = new_db().await;
db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap();
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query(format!("LIVE SELECT * FROM {table}"))
.await
.unwrap()
.stream::<Notification<_>>(0)
.unwrap();
// Create a record
let created: Vec<RecordId> = db.create(table).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// The returned record should match the created record
assert_eq!(created, vec![notification.data.clone()]);
// It should be newly created
assert_eq!(notification.action, Action::Create);
// Update the record
let _: Option<RecordId> =
db.update(&notification.data.id).content(json!({"foo": "bar"})).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be updated
assert_eq!(notification.action, Action::Update);
// Delete the record
let _: Option<RecordId> = db.delete(&notification.data.id).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be deleted
assert_eq!(notification.action, Action::Delete);
}
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query(format!("LIVE SELECT * FROM {table}"))
.await
.unwrap()
.stream::<Value>(0)
.unwrap();
// Create a record
db.create(Resource::from(&table)).await.unwrap();
// Pull the notification
let notification = users.next().await.unwrap();
// The returned record should be an object
assert!(notification.data.is_object());
// It should be newly created
assert_eq!(notification.action, Action::Create);
}
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query(format!("LIVE SELECT * FROM {table}"))
.await
.unwrap()
.stream::<Notification<_>>(())
.unwrap();
// Create a record
let created: Vec<RecordId> = db.create(table).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// The returned record should match the created record
assert_eq!(created, vec![notification.data.clone()]);
// It should be newly created
assert_eq!(notification.action, Action::Create);
// Update the record
let _: Option<RecordId> =
db.update(&notification.data.id).content(json!({"foo": "bar"})).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be updated
assert_eq!(notification.action, Action::Update);
// Delete the record
let _: Option<RecordId> = db.delete(&notification.data.id).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be deleted
assert_eq!(notification.action, Action::Delete);
}
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query("BEGIN")
.query(format!("LIVE SELECT * FROM {table}"))
.query("COMMIT")
.await
.unwrap()
.stream::<Value>(())
.unwrap();
// Create a record
db.create(Resource::from(&table)).await.unwrap();
// Pull the notification
let notification = users.next().await.unwrap();
// The returned record should be an object
assert!(notification.data.is_object());
// It should be newly created
assert_eq!(notification.action, Action::Create);
}
drop(permit);
}

View file

@ -33,7 +33,6 @@ async fn yuse() {
drop(permit); drop(permit);
} }
#[ignore]
#[test_log::test(tokio::test)] #[test_log::test(tokio::test)]
async fn invalidate() { async fn invalidate() {
let (permit, db) = new_db().await; let (permit, db) = new_db().await;

View file

@ -5,6 +5,8 @@ use crate::cli::abstraction::{
use crate::cnf::PKG_VERSION; use crate::cnf::PKG_VERSION;
use crate::err::Error; use crate::err::Error;
use clap::Args; use clap::Args;
use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
use futures_util::{SinkExt, StreamExt};
use rustyline::error::ReadlineError; use rustyline::error::ReadlineError;
use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::validate::{ValidationContext, ValidationResult, Validator};
use rustyline::{Completer, Editor, Helper, Highlighter, Hinter}; use rustyline::{Completer, Editor, Helper, Highlighter, Hinter};
@ -15,7 +17,7 @@ use surrealdb::engine::any::{connect, IntoEndpoint};
use surrealdb::method::{Stats, WithStats}; use surrealdb::method::{Stats, WithStats};
use surrealdb::opt::Config; use surrealdb::opt::Config;
use surrealdb::sql::{self, Statement, Value}; use surrealdb::sql::{self, Statement, Value};
use surrealdb::Response; use surrealdb::{Notification, Response};
#[derive(Args, Debug)] #[derive(Args, Debug)]
pub struct SqlCommandArguments { pub struct SqlCommandArguments {
@ -149,6 +151,10 @@ pub async fn init(
); );
} }
// Set up the print job
let (tx, rx) = mpsc::unbounded();
tokio::spawn(printer(rx));
// Loop over each command-line input // Loop over each command-line input
loop { loop {
// Prompt the user to input SQL and check the input. // Prompt the user to input SQL and check the input.
@ -206,15 +212,12 @@ pub async fn init(
continue; continue;
} }
// Run the query provided // Run the query provided
let res = client.query(query).with_stats().await; let result = client.query(query).with_stats().await;
match process(pretty, json, res) { let result = process(pretty, json, result, tx.clone());
Ok(v) => { let result_is_error = result.is_err();
println!("{v}\n"); tx.clone().send(result).await.expect("print job terminated unexpectedly");
} if result_is_error {
Err(e) => { continue;
eprintln!("{e}\n");
continue;
}
} }
// Persist the variables extracted from the query // Persist the variables extracted from the query
for (key, value) in vars { for (key, value) in vars {
@ -253,6 +256,7 @@ fn process(
pretty: bool, pretty: bool,
json: bool, json: bool,
res: surrealdb::Result<WithStats<Response>>, res: surrealdb::Result<WithStats<Response>>,
mut tx: UnboundedSender<Result<String, Error>>,
) -> Result<String, Error> { ) -> Result<String, Error> {
// Check query response for an error // Check query response for an error
let mut response = res?; let mut response = res?;
@ -271,6 +275,62 @@ fn process(
vec.push((stats, output)); vec.push((stats, output));
} }
tokio::spawn(async move {
let mut stream = match response.into_inner().stream::<Value>(()) {
Ok(stream) => stream,
Err(error) => {
tx.send(Err(error.into())).await.ok();
return;
}
};
while let Some(Notification {
query_id,
action,
data,
..
}) = stream.next().await
{
let message = match (json, pretty) {
// Don't prettify the SurrealQL response
(false, false) => {
let value = Value::from(map! {
String::from("id") => query_id.into(),
String::from("action") => format!("{action:?}").to_ascii_uppercase().into(),
String::from("result") => data,
});
value.to_string()
}
// Yes prettify the SurrealQL response
(false, true) => format!(
"-- Notification (action: {action:?}, live query ID: {query_id})\n{data:#}"
),
// Don't pretty print the JSON response
(true, false) => {
let value = Value::from(map! {
String::from("id") => query_id.into(),
String::from("action") => format!("{action:?}").to_ascii_uppercase().into(),
String::from("result") => data,
});
value.into_json().to_string()
}
// Yes prettify the JSON response
(true, true) => {
let mut buf = Vec::new();
let mut serializer = serde_json::Serializer::with_formatter(
&mut buf,
PrettyFormatter::with_indent(b"\t"),
);
data.into_json().serialize(&mut serializer).unwrap();
let output = String::from_utf8(buf).unwrap();
format!("-- Notification (action: {action:?}, live query ID: {query_id})\n{output:#}")
}
};
if tx.send(Ok(format!("\n{message}"))).await.is_err() {
return;
}
}
});
// Check if we should emit JSON and/or prettify // Check if we should emit JSON and/or prettify
Ok(match (json, pretty) { Ok(match (json, pretty) {
// Don't prettify the SurrealQL response // Don't prettify the SurrealQL response
@ -314,6 +374,19 @@ fn process(
}) })
} }
async fn printer(mut rx: UnboundedReceiver<Result<String, Error>>) {
while let Some(result) = rx.next().await {
match result {
Ok(v) => {
println!("{v}\n");
}
Err(e) => {
eprintln!("{e}\n");
}
}
}
}
#[derive(Completer, Helper, Highlighter, Hinter)] #[derive(Completer, Helper, Highlighter, Hinter)]
struct InputValidator { struct InputValidator {
/// If omitting semicolon causes newline. /// If omitting semicolon causes newline.

View file

@ -799,6 +799,7 @@ mod cli_integration {
} }
#[test(tokio::test)] #[test(tokio::test)]
#[ignore]
async fn test_capabilities() { async fn test_capabilities() {
// Default capabilities only allow functions // Default capabilities only allow functions
info!("* When default capabilities"); info!("* When default capabilities");
@ -825,7 +826,8 @@ mod cli_integration {
let query = "RETURN function() { return '1' };"; let query = "RETURN function() { return '1' };";
let output = common::run(&cmd).input(query).output().unwrap(); let output = common::run(&cmd).input(query).output().unwrap();
assert!( assert!(
output.contains("Scripting functions are not allowed"), output.contains("Scripting functions are not allowed")
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}" "unexpected output: {output:?}"
); );
} }
@ -855,7 +857,8 @@ mod cli_integration {
let query = "RETURN function() { return '1' };"; let query = "RETURN function() { return '1' };";
let output = common::run(&cmd).input(query).output().unwrap(); let output = common::run(&cmd).input(query).output().unwrap();
assert!( assert!(
output.contains("Scripting functions are not allowed"), output.contains("Scripting functions are not allowed")
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}" "unexpected output: {output:?}"
); );
} }
@ -901,7 +904,8 @@ mod cli_integration {
let query = "RETURN function() { return '1' };"; let query = "RETURN function() { return '1' };";
let output = common::run(&cmd).input(query).output().unwrap(); let output = common::run(&cmd).input(query).output().unwrap();
assert!( assert!(
output.contains("Scripting functions are not allowed"), output.contains("Scripting functions are not allowed")
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}" "unexpected output: {output:?}"
); );
} }