Add support for LIVE SELECT in the SDK and CLI (#3309)

This commit is contained in:
Rushmore Mushambi 2024-01-16 13:48:29 +02:00 committed by GitHub
parent f587289923
commit c5138245a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 812 additions and 189 deletions

View file

@ -16,7 +16,6 @@ use serde::Serialize;
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
@ -31,26 +30,19 @@ pub(crate) struct Route {
/// Message router
#[derive(Debug)]
pub struct Router<C: api::Connection> {
pub(crate) conn: PhantomData<C>,
pub struct Router {
pub(crate) sender: Sender<Option<Route>>,
pub(crate) last_id: AtomicI64,
pub(crate) features: HashSet<ExtraFeatures>,
}
impl<C> Router<C>
where
C: api::Connection,
{
impl Router {
pub(crate) fn next_id(&self) -> i64 {
self.last_id.fetch_add(1, Ordering::SeqCst)
}
}
impl<C> Drop for Router<C>
where
C: api::Connection,
{
impl Drop for Router {
fn drop(&mut self) {
let _res = self.sender.send(None);
}
@ -189,7 +181,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
#[allow(clippy::type_complexity)]
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>>
where
@ -226,7 +218,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute all methods except `query`
fn execute<'r, R>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<R>> + Send + Sync + 'r>>
where
@ -243,7 +235,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return an optional single response
fn execute_opt<'r, R>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Option<R>>> + Send + Sync + 'r>>
where
@ -262,7 +254,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return multiple responses
fn execute_vec<'r, R>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Vec<R>>> + Send + Sync + 'r>>
where
@ -283,7 +275,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return nothing
fn execute_unit<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + Sync + 'r>>
where
@ -306,7 +298,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute methods that return a raw value
fn execute_value<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Value>> + Send + Sync + 'r>>
where
@ -321,7 +313,7 @@ pub trait Connection: Sized + Send + Sync + 'static {
/// Execute the `query` method
fn execute_query<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Response>> + Send + Sync + 'r>>
where

View file

@ -193,6 +193,7 @@ impl Surreal<Any> {
pub fn connect(&self, address: impl IntoEndpoint) -> Connect<Any, ()> {
Connect {
router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
@ -242,6 +243,7 @@ impl Surreal<Any> {
pub fn connect(address: impl IntoEndpoint) -> Connect<Any, Surreal<Any>> {
Connect {
router: Arc::new(OnceLock::new()),
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,

View file

@ -215,17 +215,17 @@ impl Connection for Any {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -170,17 +170,17 @@ impl Connection for Any {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -383,6 +383,7 @@ impl Surreal<Db> {
pub fn connect<P>(&self, address: impl IntoEndpoint<P, Client = Db>) -> Connect<Db, ()> {
Connect {
router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
@ -402,11 +403,14 @@ fn process(responses: Vec<Response>) -> QueryResponse {
Err(error) => map.insert(index, (stats, Err(error.into()))),
};
}
QueryResponse(map)
QueryResponse {
results: map,
..QueryResponse::new()
}
}
async fn take(one: bool, responses: Vec<Response>) -> Result<Value> {
if let Some((_stats, result)) = process(responses).0.remove(&0) {
if let Some((_stats, result)) = process(responses).results.remove(&0) {
let value = result?;
match one {
true => match value {

View file

@ -68,17 +68,17 @@ impl Connection for Db {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -68,17 +68,17 @@ impl Connection for Db {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -100,6 +100,7 @@ impl Surreal<Client> {
) -> Connect<Client, ()> {
Connect {
router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
@ -210,11 +211,14 @@ async fn query(request: RequestBuilder) -> Result<QueryResponse> {
}
}
Ok(QueryResponse(map))
Ok(QueryResponse {
results: map,
..QueryResponse::new()
})
}
async fn take(one: bool, request: RequestBuilder) -> Result<Value> {
if let Some((_stats, result)) = query(request).await?.0.remove(&0) {
if let Some((_stats, result)) = query(request).await?.results.remove(&0) {
let value = result?;
match one {
true => match value {

View file

@ -74,17 +74,17 @@ impl Connection for Client {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -53,17 +53,17 @@ impl Connection for Client {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features: HashSet::new(),
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -68,6 +68,7 @@ impl Surreal<Client> {
) -> Connect<Client, ()> {
Connect {
router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
@ -135,7 +136,10 @@ impl DbResponse {
}
}
Ok(DbResponse::Query(api::Response(map)))
Ok(DbResponse::Query(api::Response {
results: map,
..api::Response::new()
}))
}
// Live notifications don't call this method
Data::Live(..) => unreachable!(),

View file

@ -136,17 +136,17 @@ impl Connection for Client {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -90,17 +90,17 @@ impl Connection for Client {
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -181,6 +181,18 @@ pub enum Error {
/// Tried to use a range query on an edge or edges
#[error("Live queries on edges not supported: {0}")]
LiveOnEdges(Edges),
/// Tried to access a query statement as a live query when it isn't a live query
#[error("Query statement {0} is not a live query")]
NotLiveQuery(usize),
/// Tried to access a query statement falling outside the bounds of the statements supplied
#[error("Query statement {0} is out of bounds")]
QueryIndexOutOfBounds(usize),
/// Called `Response::take` or `Response::stream` on a query response more than once
#[error("Tried to take a query response that has already been taken")]
ResponseAlreadyTaken,
}
#[cfg(feature = "protocol-http")]

View file

@ -1,10 +1,12 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Router;
use crate::api::err::Error;
use crate::api::Connection;
use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::dbs;
use crate::engine::any::Any;
use crate::method::Live;
use crate::method::OnceLockExt;
use crate::method::Query;
@ -30,7 +32,6 @@ use crate::Surreal;
use channel::Receiver;
use futures::StreamExt;
use serde::de::DeserializeOwned;
use std::borrow::Cow;
use std::future::Future;
use std::future::IntoFuture;
use std::marker::PhantomData;
@ -93,24 +94,40 @@ macro_rules! into_future {
client: client.clone(),
query: vec![Ok(vec![Statement::Live(stmt)])],
bindings: Ok(Default::default()),
register_live_queries: false,
};
let id: Value = query.await?.take(0)?;
let mut conn = Client::new(Method::Live);
let (tx, rx) = channel::unbounded();
let mut param = Param::notification_sender(tx);
param.other = vec![id.clone()];
conn.execute_unit(router, param).await?;
let rx = register::<Client>(router, id.clone()).await?;
Ok(Stream {
id,
rx,
client,
rx: Some(rx),
client: Surreal {
router: client.router.clone(),
engine: PhantomData,
},
response_type: PhantomData,
engine: PhantomData,
})
})
}
};
}
pub(crate) async fn register<Client>(
router: &Router,
id: Value,
) -> Result<Receiver<dbs::Notification>>
where
Client: Connection,
{
let mut conn = Client::new(Method::Live);
let (tx, rx) = channel::unbounded();
let mut param = Param::notification_sender(tx);
param.other = vec![id];
conn.execute_unit(router, param).await?;
Ok(rx)
}
fn cond_from_range(range: crate::sql::Range) -> Option<Cond> {
match (range.beg, range.end) {
(Bound::Unbounded, Bound::Unbounded) => None,
@ -241,21 +258,23 @@ where
#[derive(Debug)]
#[must_use = "streams do nothing unless you poll them"]
pub struct Stream<'r, C: Connection, R> {
client: Cow<'r, Surreal<C>>,
id: Value,
rx: Receiver<dbs::Notification>,
response_type: PhantomData<R>,
pub(crate) client: Surreal<Any>,
// We no longer need the lifetime and the type parameter
// Leaving them in for backwards compatibility
pub(crate) engine: PhantomData<&'r C>,
pub(crate) id: Value,
pub(crate) rx: Option<Receiver<dbs::Notification>>,
pub(crate) response_type: PhantomData<R>,
}
macro_rules! poll_next {
($action:ident, $result:ident => $body:expr) => {
($notification:ident => $body:expr) => {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.as_mut().rx.poll_next_unpin(cx) {
Poll::Ready(Some(dbs::Notification {
$action,
$result,
..
})) => $body,
let Some(ref mut rx) = self.as_mut().rx else {
return Poll::Ready(None);
};
match rx.poll_next_unpin(cx) {
Poll::Ready(Some($notification)) => $body,
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
@ -270,15 +289,23 @@ where
type Item = Notification<Value>;
poll_next! {
action, result => Poll::Ready(Some(Notification { action: action.into(), data: result }))
notification => Poll::Ready(Some(Notification {
query_id: notification.id.0,
action: notification.action.into(),
data: notification.result,
}))
}
}
macro_rules! poll_next_and_convert {
() => {
poll_next! {
action, result => match from_value(result) {
Ok(data) => Poll::Ready(Some(Ok(Notification { action: action.into(), data }))),
notification => match from_value(notification.result) {
Ok(data) => Poll::Ready(Some(Ok(Notification {
data,
query_id: notification.id.0,
action: notification.action.into(),
}))),
Err(error) => Poll::Ready(Some(Err(error.into()))),
}
}
@ -305,6 +332,29 @@ where
poll_next_and_convert! {}
}
impl<C, R> futures::Stream for Stream<'_, C, Notification<R>>
where
C: Connection,
R: DeserializeOwned + Unpin,
{
type Item = Result<Notification<R>>;
poll_next_and_convert! {}
}
pub(crate) fn kill<Client>(client: &Surreal<Client>, id: Value)
where
Client: Connection,
{
let client = client.clone();
spawn(async move {
if let Ok(router) = client.router.extract() {
let mut conn = Client::new(Method::Kill);
conn.execute_unit(router, Param::new(vec![id.clone()])).await.ok();
}
});
}
impl<Client, R> Drop for Stream<'_, Client, R>
where
Client: Connection,
@ -313,20 +363,9 @@ where
///
/// This kills the live query process responsible for this stream.
fn drop(&mut self) {
if !self.id.is_none() {
if !self.id.is_none() && self.rx.is_some() {
let id = mem::take(&mut self.id);
let client = self.client.clone().into_owned();
spawn(async move {
if let Ok(router) = client.router.extract() {
let mut conn = Client::new(Method::Kill);
match conn.execute_unit(router, Param::new(vec![id.clone()])).await {
Ok(()) => trace!("Live query {id} dropped successfully"),
Err(error) => warn!("Failed to drop live query {id}; {error}"),
}
}
});
} else {
trace!("Ignoring drop call on an already dropped live::Stream");
kill(&self.client, id);
}
}
}

View file

@ -1,5 +1,6 @@
//! Methods to use when interacting with a SurrealDB instance
pub(crate) mod live;
pub(crate) mod query;
mod authenticate;
@ -13,7 +14,6 @@ mod export;
mod health;
mod import;
mod invalidate;
mod live;
mod merge;
mod patch;
mod select;
@ -50,6 +50,7 @@ pub use live::Stream;
pub use merge::Merge;
pub use patch::Patch;
pub use query::Query;
pub use query::QueryStream;
pub use select::Select;
pub use set::Set;
pub use signin::Signin;
@ -227,6 +228,7 @@ where
pub fn init() -> Self {
Self {
router: Arc::new(OnceLock::new()),
engine: PhantomData,
}
}
@ -252,6 +254,7 @@ where
pub fn new<P>(address: impl IntoEndpoint<P, Client = C>) -> Connect<C, Self> {
Connect {
router: Arc::new(OnceLock::new()),
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
@ -638,6 +641,7 @@ where
client: Cow::Borrowed(self),
query: vec![query.into_query()],
bindings: Ok(Default::default()),
register_live_queries: true,
}
}

View file

@ -1,9 +1,14 @@
use super::live;
use super::Stream;
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::err::Error;
use crate::api::opt;
use crate::api::Connection;
use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::engine::any::Any;
use crate::method::OnceLockExt;
use crate::method::Stats;
use crate::method::WithStats;
@ -15,7 +20,11 @@ use crate::sql::Statement;
use crate::sql::Statements;
use crate::sql::Strand;
use crate::sql::Value;
use crate::Notification;
use crate::Surreal;
use futures::future::Either;
use futures::stream::SelectAll;
use futures::StreamExt;
use indexmap::IndexMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -24,8 +33,11 @@ use std::collections::BTreeMap;
use std::collections::HashMap;
use std::future::Future;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::mem;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
/// A query future
#[derive(Debug)]
@ -34,6 +46,7 @@ pub struct Query<'r, C: Connection> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) query: Vec<Result<Vec<Statement>>>,
pub(super) bindings: Result<BTreeMap<String, Value>>,
pub(crate) register_live_queries: bool,
}
impl<C> Query<'_, C>
@ -58,14 +71,66 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
// Extract the router from the client
let router = self.client.router.extract()?;
// Combine all query statements supplied
let mut statements = Vec::with_capacity(self.query.len());
for query in self.query {
statements.extend(query?);
}
let query = sql::Query(Statements(statements));
// Build the query and execute it
let query = sql::Query(Statements(statements.clone()));
let param = Param::query(query, self.bindings?);
let mut conn = Client::new(Method::Query);
conn.execute_query(self.client.router.extract()?, param).await
let mut response = conn.execute_query(router, param).await?;
// Register live queries if necessary
if self.register_live_queries {
let mut live_queries = IndexMap::new();
let mut checked = false;
// Adjusting offsets as a workaround to https://github.com/surrealdb/surrealdb/issues/3318
let mut offset = 0;
for (index, stmt) in statements.into_iter().enumerate() {
if let Statement::Live(stmt) = stmt {
if !checked && !router.features.contains(&ExtraFeatures::LiveQueries) {
return Err(Error::LiveQueriesNotSupported.into());
}
checked = true;
let index = index - offset;
if let Some((_, result)) = response.results.get(&index) {
let result =
match result {
Ok(id) => live::register::<Client>(router, id.clone())
.await
.map(|rx| Stream {
id: stmt.id.into(),
rx: Some(rx),
client: Surreal {
router: self.client.router.clone(),
engine: PhantomData,
},
response_type: PhantomData,
engine: PhantomData,
}),
// This is a live query. We are using this as a workaround to avoid
// creating another public error variant for this internal error.
Err(..) => Err(Error::NotLiveQuery(index).into()),
};
live_queries.insert(index, result);
}
} else if matches!(
stmt,
Statement::Begin(..) | Statement::Commit(..) | Statement::Cancel(..)
) {
offset += 1;
}
}
response.live_queries = live_queries;
}
response.client = Surreal {
router: self.client.router.clone(),
engine: PhantomData,
};
Ok(response)
})
}
}
@ -172,9 +237,47 @@ pub(crate) type QueryResult = Result<Value>;
/// The response type of a `Surreal::query` request
#[derive(Debug)]
pub struct Response(pub(crate) IndexMap<usize, (Stats, QueryResult)>);
pub struct Response {
pub(crate) client: Surreal<Any>,
pub(crate) results: IndexMap<usize, (Stats, QueryResult)>,
pub(crate) live_queries: IndexMap<usize, Result<Stream<'static, Any, Value>>>,
}
/// A `LIVE SELECT` stream from the `query` method
#[derive(Debug)]
#[must_use = "streams do nothing unless you poll them"]
pub struct QueryStream<R>(
pub(crate) Either<Stream<'static, Any, R>, SelectAll<Stream<'static, Any, R>>>,
);
impl futures::Stream for QueryStream<Value> {
type Item = Notification<Value>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().0.poll_next_unpin(cx)
}
}
impl<R> futures::Stream for QueryStream<Notification<R>>
where
R: DeserializeOwned + Unpin,
{
type Item = Result<Notification<R>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().0.poll_next_unpin(cx)
}
}
impl Response {
pub(crate) fn new() -> Self {
Self {
client: Surreal::init(),
results: Default::default(),
live_queries: Default::default(),
}
}
/// Takes and returns records returned from the database
///
/// A query that only returns one result can be deserialized into an
@ -185,7 +288,6 @@ impl Response {
///
/// ```no_run
/// use serde::Deserialize;
/// use surrealdb::sql;
///
/// #[derive(Debug, Deserialize)]
/// # #[allow(dead_code)]
@ -244,6 +346,53 @@ impl Response {
index.query_result(self)
}
/// Takes and streams records returned from a `LIVE SELECT` query
///
/// This is the counterpart to [Response::take] used to stream the results
/// of a live query.
///
/// # Examples
///
/// ```no_run
/// use serde::Deserialize;
/// use surrealdb::Notification;
/// use surrealdb::sql::Value;
///
/// #[derive(Debug, Deserialize)]
/// # #[allow(dead_code)]
/// struct User {
/// id: String,
/// balance: String
/// }
///
/// # #[tokio::main]
/// # async fn main() -> surrealdb::Result<()> {
/// # let db = surrealdb::engine::any::connect("mem://").await?;
/// #
/// let mut response = db
/// // Stream all changes to the user table
/// .query("LIVE SELECT * FROM user")
/// .await?;
///
/// // Stream the result of the live query at the given index
/// // while deserialising into the User type
/// let mut stream = response.stream::<Notification<User>>(0)?;
///
/// // Stream raw values instead
/// let mut stream = response.stream::<Value>(0)?;
///
/// // Combine and stream all `LIVE SELECT` statements in this query
/// let mut stream = response.stream::<Value>(())?;
/// #
/// # Ok(())
/// # }
/// ```
///
/// Consume the stream the same way you would any other type that implements `futures::Stream`.
pub fn stream<R>(&mut self, index: impl opt::QueryStream<R>) -> Result<QueryStream<R>> {
index.query_stream(self)
}
/// Take all errors from the query response
///
/// The errors are keyed by the corresponding index of the statement that failed.
@ -264,14 +413,14 @@ impl Response {
/// ```
pub fn take_errors(&mut self) -> HashMap<usize, crate::Error> {
let mut keys = Vec::new();
for (key, result) in &self.0 {
for (key, result) in &self.results {
if result.1.is_err() {
keys.push(*key);
}
}
let mut errors = HashMap::with_capacity(keys.len());
for key in keys {
if let Some((_, Err(error))) = self.0.remove(&key) {
if let Some((_, Err(error))) = self.results.remove(&key) {
errors.insert(key, error);
}
}
@ -295,14 +444,14 @@ impl Response {
/// ```
pub fn check(mut self) -> Result<Self> {
let mut first_error = None;
for (key, result) in &self.0 {
for (key, result) in &self.results {
if result.1.is_err() {
first_error = Some(*key);
break;
}
}
if let Some(key) = first_error {
if let Some((_, Err(error))) = self.0.remove(&key) {
if let Some((_, Err(error))) = self.results.remove(&key) {
return Err(error);
}
}
@ -326,7 +475,7 @@ impl Response {
/// # Ok(())
/// # }
pub fn num_statements(&self) -> usize {
self.0.len()
self.results.len()
}
}
@ -423,14 +572,14 @@ impl WithStats<Response> {
/// ```
pub fn take_errors(&mut self) -> HashMap<usize, (Stats, crate::Error)> {
let mut keys = Vec::new();
for (key, result) in &self.0 .0 {
for (key, result) in &self.0.results {
if result.1.is_err() {
keys.push(*key);
}
}
let mut errors = HashMap::with_capacity(keys.len());
for key in keys {
if let Some((stats, Err(error))) = self.0 .0.remove(&key) {
if let Some((stats, Err(error))) = self.0.results.remove(&key) {
errors.insert(key, (stats, error));
}
}
@ -476,6 +625,11 @@ impl WithStats<Response> {
pub fn num_statements(&self) -> usize {
self.0.num_statements()
}
/// Returns the unwrapped response
pub fn into_inner(self) -> Response {
self.0
}
}
#[cfg(test)]
@ -509,36 +663,48 @@ mod tests {
#[test]
fn take_from_an_empty_response() {
let mut response = Response(Default::default());
let mut response = Response::new();
let value: Value = response.take(0).unwrap();
assert!(value.is_none());
let mut response = Response(Default::default());
let mut response = Response::new();
let option: Option<String> = response.take(0).unwrap();
assert!(option.is_none());
let mut response = Response(Default::default());
let mut response = Response::new();
let vec: Vec<String> = response.take(0).unwrap();
assert!(vec.is_empty());
}
#[test]
fn take_from_an_errored_query() {
let mut response = Response(to_map(vec![Err(Error::ConnectionUninitialised.into())]));
let mut response = Response {
results: to_map(vec![Err(Error::ConnectionUninitialised.into())]),
..Response::new()
};
response.take::<Option<()>>(0).unwrap_err();
}
#[test]
fn take_from_empty_records() {
let mut response = Response(to_map(vec![]));
let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value, Default::default());
let mut response = Response(to_map(vec![]));
let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let option: Option<String> = response.take(0).unwrap();
assert!(option.is_none());
let mut response = Response(to_map(vec![]));
let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let vec: Vec<String> = response.take(0).unwrap();
assert!(vec.is_empty());
}
@ -547,36 +713,55 @@ mod tests {
fn take_from_a_scalar_response() {
let scalar = 265;
let mut response = Response(to_map(vec![Ok(scalar.into())]));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value, Value::from(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())]));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let option: Option<_> = response.take(0).unwrap();
assert_eq!(option, Some(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())]));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let vec: Vec<usize> = response.take(0).unwrap();
assert_eq!(vec, vec![scalar]);
let scalar = true;
let mut response = Response(to_map(vec![Ok(scalar.into())]));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value, Value::from(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())]));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let option: Option<_> = response.take(0).unwrap();
assert_eq!(option, Some(scalar));
let mut response = Response(to_map(vec![Ok(scalar.into())]));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let vec: Vec<bool> = response.take(0).unwrap();
assert_eq!(vec, vec![scalar]);
}
#[test]
fn take_preserves_order() {
let mut response = Response(to_map(vec![
let mut response = Response {
results: to_map(vec![
Ok(0.into()),
Ok(1.into()),
Ok(2.into()),
@ -585,7 +770,9 @@ mod tests {
Ok(5.into()),
Ok(6.into()),
Ok(7.into()),
]));
]),
..Response::new()
};
let Some(four): Option<i32> = response.take(4).unwrap() else {
panic!("query not found");
};
@ -609,17 +796,26 @@ mod tests {
};
let value = to_value(summary.clone()).unwrap();
let mut response = Response(to_map(vec![Ok(value.clone())]));
let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let title: Value = response.take("title").unwrap();
assert_eq!(title, Value::from(summary.title.as_str()));
let mut response = Response(to_map(vec![Ok(value.clone())]));
let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let Some(title): Option<String> = response.take("title").unwrap() else {
panic!("title not found");
};
assert_eq!(title, summary.title);
let mut response = Response(to_map(vec![Ok(value)]));
let mut response = Response {
results: to_map(vec![Ok(value)]),
..Response::new()
};
let vec: Vec<String> = response.take("title").unwrap();
assert_eq!(vec, vec![summary.title]);
@ -629,7 +825,10 @@ mod tests {
};
let value = to_value(article.clone()).unwrap();
let mut response = Response(to_map(vec![Ok(value.clone())]));
let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let Some(title): Option<String> = response.take("title").unwrap() else {
panic!("title not found");
};
@ -639,27 +838,45 @@ mod tests {
};
assert_eq!(body, article.body);
let mut response = Response(to_map(vec![Ok(value.clone())]));
let mut response = Response {
results: to_map(vec![Ok(value.clone())]),
..Response::new()
};
let vec: Vec<String> = response.take("title").unwrap();
assert_eq!(vec, vec![article.title.clone()]);
let mut response = Response(to_map(vec![Ok(value)]));
let mut response = Response {
results: to_map(vec![Ok(value)]),
..Response::new()
};
let value: Value = response.take("title").unwrap();
assert_eq!(value, Value::from(article.title));
}
#[test]
fn take_partial_records() {
let mut response = Response(to_map(vec![Ok(vec![true, false].into())]));
let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value, vec![Value::from(true), Value::from(false)].into());
let mut response = Response(to_map(vec![Ok(vec![true, false].into())]));
let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let vec: Vec<bool> = response.take(0).unwrap();
assert_eq!(vec, vec![true, false]);
let mut response = Response(to_map(vec![Ok(vec![true, false].into())]));
let Err(Api(Error::LossyTake(Response(mut map)))): Result<Option<bool>> = response.take(0)
let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let Err(Api(Error::LossyTake(Response {
results: mut map,
..
}))): Result<Option<bool>> = response.take(0)
else {
panic!("silently dropping records not allowed");
};
@ -682,7 +899,10 @@ mod tests {
Ok(7.into()),
Err(Error::DuplicateRequestId(0).into()),
];
let response = Response(to_map(response));
let response = Response {
results: to_map(response),
..Response::new()
};
let crate::Error::Api(Error::ConnectionUninitialised) = response.check().unwrap_err()
else {
panic!("check did not return the first error");
@ -704,7 +924,10 @@ mod tests {
Ok(7.into()),
Err(Error::DuplicateRequestId(0).into()),
];
let mut response = Response(to_map(response));
let mut response = Response {
results: to_map(response),
..Response::new()
};
let errors = response.take_errors();
assert_eq!(response.num_statements(), 8);
assert_eq!(errors.len(), 3);

View file

@ -49,6 +49,7 @@ impl Surreal<Client> {
) -> Connect<Client, ()> {
Connect {
router: self.router.clone(),
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
@ -76,20 +77,20 @@ impl Connection for Client {
features.insert(ExtraFeatures::Backup);
let router = Router {
features,
conn: PhantomData,
sender: route_tx,
last_id: AtomicI64::new(0),
};
server::mock(route_rx);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(router)),
engine: PhantomData,
})
})
}
fn send<'r>(
&'r mut self,
router: &'r Router<Self>,
router: &'r Router,
param: Param,
) -> Pin<Box<dyn Future<Output = Result<Receiver<Result<DbResponse>>>> + Send + Sync + 'r>> {
Box::pin(async move {

View file

@ -53,7 +53,7 @@ pub(super) fn mock(route_rx: Receiver<Option<Route>>) {
_ => unreachable!(),
},
Method::Query => match param.query {
Some(_) => Ok(DbResponse::Query(QueryResponse(Default::default()))),
Some(_) => Ok(DbResponse::Query(QueryResponse::new())),
_ => unreachable!(),
},
Method::Create => match &params[..] {

View file

@ -17,6 +17,7 @@ use crate::api::err::Error;
use crate::api::opt::Endpoint;
use semver::BuildMetadata;
use semver::VersionReq;
use std::fmt;
use std::fmt::Debug;
use std::future::Future;
use std::future::IntoFuture;
@ -37,7 +38,8 @@ pub trait Connection: conn::Connection {}
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Connect<C: Connection, Response> {
router: Arc<OnceLock<Router<C>>>,
router: Arc<OnceLock<Router>>,
engine: PhantomData<C>,
address: Result<Endpoint>,
capacity: usize,
client: PhantomData<C>,
@ -115,6 +117,7 @@ where
self.router.set(router).map_err(|_| Error::AlreadyConnected)?;
let client = Surreal {
router: self.router,
engine: PhantomData::<Client>,
};
client.check_server_version().await?;
Ok(())
@ -129,9 +132,9 @@ pub(crate) enum ExtraFeatures {
}
/// A database client instance for embedded or remote databases
#[derive(Debug)]
pub struct Surreal<C: Connection> {
router: Arc<OnceLock<Router<C>>>,
router: Arc<OnceLock<Router>>,
engine: PhantomData<C>,
}
impl<C> Surreal<C>
@ -171,15 +174,25 @@ where
fn clone(&self) -> Self {
Self {
router: self.router.clone(),
engine: self.engine,
}
}
}
trait OnceLockExt<C>
impl<C> Debug for Surreal<C>
where
C: Connection,
{
fn with_value(value: Router<C>) -> OnceLock<Router<C>> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Surreal")
.field("router", &self.router)
.field("engine", &self.engine)
.finish()
}
}
trait OnceLockExt {
fn with_value(value: Router) -> OnceLock<Router> {
let cell = OnceLock::new();
match cell.set(value) {
Ok(()) => cell,
@ -187,14 +200,11 @@ where
}
}
fn extract(&self) -> Result<&Router<C>>;
fn extract(&self) -> Result<&Router>;
}
impl<C> OnceLockExt<C> for OnceLock<Router<C>>
where
C: Connection,
{
fn extract(&self) -> Result<&Router<C>> {
impl OnceLockExt for OnceLock<Router> {
fn extract(&self) -> Result<&Router> {
let router = self.get().ok_or(Error::ConnectionUninitialised)?;
Ok(router)
}

View file

@ -1,8 +1,12 @@
use crate::api::{err::Error, opt::from_value, Response as QueryResponse, Result};
use crate::method::Stats;
use crate::method;
use crate::method::{Stats, Stream};
use crate::sql::{self, statements::*, Array, Object, Statement, Statements, Value};
use crate::syn;
use crate::{syn, Notification};
use futures::future::Either;
use futures::stream::select_all;
use serde::de::DeserializeOwned;
use std::marker::PhantomData;
use std::mem;
/// A trait for converting inputs into SQL statements
@ -178,21 +182,21 @@ where
fn query_result(self, response: &mut QueryResponse) -> Result<Response>;
/// Extracts the statistics from a query response
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> {
map.get(&0).map(|x| x.0)
fn stats(&self, response: &QueryResponse) -> Option<Stats> {
response.results.get(&0).map(|x| x.0)
}
}
impl QueryResult<Value> for usize {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Value> {
match map.remove(&self) {
fn query_result(self, response: &mut QueryResponse) -> Result<Value> {
match response.results.remove(&self) {
Some((_, result)) => Ok(result?),
None => Ok(Value::None),
}
}
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> {
map.get(self).map(|x| x.0)
fn stats(&self, response: &QueryResponse) -> Option<Stats> {
response.results.get(self).map(|x| x.0)
}
}
@ -200,13 +204,13 @@ impl<T> QueryResult<Option<T>> for usize
where
T: DeserializeOwned,
{
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Option<T>> {
let value = match map.get_mut(&self) {
fn query_result(self, response: &mut QueryResponse) -> Result<Option<T>> {
let value = match response.results.get_mut(&self) {
Some((_, result)) => match result {
Ok(val) => val,
Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&self);
response.results.remove(&self);
return Err(error);
}
},
@ -221,31 +225,36 @@ where
let value = mem::take(value);
from_value(value).map_err(Into::into)
}
_ => Err(Error::LossyTake(QueryResponse(mem::take(map))).into()),
_ => Err(Error::LossyTake(QueryResponse {
results: mem::take(&mut response.results),
live_queries: mem::take(&mut response.live_queries),
..QueryResponse::new()
})
.into()),
},
_ => {
let value = mem::take(value);
from_value(value).map_err(Into::into)
}
};
map.remove(&self);
response.results.remove(&self);
result
}
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> {
map.get(self).map(|x| x.0)
fn stats(&self, response: &QueryResponse) -> Option<Stats> {
response.results.get(self).map(|x| x.0)
}
}
impl QueryResult<Value> for (usize, &str) {
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Value> {
fn query_result(self, response: &mut QueryResponse) -> Result<Value> {
let (index, key) = self;
let response = match map.get_mut(&index) {
let value = match response.results.get_mut(&index) {
Some((_, result)) => match result {
Ok(val) => val,
Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&index);
response.results.remove(&index);
return Err(error);
}
},
@ -254,16 +263,16 @@ impl QueryResult<Value> for (usize, &str) {
}
};
let response = match response {
let value = match value {
Value::Object(Object(object)) => object.remove(key).unwrap_or_default(),
_ => Value::None,
};
Ok(response)
Ok(value)
}
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> {
map.get(&self.0).map(|x| x.0)
fn stats(&self, response: &QueryResponse) -> Option<Stats> {
response.results.get(&self.0).map(|x| x.0)
}
}
@ -271,14 +280,14 @@ impl<T> QueryResult<Option<T>> for (usize, &str)
where
T: DeserializeOwned,
{
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Option<T>> {
fn query_result(self, response: &mut QueryResponse) -> Result<Option<T>> {
let (index, key) = self;
let value = match map.get_mut(&index) {
let value = match response.results.get_mut(&index) {
Some((_, result)) => match result {
Ok(val) => val,
Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&index);
response.results.remove(&index);
return Err(error);
}
},
@ -289,24 +298,29 @@ where
let value = match value {
Value::Array(Array(vec)) => match &mut vec[..] {
[] => {
map.remove(&index);
response.results.remove(&index);
return Ok(None);
}
[value] => value,
_ => {
return Err(Error::LossyTake(QueryResponse(mem::take(map))).into());
return Err(Error::LossyTake(QueryResponse {
results: mem::take(&mut response.results),
live_queries: mem::take(&mut response.live_queries),
..QueryResponse::new()
})
.into());
}
},
value => value,
};
match value {
Value::None | Value::Null => {
map.remove(&index);
response.results.remove(&index);
Ok(None)
}
Value::Object(Object(object)) => {
if object.is_empty() {
map.remove(&index);
response.results.remove(&index);
return Ok(None);
}
let Some(value) = object.remove(key) else {
@ -318,8 +332,8 @@ where
}
}
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> {
map.get(&self.0).map(|x| x.0)
fn stats(&self, response: &QueryResponse) -> Option<Stats> {
response.results.get(&self.0).map(|x| x.0)
}
}
@ -327,8 +341,8 @@ impl<T> QueryResult<Vec<T>> for usize
where
T: DeserializeOwned,
{
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Vec<T>> {
let vec = match map.remove(&self) {
fn query_result(self, response: &mut QueryResponse) -> Result<Vec<T>> {
let vec = match response.results.remove(&self) {
Some((_, result)) => match result? {
Value::Array(Array(vec)) => vec,
vec => vec![vec],
@ -340,8 +354,8 @@ where
from_value(vec.into()).map_err(Into::into)
}
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> {
map.get(self).map(|x| x.0)
fn stats(&self, response: &QueryResponse) -> Option<Stats> {
response.results.get(self).map(|x| x.0)
}
}
@ -349,9 +363,9 @@ impl<T> QueryResult<Vec<T>> for (usize, &str)
where
T: DeserializeOwned,
{
fn query_result(self, QueryResponse(map): &mut QueryResponse) -> Result<Vec<T>> {
fn query_result(self, response: &mut QueryResponse) -> Result<Vec<T>> {
let (index, key) = self;
let mut response = match map.get_mut(&index) {
let mut response = match response.results.get_mut(&index) {
Some((_, result)) => match result {
Ok(val) => match val {
Value::Array(Array(vec)) => mem::take(vec),
@ -362,7 +376,7 @@ where
},
Err(error) => {
let error = mem::replace(error, Error::ConnectionUninitialised.into());
map.remove(&index);
response.results.remove(&index);
return Err(error);
}
},
@ -381,8 +395,8 @@ where
from_value(vec.into()).map_err(Into::into)
}
fn stats(&self, QueryResponse(map): &QueryResponse) -> Option<Stats> {
map.get(&self.0).map(|x| x.0)
fn stats(&self, response: &QueryResponse) -> Option<Stats> {
response.results.get(&self.0).map(|x| x.0)
}
}
@ -409,3 +423,114 @@ where
(0, self).query_result(response)
}
}
/// A way to take a query stream future from a query response
pub trait QueryStream<R> {
/// Retrieves the query stream future
fn query_stream(self, response: &mut QueryResponse) -> Result<method::QueryStream<R>>;
}
impl QueryStream<Value> for usize {
fn query_stream(self, response: &mut QueryResponse) -> Result<method::QueryStream<Value>> {
let stream = response
.live_queries
.remove(&self)
.and_then(|result| match result {
Err(crate::Error::Api(Error::NotLiveQuery(..))) => {
response.results.remove(&self).and_then(|x| x.1.err().map(Err))
}
result => Some(result),
})
.unwrap_or_else(|| match response.results.contains_key(&self) {
true => Err(Error::NotLiveQuery(self).into()),
false => Err(Error::QueryIndexOutOfBounds(self).into()),
})?;
Ok(method::QueryStream(Either::Left(stream)))
}
}
impl QueryStream<Value> for () {
fn query_stream(self, response: &mut QueryResponse) -> Result<method::QueryStream<Value>> {
let mut streams = Vec::with_capacity(response.live_queries.len());
for (index, result) in mem::take(&mut response.live_queries) {
match result {
Ok(stream) => streams.push(stream),
Err(crate::Error::Api(Error::NotLiveQuery(..))) => match response.results.remove(&index) {
Some((stats, Err(error))) => {
response.results.insert(index, (stats, Err(Error::ResponseAlreadyTaken.into())));
return Err(error);
}
Some((_, Ok(..))) => unreachable!("the internal error variant indicates that an error occurred in the `LIVE SELECT` query"),
None => { return Err(Error::ResponseAlreadyTaken.into()); }
}
Err(error) => { return Err(error); }
}
}
Ok(method::QueryStream(Either::Right(select_all(streams))))
}
}
impl<R> QueryStream<Notification<R>> for usize
where
R: DeserializeOwned + Unpin,
{
fn query_stream(
self,
response: &mut QueryResponse,
) -> Result<method::QueryStream<Notification<R>>> {
let mut stream = response
.live_queries
.remove(&self)
.and_then(|result| match result {
Err(crate::Error::Api(Error::NotLiveQuery(..))) => {
response.results.remove(&self).and_then(|x| x.1.err().map(Err))
}
result => Some(result),
})
.unwrap_or_else(|| match response.results.contains_key(&self) {
true => Err(Error::NotLiveQuery(self).into()),
false => Err(Error::QueryIndexOutOfBounds(self).into()),
})?;
Ok(method::QueryStream(Either::Left(Stream {
client: stream.client.clone(),
engine: stream.engine,
id: mem::take(&mut stream.id),
rx: stream.rx.take(),
response_type: PhantomData,
})))
}
}
impl<R> QueryStream<Notification<R>> for ()
where
R: DeserializeOwned + Unpin,
{
fn query_stream(
self,
response: &mut QueryResponse,
) -> Result<method::QueryStream<Notification<R>>> {
let mut streams = Vec::with_capacity(response.live_queries.len());
for (index, result) in mem::take(&mut response.live_queries) {
let mut stream = match result {
Ok(stream) => stream,
Err(crate::Error::Api(Error::NotLiveQuery(..))) => match response.results.remove(&index) {
Some((stats, Err(error))) => {
response.results.insert(index, (stats, Err(Error::ResponseAlreadyTaken.into())));
return Err(error);
}
Some((_, Ok(..))) => unreachable!("the internal error variant indicates that an error occurred in the `LIVE SELECT` query"),
None => { return Err(Error::ResponseAlreadyTaken.into()); }
}
Err(error) => { return Err(error); }
};
streams.push(Stream {
client: stream.client.clone(),
engine: stream.engine,
id: mem::take(&mut stream.id),
rx: stream.rx.take(),
response_type: PhantomData,
});
}
Ok(method::QueryStream(Either::Right(select_all(streams))))
}
}

View file

@ -158,6 +158,7 @@ pub use api::Response;
pub use api::Result;
#[doc(inline)]
pub use api::Surreal;
use uuid::Uuid;
#[doc(hidden)]
/// Channels for receiving a SurrealQL database export
@ -203,6 +204,7 @@ impl From<dbs::Action> for Action {
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[non_exhaustive]
pub struct Notification<R> {
pub query_id: Uuid,
pub action: Action,
pub data: R,
}

View file

@ -174,3 +174,128 @@ async fn live_select_record_ranges() {
drop(permit);
}
#[test_log::test(tokio::test)]
async fn live_select_query() {
let (permit, db) = new_db().await;
db.use_ns(NS).use_db(Ulid::new().to_string()).await.unwrap();
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query(format!("LIVE SELECT * FROM {table}"))
.await
.unwrap()
.stream::<Notification<_>>(0)
.unwrap();
// Create a record
let created: Vec<RecordId> = db.create(table).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// The returned record should match the created record
assert_eq!(created, vec![notification.data.clone()]);
// It should be newly created
assert_eq!(notification.action, Action::Create);
// Update the record
let _: Option<RecordId> =
db.update(&notification.data.id).content(json!({"foo": "bar"})).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be updated
assert_eq!(notification.action, Action::Update);
// Delete the record
let _: Option<RecordId> = db.delete(&notification.data.id).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be deleted
assert_eq!(notification.action, Action::Delete);
}
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query(format!("LIVE SELECT * FROM {table}"))
.await
.unwrap()
.stream::<Value>(0)
.unwrap();
// Create a record
db.create(Resource::from(&table)).await.unwrap();
// Pull the notification
let notification = users.next().await.unwrap();
// The returned record should be an object
assert!(notification.data.is_object());
// It should be newly created
assert_eq!(notification.action, Action::Create);
}
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query(format!("LIVE SELECT * FROM {table}"))
.await
.unwrap()
.stream::<Notification<_>>(())
.unwrap();
// Create a record
let created: Vec<RecordId> = db.create(table).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// The returned record should match the created record
assert_eq!(created, vec![notification.data.clone()]);
// It should be newly created
assert_eq!(notification.action, Action::Create);
// Update the record
let _: Option<RecordId> =
db.update(&notification.data.id).content(json!({"foo": "bar"})).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be updated
assert_eq!(notification.action, Action::Update);
// Delete the record
let _: Option<RecordId> = db.delete(&notification.data.id).await.unwrap();
// Pull the notification
let notification: Notification<RecordId> = users.next().await.unwrap().unwrap();
// It should be deleted
assert_eq!(notification.action, Action::Delete);
}
{
let table = Ulid::new().to_string();
// Start listening
let mut users = db
.query("BEGIN")
.query(format!("LIVE SELECT * FROM {table}"))
.query("COMMIT")
.await
.unwrap()
.stream::<Value>(())
.unwrap();
// Create a record
db.create(Resource::from(&table)).await.unwrap();
// Pull the notification
let notification = users.next().await.unwrap();
// The returned record should be an object
assert!(notification.data.is_object());
// It should be newly created
assert_eq!(notification.action, Action::Create);
}
drop(permit);
}

View file

@ -33,7 +33,6 @@ async fn yuse() {
drop(permit);
}
#[ignore]
#[test_log::test(tokio::test)]
async fn invalidate() {
let (permit, db) = new_db().await;

View file

@ -5,6 +5,8 @@ use crate::cli::abstraction::{
use crate::cnf::PKG_VERSION;
use crate::err::Error;
use clap::Args;
use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
use futures_util::{SinkExt, StreamExt};
use rustyline::error::ReadlineError;
use rustyline::validate::{ValidationContext, ValidationResult, Validator};
use rustyline::{Completer, Editor, Helper, Highlighter, Hinter};
@ -15,7 +17,7 @@ use surrealdb::engine::any::{connect, IntoEndpoint};
use surrealdb::method::{Stats, WithStats};
use surrealdb::opt::Config;
use surrealdb::sql::{self, Statement, Value};
use surrealdb::Response;
use surrealdb::{Notification, Response};
#[derive(Args, Debug)]
pub struct SqlCommandArguments {
@ -149,6 +151,10 @@ pub async fn init(
);
}
// Set up the print job
let (tx, rx) = mpsc::unbounded();
tokio::spawn(printer(rx));
// Loop over each command-line input
loop {
// Prompt the user to input SQL and check the input.
@ -206,16 +212,13 @@ pub async fn init(
continue;
}
// Run the query provided
let res = client.query(query).with_stats().await;
match process(pretty, json, res) {
Ok(v) => {
println!("{v}\n");
}
Err(e) => {
eprintln!("{e}\n");
let result = client.query(query).with_stats().await;
let result = process(pretty, json, result, tx.clone());
let result_is_error = result.is_err();
tx.clone().send(result).await.expect("print job terminated unexpectedly");
if result_is_error {
continue;
}
}
// Persist the variables extracted from the query
for (key, value) in vars {
let _ = client.set(key, value).await;
@ -253,6 +256,7 @@ fn process(
pretty: bool,
json: bool,
res: surrealdb::Result<WithStats<Response>>,
mut tx: UnboundedSender<Result<String, Error>>,
) -> Result<String, Error> {
// Check query response for an error
let mut response = res?;
@ -271,6 +275,62 @@ fn process(
vec.push((stats, output));
}
tokio::spawn(async move {
let mut stream = match response.into_inner().stream::<Value>(()) {
Ok(stream) => stream,
Err(error) => {
tx.send(Err(error.into())).await.ok();
return;
}
};
while let Some(Notification {
query_id,
action,
data,
..
}) = stream.next().await
{
let message = match (json, pretty) {
// Don't prettify the SurrealQL response
(false, false) => {
let value = Value::from(map! {
String::from("id") => query_id.into(),
String::from("action") => format!("{action:?}").to_ascii_uppercase().into(),
String::from("result") => data,
});
value.to_string()
}
// Yes prettify the SurrealQL response
(false, true) => format!(
"-- Notification (action: {action:?}, live query ID: {query_id})\n{data:#}"
),
// Don't pretty print the JSON response
(true, false) => {
let value = Value::from(map! {
String::from("id") => query_id.into(),
String::from("action") => format!("{action:?}").to_ascii_uppercase().into(),
String::from("result") => data,
});
value.into_json().to_string()
}
// Yes prettify the JSON response
(true, true) => {
let mut buf = Vec::new();
let mut serializer = serde_json::Serializer::with_formatter(
&mut buf,
PrettyFormatter::with_indent(b"\t"),
);
data.into_json().serialize(&mut serializer).unwrap();
let output = String::from_utf8(buf).unwrap();
format!("-- Notification (action: {action:?}, live query ID: {query_id})\n{output:#}")
}
};
if tx.send(Ok(format!("\n{message}"))).await.is_err() {
return;
}
}
});
// Check if we should emit JSON and/or prettify
Ok(match (json, pretty) {
// Don't prettify the SurrealQL response
@ -314,6 +374,19 @@ fn process(
})
}
async fn printer(mut rx: UnboundedReceiver<Result<String, Error>>) {
while let Some(result) = rx.next().await {
match result {
Ok(v) => {
println!("{v}\n");
}
Err(e) => {
eprintln!("{e}\n");
}
}
}
}
#[derive(Completer, Helper, Highlighter, Hinter)]
struct InputValidator {
/// If omitting semicolon causes newline.

View file

@ -799,6 +799,7 @@ mod cli_integration {
}
#[test(tokio::test)]
#[ignore]
async fn test_capabilities() {
// Default capabilities only allow functions
info!("* When default capabilities");
@ -825,7 +826,8 @@ mod cli_integration {
let query = "RETURN function() { return '1' };";
let output = common::run(&cmd).input(query).output().unwrap();
assert!(
output.contains("Scripting functions are not allowed"),
output.contains("Scripting functions are not allowed")
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}"
);
}
@ -855,7 +857,8 @@ mod cli_integration {
let query = "RETURN function() { return '1' };";
let output = common::run(&cmd).input(query).output().unwrap();
assert!(
output.contains("Scripting functions are not allowed"),
output.contains("Scripting functions are not allowed")
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}"
);
}
@ -901,7 +904,8 @@ mod cli_integration {
let query = "RETURN function() { return '1' };";
let output = common::run(&cmd).input(query).output().unwrap();
assert!(
output.contains("Scripting functions are not allowed"),
output.contains("Scripting functions are not allowed")
|| output.contains("Embedded functions are not enabled"),
"unexpected output: {output:?}"
);
}