Make communication with the router task more type-safe (#4406)

This commit is contained in:
Mees Delzenne 2024-07-23 16:38:54 +02:00 committed by GitHub
parent 08f4ad6c82
commit 2b1e6a32ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 2228 additions and 1628 deletions

View file

@ -1,5 +1,4 @@
use crate::sql::constant::ConstantValue;
use crate::sql::id::Gen;
use crate::sql::Value;
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -174,6 +173,8 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
}
}
// TODO: Checkout blame for why these are here.
/*
#[derive(Serialize)]
enum Id {
Number(i64),
@ -212,6 +213,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
}
}
}
*/
match value {
// These value types are simple values which

View file

@ -5,34 +5,8 @@ use crate::sql::Kind;
use crate::sql::Value;
use ser::Serializer as _;
use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Cast;
type Error = Error;
type SerializeSeq = Impossible<Cast, Error>;
type SerializeTuple = Impossible<Cast, Error>;
type SerializeTupleStruct = SerializeCast;
type SerializeTupleVariant = Impossible<Cast, Error>;
type SerializeMap = Impossible<Cast, Error>;
type SerializeStruct = Impossible<Cast, Error>;
type SerializeStructVariant = Impossible<Cast, Error>;
const EXPECTED: &'static str = "an struct `Cast`";
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleStruct, Error> {
Ok(SerializeCast::default())
}
}
#[derive(Default)]
pub(super) struct SerializeCast {
index: usize,
@ -74,8 +48,34 @@ impl serde::ser::SerializeTupleStruct for SerializeCast {
#[cfg(test)]
mod tests {
use super::*;
use serde::ser::Impossible;
use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Cast;
type Error = Error;
type SerializeSeq = Impossible<Cast, Error>;
type SerializeTuple = Impossible<Cast, Error>;
type SerializeTupleStruct = SerializeCast;
type SerializeTupleVariant = Impossible<Cast, Error>;
type SerializeMap = Impossible<Cast, Error>;
type SerializeStruct = Impossible<Cast, Error>;
type SerializeStructVariant = Impossible<Cast, Error>;
const EXPECTED: &'static str = "an struct `Cast`";
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleStruct, Error> {
Ok(SerializeCast::default())
}
}
#[test]
fn cast() {
let cast = Cast(Default::default(), Default::default());

View file

@ -6,35 +6,8 @@ use crate::sql::Tables;
use crate::sql::Thing;
use ser::Serializer as _;
use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Edges;
type Error = Error;
type SerializeSeq = Impossible<Edges, Error>;
type SerializeTuple = Impossible<Edges, Error>;
type SerializeTupleStruct = Impossible<Edges, Error>;
type SerializeTupleVariant = Impossible<Edges, Error>;
type SerializeMap = Impossible<Edges, Error>;
type SerializeStruct = SerializeEdges;
type SerializeStructVariant = Impossible<Edges, Error>;
const EXPECTED: &'static str = "a struct `Edges`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeEdges::default())
}
}
#[derive(Default)]
pub(super) struct SerializeEdges {
dir: Option<Dir>,
@ -83,8 +56,35 @@ impl serde::ser::SerializeStruct for SerializeEdges {
mod tests {
use super::*;
use crate::sql::thing;
use serde::ser::Impossible;
use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Edges;
type Error = Error;
type SerializeSeq = Impossible<Edges, Error>;
type SerializeTuple = Impossible<Edges, Error>;
type SerializeTupleStruct = Impossible<Edges, Error>;
type SerializeTupleVariant = Impossible<Edges, Error>;
type SerializeMap = Impossible<Edges, Error>;
type SerializeStruct = SerializeEdges;
type SerializeStructVariant = Impossible<Edges, Error>;
const EXPECTED: &'static str = "a struct `Edges`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeEdges::default())
}
}
#[test]
fn edges() {
let edges = Edges {

View file

@ -5,42 +5,8 @@ use crate::sql::Operator;
use crate::sql::Value;
use ser::Serializer as _;
use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Expression;
type Error = Error;
type SerializeSeq = Impossible<Expression, Error>;
type SerializeTuple = Impossible<Expression, Error>;
type SerializeTupleStruct = Impossible<Expression, Error>;
type SerializeTupleVariant = Impossible<Expression, Error>;
type SerializeMap = Impossible<Expression, Error>;
type SerializeStruct = Impossible<Expression, Error>;
type SerializeStructVariant = SerializeExpression;
const EXPECTED: &'static str = "an enum `Expression`";
#[inline]
fn serialize_struct_variant(
self,
name: &'static str,
_variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant, Self::Error> {
debug_assert_eq!(name, crate::sql::expression::TOKEN);
match variant {
"Unary" => Ok(SerializeExpression::Unary(Default::default())),
"Binary" => Ok(SerializeExpression::Binary(Default::default())),
_ => Err(Error::custom(format!("unexpected `Expression::{name}`"))),
}
}
}
pub(super) enum SerializeExpression {
Unary(SerializeUnary),
Binary(SerializeBinary),
@ -158,8 +124,42 @@ impl serde::ser::SerializeStructVariant for SerializeBinary {
#[cfg(test)]
mod tests {
use super::*;
use serde::ser::Impossible;
use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Expression;
type Error = Error;
type SerializeSeq = Impossible<Expression, Error>;
type SerializeTuple = Impossible<Expression, Error>;
type SerializeTupleStruct = Impossible<Expression, Error>;
type SerializeTupleVariant = Impossible<Expression, Error>;
type SerializeMap = Impossible<Expression, Error>;
type SerializeStruct = Impossible<Expression, Error>;
type SerializeStructVariant = SerializeExpression;
const EXPECTED: &'static str = "an enum `Expression`";
#[inline]
fn serialize_struct_variant(
self,
name: &'static str,
_variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant, Self::Error> {
debug_assert_eq!(name, crate::sql::expression::TOKEN);
match variant {
"Unary" => Ok(SerializeExpression::Unary(Default::default())),
"Binary" => Ok(SerializeExpression::Binary(Default::default())),
_ => Err(Error::custom(format!("unexpected `Expression::{name}`"))),
}
}
}
#[test]
fn default() {
let expression = Expression::default();

View file

@ -6,36 +6,9 @@ use crate::sql::Id;
use crate::sql::Range;
use ser::Serializer as _;
use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize;
use std::ops::Bound;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Range;
type Error = Error;
type SerializeSeq = Impossible<Range, Error>;
type SerializeTuple = Impossible<Range, Error>;
type SerializeTupleStruct = Impossible<Range, Error>;
type SerializeTupleVariant = Impossible<Range, Error>;
type SerializeMap = Impossible<Range, Error>;
type SerializeStruct = SerializeRange;
type SerializeStructVariant = Impossible<Range, Error>;
const EXPECTED: &'static str = "a struct `Range`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeRange::default())
}
}
#[derive(Default)]
pub(super) struct SerializeRange {
tb: Option<String>,
@ -83,8 +56,35 @@ impl serde::ser::SerializeStruct for SerializeRange {
#[cfg(test)]
mod tests {
use super::*;
use serde::ser::Impossible;
use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Range;
type Error = Error;
type SerializeSeq = Impossible<Range, Error>;
type SerializeTuple = Impossible<Range, Error>;
type SerializeTupleStruct = Impossible<Range, Error>;
type SerializeTupleVariant = Impossible<Range, Error>;
type SerializeMap = Impossible<Range, Error>;
type SerializeStruct = SerializeRange;
type SerializeStructVariant = Impossible<Range, Error>;
const EXPECTED: &'static str = "a struct `Range`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeRange::default())
}
}
#[test]
fn range() {
let range = Range {

481
lib/src/api/conn/cmd.rs Normal file
View file

@ -0,0 +1,481 @@
use super::MlExportConfig;
use crate::Result;
use bincode::Options;
use channel::Sender;
use revision::Revisioned;
use serde::{ser::SerializeMap as _, Serialize};
use std::path::PathBuf;
use std::{collections::BTreeMap, io::Read};
use surrealdb_core::{
dbs::Notification,
sql::{Object, Query, Value},
};
use uuid::Uuid;
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) enum Command {
Use {
namespace: Option<String>,
database: Option<String>,
},
Signup {
credentials: Object,
},
Signin {
credentials: Object,
},
Authenticate {
token: String,
},
Invalidate,
Create {
what: Value,
data: Option<Value>,
},
Upsert {
what: Value,
data: Option<Value>,
},
Update {
what: Value,
data: Option<Value>,
},
Insert {
what: Option<Value>,
data: Value,
},
Patch {
what: Value,
data: Option<Value>,
},
Merge {
what: Value,
data: Option<Value>,
},
Select {
what: Value,
},
Delete {
what: Value,
},
Query {
query: Query,
variables: BTreeMap<String, Value>,
},
ExportFile {
path: PathBuf,
},
ExportMl {
path: PathBuf,
config: MlExportConfig,
},
ExportBytes {
bytes: Sender<Result<Vec<u8>>>,
},
ExportBytesMl {
bytes: Sender<Result<Vec<u8>>>,
config: MlExportConfig,
},
ImportFile {
path: PathBuf,
},
ImportMl {
path: PathBuf,
},
Health,
Version,
Set {
key: String,
value: Value,
},
Unset {
key: String,
},
SubscribeLive {
uuid: Uuid,
notification_sender: Sender<Notification>,
},
Kill {
uuid: Uuid,
},
}
impl Command {
#[cfg(feature = "protocol-ws")]
pub(crate) fn into_router_request(self, id: Option<i64>) -> Option<RouterRequest> {
let id = id.map(Value::from);
let res = match self {
Command::Use {
namespace,
database,
} => RouterRequest {
id,
method: Value::from("use"),
params: Some(vec![Value::from(namespace), Value::from(database)].into()),
},
Command::Signup {
credentials,
} => RouterRequest {
id,
method: "signup".into(),
params: Some(vec![Value::from(credentials)].into()),
},
Command::Signin {
credentials,
} => RouterRequest {
id,
method: "signin".into(),
params: Some(vec![Value::from(credentials)].into()),
},
Command::Authenticate {
token,
} => RouterRequest {
id,
method: "authenticate".into(),
params: Some(vec![Value::from(token)].into()),
},
Command::Invalidate => RouterRequest {
id,
method: "invalidate".into(),
params: None,
},
Command::Create {
what,
data,
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "create".into(),
params: Some(params.into()),
}
}
Command::Upsert {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "upsert".into(),
params: Some(params.into()),
}
}
Command::Update {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "update".into(),
params: Some(params.into()),
}
}
Command::Insert {
what,
data,
} => {
let mut params = if let Some(w) = what {
vec![w]
} else {
vec![Value::None]
};
params.push(data);
RouterRequest {
id,
method: "insert".into(),
params: Some(params.into()),
}
}
Command::Patch {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "patch".into(),
params: Some(params.into()),
}
}
Command::Merge {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "merge".into(),
params: Some(params.into()),
}
}
Command::Select {
what,
..
} => RouterRequest {
id,
method: "select".into(),
params: Some(vec![what].into()),
},
Command::Delete {
what,
..
} => RouterRequest {
id,
method: "delete".into(),
params: Some(vec![what].into()),
},
Command::Query {
query,
variables,
} => {
let params: Vec<Value> = vec![query.into(), variables.into()];
RouterRequest {
id,
method: "query".into(),
params: Some(params.into()),
}
}
Command::ExportFile {
..
}
| Command::ExportBytes {
..
}
| Command::ImportFile {
..
}
| Command::ExportBytesMl {
..
}
| Command::ExportMl {
..
}
| Command::ImportMl {
..
} => return None,
Command::Health => RouterRequest {
id,
method: "ping".into(),
params: None,
},
Command::Version => RouterRequest {
id,
method: "version".into(),
params: None,
},
Command::Set {
key,
value,
} => RouterRequest {
id,
method: "let".into(),
params: Some(vec![Value::from(key), value].into()),
},
Command::Unset {
key,
} => RouterRequest {
id,
method: "unset".into(),
params: Some(vec![Value::from(key)].into()),
},
Command::SubscribeLive {
..
} => return None,
Command::Kill {
uuid,
} => RouterRequest {
id,
method: "kill".into(),
params: Some(vec![Value::from(uuid)].into()),
},
};
Some(res)
}
}
/// A struct which will be serialized as a map to behave like the previously used BTreeMap.
///
/// This struct serializes as if it is a surrealdb_core::sql::Value::Object.
#[derive(Debug)]
pub(crate) struct RouterRequest {
id: Option<Value>,
method: Value,
params: Option<Value>,
}
impl Serialize for RouterRequest {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
struct InnerRequest<'a>(&'a RouterRequest);
impl Serialize for InnerRequest<'_> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let size = 1 + self.0.id.is_some() as usize + self.0.params.is_some() as usize;
let mut map = serializer.serialize_map(Some(size))?;
if let Some(id) = self.0.id.as_ref() {
map.serialize_entry("id", id)?;
}
map.serialize_entry("method", &self.0.method)?;
if let Some(params) = self.0.params.as_ref() {
map.serialize_entry("params", params)?;
}
map.end()
}
}
serializer.serialize_newtype_variant("Value", 9, "Object", &InnerRequest(self))
}
}
impl Revisioned for RouterRequest {
fn revision() -> u16 {
1
}
fn serialize_revisioned<W: std::io::Write>(
&self,
w: &mut W,
) -> std::result::Result<(), revision::Error> {
// version
Revisioned::serialize_revisioned(&1u32, w)?;
// object variant
Revisioned::serialize_revisioned(&9u32, w)?;
// object wrapper version
Revisioned::serialize_revisioned(&1u32, w)?;
let size = 1 + self.id.is_some() as usize + self.params.is_some() as usize;
size.serialize_revisioned(w)?;
let serializer = bincode::options()
.with_no_limit()
.with_little_endian()
.with_varint_encoding()
.reject_trailing_bytes();
if let Some(x) = self.id.as_ref() {
serializer
.serialize_into(&mut *w, "id")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
serializer
.serialize_into(&mut *w, "method")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
self.method.serialize_revisioned(w)?;
if let Some(x) = self.params.as_ref() {
serializer
.serialize_into(&mut *w, "params")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
Ok(())
}
fn deserialize_revisioned<R: Read>(_: &mut R) -> std::result::Result<Self, revision::Error>
where
Self: Sized,
{
panic!("deliberately unimplemented");
}
}
#[cfg(test)]
mod test {
use std::io::Cursor;
use revision::Revisioned;
use surrealdb_core::sql::Value;
use super::RouterRequest;
fn assert_converts<S, D, I>(req: &RouterRequest, s: S, d: D)
where
S: FnOnce(&RouterRequest) -> I,
D: FnOnce(I) -> Value,
{
let ser = s(req);
let val = d(ser);
let Value::Object(obj) = val else {
panic!("not an object");
};
assert_eq!(obj.get("id").cloned(), req.id);
assert_eq!(obj.get("method").unwrap().clone(), req.method);
assert_eq!(obj.get("params").cloned(), req.params);
}
#[test]
fn router_request_value_conversion() {
let request = RouterRequest {
id: Some(Value::from(1234i64)),
method: Value::from("request"),
params: Some(vec![Value::from(1234i64), Value::from("request")].into()),
};
println!("test convert bincode");
assert_converts(
&request,
|i| bincode::serialize(i).unwrap(),
|b| bincode::deserialize(&b).unwrap(),
);
println!("test convert json");
assert_converts(
&request,
|i| serde_json::to_string(i).unwrap(),
|b| serde_json::from_str(&b).unwrap(),
);
println!("test convert revisioned");
assert_converts(
&request,
|i| {
let mut buf = Vec::new();
i.serialize_revisioned(&mut Cursor::new(&mut buf)).unwrap();
buf
},
|b| Value::deserialize_revisioned(&mut Cursor::new(b)).unwrap(),
);
println!("done");
}
}

View file

@ -6,24 +6,29 @@ use crate::api::opt::Endpoint;
use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::api::Surreal;
use crate::dbs::Notification;
use crate::sql::from_value;
use crate::sql::Query;
use crate::sql::Value;
use channel::Receiver;
use channel::Sender;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
use surrealdb_core::sql::{from_value, Value};
mod cmd;
pub(crate) use cmd::Command;
#[derive(Debug)]
#[allow(dead_code)] // used by the embedded and remote connections
pub struct RequestData {
pub(crate) id: i64,
pub(crate) command: Command,
}
#[derive(Debug)]
#[allow(dead_code)] // used by the embedded and remote connections
pub(crate) struct Route {
pub(crate) request: (i64, Method, Param),
pub(crate) request: RequestData,
pub(crate) response: Sender<Result<DbResponse>>,
}
@ -42,14 +47,16 @@ impl Router {
pub(crate) fn send(
&self,
method: Method,
param: Param,
command: Command,
) -> BoxFuture<'_, Result<Receiver<Result<DbResponse>>>> {
Box::pin(async move {
let id = self.next_id();
let (sender, receiver) = channel::bounded(1);
let route = Route {
request: (id, method, param),
request: RequestData {
id,
command,
},
response: sender,
};
self.sender.send(route).await?;
@ -86,28 +93,24 @@ impl Router {
}
/// Execute all methods except `query`
pub(crate) fn execute<R>(&self, method: Method, param: Param) -> BoxFuture<'_, Result<R>>
pub(crate) fn execute<R>(&self, command: Command) -> BoxFuture<'_, Result<R>>
where
R: DeserializeOwned,
{
Box::pin(async move {
let rx = self.send(method, param).await?;
let rx = self.send(command).await?;
let value = self.recv(rx).await?;
from_value(value).map_err(Into::into)
})
}
/// Execute methods that return an optional single response
pub(crate) fn execute_opt<R>(
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Option<R>>>
pub(crate) fn execute_opt<R>(&self, command: Command) -> BoxFuture<'_, Result<Option<R>>>
where
R: DeserializeOwned,
{
Box::pin(async move {
let rx = self.send(method, param).await?;
let rx = self.send(command).await?;
match self.recv(rx).await? {
Value::None | Value::Null => Ok(None),
value => from_value(value).map_err(Into::into),
@ -116,16 +119,12 @@ impl Router {
}
/// Execute methods that return multiple responses
pub(crate) fn execute_vec<R>(
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Vec<R>>>
pub(crate) fn execute_vec<R>(&self, command: Command) -> BoxFuture<'_, Result<Vec<R>>>
where
R: DeserializeOwned,
{
Box::pin(async move {
let rx = self.send(method, param).await?;
let rx = self.send(command).await?;
let value = match self.recv(rx).await? {
Value::None | Value::Null => Value::Array(Default::default()),
Value::Array(array) => Value::Array(array),
@ -136,9 +135,9 @@ impl Router {
}
/// Execute methods that return nothing
pub(crate) fn execute_unit(&self, method: Method, param: Param) -> BoxFuture<'_, Result<()>> {
pub(crate) fn execute_unit(&self, command: Command) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
let rx = self.send(method, param).await?;
let rx = self.send(command).await?;
match self.recv(rx).await? {
Value::None | Value::Null => Ok(()),
Value::Array(array) if array.is_empty() => Ok(()),
@ -152,82 +151,22 @@ impl Router {
}
/// Execute methods that return a raw value
pub(crate) fn execute_value(
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Value>> {
pub(crate) fn execute_value(&self, command: Command) -> BoxFuture<'_, Result<Value>> {
Box::pin(async move {
let rx = self.send(method, param).await?;
let rx = self.send(command).await?;
self.recv(rx).await
})
}
/// Execute the `query` method
pub(crate) fn execute_query(
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Response>> {
pub(crate) fn execute_query(&self, command: Command) -> BoxFuture<'_, Result<Response>> {
Box::pin(async move {
let rx = self.send(method, param).await?;
let rx = self.send(command).await?;
self.recv_query(rx).await
})
}
}
/// The query method
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum Method {
/// Sends an authentication token to the server
Authenticate,
/// Performs a merge update operation
Merge,
/// Creates a record in a table
Create,
/// Deletes a record from a table
Delete,
/// Exports a database
Export,
/// Checks the health of the server
Health,
/// Imports a database
Import,
/// Invalidates a session
Invalidate,
/// Inserts a record or records into a table
Insert,
/// Kills a live query
#[doc(hidden)] // Not supported yet
Kill,
/// Starts a live query
#[doc(hidden)] // Not supported yet
Live,
/// Performs a patch update operation
Patch,
/// Sends a raw query to the database
Query,
/// Selects a record or records from a table
Select,
/// Sets a parameter on the connection
Set,
/// Signs into the server
Signin,
/// Signs up on the server
Signup,
/// Removes a parameter from a connection
Unset,
/// Performs an update operation
Update,
/// Performs an upsert operation
Upsert,
/// Selects a namespace and database to use
Use,
/// Queries the version of the server
Version,
}
/// The database response sent from the router to the caller
#[derive(Debug)]
pub enum DbResponse {
@ -237,63 +176,13 @@ pub enum DbResponse {
Other(Value),
}
#[derive(Debug)]
#[allow(dead_code)] // used by ML model import and export functions
pub(crate) enum MlConfig {
Import,
Export {
name: String,
version: String,
},
}
/// Holds the parameters given to the caller
#[derive(Debug, Default)]
#[allow(dead_code)] // used by the embedded and remote connections
pub struct Param {
pub(crate) query: Option<(Query, BTreeMap<String, Value>)>,
pub(crate) other: Vec<Value>,
pub(crate) file: Option<PathBuf>,
pub(crate) bytes_sender: Option<channel::Sender<Result<Vec<u8>>>>,
pub(crate) notification_sender: Option<channel::Sender<Notification>>,
pub(crate) ml_config: Option<MlConfig>,
}
impl Param {
pub(crate) fn new(other: Vec<Value>) -> Self {
Self {
other,
..Default::default()
}
}
pub(crate) fn query(query: Query, bindings: BTreeMap<String, Value>) -> Self {
Self {
query: Some((query, bindings)),
..Default::default()
}
}
pub(crate) fn file(file: PathBuf) -> Self {
Self {
file: Some(file),
..Default::default()
}
}
pub(crate) fn bytes_sender(send: channel::Sender<Result<Vec<u8>>>) -> Self {
Self {
bytes_sender: Some(send),
..Default::default()
}
}
pub(crate) fn notification_sender(send: channel::Sender<Notification>) -> Self {
Self {
notification_sender: Some(send),
..Default::default()
}
}
#[derive(Debug, Clone)]
pub(crate) struct MlExportConfig {
// fields are used in http and local non-wasm with ml features
#[allow(dead_code)]
pub(crate) name: String,
#[allow(dead_code)]
pub(crate) version: String,
}
/// Connection trait implemented by supported protocols

View file

@ -142,10 +142,7 @@ impl Connection for Any {
}
let client = builder.build()?;
let base_url = address.url;
engine::remote::http::health(
client.get(base_url.join(crate::api::conn::Method::Health.as_str())?),
)
.await?;
engine::remote::http::health(client.get(base_url.join("health")?)).await?;
tokio::spawn(engine::remote::http::native::run_router(
base_url, client, route_rx,
));

View file

@ -26,78 +26,58 @@ pub(crate) mod native;
#[cfg(target_arch = "wasm32")]
pub(crate) mod wasm;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
#[cfg(not(target_arch = "wasm32"))]
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::engine::create_statement;
use crate::api::engine::delete_statement;
use crate::api::engine::insert_statement;
use crate::api::engine::merge_statement;
use crate::api::engine::patch_statement;
use crate::api::engine::select_statement;
use crate::api::engine::update_statement;
use crate::api::engine::upsert_statement;
use crate::{
api::{
conn::{Command, DbResponse, RequestData},
Connect, Response as QueryResponse, Result, Surreal,
},
method::Stats,
opt::IntoEndpoint,
};
use channel::Sender;
use indexmap::IndexMap;
use std::{
collections::{BTreeMap, HashMap},
marker::PhantomData,
mem,
sync::Arc,
time::Duration,
};
use surrealdb_core::{
dbs::{Notification, Response, Session},
kvs::Datastore,
sql::{
statements::{
CreateStatement, DeleteStatement, InsertStatement, KillStatement, SelectStatement,
UpdateStatement, UpsertStatement,
},
Data, Field, Output, Query, Statement, Uuid, Value,
},
};
#[cfg(not(target_arch = "wasm32"))]
use crate::api::err::Error;
use crate::api::Connect;
use crate::api::Response as QueryResponse;
use crate::api::Result;
use crate::api::Surreal;
use crate::dbs::Notification;
use crate::dbs::Response;
use crate::dbs::Session;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::check::check_ns_db;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::Action;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::ResourceKind;
use crate::kvs::Datastore;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::kvs::{LockType, TransactionType};
use crate::method::Stats;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::ml::storage::surml_file::SurMlFile;
use crate::opt::IntoEndpoint;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::sql::statements::DefineModelStatement;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::sql::statements::DefineStatement;
use crate::sql::statements::KillStatement;
use crate::sql::Query;
use crate::sql::Statement;
use crate::sql::Uuid;
use crate::sql::Value;
use channel::Sender;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use futures::StreamExt;
use indexmap::IndexMap;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem;
#[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
#[cfg(not(target_arch = "wasm32"))]
use tokio::fs::OpenOptions;
#[cfg(not(target_arch = "wasm32"))]
use tokio::io;
#[cfg(not(target_arch = "wasm32"))]
use tokio::io::AsyncReadExt;
#[cfg(not(target_arch = "wasm32"))]
use tokio::io::AsyncWriteExt;
use tokio::{
fs::OpenOptions,
io::{self, AsyncReadExt, AsyncWriteExt},
};
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
use crate::api::conn::MlExportConfig;
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
use futures::StreamExt;
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
use surrealdb_core::{
iam::{check::check_ns_db, Action, ResourceKind},
kvs::{LockType, TransactionType},
ml::storage::surml_file::SurMlFile,
sql::statements::{DefineModelStatement, DefineStatement},
};
use super::value_to_values;
const DEFAULT_TICK_INTERVAL: Duration = Duration::from_secs(10);
@ -429,44 +409,42 @@ async fn take(one: bool, responses: Vec<Response>) -> Result<Value> {
}
#[cfg(not(target_arch = "wasm32"))]
async fn export(
async fn export_file(kvs: &Datastore, sess: &Session, chn: channel::Sender<Vec<u8>>) -> Result<()> {
if let Err(error) = kvs.export(sess, chn).await?.await {
if let crate::error::Db::Channel(message) = error {
// This is not really an error. Just logging it for improved visibility.
trace!("{message}");
return Ok(());
}
return Err(error.into());
}
Ok(())
}
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
async fn export_ml(
kvs: &Datastore,
sess: &Session,
chn: channel::Sender<Vec<u8>>,
ml_config: Option<MlConfig>,
MlExportConfig {
name,
version,
}: MlExportConfig,
) -> Result<()> {
match ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Export {
name,
version,
}) => {
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(sess)?;
// Check the permissions level
kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Start a new readonly transaction
let tx = kvs.transaction(TransactionType::Read, LockType::Optimistic).await?;
// Attempt to get the model definition
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
// Export the file data in to the store
let mut data = crate::obs::stream(info.hash.to_owned()).await?;
// Process all stream values
while let Some(Ok(bytes)) = data.next().await {
if chn.send(bytes.to_vec()).await.is_err() {
break;
}
}
}
_ => {
if let Err(error) = kvs.export(sess, chn).await?.await {
if let crate::error::Db::Channel(message) = error {
// This is not really an error. Just logging it for improved visibility.
trace!("{message}");
return Ok(());
}
return Err(error.into());
}
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(sess)?;
// Check the permissions level
kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Start a new readonly transaction
let tx = kvs.transaction(TransactionType::Read, LockType::Optimistic).await?;
// Attempt to get the model definition
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
// Export the file data in to the store
let mut data = crate::obs::stream(info.hash.to_owned()).await?;
// Process all stream values
while let Some(Ok(bytes)) = data.next().await {
if chn.send(bytes.to_vec()).await.is_err() {
break;
}
}
Ok(())
@ -505,211 +483,376 @@ async fn kill_live_query(
}
async fn router(
(_, method, param): (i64, Method, Param),
RequestData {
command,
..
}: RequestData,
kvs: &Arc<Datastore>,
session: &mut Session,
vars: &mut BTreeMap<String, Value>,
live_queries: &mut HashMap<Uuid, Sender<Notification>>,
) -> Result<DbResponse> {
let mut params = param.other;
match method {
Method::Use => {
match &mut params[..] {
[Value::Strand(ns), Value::Strand(db)] => {
session.ns = Some(mem::take(&mut ns.0));
session.db = Some(mem::take(&mut db.0));
}
[Value::Strand(ns), Value::None] => {
session.ns = Some(mem::take(&mut ns.0));
}
[Value::None, Value::Strand(db)] => {
session.db = Some(mem::take(&mut db.0));
}
_ => unreachable!(),
match command {
Command::Use {
namespace,
database,
} => {
if let Some(ns) = namespace {
session.ns = Some(ns);
}
if let Some(db) = database {
session.db = Some(db);
}
Ok(DbResponse::Other(Value::None))
}
Method::Signup => {
let credentials = match &mut params[..] {
[Value::Object(credentials)] => mem::take(credentials),
_ => unreachable!(),
};
Command::Signup {
credentials,
} => {
let response = crate::iam::signup::signup(kvs, session, credentials).await?;
Ok(DbResponse::Other(response.into()))
}
Method::Signin => {
let credentials = match &mut params[..] {
[Value::Object(credentials)] => mem::take(credentials),
_ => unreachable!(),
};
Command::Signin {
credentials,
} => {
let response = crate::iam::signin::signin(kvs, session, credentials).await?;
Ok(DbResponse::Other(response.into()))
}
Method::Authenticate => {
let token = match &mut params[..] {
[Value::Strand(token)] => mem::take(&mut token.0),
_ => unreachable!(),
};
Command::Authenticate {
token,
} => {
crate::iam::verify::token(kvs, session, &token).await?;
Ok(DbResponse::Other(Value::None))
}
Method::Invalidate => {
Command::Invalidate => {
crate::iam::clear::clear(session)?;
Ok(DbResponse::Other(Value::None))
}
Method::Create => {
Command::Create {
what,
data,
} => {
let mut query = Query::default();
let statement = create_statement(&mut params);
let statement = {
let mut stmt = CreateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Create(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(true, response).await?;
Ok(DbResponse::Other(value))
}
Method::Upsert => {
Command::Upsert {
what,
data,
} => {
let mut query = Query::default();
let (one, statement) = upsert_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpsertStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Upsert(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Update => {
Command::Update {
what,
data,
} => {
let mut query = Query::default();
let (one, statement) = update_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Update(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Insert => {
Command::Insert {
what,
data,
} => {
let mut query = Query::default();
let (one, statement) = insert_statement(&mut params);
let one = !data.is_array();
let statement = {
let mut stmt = InsertStatement::default();
stmt.into = what;
stmt.data = Data::SingleExpression(data);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Insert(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Patch => {
Command::Patch {
what,
data,
} => {
let mut query = Query::default();
let (one, statement) = patch_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::PatchExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Update(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Merge => {
Command::Merge {
what,
data,
} => {
let mut query = Query::default();
let (one, statement) = merge_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::MergeExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Update(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Select => {
Command::Select {
what,
} => {
let mut query = Query::default();
let (one, statement) = select_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = SelectStatement::default();
stmt.what = value_to_values(what);
stmt.expr.0 = vec![Field::All];
stmt
};
query.0 .0 = vec![Statement::Select(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Delete => {
Command::Delete {
what,
} => {
let mut query = Query::default();
let (one, statement) = delete_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = DeleteStatement::default();
stmt.what = value_to_values(what);
stmt.output = Some(Output::Before);
stmt
};
query.0 .0 = vec![Statement::Delete(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?;
Ok(DbResponse::Other(value))
}
Method::Query => {
let response = match param.query {
Some((query, mut bindings)) => {
let mut vars = vars.clone();
vars.append(&mut bindings);
kvs.process(query, &*session, Some(vars)).await?
}
None => unreachable!(),
};
Command::Query {
query,
mut variables,
} => {
let mut vars = vars.clone();
vars.append(&mut variables);
let response = kvs.process(query, &*session, Some(vars)).await?;
let response = process(response);
Ok(DbResponse::Query(response))
}
#[cfg(target_arch = "wasm32")]
Method::Export | Method::Import => unreachable!(),
Command::ExportFile {
..
}
| Command::ExportBytes {
..
}
| Command::ImportFile {
..
} => Err(crate::api::Error::BackupsNotSupported.into()),
#[cfg(any(target_arch = "wasm32", not(feature = "ml")))]
Command::ExportMl {
..
}
| Command::ExportBytesMl {
..
}
| Command::ImportMl {
..
} => Err(crate::api::Error::BackupsNotSupported.into()),
#[cfg(not(target_arch = "wasm32"))]
Method::Export => {
Command::ExportFile {
path: file,
} => {
let (tx, rx) = crate::channel::bounded(1);
let (mut writer, mut reader) = io::duplex(10_240);
// Write to channel.
let export = export_file(kvs, session, tx);
// Read from channel and write to pipe.
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
// Broken pipe. Let either side's error be propagated.
break;
}
}
Ok(())
};
// Output to stdout or file.
let mut output = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&file)
.await
{
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path: file,
error,
}
.into());
}
};
// Copy from pipe to output.
let copy = copy(file, &mut reader, &mut output);
tokio::try_join!(export, bridge, copy)?;
Ok(DbResponse::Other(Value::None))
}
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
Command::ExportMl {
path,
config,
} => {
let (tx, rx) = crate::channel::bounded(1);
let (mut writer, mut reader) = io::duplex(10_240);
// Write to channel.
let export = export_ml(kvs, session, tx, config);
// Read from channel and write to pipe.
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
// Broken pipe. Let either side's error be propagated.
break;
}
}
Ok(())
};
// Output to stdout or file.
let mut output = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
.await
{
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
}
};
// Copy from pipe to output.
let copy = copy(path, &mut reader, &mut output);
tokio::try_join!(export, bridge, copy)?;
Ok(DbResponse::Other(Value::None))
}
#[cfg(not(target_arch = "wasm32"))]
Command::ExportBytes {
bytes,
} => {
let (tx, rx) = crate::channel::bounded(1);
match (param.file, param.bytes_sender) {
(Some(path), None) => {
let (mut writer, mut reader) = io::duplex(10_240);
let kvs = kvs.clone();
let session = session.clone();
tokio::spawn(async move {
let export = async {
if let Err(error) = export_file(&kvs, &session, tx).await {
let _ = bytes.send(Err(error)).await;
}
};
// Write to channel.
let export = export(kvs, session, tx, param.ml_config);
// Read from channel and write to pipe.
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
// Broken pipe. Let either side's error be propagated.
break;
}
let bridge = async {
while let Ok(b) = rx.recv().await {
if bytes.send(Ok(b)).await.is_err() {
break;
}
Ok(())
};
}
};
// Output to stdout or file.
let mut output = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
.await
{
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
tokio::join!(export, bridge);
});
Ok(DbResponse::Other(Value::None))
}
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
Command::ExportBytesMl {
bytes,
config,
} => {
let (tx, rx) = crate::channel::bounded(1);
let kvs = kvs.clone();
let session = session.clone();
tokio::spawn(async move {
let export = async {
if let Err(error) = export_ml(&kvs, &session, tx, config).await {
let _ = bytes.send(Err(error)).await;
}
};
let bridge = async {
while let Ok(b) = rx.recv().await {
if bytes.send(Ok(b)).await.is_err() {
break;
}
};
}
};
// Copy from pipe to output.
let copy = copy(path, &mut reader, &mut output);
tokio::try_join!(export, bridge, copy)?;
}
(None, Some(backup)) => {
let kvs = kvs.clone();
let session = session.clone();
tokio::spawn(async move {
let export = async {
if let Err(error) = export(&kvs, &session, tx, param.ml_config).await {
let _ = backup.send(Err(error)).await;
}
};
let bridge = async {
while let Ok(bytes) = rx.recv().await {
if backup.send(Ok(bytes)).await.is_err() {
break;
}
}
};
tokio::join!(export, bridge);
});
}
_ => unreachable!(),
}
tokio::join!(export, bridge);
});
Ok(DbResponse::Other(Value::None))
}
#[cfg(not(target_arch = "wasm32"))]
Method::Import => {
let path = param.file.expect("file to import from");
Command::ImportFile {
path,
} => {
let mut file = match OpenOptions::new().read(true).open(&path).await {
Ok(path) => path,
Err(error) => {
@ -720,76 +863,93 @@ async fn router(
.into());
}
};
let responses = match param.ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Import) => {
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(session)?;
// Check the permissions level
kvs.check(session, Action::Edit, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Create a new buffer
let mut buffer = Vec::new();
// Load all the uploaded file chunks
if let Err(error) = file.read_to_end(&mut buffer).await {
return Err(Error::FileRead {
path,
error,
}
.into());
}
// Check that the SurrealML file is valid
let file = match SurMlFile::from_bytes(buffer) {
Ok(file) => file,
Err(error) => {
return Err(Error::FileRead {
path,
error: io::Error::new(
io::ErrorKind::InvalidData,
error.message.to_string(),
),
}
.into());
}
};
// Convert the file back in to raw bytes
let data = file.to_bytes();
// Calculate the hash of the model file
let hash = crate::obs::hash(&data);
// Insert the file data in to the store
crate::obs::put(&hash, data).await?;
// Insert the model in to the database
let mut model = DefineModelStatement::default();
model.name = file.header.name.to_string().into();
model.version = file.header.version.to_string();
model.comment = Some(file.header.description.to_string().into());
model.hash = hash;
let query = DefineStatement::Model(model).into();
kvs.process(query, session, Some(vars.clone())).await?
let mut statements = String::new();
if let Err(error) = file.read_to_string(&mut statements).await {
return Err(Error::FileRead {
path,
error,
}
_ => {
let mut statements = String::new();
if let Err(error) = file.read_to_string(&mut statements).await {
return Err(Error::FileRead {
path,
error,
}
.into());
}
kvs.execute(&statements, &*session, Some(vars.clone())).await?
}
};
.into());
}
let responses = kvs.execute(&statements, &*session, Some(vars.clone())).await?;
for response in responses {
response.result?;
}
Ok(DbResponse::Other(Value::None))
}
Method::Health => Ok(DbResponse::Other(Value::None)),
Method::Version => Ok(DbResponse::Other(crate::env::VERSION.into())),
Method::Set => {
let (key, value) = match &mut params[..2] {
[Value::Strand(key), value] => (mem::take(&mut key.0), mem::take(value)),
_ => unreachable!(),
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
Command::ImportMl {
path,
} => {
let mut file = match OpenOptions::new().read(true).open(&path).await {
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
}
};
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(session)?;
// Check the permissions level
kvs.check(session, Action::Edit, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Create a new buffer
let mut buffer = Vec::new();
// Load all the uploaded file chunks
if let Err(error) = file.read_to_end(&mut buffer).await {
return Err(Error::FileRead {
path,
error,
}
.into());
}
// Check that the SurrealML file is valid
let file = match SurMlFile::from_bytes(buffer) {
Ok(file) => file,
Err(error) => {
return Err(Error::FileRead {
path,
error: io::Error::new(
io::ErrorKind::InvalidData,
error.message.to_string(),
),
}
.into());
}
};
// Convert the file back in to raw bytes
let data = file.to_bytes();
// Calculate the hash of the model file
let hash = crate::obs::hash(&data);
// Insert the file data in to the store
crate::obs::put(&hash, data).await?;
// Insert the model in to the database
let mut model = DefineModelStatement::default();
model.name = file.header.name.to_string().into();
model.version = file.header.version.to_string();
model.comment = Some(file.header.description.to_string().into());
model.hash = hash;
let query = DefineStatement::Model(model).into();
let responses = kvs.process(query, session, Some(vars.clone())).await?;
for response in responses {
response.result?;
}
Ok(DbResponse::Other(Value::None))
}
Command::Health => Ok(DbResponse::Other(Value::None)),
Command::Version => Ok(DbResponse::Other(crate::env::VERSION.into())),
Command::Set {
key,
value,
} => {
let var = Some(map! {
key.clone() => Value::None,
=> vars
@ -800,27 +960,24 @@ async fn router(
};
Ok(DbResponse::Other(Value::None))
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
vars.remove(&key.0);
}
Command::Unset {
key,
} => {
vars.remove(&key);
Ok(DbResponse::Other(Value::None))
}
Method::Live => {
if let Some(sender) = param.notification_sender {
if let [Value::Uuid(id)] = &params[..1] {
live_queries.insert(*id, sender);
}
}
Command::SubscribeLive {
uuid,
notification_sender,
} => {
live_queries.insert(uuid.into(), notification_sender);
Ok(DbResponse::Other(Value::None))
}
Method::Kill => {
let id = match &params[..] {
[Value::Uuid(id)] => *id,
_ => unreachable!(),
};
live_queries.remove(&id);
let value = kill_live_query(kvs, id, session, vars.clone()).await?;
Command::Kill {
uuid,
} => {
live_queries.remove(&uuid.into());
let value = kill_live_query(kvs, uuid.into(), session, vars.clone()).await?;
Ok(DbResponse::Other(value))
}
}

View file

@ -15,19 +15,9 @@ pub mod remote;
#[doc(hidden)]
pub mod tasks;
use crate::sql::statements::CreateStatement;
use crate::sql::statements::DeleteStatement;
use crate::sql::statements::InsertStatement;
use crate::sql::statements::SelectStatement;
use crate::sql::statements::UpdateStatement;
use crate::sql::statements::UpsertStatement;
use crate::sql::Data;
use crate::sql::Field;
use crate::sql::Output;
use crate::sql::Value;
use crate::sql::Values;
use futures::Stream;
use std::mem;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
@ -40,133 +30,21 @@ use wasmtimer::std::Instant;
#[cfg(target_arch = "wasm32")]
use wasmtimer::tokio::Interval;
#[allow(dead_code)] // used by the the embedded database and `http`
fn split_params(params: &mut [Value]) -> (bool, Values, Value) {
let (what, data) = match params {
[what] => (mem::take(what), Value::None),
[what, data] => (mem::take(what), mem::take(data)),
_ => unreachable!(),
};
let one = what.is_thing();
let what = match what {
Value::Array(vec) => {
// used in http and all local engines.
#[allow(dead_code)]
fn value_to_values(v: Value) -> Values {
match v {
Value::Array(x) => {
let mut values = Values::default();
values.0 = vec.0;
values.0 = x.0;
values
}
value => {
x => {
let mut values = Values::default();
values.0 = vec![value];
values.0 = vec![x];
values
}
};
(one, what, data)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn create_statement(params: &mut [Value]) -> CreateStatement {
let (_, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::ContentExpression(value)),
};
let mut stmt = CreateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
stmt
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn upsert_statement(params: &mut [Value]) -> (bool, UpsertStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::ContentExpression(value)),
};
let mut stmt = UpsertStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn update_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::ContentExpression(value)),
};
let mut stmt = UpdateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn insert_statement(params: &mut [Value]) -> (bool, InsertStatement) {
let (what, data) = match params {
[what, data] => (mem::take(what), mem::take(data)),
_ => unreachable!(),
};
let one = !data.is_array();
let mut stmt = InsertStatement::default();
stmt.into = match what {
Value::None => None,
Value::Null => None,
what => Some(what),
};
stmt.data = Data::SingleExpression(data);
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn patch_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::PatchExpression(value)),
};
let mut stmt = UpdateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn merge_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::MergeExpression(value)),
};
let mut stmt = UpdateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn select_statement(params: &mut [Value]) -> (bool, SelectStatement) {
let (one, what, _) = split_params(params);
let mut stmt = SelectStatement::default();
stmt.what = what;
stmt.expr.0 = vec![Field::All];
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn delete_statement(params: &mut [Value]) -> (bool, DeleteStatement) {
let (one, what, _) = split_params(params);
let mut stmt = DeleteStatement::default();
stmt.what = what;
stmt.output = Some(Output::Before);
(one, stmt)
}
}
struct IntervalStream {

View file

@ -5,21 +5,10 @@ pub(crate) mod native;
#[cfg(target_arch = "wasm32")]
pub(crate) mod wasm;
use crate::api::conn::Command;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::engine::create_statement;
use crate::api::engine::delete_statement;
use crate::api::engine::insert_statement;
use crate::api::engine::merge_statement;
use crate::api::engine::patch_statement;
use crate::api::conn::RequestData;
use crate::api::engine::remote::duration_from_str;
use crate::api::engine::select_statement;
use crate::api::engine::update_statement;
use crate::api::engine::upsert_statement;
use crate::api::err::Error;
use crate::api::method::query::QueryResult;
use crate::api::Connect;
@ -27,6 +16,7 @@ use crate::api::Response as QueryResponse;
use crate::api::Result;
use crate::api::Surreal;
use crate::dbs::Status;
use crate::engine::value_to_values;
use crate::headers::AUTH_DB;
use crate::headers::AUTH_NS;
use crate::headers::DB;
@ -36,19 +26,29 @@ use crate::opt::IntoEndpoint;
use crate::sql::from_value;
use crate::sql::serde::deserialize;
use crate::sql::Value;
#[cfg(not(target_arch = "wasm32"))]
use futures::TryStreamExt;
use indexmap::IndexMap;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest::header::ACCEPT;
#[cfg(not(target_arch = "wasm32"))]
use reqwest::header::CONTENT_TYPE;
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde::Serialize;
use std::marker::PhantomData;
use std::mem;
use surrealdb_core::sql::statements::CreateStatement;
use surrealdb_core::sql::statements::DeleteStatement;
use surrealdb_core::sql::statements::InsertStatement;
use surrealdb_core::sql::statements::SelectStatement;
use surrealdb_core::sql::statements::UpdateStatement;
use surrealdb_core::sql::statements::UpsertStatement;
use surrealdb_core::sql::Data;
use surrealdb_core::sql::Field;
use surrealdb_core::sql::Output;
use url::Url;
#[cfg(not(target_arch = "wasm32"))]
use reqwest::header::CONTENT_TYPE;
#[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf;
#[cfg(not(target_arch = "wasm32"))]
@ -57,7 +57,8 @@ use tokio::fs::OpenOptions;
use tokio::io;
#[cfg(not(target_arch = "wasm32"))]
use tokio_util::compat::FuturesAsyncReadCompatExt;
use url::Url;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
const SQL_PATH: &str = "sql";
@ -238,65 +239,61 @@ async fn take(one: bool, request: RequestBuilder) -> Result<Value> {
}
}
#[cfg(not(target_arch = "wasm32"))]
type BackupSender = channel::Sender<Result<Vec<u8>>>;
#[cfg(not(target_arch = "wasm32"))]
async fn export(
request: RequestBuilder,
(file, sender): (Option<PathBuf>, Option<BackupSender>),
) -> Result<Value> {
match (file, sender) {
(Some(path), None) => {
let mut response = request
.send()
.await?
.error_for_status()?
.bytes_stream()
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
.into_async_read()
.compat();
let mut file = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
.await
{
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
}
};
if let Err(error) = io::copy(&mut response, &mut file).await {
return Err(Error::FileRead {
async fn export_file(request: RequestBuilder, path: PathBuf) -> Result<Value> {
let mut response = request
.send()
.await?
.error_for_status()?
.bytes_stream()
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
.into_async_read()
.compat();
let mut file =
match OpenOptions::new().write(true).create(true).truncate(true).open(&path).await {
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
}
};
if let Err(error) = io::copy(&mut response, &mut file).await {
return Err(Error::FileRead {
path,
error,
}
(None, Some(tx)) => {
let mut response = request.send().await?.error_for_status()?.bytes_stream();
tokio::spawn(async move {
while let Ok(Some(bytes)) = response.try_next().await {
if tx.send(Ok(bytes.to_vec())).await.is_err() {
break;
}
}
});
}
_ => unreachable!(),
.into());
}
Ok(Value::None)
}
async fn export_bytes(request: RequestBuilder, bytes: BackupSender) -> Result<Value> {
let response = request.send().await?.error_for_status()?;
let future = async move {
let mut response = response.bytes_stream();
while let Ok(Some(b)) = response.try_next().await {
if bytes.send(Ok(b.to_vec())).await.is_err() {
break;
}
}
};
#[cfg(not(target_arch = "wasm32"))]
tokio::spawn(future);
#[cfg(target_arch = "wasm32")]
spawn_local(future);
Ok(Value::None)
}
#[cfg(not(target_arch = "wasm32"))]
async fn import(request: RequestBuilder, path: PathBuf) -> Result<Value> {
let file = match OpenOptions::new().read(true).open(&path).await {
@ -343,28 +340,24 @@ pub(crate) async fn health(request: RequestBuilder) -> Result<Value> {
}
async fn router(
(_, method, param): (i64, Method, Param),
RequestData {
command,
..
}: RequestData,
base_url: &Url,
client: &reqwest::Client,
headers: &mut HeaderMap,
vars: &mut IndexMap<String, String>,
auth: &mut Option<Auth>,
) -> Result<DbResponse> {
let mut params = param.other;
match method {
Method::Use => {
match command {
Command::Use {
namespace,
database,
} => {
let path = base_url.join(SQL_PATH)?;
let mut request = client.post(path).headers(headers.clone());
let (ns, db) = match &mut params[..] {
[Value::Strand(ns), Value::Strand(db)] => {
(Some(mem::take(&mut ns.0)), Some(mem::take(&mut db.0)))
}
[Value::Strand(ns), Value::None] => (Some(mem::take(&mut ns.0)), None),
[Value::None, Value::Strand(db)] => (None, Some(mem::take(&mut db.0))),
_ => unreachable!(),
};
let ns = match ns {
let ns = match namespace {
Some(ns) => match HeaderValue::try_from(&ns) {
Ok(ns) => {
request = request.header(&NS, &ns);
@ -376,7 +369,7 @@ async fn router(
},
None => None,
};
let db = match db {
let db = match database {
Some(db) => match HeaderValue::try_from(&db) {
Ok(db) => {
request = request.header(&DB, &db);
@ -398,52 +391,46 @@ async fn router(
}
Ok(DbResponse::Other(Value::None))
}
Method::Signin => {
let path = base_url.join(Method::Signin.as_str())?;
let credentials = match &mut params[..] {
[credentials] => credentials.to_string(),
_ => unreachable!(),
};
let request = client.post(path).headers(headers.clone()).auth(auth).body(credentials);
Command::Signin {
credentials,
} => {
let path = base_url.join("signin")?;
let request =
client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string());
let value = submit_auth(request).await?;
if let [credentials] = &mut params[..] {
if let Ok(Credentials {
if let Ok(Credentials {
user,
pass,
ns,
db,
}) = from_value(credentials.into())
{
*auth = Some(Auth::Basic {
user,
pass,
ns,
db,
}) = from_value(mem::take(credentials))
{
*auth = Some(Auth::Basic {
user,
pass,
ns,
db,
});
} else {
*auth = Some(Auth::Bearer {
token: value.to_raw_string(),
});
}
});
} else {
*auth = Some(Auth::Bearer {
token: value.to_raw_string(),
});
}
Ok(DbResponse::Other(value))
}
Method::Signup => {
let path = base_url.join(Method::Signup.as_str())?;
let credentials = match &mut params[..] {
[credentials] => credentials.to_string(),
_ => unreachable!(),
};
let request = client.post(path).headers(headers.clone()).auth(auth).body(credentials);
Command::Signup {
credentials,
} => {
let path = base_url.join("signup")?;
let request =
client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string());
let value = submit_auth(request).await?;
Ok(DbResponse::Other(value))
}
Method::Authenticate => {
Command::Authenticate {
token,
} => {
let path = base_url.join(SQL_PATH)?;
let token = match &mut params[..1] {
[Value::Strand(token)] => mem::take(&mut token.0),
_ => unreachable!(),
};
let request =
client.post(path).headers(headers.clone()).bearer_auth(&token).body("RETURN true");
take(true, request).await?;
@ -452,142 +439,276 @@ async fn router(
});
Ok(DbResponse::Other(Value::None))
}
Method::Invalidate => {
Command::Invalidate => {
*auth = None;
Ok(DbResponse::Other(Value::None))
}
Method::Create => {
Command::Create {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let statement = create_statement(&mut params);
let statement = {
let mut stmt = CreateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(true, request).await?;
Ok(DbResponse::Other(value))
}
Method::Upsert => {
Command::Upsert {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let (one, statement) = upsert_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpsertStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Update => {
Command::Update {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let (one, statement) = update_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Insert => {
Command::Insert {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let (one, statement) = insert_statement(&mut params);
let one = !data.is_array();
let statement = {
let mut stmt = InsertStatement::default();
stmt.into = what;
stmt.data = Data::SingleExpression(data);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Patch => {
Command::Patch {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let (one, statement) = patch_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::PatchExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Merge => {
Command::Merge {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let (one, statement) = merge_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::MergeExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Select => {
Command::Select {
what,
} => {
let path = base_url.join(SQL_PATH)?;
let (one, statement) = select_statement(&mut params);
let one = what.is_thing();
let statement = {
let mut stmt = SelectStatement::default();
stmt.what = value_to_values(what);
stmt.expr.0 = vec![Field::All];
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Delete => {
Command::Delete {
what,
} => {
let one = what.is_thing();
let path = base_url.join(SQL_PATH)?;
let (one, statement) = delete_statement(&mut params);
let (one, statement) = {
let mut stmt = DeleteStatement::default();
stmt.what = value_to_values(what);
stmt.output = Some(Output::Before);
(one, stmt)
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Method::Query => {
Command::Query {
query: q,
variables,
} => {
let path = base_url.join(SQL_PATH)?;
let mut request = client.post(path).headers(headers.clone()).query(&vars).auth(auth);
match param.query {
Some((query, bindings)) => {
let bindings: Vec<_> =
bindings.iter().map(|(key, value)| (key, value.to_string())).collect();
request = request.query(&bindings).body(query.to_string());
}
None => unreachable!(),
}
let bindings: Vec<_> =
variables.iter().map(|(key, value)| (key, value.to_string())).collect();
request = request.query(&bindings).body(q.to_string());
let values = query(request).await?;
Ok(DbResponse::Query(values))
}
#[cfg(target_arch = "wasm32")]
Method::Export | Method::Import => unreachable!(),
Command::ExportFile {
..
}
| Command::ExportMl {
..
}
| Command::ImportFile {
..
}
| Command::ImportMl {
..
} => {
// TODO: Better error message here, some backups are supported
Err(Error::BackupsNotSupported.into())
}
#[cfg(not(target_arch = "wasm32"))]
Method::Export => {
let path = match param.ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Export {
name,
version,
}) => base_url.join(&format!("ml/export/{name}/{version}"))?,
_ => base_url.join(Method::Export.as_str())?,
};
Command::ExportFile {
path,
} => {
let req_path = base_url.join("export")?;
let request = client
.get(path)
.get(req_path)
.headers(headers.clone())
.auth(auth)
.header(ACCEPT, "application/octet-stream");
let value = export(request, (param.file, param.bytes_sender)).await?;
let value = export_file(request, path).await?;
Ok(DbResponse::Other(value))
}
Command::ExportBytes {
bytes,
} => {
let req_path = base_url.join("export")?;
let request = client
.get(req_path)
.headers(headers.clone())
.auth(auth)
.header(ACCEPT, "application/octet-stream");
let value = export_bytes(request, bytes).await?;
Ok(DbResponse::Other(value))
}
#[cfg(not(target_arch = "wasm32"))]
Method::Import => {
let path = match param.ml_config {
#[cfg(feature = "ml")]
Some(MlConfig::Import) => base_url.join("ml/import")?,
_ => base_url.join(Method::Import.as_str())?,
};
let file = param.file.expect("file to import from");
Command::ExportMl {
path,
config,
} => {
let req_path =
base_url.join("ml")?.join("export")?.join(&config.name)?.join(&config.version)?;
let request = client
.post(path)
.get(req_path)
.headers(headers.clone())
.auth(auth)
.header(ACCEPT, "application/octet-stream");
let value = export_file(request, path).await?;
Ok(DbResponse::Other(value))
}
Command::ExportBytesMl {
bytes,
config,
} => {
let req_path =
base_url.join("ml")?.join("export")?.join(&config.name)?.join(&config.version)?;
let request = client
.get(req_path)
.headers(headers.clone())
.auth(auth)
.header(ACCEPT, "application/octet-stream");
let value = export_bytes(request, bytes).await?;
Ok(DbResponse::Other(value))
}
#[cfg(not(target_arch = "wasm32"))]
Command::ImportFile {
path,
} => {
let req_path = base_url.join("import")?;
let request = client
.post(req_path)
.headers(headers.clone())
.auth(auth)
.header(CONTENT_TYPE, "application/octet-stream");
let value = import(request, file).await?;
let value = import(request, path).await?;
Ok(DbResponse::Other(value))
}
Method::Health => {
let path = base_url.join(Method::Health.as_str())?;
#[cfg(not(target_arch = "wasm32"))]
Command::ImportMl {
path,
} => {
let req_path = base_url.join("ml")?.join("import")?;
let request = client
.post(req_path)
.headers(headers.clone())
.auth(auth)
.header(CONTENT_TYPE, "application/octet-stream");
let value = import(request, path).await?;
Ok(DbResponse::Other(value))
}
Command::Health => {
let path = base_url.join("health")?;
let request = client.get(path);
let value = health(request).await?;
Ok(DbResponse::Other(value))
}
Method::Version => {
let path = base_url.join(method.as_str())?;
Command::Version => {
let path = base_url.join("version")?;
let request = client.get(path);
let value = version(request).await?;
Ok(DbResponse::Other(value))
}
Method::Set => {
Command::Set {
key,
value,
} => {
let path = base_url.join(SQL_PATH)?;
let (key, value) = match &mut params[..2] {
[Value::Strand(key), value] => (mem::take(&mut key.0), value.to_string()),
_ => unreachable!(),
};
let value = value.to_string();
let request = client
.post(path)
.headers(headers.clone())
@ -598,38 +719,24 @@ async fn router(
vars.insert(key, value);
Ok(DbResponse::Other(Value::None))
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
vars.swap_remove(&key.0);
}
Command::Unset {
key,
} => {
vars.shift_remove(&key);
Ok(DbResponse::Other(Value::None))
}
Method::Live => {
Command::SubscribeLive {
..
} => Err(Error::LiveQueriesNotSupported.into()),
Command::Kill {
uuid,
} => {
let path = base_url.join(SQL_PATH)?;
let table = match &params[..] {
[table] => table.to_string(),
_ => unreachable!(),
};
let request = client
.post(path)
.headers(headers.clone())
.auth(auth)
.query(&[("table", table)])
.body("LIVE SELECT * FROM type::table($table)");
let value = take(true, request).await?;
Ok(DbResponse::Other(value))
}
Method::Kill => {
let path = base_url.join(SQL_PATH)?;
let id = match &params[..] {
[id] => id.to_string(),
_ => unreachable!(),
};
let request = client
.post(path)
.headers(headers.clone())
.auth(auth)
.query(&[("id", id)])
.query(&[("id", uuid)])
.body("KILL type::string($id)");
let value = take(true, request).await?;
Ok(DbResponse::Other(value))

View file

@ -1,6 +1,5 @@
use super::Client;
use crate::api::conn::Connection;
use crate::api::conn::Method;
use crate::api::conn::Route;
use crate::api::conn::Router;
use crate::api::method::BoxFuture;
@ -47,7 +46,7 @@ impl Connection for Client {
let base_url = address.url;
super::health(client.get(base_url.join(Method::Health.as_str())?)).await?;
super::health(client.get(base_url.join("health")?)).await?;
let (route_tx, route_rx) = match capacity {
0 => channel::unbounded(),

View file

@ -1,6 +1,5 @@
use super::Client;
use crate::api::conn::Connection;
use crate::api::conn::Method;
use crate::api::conn::Route;
use crate::api::conn::Router;
use crate::api::method::BoxFuture;
@ -53,7 +52,7 @@ async fn client(base_url: &Url) -> Result<reqwest::Client> {
let headers = super::default_headers();
let builder = ClientBuilder::new().default_headers(headers);
let client = builder.build()?;
let health = base_url.join(Method::Health.as_str())?;
let health = base_url.join("health")?;
super::health(client.get(health)).await?;
Ok(client)
}

View file

@ -6,8 +6,8 @@ pub(crate) mod native;
pub(crate) mod wasm;
use crate::api;
use crate::api::conn::Command;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use crate::api::engine::remote::duration_from_str;
use crate::api::err::Error;
use crate::api::method::query::QueryResult;
@ -20,15 +20,12 @@ use crate::dbs::Status;
use crate::method::Stats;
use crate::opt::IntoEndpoint;
use crate::sql::Value;
use bincode::Options as _;
use channel::Sender;
use indexmap::IndexMap;
use revision::revisioned;
use revision::Revisioned;
use serde::de::DeserializeOwned;
use serde::ser::SerializeMap;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::io::Read;
use std::marker::PhantomData;
@ -39,127 +36,64 @@ use uuid::Uuid;
pub(crate) const PATH: &str = "rpc";
const PING_INTERVAL: Duration = Duration::from_secs(5);
const PING_METHOD: &str = "ping";
const REVISION_HEADER: &str = "revision";
/// A struct which will be serialized as a map to behave like the previously used BTreeMap.
///
/// This struct serializes as if it is a surrealdb_core::sql::Value::Object.
#[derive(Debug)]
struct RouterRequest {
id: Option<Value>,
method: Value,
params: Option<Value>,
enum RequestEffect {
/// Completing this request sets a variable to a give value.
Set {
key: String,
value: Value,
},
/// Completing this request sets a variable to a give value.
Clear {
key: String,
},
/// Insert requests repsonses need to be flattened in an array.
Insert,
/// No effect
None,
}
impl Serialize for RouterRequest {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
struct InnerRequest<'a>(&'a RouterRequest);
impl Serialize for InnerRequest<'_> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let size = 1 + self.0.id.is_some() as usize + self.0.params.is_some() as usize;
let mut map = serializer.serialize_map(Some(size))?;
if let Some(id) = self.0.id.as_ref() {
map.serialize_entry("id", id)?;
}
map.serialize_entry("method", &self.0.method)?;
if let Some(params) = self.0.params.as_ref() {
map.serialize_entry("params", params)?;
}
map.end()
}
}
serializer.serialize_newtype_variant("Value", 9, "Object", &InnerRequest(self))
}
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
enum ReplayMethod {
Use,
Signup,
Signin,
Invalidate,
Authenticate,
}
impl Revisioned for RouterRequest {
fn revision() -> u16 {
1
}
fn serialize_revisioned<W: std::io::Write>(
&self,
w: &mut W,
) -> std::result::Result<(), revision::Error> {
// version
Revisioned::serialize_revisioned(&1u32, w)?;
// object variant
Revisioned::serialize_revisioned(&9u32, w)?;
// object wrapper version
Revisioned::serialize_revisioned(&1u32, w)?;
let size = 1 + self.id.is_some() as usize + self.params.is_some() as usize;
size.serialize_revisioned(w)?;
let serializer = bincode::options()
.with_no_limit()
.with_little_endian()
.with_varint_encoding()
.reject_trailing_bytes();
if let Some(x) = self.id.as_ref() {
serializer
.serialize_into(&mut *w, "id")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
serializer
.serialize_into(&mut *w, "method")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
self.method.serialize_revisioned(w)?;
if let Some(x) = self.params.as_ref() {
serializer
.serialize_into(&mut *w, "params")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
Ok(())
}
fn deserialize_revisioned<R: Read>(_: &mut R) -> std::result::Result<Self, revision::Error>
where
Self: Sized,
{
panic!("deliberately unimplemented");
}
struct PendingRequest {
// Does resolving this request has some effects.
effect: RequestEffect,
// The channel to send the result of the request into.
response_channel: Sender<Result<DbResponse>>,
}
struct RouterState<Sink, Stream, Msg> {
var_stash: IndexMap<i64, (String, Value)>,
struct RouterState<Sink, Stream> {
/// Vars currently set by the set method,
vars: IndexMap<String, Value>,
/// Messages which aught to be replayed on a reconnect.
replay: IndexMap<Method, Msg>,
replay: IndexMap<ReplayMethod, Command>,
/// Pending live queries
live_queries: HashMap<Uuid, channel::Sender<CoreNotification>>,
routes: HashMap<i64, (Method, Sender<Result<DbResponse>>)>,
/// Send requests which are still awaiting an awnser.
pending_requests: HashMap<i64, PendingRequest>,
/// The last time a message was recieved from the server.
last_activity: Instant,
/// The sink into which messages are send to surrealdb
sink: Sink,
/// The stream from which messages are recieved from surrealdb
stream: Stream,
}
impl<Sink, Stream, Msg> RouterState<Sink, Stream, Msg> {
impl<Sink, Stream> RouterState<Sink, Stream> {
pub fn new(sink: Sink, stream: Stream) -> Self {
RouterState {
var_stash: IndexMap::new(),
vars: IndexMap::new(),
replay: IndexMap::new(),
live_queries: HashMap::new(),
routes: HashMap::new(),
pending_requests: HashMap::new(),
last_activity: Instant::now(),
sink,
stream,
@ -317,67 +251,3 @@ where
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
}
#[cfg(test)]
mod test {
use std::io::Cursor;
use revision::Revisioned;
use surrealdb_core::sql::Value;
use super::RouterRequest;
fn assert_converts<S, D, I>(req: &RouterRequest, s: S, d: D)
where
S: FnOnce(&RouterRequest) -> I,
D: FnOnce(I) -> Value,
{
let ser = s(req);
let val = d(ser);
let Value::Object(obj) = val else {
panic!("not an object");
};
assert_eq!(obj.get("id").cloned(), req.id);
assert_eq!(obj.get("method").unwrap().clone(), req.method);
assert_eq!(obj.get("params").cloned(), req.params);
}
#[test]
fn router_request_value_conversion() {
let request = RouterRequest {
id: Some(Value::from(1234i64)),
method: Value::from("request"),
params: Some(vec![Value::from(1234i64), Value::from("request")].into()),
};
println!("test convert bincode");
assert_converts(
&request,
|i| bincode::serialize(i).unwrap(),
|b| bincode::deserialize(&b).unwrap(),
);
println!("test convert json");
assert_converts(
&request,
|i| serde_json::to_string(i).unwrap(),
|b| serde_json::from_str(&b).unwrap(),
);
println!("test convert revisioned");
assert_converts(
&request,
|i| {
let mut buf = Vec::new();
i.serialize_revisioned(&mut Cursor::new(&mut buf)).unwrap();
buf
},
|b| Value::deserialize_revisioned(&mut Cursor::new(b)).unwrap(),
);
println!("done");
}
}

View file

@ -1,15 +1,13 @@
use super::PATH;
use super::{deserialize, serialize};
use super::{HandleResult, RouterRequest};
use crate::api::conn::Connection;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use super::{
deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH,
};
use crate::api::conn::Route;
use crate::api::conn::Router;
use crate::api::conn::{Command, DbResponse};
use crate::api::conn::{Connection, RequestData};
use crate::api::engine::remote::ws::Client;
use crate::api::engine::remote::ws::Response;
use crate::api::engine::remote::ws::PING_INTERVAL;
use crate::api::engine::remote::ws::PING_METHOD;
use crate::api::err::Error;
use crate::api::method::BoxFuture;
use crate::api::opt::Endpoint;
@ -31,7 +29,6 @@ use revision::revisioned;
use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::collections::HashSet;
use std::mem;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
@ -58,7 +55,7 @@ pub(crate) const NAGLE_ALG: bool = false;
type MessageSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type MessageStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
type RouterState = super::RouterState<MessageSink, MessageStream, Message>;
type RouterState = super::RouterState<MessageSink, MessageStream>;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
impl From<Tls> for Connector {
@ -153,80 +150,106 @@ async fn router_handle_route(
state: &mut RouterState,
endpoint: &Endpoint,
) -> HandleResult {
let (id, method, param) = request;
let params = match param.query {
Some((query, bindings)) => {
vec![query.into(), bindings.into()]
let RequestData {
id,
command,
} = request;
// We probably shouldn't be sending duplicate id requests.
let entry = state.pending_requests.entry(id);
let Entry::Vacant(entry) = entry else {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
None => param.other,
return HandleResult::Ok;
};
match method {
Method::Set => {
if let [Value::Strand(key), value] = &params[..2] {
state.var_stash.insert(id, (key.0.clone(), value.clone()));
}
let mut effect = RequestEffect::None;
match command {
Command::Set {
ref key,
ref value,
} => {
effect = RequestEffect::Set {
key: key.clone(),
value: value.clone(),
};
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
state.vars.swap_remove(&key.0);
}
Command::Unset {
ref key,
} => {
effect = RequestEffect::Clear {
key: key.clone(),
};
}
Method::Live => {
if let Some(sender) = param.notification_sender {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.insert(id.0, sender);
}
}
Command::Insert {
..
} => {
effect = RequestEffect::Insert;
}
Command::SubscribeLive {
ref uuid,
ref notification_sender,
} => {
state.live_queries.insert(*uuid, notification_sender.clone());
if response.clone().send(Ok(DbResponse::Other(Value::None))).await.is_err() {
trace!("Receiver dropped");
}
// There is nothing to send to the server here
return HandleResult::Ok;
}
Method::Kill => {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.remove(id);
}
Command::Kill {
ref uuid,
} => {
state.live_queries.remove(uuid);
}
Command::Use {
..
} => {
state.replay.insert(ReplayMethod::Use, command.clone());
}
Command::Signup {
..
} => {
state.replay.insert(ReplayMethod::Signup, command.clone());
}
Command::Signin {
..
} => {
state.replay.insert(ReplayMethod::Signin, command.clone());
}
Command::Invalidate {
..
} => {
state.replay.insert(ReplayMethod::Invalidate, command.clone());
}
Command::Authenticate {
..
} => {
state.replay.insert(ReplayMethod::Authenticate, command.clone());
}
_ => {}
}
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = {
let request = RouterRequest {
id: Some(Value::from(id)),
method: method_str.into(),
params: (!params.is_empty()).then(|| params.into()),
};
let message = {
let Some(request) = command.into_router_request(Some(id)) else {
let _ = response.send(Err(Error::BackupsNotSupported.into())).await;
return HandleResult::Ok;
};
trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
state.replay.insert(method, message.clone());
}
match state.sink.send(message).await {
Ok(_) => {
state.last_activity = Instant::now();
match state.routes.entry(id) {
Entry::Vacant(entry) => {
// Register query route
entry.insert((method, response));
}
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
}
}
entry.insert(PendingRequest {
effect,
response_channel: response,
});
}
Err(error) => {
let error = Error::Ws(error.to_string());
@ -254,23 +277,50 @@ async fn router_handle_response(
Some(id) => {
if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs
if let Some((method, sender)) = state.routes.remove(&id) {
if matches!(method, Method::Set) {
if let Some((key, value)) = state.var_stash.swap_remove(&id) {
state.vars.insert(key, value);
}
}
// Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert) {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(value))) = &mut response {
if let [value] = &mut value.0[..] {
response = Ok(Data::Other(mem::take(value)));
if let Some(pending) = state.pending_requests.remove(&id) {
match pending.effect {
RequestEffect::None => {}
RequestEffect::Insert => {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(value))) =
response.result
{
if value.len() == 1 {
let _ = pending
.response_channel
.send(DbResponse::from(Ok(Data::Other(
value.into_iter().next().unwrap(),
))))
.await;
} else {
let _ = pending
.response_channel
.send(DbResponse::from(Ok(Data::Other(
Value::Array(value),
))))
.await;
}
return HandleResult::Ok;
}
}
RequestEffect::Set {
key,
value,
} => {
state.vars.insert(key, value);
}
RequestEffect::Clear {
key,
} => {
state.vars.shift_remove(&key);
}
}
let _res = sender.send(DbResponse::from(response)).await;
let _res = pending
.response_channel
.send(DbResponse::from(response.result))
.await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
}
}
}
@ -285,13 +335,11 @@ async fn router_handle_response(
if sender.send(notification).await.is_err() {
state.live_queries.remove(&live_query_id);
let kill = {
let request = RouterRequest {
id: None,
method: Method::Kill.as_str().into(),
params: Some(
vec![Value::from(live_query_id)].into(),
),
};
let request = Command::Kill {
uuid: *live_query_id,
}
.into_router_request(None)
.unwrap();
let value =
serialize(&request, endpoint.supports_revision)
.unwrap();
@ -326,8 +374,10 @@ async fn router_handle_response(
{
// Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
if let Some((_method, sender)) = state.routes.remove(&id) {
let _res = sender.send(Err(error)).await;
if let Some(pending) = state.pending_requests.remove(&id) {
let _res = pending.response_channel.send(Err(error)).await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
}
}
} else {
@ -353,19 +403,27 @@ async fn router_reconnect(
let (new_sink, new_stream) = s.split();
state.sink = new_sink;
state.stream = new_stream;
for (_, message) in &state.replay {
if let Err(error) = state.sink.send(message.clone()).await {
for commands in state.replay.values() {
let request = commands
.clone()
.into_router_request(None)
.expect("replay commands should always convert to route requests");
let message = serialize(&request, endpoint.supports_revision).unwrap();
if let Err(error) = state.sink.send(Message::Binary(message)).await {
trace!("{error}");
time::sleep(time::Duration::from_secs(1)).await;
continue;
}
}
for (key, value) in &state.vars {
let request = RouterRequest {
id: None,
method: Method::Set.as_str().into(),
params: Some(vec![key.as_str().into(), value.clone()].into()),
};
let request = Command::Set {
key: key.as_str().into(),
value: value.clone(),
}
.into_router_request(None)
.unwrap();
trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap();
if let Err(error) = state.sink.send(Message::Binary(payload)).await {
@ -394,11 +452,7 @@ pub(crate) async fn run_router(
route_rx: Receiver<Route>,
) {
let ping = {
let request = RouterRequest {
id: None,
method: PING_METHOD.into(),
params: None,
};
let request = Command::Health.into_router_request(None).unwrap();
let value = serialize(&request, endpoint.supports_revision).unwrap();
Message::Binary(value)
};
@ -418,7 +472,7 @@ pub(crate) async fn run_router(
state.last_activity = Instant::now();
state.live_queries.clear();
state.routes.clear();
state.pending_requests.clear();
loop {
tokio::select! {

View file

@ -1,14 +1,13 @@
use super::{deserialize, serialize};
use super::{HandleResult, PATH};
use crate::api::conn::Connection;
use super::{
deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH,
};
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use crate::api::conn::Route;
use crate::api::conn::Router;
use crate::api::conn::{Command, Connection, RequestData};
use crate::api::engine::remote::ws::Client;
use crate::api::engine::remote::ws::Response;
use crate::api::engine::remote::ws::PING_INTERVAL;
use crate::api::engine::remote::ws::PING_METHOD;
use crate::api::err::Error;
use crate::api::method::BoxFuture;
use crate::api::opt::Endpoint;
@ -16,7 +15,7 @@ use crate::api::ExtraFeatures;
use crate::api::OnceLockExt;
use crate::api::Result;
use crate::api::Surreal;
use crate::engine::remote::ws::{Data, RouterRequest};
use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream;
use crate::opt::WaitFor;
use crate::sql::Value;
@ -34,7 +33,6 @@ use serde::Deserialize;
use std::collections::hash_map::Entry;
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::mem;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::sync::OnceLock;
@ -50,7 +48,7 @@ use ws_stream_wasm::{WsEvent, WsStream};
type MessageStream = SplitStream<WsStream>;
type MessageSink = SplitSink<WsStream, Message>;
type RouterState = super::RouterState<MessageSink, MessageStream, Message>;
type RouterState = super::RouterState<MessageSink, MessageStream>;
impl crate::api::Connection for Client {}
@ -96,79 +94,106 @@ async fn router_handle_request(
state: &mut RouterState,
endpoint: &Endpoint,
) -> HandleResult {
let (id, method, param) = request;
let params = match param.query {
Some((query, bindings)) => {
vec![query.into(), bindings.into()]
let RequestData {
id,
command,
} = request;
let entry = state.pending_requests.entry(id);
// We probably shouldn't be sending duplicate id requests.
let Entry::Vacant(entry) = entry else {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
None => param.other,
return HandleResult::Ok;
};
match method {
Method::Set => {
if let [Value::Strand(key), value] = &params[..2] {
state.var_stash.insert(id, (key.0.clone(), value.clone()));
}
let mut effect = RequestEffect::None;
match command {
Command::Set {
ref key,
ref value,
} => {
effect = RequestEffect::Set {
key: key.clone(),
value: value.clone(),
};
}
Method::Unset => {
if let [Value::Strand(key)] = &params[..1] {
state.vars.swap_remove(&key.0);
}
Command::Unset {
ref key,
} => {
effect = RequestEffect::Clear {
key: key.clone(),
};
}
Method::Live => {
if let Some(sender) = param.notification_sender {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.insert(id.0, sender);
}
}
Command::Insert {
..
} => {
effect = RequestEffect::Insert;
}
Command::SubscribeLive {
ref uuid,
ref notification_sender,
} => {
state.live_queries.insert(*uuid, notification_sender.clone());
if response.send(Ok(DbResponse::Other(Value::None))).await.is_err() {
trace!("Receiver dropped");
}
// There is nothing to send to the server here
return HandleResult::Ok;
}
Method::Kill => {
if let [Value::Uuid(id)] = &params[..1] {
state.live_queries.remove(id);
}
Command::Kill {
ref uuid,
} => {
state.live_queries.remove(uuid);
}
Command::Use {
..
} => {
state.replay.insert(ReplayMethod::Use, command.clone());
}
Command::Signup {
..
} => {
state.replay.insert(ReplayMethod::Signup, command.clone());
}
Command::Signin {
..
} => {
state.replay.insert(ReplayMethod::Signin, command.clone());
}
Command::Invalidate {
..
} => {
state.replay.insert(ReplayMethod::Invalidate, command.clone());
}
Command::Authenticate {
..
} => {
state.replay.insert(ReplayMethod::Authenticate, command.clone());
}
_ => {}
}
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = {
let request = RouterRequest {
id: Some(Value::from(id)),
method: method_str.into(),
params: (!params.is_empty()).then(|| params.into()),
let Some(req) = command.into_router_request(Some(id)) else {
let _ = response.send(Err(Error::BackupsNotSupported.into())).await;
return HandleResult::Ok;
};
trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap();
trace!("Request {:?}", req);
let payload = serialize(&req, endpoint.supports_revision).unwrap();
Message::Binary(payload)
};
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
state.replay.insert(method, message.clone());
}
match state.sink.send(message).await {
Ok(..) => {
state.last_activity = Instant::now();
match state.routes.entry(id) {
Entry::Vacant(entry) => {
entry.insert((method, response));
}
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
}
}
entry.insert(PendingRequest {
effect,
response_channel: response,
});
}
Err(error) => {
let error = Error::Ws(error.to_string());
@ -196,23 +221,50 @@ async fn router_handle_response(
Some(id) => {
if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs
if let Some((method, sender)) = state.routes.remove(&id) {
if matches!(method, Method::Set) {
if let Some((key, value)) = state.var_stash.swap_remove(&id) {
state.vars.insert(key, value);
}
}
// Send the response back to the caller
let mut response = response.result;
if matches!(method, Method::Insert) {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(value))) = &mut response {
if let [value] = &mut value.0[..] {
response = Ok(Data::Other(mem::take(value)));
if let Some(pending) = state.pending_requests.remove(&id) {
match pending.effect {
RequestEffect::None => {}
RequestEffect::Insert => {
// For insert, we need to flatten single responses in an array
if let Ok(Data::Other(Value::Array(value))) =
response.result
{
if value.len() == 1 {
let _ = pending
.response_channel
.send(DbResponse::from(Ok(Data::Other(
value.into_iter().next().unwrap(),
))))
.await;
} else {
let _ = pending
.response_channel
.send(DbResponse::from(Ok(Data::Other(
Value::Array(value),
))))
.await;
}
return HandleResult::Ok;
}
}
RequestEffect::Set {
key,
value,
} => {
state.vars.insert(key, value);
}
RequestEffect::Clear {
key,
} => {
state.vars.shift_remove(&key);
}
}
let _res = sender.send(DbResponse::from(response)).await;
let _res = pending
.response_channel
.send(DbResponse::from(response.result))
.await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
}
}
}
@ -226,11 +278,10 @@ async fn router_handle_response(
if sender.send(notification).await.is_err() {
state.live_queries.remove(&live_query_id);
let kill = {
let request = RouterRequest {
id: None,
method: Method::Kill.as_str().into(),
params: Some(vec![Value::from(live_query_id)].into()),
};
let request = Command::Kill {
uuid: live_query_id.0,
}
.into_router_request(None);
let value = serialize(&request, endpoint.supports_revision)
.unwrap();
Message::Binary(value)
@ -265,8 +316,10 @@ async fn router_handle_response(
{
// Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
if let Some((_method, sender)) = state.routes.remove(&id) {
let _res = sender.send(Err(error)).await;
if let Some(req) = state.pending_requests.remove(&id) {
let _res = req.response_channel.send(Err(error)).await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
}
}
} else {
@ -311,18 +364,21 @@ async fn router_reconnect(
}
};
for (_, message) in &state.replay {
if let Err(error) = state.sink.send(message.clone()).await {
let message = message.clone().into_router_request(None);
let message = serialize(&message, endpoint.supports_revision).unwrap();
if let Err(error) = state.sink.send(Message::Binary(message)).await {
trace!("{error}");
time::sleep(Duration::from_secs(1)).await;
continue;
}
}
for (key, value) in &state.vars {
let request = RouterRequest {
id: None,
method: Method::Set.as_str().into(),
params: Some(vec![key.as_str().into(), value.clone()].into()),
};
let request = Command::Set {
key: key.as_str().into(),
value: value.clone(),
}
.into_router_request(None);
trace!("Request {:?}", request);
let serialize = serialize(&request, false).unwrap();
if let Err(error) = state.sink.send(Message::Binary(serialize)).await {
@ -378,7 +434,7 @@ pub(crate) async fn run_router(
let ping = {
let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into());
request.insert("method".to_owned(), "ping".into());
let value = Value::from(request);
let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value)
@ -397,7 +453,7 @@ pub(crate) async fn run_router(
state.last_activity = Instant::now();
state.live_queries.clear();
state.routes.clear();
state.pending_requests.clear();
loop {
futures::select! {

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::method::OnceLockExt;
use crate::api::opt::auth::Jwt;
@ -27,7 +26,11 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
router.execute_unit(Method::Authenticate, Param::new(vec![self.token.0.into()])).await
router
.execute_unit(Command::Authenticate {
token: self.token.0,
})
.await
})
}
}

View file

@ -1,17 +1,11 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::opt::Range;
use crate::api::opt::Resource;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::sql::to_value;
use crate::sql::Id;
use crate::sql::Value;
use crate::Surreal;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::borrow::Cow;
use std::future::IntoFuture;
use std::marker::PhantomData;
@ -21,21 +15,29 @@ use std::marker::PhantomData;
/// Content inserts or replaces the contents of a record entirely
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Content<'r, C: Connection, D, R> {
pub struct Content<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) method: Method,
pub(super) resource: Result<Resource>,
pub(super) range: Option<Range<Id>>,
pub(super) content: D,
pub(super) command: Result<Command>,
pub(super) response_type: PhantomData<R>,
}
impl<C, D, R> Content<'_, C, D, R>
impl<'r, C, R> Content<'r, C, R>
where
C: Connection,
{
pub(crate) fn from_closure<F>(client: Cow<'r, Surreal<C>>, f: F) -> Self
where
F: FnOnce() -> Result<Command>,
{
Content {
client,
command: f(),
response_type: PhantomData,
}
}
/// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Content<'static, C, D, R> {
pub fn into_owned(self) -> Content<'static, C, R> {
Content {
client: Cow::Owned(self.client.into_owned()),
..self
@ -48,33 +50,20 @@ macro_rules! into_future {
fn into_future(self) -> Self::IntoFuture {
let Content {
client,
method,
resource,
range,
content,
command,
..
} = self;
let content = to_value(content);
Box::pin(async move {
let param = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
let params = match content? {
Value::None | Value::Null => vec![param],
content => vec![param, content],
};
let router = client.router.extract()?;
router.$method(method, Param::new(params)).await
router.$method(command?).await
})
}
};
}
impl<'r, Client, D> IntoFuture for Content<'r, Client, D, Value>
impl<'r, Client> IntoFuture for Content<'r, Client, Value>
where
Client: Connection,
D: Serialize,
{
type Output = Result<Value>;
type IntoFuture = BoxFuture<'r, Self::Output>;
@ -82,10 +71,9 @@ where
into_future! {execute_value}
}
impl<'r, Client, D, R> IntoFuture for Content<'r, Client, D, Option<R>>
impl<'r, Client, R> IntoFuture for Content<'r, Client, Option<R>>
where
Client: Connection,
D: Serialize,
R: DeserializeOwned,
{
type Output = Result<Option<R>>;
@ -94,10 +82,9 @@ where
into_future! {execute_opt}
}
impl<'r, Client, D, R> IntoFuture for Content<'r, Client, D, Vec<R>>
impl<'r, Client, R> IntoFuture for Content<'r, Client, Vec<R>>
where
Client: Connection,
D: Serialize,
R: DeserializeOwned,
{
type Output = Result<Vec<R>>;

View file

@ -1,7 +1,5 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::method::Content;
use crate::api::opt::Resource;
use crate::api::Connection;
use crate::api::Result;
@ -13,6 +11,9 @@ use serde::Serialize;
use std::borrow::Cow;
use std::future::IntoFuture;
use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
use super::Content;
/// A record create future
#[derive(Debug)]
@ -46,7 +47,11 @@ macro_rules! into_future {
} = self;
Box::pin(async move {
let router = client.router.extract()?;
router.$method(Method::Create, Param::new(vec![resource?.into()])).await
let cmd = Command::Create {
what: resource?.into(),
data: None,
};
router.$method(cmd).await
})
}
};
@ -89,17 +94,22 @@ where
C: Connection,
{
/// Sets content of a record
pub fn content<D>(self, data: D) -> Content<'r, C, D, R>
pub fn content<D>(self, data: D) -> Content<'r, C, R>
where
D: Serialize,
{
Content {
client: self.client,
method: Method::Create,
resource: self.resource,
range: None,
content: data,
response_type: PhantomData,
}
Content::from_closure(self.client, || {
let content = to_value(data)?;
let data = match content {
Value::None | Value::Null => None,
content => Some(content),
};
Ok(Command::Create {
what: self.resource?.into(),
data,
})
})
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::opt::Range;
use crate::api::opt::Resource;
@ -47,12 +46,16 @@ macro_rules! into_future {
..
} = self;
Box::pin(async move {
let param = match range {
let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
let router = client.router.extract()?;
router.$method(Method::Delete, Param::new(vec![param])).await
router
.$method(Command::Delete {
what: param,
})
.await
})
}
};

View file

@ -1,6 +1,5 @@
use crate::api::conn::Method;
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::conn::MlExportConfig;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Error;
@ -8,7 +7,6 @@ use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::method::Model;
use crate::method::OnceLockExt;
use crate::opt::ExportDestination;
use crate::Surreal;
use channel::Receiver;
use futures::Stream;
@ -27,8 +25,8 @@ use std::task::Poll;
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Export<'r, C: Connection, R, T = ()> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) target: ExportDestination,
pub(super) ml_config: Option<MlConfig>,
pub(super) target: R,
pub(super) ml_config: Option<MlExportConfig>,
pub(super) response: PhantomData<R>,
pub(super) export_type: PhantomData<T>,
}
@ -42,7 +40,7 @@ where
Export {
client: self.client,
target: self.target,
ml_config: Some(MlConfig::Export {
ml_config: Some(MlExportConfig {
name: name.to_owned(),
version: version.to_string(),
}),
@ -78,12 +76,21 @@ where
if !router.features.contains(&ExtraFeatures::Backup) {
return Err(Error::BackupsNotSupported.into());
}
let mut param = match self.target {
ExportDestination::File(path) => Param::file(path),
ExportDestination::Memory => unreachable!(),
};
param.ml_config = self.ml_config;
router.execute_unit(Method::Export, param).await
if let Some(config) = self.ml_config {
return router
.execute_unit(Command::ExportMl {
path: self.target,
config,
})
.await;
}
router
.execute_unit(Command::ExportFile {
path: self.target,
})
.await
})
}
}
@ -102,12 +109,25 @@ where
return Err(Error::BackupsNotSupported.into());
}
let (tx, rx) = crate::channel::bounded(1);
let ExportDestination::Memory = self.target else {
unreachable!();
};
let mut param = Param::bytes_sender(tx);
param.ml_config = self.ml_config;
router.execute_unit(Method::Export, param).await?;
if let Some(config) = self.ml_config {
router
.execute_unit(Command::ExportBytesMl {
bytes: tx,
config,
})
.await?;
return Ok(Backup {
rx,
});
}
router
.execute_unit(Command::ExportBytes {
bytes: tx,
})
.await?;
Ok(Backup {
rx,
})

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
@ -37,7 +36,7 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
router.execute_unit(Method::Health, Param::new(Vec::new())).await
router.execute_unit(Command::Health).await
})
}
}

View file

@ -1,6 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Error;
@ -20,7 +18,7 @@ use std::path::PathBuf;
pub struct Import<'r, C: Connection, T = ()> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) file: PathBuf,
pub(super) ml_config: Option<MlConfig>,
pub(super) is_ml: bool,
pub(super) import_type: PhantomData<T>,
}
@ -33,7 +31,7 @@ where
Import {
client: self.client,
file: self.file,
ml_config: Some(MlConfig::Import),
is_ml: true,
import_type: PhantomData,
}
}
@ -65,9 +63,20 @@ where
if !router.features.contains(&ExtraFeatures::Backup) {
return Err(Error::BackupsNotSupported.into());
}
let mut param = Param::file(self.file);
param.ml_config = self.ml_config;
router.execute_unit(Method::Import, param).await
if self.is_ml {
return router
.execute_unit(Command::ImportMl {
path: self.file,
})
.await;
}
router
.execute_unit(Command::ImportFile {
path: self.file,
})
.await
})
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::err::Error;
use crate::api::method::BoxFuture;
use crate::api::method::Content;
@ -60,9 +59,13 @@ macro_rules! into_future {
Resource::Array(arr) => return Err(Error::InsertOnArray(arr).into()),
Resource::Edges(edges) => return Err(Error::InsertOnEdges(edges).into()),
};
let param = vec![table, data];
let cmd = Command::Insert {
what: Some(table),
data,
};
let router = client.router.extract()?;
router.$method(Method::Insert, Param::new(param)).await
router.$method(cmd).await
})
}
};
@ -106,59 +109,41 @@ where
R: DeserializeOwned,
{
/// Specifies the data to insert into the table
pub fn content<D>(self, data: D) -> Content<'r, C, Value, R>
pub fn content<D>(self, data: D) -> Content<'r, C, R>
where
D: Serialize,
{
let mut content = Content {
client: self.client,
method: Method::Insert,
resource: self.resource,
range: None,
content: Value::None,
response_type: PhantomData,
};
match crate::sql::to_value(data) {
Ok(mut data) => match content.resource {
Ok(Resource::Table(table)) => {
content.resource = Ok(table.into());
content.content = data;
}
Ok(Resource::RecordId(record_id)) => match data.is_array() {
true => {
content.resource = Err(Error::InvalidParams(
Content::from_closure(self.client, || {
let mut data = crate::sql::to_value(data)?;
match self.resource? {
Resource::Table(table) => Ok(Command::Insert {
what: Some(table.into()),
data,
}),
Resource::RecordId(thing) => {
if data.is_array() {
Err(Error::InvalidParams(
"Tried to insert multiple records on a record ID".to_owned(),
)
.into());
}
false => {
.into())
} else {
let mut table = Table::default();
table.0.clone_from(&record_id.tb);
content.resource = Ok(table.into());
table.0.clone_from(&thing.tb);
let what = Value::Table(table);
let mut ident = Ident::default();
"id".clone_into(&mut ident.0);
let id = Part::Field(ident);
data.put(&[id], record_id.into());
content.content = data;
data.put(&[id], thing.into());
Ok(Command::Insert {
what: Some(what),
data,
})
}
},
Ok(Resource::Object(obj)) => {
content.resource = Err(Error::InsertOnObject(obj).into());
}
Ok(Resource::Array(arr)) => {
content.resource = Err(Error::InsertOnArray(arr).into());
}
Ok(Resource::Edges(edges)) => {
content.resource = Err(Error::InsertOnEdges(edges).into());
}
Err(error) => {
content.resource = Err(error);
}
},
Err(error) => {
content.resource = Err(error.into());
Resource::Object(obj) => Err(Error::InsertOnObject(obj).into()),
Resource::Array(arr) => Err(Error::InsertOnArray(arr).into()),
Resource::Edges(edges) => Err(Error::InsertOnEdges(edges).into()),
}
};
content
})
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
@ -37,7 +36,7 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
router.execute_unit(Method::Invalidate, Param::new(Vec::new())).await
router.execute_unit(Command::Invalidate).await
})
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::conn::Router;
use crate::api::err::Error;
use crate::api::method::BoxFuture;
@ -33,12 +32,12 @@ use futures::StreamExt;
use serde::de::DeserializeOwned;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::mem;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
#[cfg(not(target_arch = "wasm32"))]
use tokio::spawn;
use uuid::Uuid;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local as spawn;
@ -99,11 +98,16 @@ macro_rules! into_future {
Default::default(),
false,
);
let id: Value = query.await?.take(0)?;
let rx = register(router, id.clone()).await?;
let Value::Uuid(id) = query.await?.take(0)? else {
return Err(Error::InternalError(
"successufull live query didn't return a uuid".to_string(),
)
.into());
};
let rx = register(router, id.0).await?;
Ok(Stream::new(
Surreal::new_from_router_waiter(client.router.clone(), client.waiter.clone()),
id,
id.0,
Some(rx),
))
})
@ -111,11 +115,14 @@ macro_rules! into_future {
};
}
pub(crate) async fn register(router: &Router, id: Value) -> Result<Receiver<dbs::Notification>> {
pub(crate) async fn register(router: &Router, id: Uuid) -> Result<Receiver<dbs::Notification>> {
let (tx, rx) = channel::unbounded();
let mut param = Param::notification_sender(tx);
param.other = vec![id];
router.execute_unit(Method::Live, param).await?;
router
.execute_unit(Command::SubscribeLive {
uuid: id,
notification_sender: tx,
})
.await?;
Ok(rx)
}
@ -158,7 +165,7 @@ pub struct Stream<R> {
pub(crate) client: Surreal<Any>,
// We no longer need the lifetime and the type parameter
// Leaving them in for backwards compatibility
pub(crate) id: Value,
pub(crate) id: Uuid,
pub(crate) rx: Option<Receiver<dbs::Notification>>,
pub(crate) response_type: PhantomData<R>,
}
@ -166,7 +173,7 @@ pub struct Stream<R> {
impl<R> Stream<R> {
pub(crate) fn new(
client: Surreal<Any>,
id: Value,
id: Uuid,
rx: Option<Receiver<dbs::Notification>>,
) -> Self {
Self {
@ -247,14 +254,19 @@ where
poll_next_and_convert! {}
}
pub(crate) fn kill<Client>(client: &Surreal<Client>, id: Value)
pub(crate) fn kill<Client>(client: &Surreal<Client>, uuid: Uuid)
where
Client: Connection,
{
let client = client.clone();
spawn(async move {
if let Ok(router) = client.router.extract() {
router.execute_unit(Method::Kill, Param::new(vec![id.clone()])).await.ok();
router
.execute_unit(Command::Kill {
uuid,
})
.await
.ok();
}
});
}
@ -264,9 +276,8 @@ impl<R> Drop for Stream<R> {
///
/// This kills the live query process responsible for this stream.
fn drop(&mut self) {
if !self.id.is_none() && self.rx.is_some() {
let id = mem::take(&mut self.id);
kill(&self.client, id);
if self.rx.is_some() {
kill(&self.client, self.id);
}
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::opt::Range;
use crate::api::opt::Resource;
@ -52,12 +51,22 @@ macro_rules! into_future {
} = self;
let content = to_value(content);
Box::pin(async move {
let param = match range {
let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
let content = match content? {
Value::None | Value::Null => None,
x => Some(x),
};
let router = client.router.extract()?;
router.$method(Method::Merge, Param::new(vec![param, content?])).await
let cmd = Command::Merge {
what: param,
data: content,
};
router.$method(cmd).await
})
}
};

View file

@ -1,4 +1,27 @@
//! Methods to use when interacting with a SurrealDB instance
use self::query::ValidQuery;
use crate::api::err::Error;
use crate::api::opt;
use crate::api::opt::auth;
use crate::api::opt::auth::Credentials;
use crate::api::opt::auth::Jwt;
use crate::api::opt::IntoEndpoint;
use crate::api::Connect;
use crate::api::Connection;
use crate::api::OnceLockExt;
use crate::api::Surreal;
use crate::opt::IntoExportDestination;
use crate::opt::WaitFor;
use crate::sql::to_value;
use crate::sql::Value;
use serde::Serialize;
use std::borrow::Cow;
use std::marker::PhantomData;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
pub(crate) mod live;
pub(crate) mod query;
@ -43,8 +66,7 @@ pub use commit::Commit;
pub use content::Content;
pub use create::Create;
pub use delete::Delete;
pub use export::Backup;
pub use export::Export;
pub use export::{Backup, Export};
use futures::Future;
pub use health::Health;
pub use import::Import;
@ -67,31 +89,6 @@ pub use use_db::UseDb;
pub use use_ns::UseNs;
pub use version::Version;
use crate::api::conn::Method;
use crate::api::opt;
use crate::api::opt::auth;
use crate::api::opt::auth::Credentials;
use crate::api::opt::auth::Jwt;
use crate::api::opt::IntoEndpoint;
use crate::api::Connect;
use crate::api::Connection;
use crate::api::OnceLockExt;
use crate::api::Surreal;
use crate::opt::IntoExportDestination;
use crate::opt::WaitFor;
use crate::sql::to_value;
use crate::sql::Value;
use serde::Serialize;
use std::borrow::Cow;
use std::marker::PhantomData;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use self::query::ValidQuery;
/// A alias for an often used type of future returned by async methods in this library.
pub(crate) type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>;
@ -113,36 +110,6 @@ pub struct Live;
#[derive(Debug)]
pub struct WithStats<T>(T);
impl Method {
#[allow(dead_code)] // used by `ws` and `http`
pub(crate) fn as_str(&self) -> &str {
match self {
Method::Authenticate => "authenticate",
Method::Create => "create",
Method::Delete => "delete",
Method::Export => "export",
Method::Health => "health",
Method::Import => "import",
Method::Invalidate => "invalidate",
Method::Insert => "insert",
Method::Kill => "kill",
Method::Live => "live",
Method::Merge => "merge",
Method::Patch => "patch",
Method::Query => "query",
Method::Select => "select",
Method::Set => "set",
Method::Signin => "signin",
Method::Signup => "signup",
Method::Unset => "unset",
Method::Update => "update",
Method::Upsert => "upsert",
Method::Use => "use",
Method::Version => "version",
}
}
}
impl<C> Surreal<C>
where
C: Connection,
@ -318,7 +285,7 @@ where
pub fn use_db(&self, db: impl Into<String>) -> UseDb<C> {
UseDb {
client: Cow::Borrowed(self),
ns: Value::None,
ns: None,
db: db.into(),
}
}
@ -457,7 +424,16 @@ where
pub fn signup<R>(&self, credentials: impl Credentials<auth::Signup, R>) -> Signup<C, R> {
Signup {
client: Cow::Borrowed(self),
credentials: to_value(credentials).map_err(Into::into),
credentials: to_value(credentials).map_err(Into::into).and_then(|x| {
if let Value::Object(x) = x {
Ok(x)
} else {
Err(Error::InternalError(
"credentials did not serialize to an object".to_string(),
)
.into())
}
}),
response_type: PhantomData,
}
}
@ -576,7 +552,16 @@ where
pub fn signin<R>(&self, credentials: impl Credentials<auth::Signin, R>) -> Signin<C, R> {
Signin {
client: Cow::Borrowed(self),
credentials: to_value(credentials).map_err(Into::into),
credentials: to_value(credentials).map_err(Into::into).and_then(|x| {
if let Value::Object(x) = x {
Ok(x)
} else {
Err(Error::InternalError(
"credentials did not serialize to an object".to_string(),
)
.into())
}
}),
response_type: PhantomData,
}
}
@ -1359,7 +1344,7 @@ where
Import {
client: Cow::Borrowed(self),
file: file.as_ref().to_owned(),
ml_config: None,
is_ml: false,
import_type: PhantomData,
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::opt::PatchOp;
use crate::api::opt::Range;
@ -51,7 +50,7 @@ macro_rules! into_future {
..
} = self;
Box::pin(async move {
let param = match range {
let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
@ -59,9 +58,14 @@ macro_rules! into_future {
for result in patches {
vec.push(result?);
}
let patches = vec.into();
let patches = Value::from(vec);
let router = client.router.extract()?;
router.$method(Method::Patch, Param::new(vec![param, patches])).await
let cmd = Command::Patch {
what: param,
data: Some(patches),
};
router.$method(cmd).await
})
}
};

View file

@ -1,8 +1,7 @@
use super::live;
use super::Stream;
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::err::Error;
use crate::api::method::BoxFuture;
use crate::api::opt;
@ -158,8 +157,12 @@ where
let mut query = sql::Query::default();
query.0 .0 = query_statements;
let param = Param::query(query, bindings);
let mut response = router.execute_query(Method::Query, param).await?;
let mut response = router
.execute_query(Command::Query {
query,
variables: bindings,
})
.await?;
for idx in query_indicies {
let Some((_, result)) = response.results.get(&idx) else {
@ -169,16 +172,24 @@ where
// This is a live query. We are using this as a workaround to avoid
// creating another public error variant for this internal error.
let res = match result {
Ok(id) => live::register(router, id.clone()).await.map(|rx| {
Stream::new(
Surreal::new_from_router_waiter(
client.router.clone(),
client.waiter.clone(),
),
id.clone(),
Some(rx),
)
}),
Ok(id) => {
let Value::Uuid(uuid) = id else {
return Err(Error::InternalError(
"successfull live query did not return a uuid".to_string(),
)
.into());
};
live::register(router, uuid.0).await.map(|rx| {
Stream::new(
Surreal::new_from_router_waiter(
client.router.clone(),
client.waiter.clone(),
),
uuid.0,
Some(rx),
)
})
}
Err(_) => Err(crate::Error::from(Error::NotLiveQuery(idx))),
};
response.live_queries.insert(idx, res);

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::method::OnceLockExt;
use crate::api::opt::Range;
@ -49,12 +48,16 @@ macro_rules! into_future {
..
} = self;
Box::pin(async move {
let param = match range {
let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
let router = client.router.extract()?;
router.$method(Method::Select, Param::new(vec![param])).await
router
.$method(Command::Select {
what: param,
})
.await
})
}
};

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
@ -41,7 +40,12 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
router.execute_unit(Method::Set, Param::new(vec![self.key.into(), self.value?])).await
router
.execute_unit(Command::Set {
key: self.key,
value: self.value?,
})
.await
})
}
}

View file

@ -1,22 +1,21 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal;
use serde::de::DeserializeOwned;
use std::borrow::Cow;
use std::future::IntoFuture;
use std::marker::PhantomData;
use surrealdb_core::sql::Object;
/// A signin future
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Signin<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) credentials: Result<Value>,
pub(super) credentials: Result<Object>,
pub(super) response_type: PhantomData<R>,
}
@ -49,7 +48,11 @@ where
} = self;
Box::pin(async move {
let router = client.router.extract()?;
router.execute(Method::Signin, Param::new(vec![credentials?])).await
router
.execute(Command::Signin {
credentials: credentials?,
})
.await
})
}
}

View file

@ -1,22 +1,21 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal;
use serde::de::DeserializeOwned;
use std::borrow::Cow;
use std::future::IntoFuture;
use std::marker::PhantomData;
use surrealdb_core::sql::Object;
/// A signup future
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Signup<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) credentials: Result<Value>,
pub(super) credentials: Result<Object>,
pub(super) response_type: PhantomData<R>,
}
@ -49,7 +48,11 @@ where
} = self;
Box::pin(async move {
let router = client.router.extract()?;
router.execute(Method::Signup, Param::new(vec![credentials?])).await
router
.execute(Command::Signup {
credentials: credentials?,
})
.await
})
}
}

View file

@ -1,12 +1,11 @@
use channel::Receiver;
use super::types::User;
use crate::api::conn::Command;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use crate::api::conn::Route;
use crate::api::Response as QueryResponse;
use crate::sql::to_value;
use crate::sql::Value;
use channel::Receiver;
pub(super) fn mock(route_rx: Receiver<Route>) {
tokio::spawn(async move {
@ -15,81 +14,111 @@ pub(super) fn mock(route_rx: Receiver<Route>) {
response,
}) = route_rx.recv().await
{
let (_, method, param) = request;
let mut params = param.other;
let cmd = request.command;
let result = match method {
Method::Invalidate | Method::Health => match &params[..] {
[] => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Method::Authenticate | Method::Kill | Method::Unset => match &params[..] {
[_] => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Method::Live => match &params[..] {
[_] => Ok(DbResponse::Other(
"c6c0e36c-e2cf-42cb-b2d5-75415249b261".to_owned().into(),
)),
_ => unreachable!(),
},
Method::Version => match &params[..] {
[] => Ok(DbResponse::Other("1.0.0".into())),
_ => unreachable!(),
},
Method::Use => match &params[..] {
[_] | [_, _] => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Method::Signup | Method::Signin => match &mut params[..] {
[_] => Ok(DbResponse::Other("jwt".to_owned().into())),
_ => unreachable!(),
},
Method::Set => match &params[..] {
[_, _] => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Method::Query => match param.query {
Some(_) => Ok(DbResponse::Query(QueryResponse::new())),
_ => unreachable!(),
},
Method::Create => match &params[..] {
[_] => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
[_, user] => Ok(DbResponse::Other(user.clone())),
_ => unreachable!(),
},
Method::Select | Method::Delete => match &params[..] {
[Value::Thing(..)] => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
[Value::Table(..) | Value::Array(..) | Value::Range(..)] => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
_ => unreachable!(),
},
Method::Upsert | Method::Update | Method::Merge | Method::Patch => {
match &params[..] {
[Value::Thing(..)] | [Value::Thing(..), _] => {
Ok(DbResponse::Other(to_value(User::default()).unwrap()))
}
[Value::Table(..) | Value::Array(..) | Value::Range(..)]
| [Value::Table(..) | Value::Array(..) | Value::Range(..), _] => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
_ => unreachable!(),
}
let result = match cmd {
Command::Invalidate | Command::Health => Ok(DbResponse::Other(Value::None)),
Command::Authenticate {
..
}
Method::Insert => match &params[..] {
[Value::Table(..), Value::Array(..)] => {
| Command::Kill {
..
}
| Command::Unset {
..
} => Ok(DbResponse::Other(Value::None)),
Command::SubscribeLive {
..
} => Ok(DbResponse::Other("c6c0e36c-e2cf-42cb-b2d5-75415249b261".to_owned().into())),
Command::Version => Ok(DbResponse::Other("1.0.0".into())),
Command::Use {
..
} => Ok(DbResponse::Other(Value::None)),
Command::Signup {
..
}
| Command::Signin {
..
} => Ok(DbResponse::Other("jwt".to_owned().into())),
Command::Set {
..
} => Ok(DbResponse::Other(Value::None)),
Command::Query {
..
} => Ok(DbResponse::Query(QueryResponse::new())),
Command::Create {
data,
..
} => match data {
None => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
Some(user) => Ok(DbResponse::Other(user.clone())),
},
Command::Select {
what,
..
}
| Command::Delete {
what,
..
} => match what {
Value::Thing(..) => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
Value::Table(..) | Value::Array(..) | Value::Range(..) => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
[Value::Table(..), _] => {
_ => unreachable!(),
},
Command::Upsert {
what,
..
}
| Command::Update {
what,
..
}
| Command::Merge {
what,
..
}
| Command::Patch {
what,
..
} => match what {
Value::Thing(..) => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
Value::Table(..) | Value::Array(..) | Value::Range(..) => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
_ => unreachable!(),
},
Command::Insert {
what,
data,
} => match (what, data) {
(Some(Value::Table(..)), Value::Array(..)) => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
(Some(Value::Table(..)), _) => {
Ok(DbResponse::Other(to_value(User::default()).unwrap()))
}
_ => unreachable!(),
},
Method::Export | Method::Import => match param.file {
Some(_) => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Command::ExportMl {
..
}
| Command::ExportBytesMl {
..
}
| Command::ExportFile {
..
}
| Command::ExportBytes {
..
}
| Command::ImportMl {
..
}
| Command::ImportFile {
..
} => Ok(DbResponse::Other(Value::None)),
};
if let Err(message) = response.send(result).await {

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
@ -39,7 +38,11 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
router.execute_unit(Method::Unset, Param::new(vec![self.key.into()])).await
router
.execute_unit(Command::Unset {
key: self.key,
})
.await
})
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::method::Content;
use crate::api::method::Merge;
@ -18,6 +17,7 @@ use serde::Serialize;
use std::borrow::Cow;
use std::future::IntoFuture;
use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
/// An update future
#[derive(Debug)]
@ -52,12 +52,17 @@ macro_rules! into_future {
..
} = self;
Box::pin(async move {
let param = match range {
let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
let router = client.router.extract()?;
router.$method(Method::Upsert, Param::new(vec![param])).await
router
.$method(Command::Update {
what: param,
data: None,
})
.await
})
}
};
@ -123,18 +128,28 @@ where
R: DeserializeOwned,
{
/// Replaces the current document / record data with the specified data
pub fn content<D>(self, data: D) -> Content<'r, C, D, R>
pub fn content<D>(self, data: D) -> Content<'r, C, R>
where
D: Serialize,
{
Content {
client: self.client,
method: Method::Update,
resource: self.resource,
range: self.range,
content: data,
response_type: PhantomData,
}
Content::from_closure(self.client, || {
let data = to_value(data)?;
let what: Value = match self.range {
Some(range) => self.resource?.with_range(range)?.into(),
None => self.resource?.into(),
};
let data = match data {
Value::None | Value::Null => None,
content => Some(content),
};
Ok(Command::Update {
what,
data,
})
})
}
/// Merges the current document / record data with the specified data

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::method::Content;
use crate::api::method::Merge;
@ -18,6 +17,7 @@ use serde::Serialize;
use std::borrow::Cow;
use std::future::IntoFuture;
use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
/// An upsert future
#[derive(Debug)]
@ -52,12 +52,17 @@ macro_rules! into_future {
..
} = self;
Box::pin(async move {
let param = match range {
let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
let router = client.router.extract()?;
router.$method(Method::Upsert, Param::new(vec![param])).await
router
.$method(Command::Upsert {
what: param,
data: None,
})
.await
})
}
};
@ -123,18 +128,28 @@ where
R: DeserializeOwned,
{
/// Replaces the current document / record data with the specified data
pub fn content<D>(self, data: D) -> Content<'r, C, D, R>
pub fn content<D>(self, data: D) -> Content<'r, C, R>
where
D: Serialize,
{
Content {
client: self.client,
method: Method::Upsert,
resource: self.resource,
range: self.range,
content: data,
response_type: PhantomData,
}
Content::from_closure(self.client, || {
let data = to_value(data)?;
let what: Value = match self.range {
Some(range) => self.resource?.with_range(range)?.into(),
None => self.resource?.into(),
};
let data = match data {
Value::None | Value::Null => None,
content => Some(content),
};
Ok(Command::Upsert {
what,
data,
})
})
}
/// Merges the current document / record data with the specified data

View file

@ -1,11 +1,9 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::opt::WaitFor;
use crate::sql::Value;
use crate::Surreal;
use std::borrow::Cow;
use std::future::IntoFuture;
@ -14,7 +12,7 @@ use std::future::IntoFuture;
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct UseDb<'r, C: Connection> {
pub(super) client: Cow<'r, Surreal<C>>,
pub(super) ns: Value,
pub(super) ns: Option<String>,
pub(super) db: String,
}
@ -41,7 +39,12 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
router.execute_unit(Method::Use, Param::new(vec![self.ns, self.db.into()])).await?;
router
.execute_unit(Command::Use {
namespace: self.ns,
database: Some(self.db),
})
.await?;
self.client.waiter.0.send(Some(WaitFor::Database)).ok();
Ok(())
})

View file

@ -1,11 +1,9 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::method::BoxFuture;
use crate::api::method::UseDb;
use crate::api::Connection;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal;
use std::borrow::Cow;
use std::future::IntoFuture;
@ -55,7 +53,12 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
router.execute_unit(Method::Use, Param::new(vec![self.ns.into(), Value::None])).await
router
.execute_unit(Command::Use {
namespace: Some(self.ns),
database: None,
})
.await
})
}
}

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Command;
use crate::api::err::Error;
use crate::api::method::BoxFuture;
use crate::api::Connection;
@ -38,10 +37,7 @@ where
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
let version = router
.execute_value(Method::Version, Param::new(Vec::new()))
.await?
.convert_to_string()?;
let version = router.execute_value(Command::Version).await?.convert_to_string()?;
let semantic = version.trim_start_matches("surrealdb-");
semantic.parse().map_err(|_| Error::InvalidSemanticVersion(semantic.to_string()).into())
})

View file

@ -3,28 +3,28 @@ use std::path::PathBuf;
#[derive(Debug)]
#[non_exhaustive]
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub enum ExportDestination {
File(PathBuf),
Memory,
}
/// A trait for converting inputs into database export locations
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub trait IntoExportDestination<R> {
/// Converts an input into a database export location
fn into_export_destination(self) -> ExportDestination;
fn into_export_destination(self) -> R;
}
impl<T> IntoExportDestination<PathBuf> for T
where
T: AsRef<Path>,
{
fn into_export_destination(self) -> ExportDestination {
ExportDestination::File(self.as_ref().to_path_buf())
fn into_export_destination(self) -> PathBuf {
self.as_ref().to_path_buf()
}
}
impl IntoExportDestination<()> for () {
fn into_export_destination(self) -> ExportDestination {
ExportDestination::Memory
}
fn into_export_destination(self) {}
}

View file

@ -1,4 +1,9 @@
//! The different options and types for use in API functions
use crate::sql::to_value;
use crate::sql::Thing;
use crate::sql::Value;
use dmp::Diff;
use serde::Serialize;
pub mod auth;
pub mod capabilities;
@ -10,12 +15,6 @@ mod query;
mod resource;
mod tls;
use crate::sql::to_value;
use crate::sql::Thing;
use crate::sql::Value;
use dmp::Diff;
use serde::Serialize;
pub use config::*;
pub use endpoint::*;
pub use export::*;