Restructure the Query struct in the rust library ()

This commit is contained in:
Mees Delzenne 2024-06-11 14:01:58 +02:00 committed by GitHub
parent b9b2974883
commit 1e0eddceaa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 229 additions and 165 deletions

View file

@ -240,7 +240,6 @@ impl Surreal<Any> {
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}
@ -297,7 +296,6 @@ pub fn connect(address: impl IntoEndpoint) -> Connect<Any, Surreal<Any>> {
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: Arc::new(watch::channel(None)),
response_type: PhantomData,
}

View file

@ -27,7 +27,6 @@ use flume::Receiver;
use reqwest::ClientBuilder;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
@ -231,15 +230,14 @@ impl Connection for Any {
EndpointKind::Unsupported(v) => return Err(Error::Scheme(v).into()),
}
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -184,15 +184,14 @@ impl Connection for Any {
EndpointKind::Unsupported(v) => return Err(Error::Scheme(v).into()),
}
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -389,7 +389,6 @@ impl Surreal<Db> {
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}

View file

@ -27,7 +27,6 @@ use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
@ -64,15 +63,14 @@ impl Connection for Db {
features.insert(ExtraFeatures::Backup);
features.insert(ExtraFeatures::LiveQueries);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -65,15 +65,14 @@ impl Connection for Db {
let mut features = HashSet::new();
features.insert(ExtraFeatures::LiveQueries);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -102,7 +102,6 @@ impl Surreal<Client> {
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}

View file

@ -20,7 +20,6 @@ use reqwest::header::HeaderMap;
use reqwest::ClientBuilder;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
@ -73,15 +72,14 @@ impl Connection for Client {
let mut features = HashSet::new();
features.insert(ExtraFeatures::Backup);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -52,15 +52,14 @@ impl Connection for Client {
conn_rx.into_recv_async().await??;
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features: HashSet::new(),
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -77,7 +77,6 @@ impl Surreal<Client> {
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}

View file

@ -35,7 +35,6 @@ use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::future::Future;
use std::marker::PhantomData;
use std::mem;
use std::pin::Pin;
use std::sync::atomic::AtomicI64;
@ -150,15 +149,14 @@ impl Connection for Client {
let mut features = HashSet::new();
features.insert(ExtraFeatures::LiveQueries);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -90,15 +90,14 @@ impl Connection for Client {
let mut features = HashSet::new();
features.insert(ExtraFeatures::LiveQueries);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(Router {
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
waiter: Arc::new(watch::channel(Some(WaitFor::Connection))),
engine: PhantomData,
})
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}

View file

@ -93,25 +93,19 @@ macro_rules! into_future {
Resource::Edges(edges) => return Err(Error::LiveOnEdges(edges).into()),
},
}
let query = Query {
client: client.clone(),
query: vec![Ok(vec![Statement::Live(stmt)])],
bindings: Ok(Default::default()),
register_live_queries: false,
};
let query = Query::new(
client.clone(),
vec![Statement::Live(stmt)],
Default::default(),
false,
);
let id: Value = query.await?.take(0)?;
let rx = register::<Client>(router, id.clone()).await?;
Ok(Stream {
Ok(Stream::new(
Surreal::new_from_router_waiter(client.router.clone(), client.waiter.clone()),
id,
rx: Some(rx),
client: Surreal {
router: client.router.clone(),
waiter: client.waiter.clone(),
engine: PhantomData,
},
response_type: PhantomData,
engine: PhantomData,
})
Some(rx),
))
})
}
};
@ -177,6 +171,22 @@ pub struct Stream<'r, C: Connection, R> {
pub(crate) response_type: PhantomData<R>,
}
impl<'r, C: Connection, R> Stream<'r, C, R> {
pub(crate) fn new(
client: Surreal<Any>,
id: Value,
rx: Option<Receiver<dbs::Notification>>,
) -> Self {
Self {
id,
rx,
client,
response_type: PhantomData,
engine: PhantomData,
}
}
}
macro_rules! poll_next {
($notification:ident => $body:expr) => {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {

View file

@ -86,6 +86,8 @@ use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use self::query::ValidQuery;
/// Query statistics
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[non_exhaustive]
@ -262,7 +264,6 @@ where
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: Arc::new(watch::channel(None)),
response_type: PhantomData,
}
@ -662,11 +663,15 @@ where
/// # }
/// ```
pub fn query(&self, query: impl opt::IntoQuery) -> Query<C> {
Query {
let inner = query.into_query().map(|x| ValidQuery {
client: Cow::Borrowed(self),
query: vec![query.into_query()],
bindings: Ok(Default::default()),
query: x,
bindings: Default::default(),
register_live_queries: true,
});
Query {
inner,
}
}

View file

@ -29,7 +29,6 @@ use std::collections::BTreeMap;
use std::collections::HashMap;
use std::future::Future;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::mem;
use std::pin::Pin;
use std::task::Context;
@ -39,21 +38,70 @@ use std::task::Poll;
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Query<'r, C: Connection> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) query: Vec<Result<Vec<Statement>>>,
pub(super) bindings: Result<BTreeMap<String, Value>>,
pub(crate) register_live_queries: bool,
pub(crate) inner: Result<ValidQuery<'r, C>>,
}
impl<C> Query<'_, C>
#[derive(Debug)]
pub(crate) struct ValidQuery<'r, C: Connection> {
pub client: Cow<'r, Surreal<C>>,
pub query: Vec<Statement>,
pub bindings: BTreeMap<String, Value>,
pub register_live_queries: bool,
}
impl<'r, C> Query<'r, C>
where
C: Connection,
{
pub(crate) fn new(
client: Cow<'r, Surreal<C>>,
query: Vec<Statement>,
bindings: BTreeMap<String, Value>,
register_live_queries: bool,
) -> Self {
Query {
inner: Ok(ValidQuery {
client,
query,
bindings,
register_live_queries,
}),
}
}
pub(crate) fn map_valid<F>(self, f: F) -> Self
where
F: FnOnce(ValidQuery<'r, C>) -> Result<ValidQuery<'r, C>>,
{
match self.inner {
Ok(x) => Query {
inner: f(x),
},
x => Query {
inner: x,
},
}
}
/// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Query<'static, C> {
let inner = match self.inner {
Ok(ValidQuery {
client,
query,
bindings,
register_live_queries,
}) => Ok(ValidQuery::<'static, C> {
client: Cow::Owned(client.into_owned()),
query,
bindings,
register_live_queries,
}),
Err(e) => Err(e),
};
Query {
client: Cow::Owned(self.client.into_owned()),
..self
inner,
}
}
}
@ -66,69 +114,81 @@ where
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;
fn into_future(self) -> Self::IntoFuture {
let ValidQuery {
client,
query,
bindings,
register_live_queries,
} = match self.inner {
Ok(x) => x,
Err(error) => return Box::pin(async move { Err(error) }),
};
let query_statements = query;
Box::pin(async move {
// Extract the router from the client
let router = self.client.router.extract()?;
// Combine all query statements supplied
let mut statements = Vec::with_capacity(self.query.len());
for query in self.query {
statements.extend(query?);
let router = client.router.extract()?;
// Collect the indexes of the live queries which should be registerd.
let query_indicies = if register_live_queries {
query_statements
.iter()
// BEGIN, COMMIT, and CANCEL don't return a result.
.filter(|x| {
!matches!(
x,
Statement::Begin(_) | Statement::Commit(_) | Statement::Cancel(_)
)
})
.enumerate()
.filter(|(_, x)| matches!(x, Statement::Live(_)))
.map(|(i, _)| i)
.collect()
} else {
Vec::new()
};
// If there are live queries and it is not supported, return an error.
if !query_indicies.is_empty() && !router.features.contains(&ExtraFeatures::LiveQueries)
{
return Err(Error::LiveQueriesNotSupported.into());
}
// Build the query and execute it
let mut query = sql::Query::default();
query.0 .0.clone_from(&statements);
let param = Param::query(query, self.bindings?);
query.0 .0 = query_statements;
let param = Param::query(query, bindings);
let mut conn = Client::new(Method::Query);
let mut response = conn.execute_query(router, param).await?;
// Register live queries if necessary
if self.register_live_queries {
let mut live_queries = IndexMap::new();
let mut checked = false;
// Adjusting offsets as a workaround to https://github.com/surrealdb/surrealdb/issues/3318
let mut offset = 0;
for (index, stmt) in statements.into_iter().enumerate() {
if let Statement::Live(stmt) = stmt {
if !checked && !router.features.contains(&ExtraFeatures::LiveQueries) {
return Err(Error::LiveQueriesNotSupported.into());
}
checked = true;
let index = index - offset;
if let Some((_, result)) = response.results.get(&index) {
let result =
match result {
Ok(id) => live::register::<Client>(router, id.clone())
.await
.map(|rx| Stream {
id: stmt.id.into(),
rx: Some(rx),
client: Surreal {
router: self.client.router.clone(),
waiter: self.client.waiter.clone(),
engine: PhantomData,
},
response_type: PhantomData,
engine: PhantomData,
}),
// This is a live query. We are using this as a workaround to avoid
// creating another public error variant for this internal error.
Err(..) => Err(Error::NotLiveQuery(index).into()),
};
live_queries.insert(index, result);
}
} else if matches!(
stmt,
Statement::Begin(..) | Statement::Commit(..) | Statement::Cancel(..)
) {
offset += 1;
}
}
response.live_queries = live_queries;
for idx in query_indicies {
let Some((_, result)) = response.results.get(&idx) else {
continue;
};
// This is a live query. We are using this as a workaround to avoid
// creating another public error variant for this internal error.
let res = match result {
Ok(id) => live::register::<Client>(router, id.clone()).await.map(|rx| {
Stream::new(
Surreal::new_from_router_waiter(
client.router.clone(),
client.waiter.clone(),
),
id.clone(),
Some(rx),
)
}),
Err(_) => Err(crate::Error::from(Error::NotLiveQuery(idx))),
};
dbg!(&response);
response.live_queries.insert(idx, res);
}
response.client = Surreal {
router: self.client.router.clone(),
waiter: self.client.waiter.clone(),
engine: PhantomData,
};
response.client =
Surreal::new_from_router_waiter(client.router.clone(), client.waiter.clone());
Ok(response)
})
}
@ -154,9 +214,12 @@ where
C: Connection,
{
/// Chains a query onto an existing query
pub fn query(mut self, query: impl opt::IntoQuery) -> Self {
self.query.push(query.into_query());
self
pub fn query(self, query: impl opt::IntoQuery) -> Self {
self.map_valid(move |mut valid| {
let new_query = query.into_query()?;
valid.query.extend(new_query);
Ok(valid)
})
}
/// Return query statistics along with its results
@ -205,30 +268,25 @@ where
/// # Ok(())
/// # }
/// ```
pub fn bind(mut self, bindings: impl Serialize) -> Self {
if let Ok(current) = &mut self.bindings {
match to_value(bindings) {
Ok(mut bindings) => {
if let Value::Array(array) = &mut bindings {
if let [Value::Strand(key), value] = &mut array.0[..] {
let mut map = BTreeMap::new();
map.insert(mem::take(&mut key.0), mem::take(value));
bindings = map.into();
}
}
match &mut bindings {
Value::Object(map) => current.append(&mut map.0),
_ => {
self.bindings = Err(Error::InvalidBindings(bindings).into());
}
}
}
Err(error) => {
self.bindings = Err(error.into());
pub fn bind(self, bindings: impl Serialize) -> Self {
self.map_valid(move |mut valid| {
let mut bindings = to_value(bindings)?;
if let Value::Array(array) = &mut bindings {
if let [Value::Strand(key), value] = &mut array.0[..] {
let mut map = BTreeMap::new();
map.insert(mem::take(&mut key.0), mem::take(value));
bindings = map.into();
}
}
}
self
match &mut bindings {
Value::Object(map) => valid.bindings.append(&mut map.0),
_ => {
return Err(Error::InvalidBindings(bindings).into());
}
}
Ok(valid)
})
}
}

View file

@ -49,7 +49,6 @@ impl Surreal<Client> {
engine: PhantomData,
address: address.into_endpoint(),
capacity: 0,
client: PhantomData,
waiter: self.waiter.clone(),
response_type: PhantomData,
}
@ -79,11 +78,10 @@ impl Connection for Client {
last_id: AtomicI64::new(0),
};
server::mock(route_rx);
Ok(Surreal {
router: Arc::new(OnceLock::with_value(router)),
waiter: Arc::new(watch::channel(None)),
engine: PhantomData,
})
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(router)),
Arc::new(watch::channel(None)),
))
})
}

View file

@ -51,7 +51,6 @@ pub struct Connect<C: Connection, Response> {
engine: PhantomData<C>,
address: Result<Endpoint>,
capacity: usize,
client: PhantomData<C>,
waiter: Arc<Waiter>,
response_type: PhantomData<Response>,
}
@ -177,6 +176,17 @@ impl<C> Surreal<C>
where
C: Connection,
{
pub(crate) fn new_from_router_waiter(
router: Arc<OnceLock<Router>>,
waiter: Arc<Waiter>,
) -> Self {
Surreal {
router,
waiter,
engine: PhantomData,
}
}
async fn check_server_version(&self, version: &Version) -> Result<()> {
let (versions, build_meta) = SUPPORTED_VERSIONS;
// invalid version requirements should be caught during development