use crate::api; use crate::api::err::Error; use crate::api::method::query::Response; use crate::api::opt::Endpoint; use crate::api::ExtraFeatures; use crate::api::Result; use crate::api::Surreal; use crate::dbs::Notification; use crate::sql::from_value; use crate::sql::Query; use crate::sql::Value; use flume::Receiver; use flume::Sender; use serde::de::DeserializeOwned; use serde::Serialize; use std::collections::BTreeMap; use std::collections::HashSet; use std::future::Future; use std::path::PathBuf; use std::pin::Pin; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; #[derive(Debug)] #[allow(dead_code)] // used by the embedded and remote connections pub(crate) struct Route { pub(crate) request: (i64, Method, Param), pub(crate) response: Sender>, } /// Message router #[derive(Debug)] pub struct Router { pub(crate) sender: Sender>, pub(crate) last_id: AtomicI64, pub(crate) features: HashSet, } impl Router { pub(crate) fn next_id(&self) -> i64 { self.last_id.fetch_add(1, Ordering::SeqCst) } } impl Drop for Router { fn drop(&mut self) { let _res = self.sender.send(None); } } /// The query method #[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash)] #[serde(rename_all = "lowercase")] pub enum Method { /// Sends an authentication token to the server Authenticate, /// Performs a merge update operation Merge, /// Creates a record in a table Create, /// Deletes a record from a table Delete, /// Exports a database Export, /// Checks the health of the server Health, /// Imports a database Import, /// Invalidates a session Invalidate, /// Kills a live query #[doc(hidden)] // Not supported yet Kill, /// Starts a live query #[doc(hidden)] // Not supported yet Live, /// Performs a patch update operation Patch, /// Sends a raw query to the database Query, /// Selects a record or records from a table Select, /// Sets a parameter on the connection Set, /// Signs into the server Signin, /// Signs up on the server Signup, /// Removes a parameter from a connection Unset, /// Performs an update operation Update, /// Selects a namespace and database to use Use, /// Queries the version of the server Version, } /// The database response sent from the router to the caller #[derive(Debug)] pub enum DbResponse { /// The response sent for the `query` method Query(Response), /// The response sent for any method except `query` Other(Value), } #[derive(Debug)] #[allow(dead_code)] // used by ML model import and export functions pub(crate) enum MlConfig { Import, Export { name: String, version: String, }, } /// Holds the parameters given to the caller #[derive(Debug, Default)] #[allow(dead_code)] // used by the embedded and remote connections pub struct Param { pub(crate) query: Option<(Query, BTreeMap)>, pub(crate) other: Vec, pub(crate) file: Option, pub(crate) bytes_sender: Option>>>, pub(crate) notification_sender: Option>, pub(crate) ml_config: Option, } impl Param { pub(crate) fn new(other: Vec) -> Self { Self { other, ..Default::default() } } pub(crate) fn query(query: Query, bindings: BTreeMap) -> Self { Self { query: Some((query, bindings)), ..Default::default() } } pub(crate) fn file(file: PathBuf) -> Self { Self { file: Some(file), ..Default::default() } } pub(crate) fn bytes_sender(send: channel::Sender>>) -> Self { Self { bytes_sender: Some(send), ..Default::default() } } pub(crate) fn notification_sender(send: channel::Sender) -> Self { Self { notification_sender: Some(send), ..Default::default() } } } /// Connection trait implemented by supported protocols pub trait Connection: Sized + Send + Sync + 'static { /// Constructs a new client without connecting to the server fn new(method: Method) -> Self; /// Connect to the server fn connect( address: Endpoint, capacity: usize, ) -> Pin>> + Send + Sync + 'static>> where Self: api::Connection; /// Send a query to the server #[allow(clippy::type_complexity)] fn send<'r>( &'r mut self, router: &'r Router, param: Param, ) -> Pin>>> + Send + Sync + 'r>> where Self: api::Connection; /// Receive responses for all methods except `query` fn recv( &mut self, receiver: Receiver>, ) -> Pin> + Send + Sync + '_>> { Box::pin(async move { let response = receiver.into_recv_async().await?; match response? { DbResponse::Other(value) => Ok(value), DbResponse::Query(..) => unreachable!(), } }) } /// Receive the response of the `query` method fn recv_query( &mut self, receiver: Receiver>, ) -> Pin> + Send + Sync + '_>> { Box::pin(async move { let response = receiver.into_recv_async().await?; match response? { DbResponse::Query(results) => Ok(results), DbResponse::Other(..) => unreachable!(), } }) } /// Execute all methods except `query` fn execute<'r, R>( &'r mut self, router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where R: DeserializeOwned, Self: api::Connection, { Box::pin(async move { let rx = self.send(router, param).await?; let value = self.recv(rx).await?; from_value(value).map_err(Into::into) }) } /// Execute methods that return an optional single response fn execute_opt<'r, R>( &'r mut self, router: &'r Router, param: Param, ) -> Pin>> + Send + Sync + 'r>> where R: DeserializeOwned, Self: api::Connection, { Box::pin(async move { let rx = self.send(router, param).await?; match self.recv(rx).await? { Value::None | Value::Null => Ok(None), value => from_value(value).map_err(Into::into), } }) } /// Execute methods that return multiple responses fn execute_vec<'r, R>( &'r mut self, router: &'r Router, param: Param, ) -> Pin>> + Send + Sync + 'r>> where R: DeserializeOwned, Self: api::Connection, { Box::pin(async move { let rx = self.send(router, param).await?; let value = match self.recv(rx).await? { Value::None | Value::Null => Value::Array(Default::default()), Value::Array(array) => Value::Array(array), value => vec![value].into(), }; from_value(value).map_err(Into::into) }) } /// Execute methods that return nothing fn execute_unit<'r>( &'r mut self, router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where Self: api::Connection, { Box::pin(async move { let rx = self.send(router, param).await?; match self.recv(rx).await? { Value::None | Value::Null => Ok(()), Value::Array(array) if array.is_empty() => Ok(()), value => Err(Error::FromValue { value, error: "expected the database to return nothing".to_owned(), } .into()), } }) } /// Execute methods that return a raw value fn execute_value<'r>( &'r mut self, router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where Self: api::Connection, { Box::pin(async move { let rx = self.send(router, param).await?; self.recv(rx).await }) } /// Execute the `query` method fn execute_query<'r>( &'r mut self, router: &'r Router, param: Param, ) -> Pin> + Send + Sync + 'r>> where Self: api::Connection, { Box::pin(async move { let rx = self.send(router, param).await?; self.recv_query(rx).await }) } }