Make communication with the router task more type-safe (#4406)

This commit is contained in:
Mees Delzenne 2024-07-23 16:38:54 +02:00 committed by GitHub
parent 08f4ad6c82
commit 2b1e6a32ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 2228 additions and 1628 deletions

View file

@ -1,5 +1,4 @@
use crate::sql::constant::ConstantValue; use crate::sql::constant::ConstantValue;
use crate::sql::id::Gen;
use crate::sql::Value; use crate::sql::Value;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
@ -174,6 +173,8 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
} }
} }
// TODO: Checkout blame for why these are here.
/*
#[derive(Serialize)] #[derive(Serialize)]
enum Id { enum Id {
Number(i64), Number(i64),
@ -212,6 +213,7 @@ fn into_json(value: Value, simplify: bool) -> JsonValue {
} }
} }
} }
*/
match value { match value {
// These value types are simple values which // These value types are simple values which

View file

@ -5,34 +5,8 @@ use crate::sql::Kind;
use crate::sql::Value; use crate::sql::Value;
use ser::Serializer as _; use ser::Serializer as _;
use serde::ser::Error as _; use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize; use serde::ser::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Cast;
type Error = Error;
type SerializeSeq = Impossible<Cast, Error>;
type SerializeTuple = Impossible<Cast, Error>;
type SerializeTupleStruct = SerializeCast;
type SerializeTupleVariant = Impossible<Cast, Error>;
type SerializeMap = Impossible<Cast, Error>;
type SerializeStruct = Impossible<Cast, Error>;
type SerializeStructVariant = Impossible<Cast, Error>;
const EXPECTED: &'static str = "an struct `Cast`";
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleStruct, Error> {
Ok(SerializeCast::default())
}
}
#[derive(Default)] #[derive(Default)]
pub(super) struct SerializeCast { pub(super) struct SerializeCast {
index: usize, index: usize,
@ -74,8 +48,34 @@ impl serde::ser::SerializeTupleStruct for SerializeCast {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use serde::ser::Impossible;
use serde::Serialize; use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Cast;
type Error = Error;
type SerializeSeq = Impossible<Cast, Error>;
type SerializeTuple = Impossible<Cast, Error>;
type SerializeTupleStruct = SerializeCast;
type SerializeTupleVariant = Impossible<Cast, Error>;
type SerializeMap = Impossible<Cast, Error>;
type SerializeStruct = Impossible<Cast, Error>;
type SerializeStructVariant = Impossible<Cast, Error>;
const EXPECTED: &'static str = "an struct `Cast`";
fn serialize_tuple_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleStruct, Error> {
Ok(SerializeCast::default())
}
}
#[test] #[test]
fn cast() { fn cast() {
let cast = Cast(Default::default(), Default::default()); let cast = Cast(Default::default(), Default::default());

View file

@ -6,35 +6,8 @@ use crate::sql::Tables;
use crate::sql::Thing; use crate::sql::Thing;
use ser::Serializer as _; use ser::Serializer as _;
use serde::ser::Error as _; use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize; use serde::ser::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Edges;
type Error = Error;
type SerializeSeq = Impossible<Edges, Error>;
type SerializeTuple = Impossible<Edges, Error>;
type SerializeTupleStruct = Impossible<Edges, Error>;
type SerializeTupleVariant = Impossible<Edges, Error>;
type SerializeMap = Impossible<Edges, Error>;
type SerializeStruct = SerializeEdges;
type SerializeStructVariant = Impossible<Edges, Error>;
const EXPECTED: &'static str = "a struct `Edges`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeEdges::default())
}
}
#[derive(Default)] #[derive(Default)]
pub(super) struct SerializeEdges { pub(super) struct SerializeEdges {
dir: Option<Dir>, dir: Option<Dir>,
@ -83,8 +56,35 @@ impl serde::ser::SerializeStruct for SerializeEdges {
mod tests { mod tests {
use super::*; use super::*;
use crate::sql::thing; use crate::sql::thing;
use serde::ser::Impossible;
use serde::Serialize; use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Edges;
type Error = Error;
type SerializeSeq = Impossible<Edges, Error>;
type SerializeTuple = Impossible<Edges, Error>;
type SerializeTupleStruct = Impossible<Edges, Error>;
type SerializeTupleVariant = Impossible<Edges, Error>;
type SerializeMap = Impossible<Edges, Error>;
type SerializeStruct = SerializeEdges;
type SerializeStructVariant = Impossible<Edges, Error>;
const EXPECTED: &'static str = "a struct `Edges`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeEdges::default())
}
}
#[test] #[test]
fn edges() { fn edges() {
let edges = Edges { let edges = Edges {

View file

@ -5,42 +5,8 @@ use crate::sql::Operator;
use crate::sql::Value; use crate::sql::Value;
use ser::Serializer as _; use ser::Serializer as _;
use serde::ser::Error as _; use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize; use serde::ser::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Expression;
type Error = Error;
type SerializeSeq = Impossible<Expression, Error>;
type SerializeTuple = Impossible<Expression, Error>;
type SerializeTupleStruct = Impossible<Expression, Error>;
type SerializeTupleVariant = Impossible<Expression, Error>;
type SerializeMap = Impossible<Expression, Error>;
type SerializeStruct = Impossible<Expression, Error>;
type SerializeStructVariant = SerializeExpression;
const EXPECTED: &'static str = "an enum `Expression`";
#[inline]
fn serialize_struct_variant(
self,
name: &'static str,
_variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant, Self::Error> {
debug_assert_eq!(name, crate::sql::expression::TOKEN);
match variant {
"Unary" => Ok(SerializeExpression::Unary(Default::default())),
"Binary" => Ok(SerializeExpression::Binary(Default::default())),
_ => Err(Error::custom(format!("unexpected `Expression::{name}`"))),
}
}
}
pub(super) enum SerializeExpression { pub(super) enum SerializeExpression {
Unary(SerializeUnary), Unary(SerializeUnary),
Binary(SerializeBinary), Binary(SerializeBinary),
@ -158,8 +124,42 @@ impl serde::ser::SerializeStructVariant for SerializeBinary {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use serde::ser::Impossible;
use serde::Serialize; use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Expression;
type Error = Error;
type SerializeSeq = Impossible<Expression, Error>;
type SerializeTuple = Impossible<Expression, Error>;
type SerializeTupleStruct = Impossible<Expression, Error>;
type SerializeTupleVariant = Impossible<Expression, Error>;
type SerializeMap = Impossible<Expression, Error>;
type SerializeStruct = Impossible<Expression, Error>;
type SerializeStructVariant = SerializeExpression;
const EXPECTED: &'static str = "an enum `Expression`";
#[inline]
fn serialize_struct_variant(
self,
name: &'static str,
_variant_index: u32,
variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant, Self::Error> {
debug_assert_eq!(name, crate::sql::expression::TOKEN);
match variant {
"Unary" => Ok(SerializeExpression::Unary(Default::default())),
"Binary" => Ok(SerializeExpression::Binary(Default::default())),
_ => Err(Error::custom(format!("unexpected `Expression::{name}`"))),
}
}
}
#[test] #[test]
fn default() { fn default() {
let expression = Expression::default(); let expression = Expression::default();

View file

@ -6,36 +6,9 @@ use crate::sql::Id;
use crate::sql::Range; use crate::sql::Range;
use ser::Serializer as _; use ser::Serializer as _;
use serde::ser::Error as _; use serde::ser::Error as _;
use serde::ser::Impossible;
use serde::ser::Serialize; use serde::ser::Serialize;
use std::ops::Bound; use std::ops::Bound;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Range;
type Error = Error;
type SerializeSeq = Impossible<Range, Error>;
type SerializeTuple = Impossible<Range, Error>;
type SerializeTupleStruct = Impossible<Range, Error>;
type SerializeTupleVariant = Impossible<Range, Error>;
type SerializeMap = Impossible<Range, Error>;
type SerializeStruct = SerializeRange;
type SerializeStructVariant = Impossible<Range, Error>;
const EXPECTED: &'static str = "a struct `Range`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeRange::default())
}
}
#[derive(Default)] #[derive(Default)]
pub(super) struct SerializeRange { pub(super) struct SerializeRange {
tb: Option<String>, tb: Option<String>,
@ -83,8 +56,35 @@ impl serde::ser::SerializeStruct for SerializeRange {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use serde::ser::Impossible;
use serde::Serialize; use serde::Serialize;
pub(super) struct Serializer;
impl ser::Serializer for Serializer {
type Ok = Range;
type Error = Error;
type SerializeSeq = Impossible<Range, Error>;
type SerializeTuple = Impossible<Range, Error>;
type SerializeTupleStruct = Impossible<Range, Error>;
type SerializeTupleVariant = Impossible<Range, Error>;
type SerializeMap = Impossible<Range, Error>;
type SerializeStruct = SerializeRange;
type SerializeStructVariant = Impossible<Range, Error>;
const EXPECTED: &'static str = "a struct `Range`";
#[inline]
fn serialize_struct(
self,
_name: &'static str,
_len: usize,
) -> Result<Self::SerializeStruct, Error> {
Ok(SerializeRange::default())
}
}
#[test] #[test]
fn range() { fn range() {
let range = Range { let range = Range {

481
lib/src/api/conn/cmd.rs Normal file
View file

@ -0,0 +1,481 @@
use super::MlExportConfig;
use crate::Result;
use bincode::Options;
use channel::Sender;
use revision::Revisioned;
use serde::{ser::SerializeMap as _, Serialize};
use std::path::PathBuf;
use std::{collections::BTreeMap, io::Read};
use surrealdb_core::{
dbs::Notification,
sql::{Object, Query, Value},
};
use uuid::Uuid;
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) enum Command {
Use {
namespace: Option<String>,
database: Option<String>,
},
Signup {
credentials: Object,
},
Signin {
credentials: Object,
},
Authenticate {
token: String,
},
Invalidate,
Create {
what: Value,
data: Option<Value>,
},
Upsert {
what: Value,
data: Option<Value>,
},
Update {
what: Value,
data: Option<Value>,
},
Insert {
what: Option<Value>,
data: Value,
},
Patch {
what: Value,
data: Option<Value>,
},
Merge {
what: Value,
data: Option<Value>,
},
Select {
what: Value,
},
Delete {
what: Value,
},
Query {
query: Query,
variables: BTreeMap<String, Value>,
},
ExportFile {
path: PathBuf,
},
ExportMl {
path: PathBuf,
config: MlExportConfig,
},
ExportBytes {
bytes: Sender<Result<Vec<u8>>>,
},
ExportBytesMl {
bytes: Sender<Result<Vec<u8>>>,
config: MlExportConfig,
},
ImportFile {
path: PathBuf,
},
ImportMl {
path: PathBuf,
},
Health,
Version,
Set {
key: String,
value: Value,
},
Unset {
key: String,
},
SubscribeLive {
uuid: Uuid,
notification_sender: Sender<Notification>,
},
Kill {
uuid: Uuid,
},
}
impl Command {
#[cfg(feature = "protocol-ws")]
pub(crate) fn into_router_request(self, id: Option<i64>) -> Option<RouterRequest> {
let id = id.map(Value::from);
let res = match self {
Command::Use {
namespace,
database,
} => RouterRequest {
id,
method: Value::from("use"),
params: Some(vec![Value::from(namespace), Value::from(database)].into()),
},
Command::Signup {
credentials,
} => RouterRequest {
id,
method: "signup".into(),
params: Some(vec![Value::from(credentials)].into()),
},
Command::Signin {
credentials,
} => RouterRequest {
id,
method: "signin".into(),
params: Some(vec![Value::from(credentials)].into()),
},
Command::Authenticate {
token,
} => RouterRequest {
id,
method: "authenticate".into(),
params: Some(vec![Value::from(token)].into()),
},
Command::Invalidate => RouterRequest {
id,
method: "invalidate".into(),
params: None,
},
Command::Create {
what,
data,
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "create".into(),
params: Some(params.into()),
}
}
Command::Upsert {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "upsert".into(),
params: Some(params.into()),
}
}
Command::Update {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "update".into(),
params: Some(params.into()),
}
}
Command::Insert {
what,
data,
} => {
let mut params = if let Some(w) = what {
vec![w]
} else {
vec![Value::None]
};
params.push(data);
RouterRequest {
id,
method: "insert".into(),
params: Some(params.into()),
}
}
Command::Patch {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "patch".into(),
params: Some(params.into()),
}
}
Command::Merge {
what,
data,
..
} => {
let mut params = vec![what];
if let Some(data) = data {
params.push(data);
}
RouterRequest {
id,
method: "merge".into(),
params: Some(params.into()),
}
}
Command::Select {
what,
..
} => RouterRequest {
id,
method: "select".into(),
params: Some(vec![what].into()),
},
Command::Delete {
what,
..
} => RouterRequest {
id,
method: "delete".into(),
params: Some(vec![what].into()),
},
Command::Query {
query,
variables,
} => {
let params: Vec<Value> = vec![query.into(), variables.into()];
RouterRequest {
id,
method: "query".into(),
params: Some(params.into()),
}
}
Command::ExportFile {
..
}
| Command::ExportBytes {
..
}
| Command::ImportFile {
..
}
| Command::ExportBytesMl {
..
}
| Command::ExportMl {
..
}
| Command::ImportMl {
..
} => return None,
Command::Health => RouterRequest {
id,
method: "ping".into(),
params: None,
},
Command::Version => RouterRequest {
id,
method: "version".into(),
params: None,
},
Command::Set {
key,
value,
} => RouterRequest {
id,
method: "let".into(),
params: Some(vec![Value::from(key), value].into()),
},
Command::Unset {
key,
} => RouterRequest {
id,
method: "unset".into(),
params: Some(vec![Value::from(key)].into()),
},
Command::SubscribeLive {
..
} => return None,
Command::Kill {
uuid,
} => RouterRequest {
id,
method: "kill".into(),
params: Some(vec![Value::from(uuid)].into()),
},
};
Some(res)
}
}
/// A struct which will be serialized as a map to behave like the previously used BTreeMap.
///
/// This struct serializes as if it is a surrealdb_core::sql::Value::Object.
#[derive(Debug)]
pub(crate) struct RouterRequest {
id: Option<Value>,
method: Value,
params: Option<Value>,
}
impl Serialize for RouterRequest {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
struct InnerRequest<'a>(&'a RouterRequest);
impl Serialize for InnerRequest<'_> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let size = 1 + self.0.id.is_some() as usize + self.0.params.is_some() as usize;
let mut map = serializer.serialize_map(Some(size))?;
if let Some(id) = self.0.id.as_ref() {
map.serialize_entry("id", id)?;
}
map.serialize_entry("method", &self.0.method)?;
if let Some(params) = self.0.params.as_ref() {
map.serialize_entry("params", params)?;
}
map.end()
}
}
serializer.serialize_newtype_variant("Value", 9, "Object", &InnerRequest(self))
}
}
impl Revisioned for RouterRequest {
fn revision() -> u16 {
1
}
fn serialize_revisioned<W: std::io::Write>(
&self,
w: &mut W,
) -> std::result::Result<(), revision::Error> {
// version
Revisioned::serialize_revisioned(&1u32, w)?;
// object variant
Revisioned::serialize_revisioned(&9u32, w)?;
// object wrapper version
Revisioned::serialize_revisioned(&1u32, w)?;
let size = 1 + self.id.is_some() as usize + self.params.is_some() as usize;
size.serialize_revisioned(w)?;
let serializer = bincode::options()
.with_no_limit()
.with_little_endian()
.with_varint_encoding()
.reject_trailing_bytes();
if let Some(x) = self.id.as_ref() {
serializer
.serialize_into(&mut *w, "id")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
serializer
.serialize_into(&mut *w, "method")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
self.method.serialize_revisioned(w)?;
if let Some(x) = self.params.as_ref() {
serializer
.serialize_into(&mut *w, "params")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
Ok(())
}
fn deserialize_revisioned<R: Read>(_: &mut R) -> std::result::Result<Self, revision::Error>
where
Self: Sized,
{
panic!("deliberately unimplemented");
}
}
#[cfg(test)]
mod test {
use std::io::Cursor;
use revision::Revisioned;
use surrealdb_core::sql::Value;
use super::RouterRequest;
fn assert_converts<S, D, I>(req: &RouterRequest, s: S, d: D)
where
S: FnOnce(&RouterRequest) -> I,
D: FnOnce(I) -> Value,
{
let ser = s(req);
let val = d(ser);
let Value::Object(obj) = val else {
panic!("not an object");
};
assert_eq!(obj.get("id").cloned(), req.id);
assert_eq!(obj.get("method").unwrap().clone(), req.method);
assert_eq!(obj.get("params").cloned(), req.params);
}
#[test]
fn router_request_value_conversion() {
let request = RouterRequest {
id: Some(Value::from(1234i64)),
method: Value::from("request"),
params: Some(vec![Value::from(1234i64), Value::from("request")].into()),
};
println!("test convert bincode");
assert_converts(
&request,
|i| bincode::serialize(i).unwrap(),
|b| bincode::deserialize(&b).unwrap(),
);
println!("test convert json");
assert_converts(
&request,
|i| serde_json::to_string(i).unwrap(),
|b| serde_json::from_str(&b).unwrap(),
);
println!("test convert revisioned");
assert_converts(
&request,
|i| {
let mut buf = Vec::new();
i.serialize_revisioned(&mut Cursor::new(&mut buf)).unwrap();
buf
},
|b| Value::deserialize_revisioned(&mut Cursor::new(b)).unwrap(),
);
println!("done");
}
}

View file

@ -6,24 +6,29 @@ use crate::api::opt::Endpoint;
use crate::api::ExtraFeatures; use crate::api::ExtraFeatures;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::dbs::Notification;
use crate::sql::from_value;
use crate::sql::Query;
use crate::sql::Value;
use channel::Receiver; use channel::Receiver;
use channel::Sender; use channel::Sender;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::BTreeMap;
use std::collections::HashSet; use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::atomic::AtomicI64; use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use surrealdb_core::sql::{from_value, Value};
mod cmd;
pub(crate) use cmd::Command;
#[derive(Debug)]
#[allow(dead_code)] // used by the embedded and remote connections
pub struct RequestData {
pub(crate) id: i64,
pub(crate) command: Command,
}
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] // used by the embedded and remote connections #[allow(dead_code)] // used by the embedded and remote connections
pub(crate) struct Route { pub(crate) struct Route {
pub(crate) request: (i64, Method, Param), pub(crate) request: RequestData,
pub(crate) response: Sender<Result<DbResponse>>, pub(crate) response: Sender<Result<DbResponse>>,
} }
@ -42,14 +47,16 @@ impl Router {
pub(crate) fn send( pub(crate) fn send(
&self, &self,
method: Method, command: Command,
param: Param,
) -> BoxFuture<'_, Result<Receiver<Result<DbResponse>>>> { ) -> BoxFuture<'_, Result<Receiver<Result<DbResponse>>>> {
Box::pin(async move { Box::pin(async move {
let id = self.next_id(); let id = self.next_id();
let (sender, receiver) = channel::bounded(1); let (sender, receiver) = channel::bounded(1);
let route = Route { let route = Route {
request: (id, method, param), request: RequestData {
id,
command,
},
response: sender, response: sender,
}; };
self.sender.send(route).await?; self.sender.send(route).await?;
@ -86,28 +93,24 @@ impl Router {
} }
/// Execute all methods except `query` /// Execute all methods except `query`
pub(crate) fn execute<R>(&self, method: Method, param: Param) -> BoxFuture<'_, Result<R>> pub(crate) fn execute<R>(&self, command: Command) -> BoxFuture<'_, Result<R>>
where where
R: DeserializeOwned, R: DeserializeOwned,
{ {
Box::pin(async move { Box::pin(async move {
let rx = self.send(method, param).await?; let rx = self.send(command).await?;
let value = self.recv(rx).await?; let value = self.recv(rx).await?;
from_value(value).map_err(Into::into) from_value(value).map_err(Into::into)
}) })
} }
/// Execute methods that return an optional single response /// Execute methods that return an optional single response
pub(crate) fn execute_opt<R>( pub(crate) fn execute_opt<R>(&self, command: Command) -> BoxFuture<'_, Result<Option<R>>>
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Option<R>>>
where where
R: DeserializeOwned, R: DeserializeOwned,
{ {
Box::pin(async move { Box::pin(async move {
let rx = self.send(method, param).await?; let rx = self.send(command).await?;
match self.recv(rx).await? { match self.recv(rx).await? {
Value::None | Value::Null => Ok(None), Value::None | Value::Null => Ok(None),
value => from_value(value).map_err(Into::into), value => from_value(value).map_err(Into::into),
@ -116,16 +119,12 @@ impl Router {
} }
/// Execute methods that return multiple responses /// Execute methods that return multiple responses
pub(crate) fn execute_vec<R>( pub(crate) fn execute_vec<R>(&self, command: Command) -> BoxFuture<'_, Result<Vec<R>>>
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Vec<R>>>
where where
R: DeserializeOwned, R: DeserializeOwned,
{ {
Box::pin(async move { Box::pin(async move {
let rx = self.send(method, param).await?; let rx = self.send(command).await?;
let value = match self.recv(rx).await? { let value = match self.recv(rx).await? {
Value::None | Value::Null => Value::Array(Default::default()), Value::None | Value::Null => Value::Array(Default::default()),
Value::Array(array) => Value::Array(array), Value::Array(array) => Value::Array(array),
@ -136,9 +135,9 @@ impl Router {
} }
/// Execute methods that return nothing /// Execute methods that return nothing
pub(crate) fn execute_unit(&self, method: Method, param: Param) -> BoxFuture<'_, Result<()>> { pub(crate) fn execute_unit(&self, command: Command) -> BoxFuture<'_, Result<()>> {
Box::pin(async move { Box::pin(async move {
let rx = self.send(method, param).await?; let rx = self.send(command).await?;
match self.recv(rx).await? { match self.recv(rx).await? {
Value::None | Value::Null => Ok(()), Value::None | Value::Null => Ok(()),
Value::Array(array) if array.is_empty() => Ok(()), Value::Array(array) if array.is_empty() => Ok(()),
@ -152,82 +151,22 @@ impl Router {
} }
/// Execute methods that return a raw value /// Execute methods that return a raw value
pub(crate) fn execute_value( pub(crate) fn execute_value(&self, command: Command) -> BoxFuture<'_, Result<Value>> {
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Value>> {
Box::pin(async move { Box::pin(async move {
let rx = self.send(method, param).await?; let rx = self.send(command).await?;
self.recv(rx).await self.recv(rx).await
}) })
} }
/// Execute the `query` method /// Execute the `query` method
pub(crate) fn execute_query( pub(crate) fn execute_query(&self, command: Command) -> BoxFuture<'_, Result<Response>> {
&self,
method: Method,
param: Param,
) -> BoxFuture<'_, Result<Response>> {
Box::pin(async move { Box::pin(async move {
let rx = self.send(method, param).await?; let rx = self.send(command).await?;
self.recv_query(rx).await self.recv_query(rx).await
}) })
} }
} }
/// The query method
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum Method {
/// Sends an authentication token to the server
Authenticate,
/// Performs a merge update operation
Merge,
/// Creates a record in a table
Create,
/// Deletes a record from a table
Delete,
/// Exports a database
Export,
/// Checks the health of the server
Health,
/// Imports a database
Import,
/// Invalidates a session
Invalidate,
/// Inserts a record or records into a table
Insert,
/// Kills a live query
#[doc(hidden)] // Not supported yet
Kill,
/// Starts a live query
#[doc(hidden)] // Not supported yet
Live,
/// Performs a patch update operation
Patch,
/// Sends a raw query to the database
Query,
/// Selects a record or records from a table
Select,
/// Sets a parameter on the connection
Set,
/// Signs into the server
Signin,
/// Signs up on the server
Signup,
/// Removes a parameter from a connection
Unset,
/// Performs an update operation
Update,
/// Performs an upsert operation
Upsert,
/// Selects a namespace and database to use
Use,
/// Queries the version of the server
Version,
}
/// The database response sent from the router to the caller /// The database response sent from the router to the caller
#[derive(Debug)] #[derive(Debug)]
pub enum DbResponse { pub enum DbResponse {
@ -237,63 +176,13 @@ pub enum DbResponse {
Other(Value), Other(Value),
} }
#[derive(Debug)] #[derive(Debug, Clone)]
#[allow(dead_code)] // used by ML model import and export functions pub(crate) struct MlExportConfig {
pub(crate) enum MlConfig { // fields are used in http and local non-wasm with ml features
Import, #[allow(dead_code)]
Export { pub(crate) name: String,
name: String, #[allow(dead_code)]
version: String, pub(crate) version: String,
},
}
/// Holds the parameters given to the caller
#[derive(Debug, Default)]
#[allow(dead_code)] // used by the embedded and remote connections
pub struct Param {
pub(crate) query: Option<(Query, BTreeMap<String, Value>)>,
pub(crate) other: Vec<Value>,
pub(crate) file: Option<PathBuf>,
pub(crate) bytes_sender: Option<channel::Sender<Result<Vec<u8>>>>,
pub(crate) notification_sender: Option<channel::Sender<Notification>>,
pub(crate) ml_config: Option<MlConfig>,
}
impl Param {
pub(crate) fn new(other: Vec<Value>) -> Self {
Self {
other,
..Default::default()
}
}
pub(crate) fn query(query: Query, bindings: BTreeMap<String, Value>) -> Self {
Self {
query: Some((query, bindings)),
..Default::default()
}
}
pub(crate) fn file(file: PathBuf) -> Self {
Self {
file: Some(file),
..Default::default()
}
}
pub(crate) fn bytes_sender(send: channel::Sender<Result<Vec<u8>>>) -> Self {
Self {
bytes_sender: Some(send),
..Default::default()
}
}
pub(crate) fn notification_sender(send: channel::Sender<Notification>) -> Self {
Self {
notification_sender: Some(send),
..Default::default()
}
}
} }
/// Connection trait implemented by supported protocols /// Connection trait implemented by supported protocols

View file

@ -142,10 +142,7 @@ impl Connection for Any {
} }
let client = builder.build()?; let client = builder.build()?;
let base_url = address.url; let base_url = address.url;
engine::remote::http::health( engine::remote::http::health(client.get(base_url.join("health")?)).await?;
client.get(base_url.join(crate::api::conn::Method::Health.as_str())?),
)
.await?;
tokio::spawn(engine::remote::http::native::run_router( tokio::spawn(engine::remote::http::native::run_router(
base_url, client, route_rx, base_url, client, route_rx,
)); ));

View file

@ -26,78 +26,58 @@ pub(crate) mod native;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub(crate) mod wasm; pub(crate) mod wasm;
use crate::api::conn::DbResponse; use crate::{
use crate::api::conn::Method; api::{
#[cfg(not(target_arch = "wasm32"))] conn::{Command, DbResponse, RequestData},
use crate::api::conn::MlConfig; Connect, Response as QueryResponse, Result, Surreal,
use crate::api::conn::Param; },
use crate::api::engine::create_statement; method::Stats,
use crate::api::engine::delete_statement; opt::IntoEndpoint,
use crate::api::engine::insert_statement; };
use crate::api::engine::merge_statement; use channel::Sender;
use crate::api::engine::patch_statement; use indexmap::IndexMap;
use crate::api::engine::select_statement; use std::{
use crate::api::engine::update_statement; collections::{BTreeMap, HashMap},
use crate::api::engine::upsert_statement; marker::PhantomData,
mem,
sync::Arc,
time::Duration,
};
use surrealdb_core::{
dbs::{Notification, Response, Session},
kvs::Datastore,
sql::{
statements::{
CreateStatement, DeleteStatement, InsertStatement, KillStatement, SelectStatement,
UpdateStatement, UpsertStatement,
},
Data, Field, Output, Query, Statement, Uuid, Value,
},
};
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::Connect;
use crate::api::Response as QueryResponse;
use crate::api::Result;
use crate::api::Surreal;
use crate::dbs::Notification;
use crate::dbs::Response;
use crate::dbs::Session;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::check::check_ns_db;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::Action;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::iam::ResourceKind;
use crate::kvs::Datastore;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::kvs::{LockType, TransactionType};
use crate::method::Stats;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::ml::storage::surml_file::SurMlFile;
use crate::opt::IntoEndpoint;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::sql::statements::DefineModelStatement;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::sql::statements::DefineStatement;
use crate::sql::statements::KillStatement;
use crate::sql::Query;
use crate::sql::Statement;
use crate::sql::Uuid;
use crate::sql::Value;
use channel::Sender;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use futures::StreamExt;
use indexmap::IndexMap;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use tokio::fs::OpenOptions; use tokio::{
#[cfg(not(target_arch = "wasm32"))] fs::OpenOptions,
use tokio::io; io::{self, AsyncReadExt, AsyncWriteExt},
#[cfg(not(target_arch = "wasm32"))] };
use tokio::io::AsyncReadExt;
#[cfg(not(target_arch = "wasm32"))] #[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
use tokio::io::AsyncWriteExt; use crate::api::conn::MlExportConfig;
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
use futures::StreamExt;
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
use surrealdb_core::{
iam::{check::check_ns_db, Action, ResourceKind},
kvs::{LockType, TransactionType},
ml::storage::surml_file::SurMlFile,
sql::statements::{DefineModelStatement, DefineStatement},
};
use super::value_to_values;
const DEFAULT_TICK_INTERVAL: Duration = Duration::from_secs(10); const DEFAULT_TICK_INTERVAL: Duration = Duration::from_secs(10);
@ -429,44 +409,42 @@ async fn take(one: bool, responses: Vec<Response>) -> Result<Value> {
} }
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
async fn export( async fn export_file(kvs: &Datastore, sess: &Session, chn: channel::Sender<Vec<u8>>) -> Result<()> {
if let Err(error) = kvs.export(sess, chn).await?.await {
if let crate::error::Db::Channel(message) = error {
// This is not really an error. Just logging it for improved visibility.
trace!("{message}");
return Ok(());
}
return Err(error.into());
}
Ok(())
}
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
async fn export_ml(
kvs: &Datastore, kvs: &Datastore,
sess: &Session, sess: &Session,
chn: channel::Sender<Vec<u8>>, chn: channel::Sender<Vec<u8>>,
ml_config: Option<MlConfig>, MlExportConfig {
name,
version,
}: MlExportConfig,
) -> Result<()> { ) -> Result<()> {
match ml_config { // Ensure a NS and DB are set
#[cfg(feature = "ml")] let (nsv, dbv) = check_ns_db(sess)?;
Some(MlConfig::Export { // Check the permissions level
name, kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))?;
version, // Start a new readonly transaction
}) => { let tx = kvs.transaction(TransactionType::Read, LockType::Optimistic).await?;
// Ensure a NS and DB are set // Attempt to get the model definition
let (nsv, dbv) = check_ns_db(sess)?; let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?;
// Check the permissions level // Export the file data in to the store
kvs.check(sess, Action::View, ResourceKind::Model.on_db(&nsv, &dbv))?; let mut data = crate::obs::stream(info.hash.to_owned()).await?;
// Start a new readonly transaction // Process all stream values
let tx = kvs.transaction(TransactionType::Read, LockType::Optimistic).await?; while let Some(Ok(bytes)) = data.next().await {
// Attempt to get the model definition if chn.send(bytes.to_vec()).await.is_err() {
let info = tx.get_db_model(&nsv, &dbv, &name, &version).await?; break;
// Export the file data in to the store
let mut data = crate::obs::stream(info.hash.to_owned()).await?;
// Process all stream values
while let Some(Ok(bytes)) = data.next().await {
if chn.send(bytes.to_vec()).await.is_err() {
break;
}
}
}
_ => {
if let Err(error) = kvs.export(sess, chn).await?.await {
if let crate::error::Db::Channel(message) = error {
// This is not really an error. Just logging it for improved visibility.
trace!("{message}");
return Ok(());
}
return Err(error.into());
}
} }
} }
Ok(()) Ok(())
@ -505,211 +483,376 @@ async fn kill_live_query(
} }
async fn router( async fn router(
(_, method, param): (i64, Method, Param), RequestData {
command,
..
}: RequestData,
kvs: &Arc<Datastore>, kvs: &Arc<Datastore>,
session: &mut Session, session: &mut Session,
vars: &mut BTreeMap<String, Value>, vars: &mut BTreeMap<String, Value>,
live_queries: &mut HashMap<Uuid, Sender<Notification>>, live_queries: &mut HashMap<Uuid, Sender<Notification>>,
) -> Result<DbResponse> { ) -> Result<DbResponse> {
let mut params = param.other; match command {
Command::Use {
match method { namespace,
Method::Use => { database,
match &mut params[..] { } => {
[Value::Strand(ns), Value::Strand(db)] => { if let Some(ns) = namespace {
session.ns = Some(mem::take(&mut ns.0)); session.ns = Some(ns);
session.db = Some(mem::take(&mut db.0)); }
} if let Some(db) = database {
[Value::Strand(ns), Value::None] => { session.db = Some(db);
session.ns = Some(mem::take(&mut ns.0));
}
[Value::None, Value::Strand(db)] => {
session.db = Some(mem::take(&mut db.0));
}
_ => unreachable!(),
} }
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Signup => { Command::Signup {
let credentials = match &mut params[..] { credentials,
[Value::Object(credentials)] => mem::take(credentials), } => {
_ => unreachable!(),
};
let response = crate::iam::signup::signup(kvs, session, credentials).await?; let response = crate::iam::signup::signup(kvs, session, credentials).await?;
Ok(DbResponse::Other(response.into())) Ok(DbResponse::Other(response.into()))
} }
Method::Signin => { Command::Signin {
let credentials = match &mut params[..] { credentials,
[Value::Object(credentials)] => mem::take(credentials), } => {
_ => unreachable!(),
};
let response = crate::iam::signin::signin(kvs, session, credentials).await?; let response = crate::iam::signin::signin(kvs, session, credentials).await?;
Ok(DbResponse::Other(response.into())) Ok(DbResponse::Other(response.into()))
} }
Method::Authenticate => { Command::Authenticate {
let token = match &mut params[..] { token,
[Value::Strand(token)] => mem::take(&mut token.0), } => {
_ => unreachable!(),
};
crate::iam::verify::token(kvs, session, &token).await?; crate::iam::verify::token(kvs, session, &token).await?;
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Invalidate => { Command::Invalidate => {
crate::iam::clear::clear(session)?; crate::iam::clear::clear(session)?;
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Create => { Command::Create {
what,
data,
} => {
let mut query = Query::default(); let mut query = Query::default();
let statement = create_statement(&mut params); let statement = {
let mut stmt = CreateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Create(statement)]; query.0 .0 = vec![Statement::Create(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(true, response).await?; let value = take(true, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Upsert => { Command::Upsert {
what,
data,
} => {
let mut query = Query::default(); let mut query = Query::default();
let (one, statement) = upsert_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpsertStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Upsert(statement)]; query.0 .0 = vec![Statement::Upsert(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Update => { Command::Update {
what,
data,
} => {
let mut query = Query::default(); let mut query = Query::default();
let (one, statement) = update_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Update(statement)]; query.0 .0 = vec![Statement::Update(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Insert => { Command::Insert {
what,
data,
} => {
let mut query = Query::default(); let mut query = Query::default();
let (one, statement) = insert_statement(&mut params); let one = !data.is_array();
let statement = {
let mut stmt = InsertStatement::default();
stmt.into = what;
stmt.data = Data::SingleExpression(data);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Insert(statement)]; query.0 .0 = vec![Statement::Insert(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Patch => { Command::Patch {
what,
data,
} => {
let mut query = Query::default(); let mut query = Query::default();
let (one, statement) = patch_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::PatchExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Update(statement)]; query.0 .0 = vec![Statement::Update(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Merge => { Command::Merge {
what,
data,
} => {
let mut query = Query::default(); let mut query = Query::default();
let (one, statement) = merge_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::MergeExpression);
stmt.output = Some(Output::After);
stmt
};
query.0 .0 = vec![Statement::Update(statement)]; query.0 .0 = vec![Statement::Update(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Select => { Command::Select {
what,
} => {
let mut query = Query::default(); let mut query = Query::default();
let (one, statement) = select_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = SelectStatement::default();
stmt.what = value_to_values(what);
stmt.expr.0 = vec![Field::All];
stmt
};
query.0 .0 = vec![Statement::Select(statement)]; query.0 .0 = vec![Statement::Select(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Delete => { Command::Delete {
what,
} => {
let mut query = Query::default(); let mut query = Query::default();
let (one, statement) = delete_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = DeleteStatement::default();
stmt.what = value_to_values(what);
stmt.output = Some(Output::Before);
stmt
};
query.0 .0 = vec![Statement::Delete(statement)]; query.0 .0 = vec![Statement::Delete(statement)];
let response = kvs.process(query, &*session, Some(vars.clone())).await?; let response = kvs.process(query, &*session, Some(vars.clone())).await?;
let value = take(one, response).await?; let value = take(one, response).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Query => { Command::Query {
let response = match param.query { query,
Some((query, mut bindings)) => { mut variables,
let mut vars = vars.clone(); } => {
vars.append(&mut bindings); let mut vars = vars.clone();
kvs.process(query, &*session, Some(vars)).await? vars.append(&mut variables);
} let response = kvs.process(query, &*session, Some(vars)).await?;
None => unreachable!(),
};
let response = process(response); let response = process(response);
Ok(DbResponse::Query(response)) Ok(DbResponse::Query(response))
} }
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
Method::Export | Method::Import => unreachable!(), Command::ExportFile {
..
}
| Command::ExportBytes {
..
}
| Command::ImportFile {
..
} => Err(crate::api::Error::BackupsNotSupported.into()),
#[cfg(any(target_arch = "wasm32", not(feature = "ml")))]
Command::ExportMl {
..
}
| Command::ExportBytesMl {
..
}
| Command::ImportMl {
..
} => Err(crate::api::Error::BackupsNotSupported.into()),
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
Method::Export => { Command::ExportFile {
path: file,
} => {
let (tx, rx) = crate::channel::bounded(1);
let (mut writer, mut reader) = io::duplex(10_240);
// Write to channel.
let export = export_file(kvs, session, tx);
// Read from channel and write to pipe.
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
// Broken pipe. Let either side's error be propagated.
break;
}
}
Ok(())
};
// Output to stdout or file.
let mut output = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&file)
.await
{
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path: file,
error,
}
.into());
}
};
// Copy from pipe to output.
let copy = copy(file, &mut reader, &mut output);
tokio::try_join!(export, bridge, copy)?;
Ok(DbResponse::Other(Value::None))
}
#[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
Command::ExportMl {
path,
config,
} => {
let (tx, rx) = crate::channel::bounded(1);
let (mut writer, mut reader) = io::duplex(10_240);
// Write to channel.
let export = export_ml(kvs, session, tx, config);
// Read from channel and write to pipe.
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
// Broken pipe. Let either side's error be propagated.
break;
}
}
Ok(())
};
// Output to stdout or file.
let mut output = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
.await
{
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
}
};
// Copy from pipe to output.
let copy = copy(path, &mut reader, &mut output);
tokio::try_join!(export, bridge, copy)?;
Ok(DbResponse::Other(Value::None))
}
#[cfg(not(target_arch = "wasm32"))]
Command::ExportBytes {
bytes,
} => {
let (tx, rx) = crate::channel::bounded(1); let (tx, rx) = crate::channel::bounded(1);
match (param.file, param.bytes_sender) { let kvs = kvs.clone();
(Some(path), None) => { let session = session.clone();
let (mut writer, mut reader) = io::duplex(10_240); tokio::spawn(async move {
let export = async {
if let Err(error) = export_file(&kvs, &session, tx).await {
let _ = bytes.send(Err(error)).await;
}
};
// Write to channel. let bridge = async {
let export = export(kvs, session, tx, param.ml_config); while let Ok(b) = rx.recv().await {
if bytes.send(Ok(b)).await.is_err() {
// Read from channel and write to pipe. break;
let bridge = async move {
while let Ok(value) = rx.recv().await {
if writer.write_all(&value).await.is_err() {
// Broken pipe. Let either side's error be propagated.
break;
}
} }
Ok(()) }
}; };
// Output to stdout or file. tokio::join!(export, bridge);
let mut output = match OpenOptions::new() });
.write(true)
.create(true) Ok(DbResponse::Other(Value::None))
.truncate(true) }
.open(&path) #[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
.await Command::ExportBytesMl {
{ bytes,
Ok(path) => path, config,
Err(error) => { } => {
return Err(Error::FileOpen { let (tx, rx) = crate::channel::bounded(1);
path,
error, let kvs = kvs.clone();
} let session = session.clone();
.into()); tokio::spawn(async move {
let export = async {
if let Err(error) = export_ml(&kvs, &session, tx, config).await {
let _ = bytes.send(Err(error)).await;
}
};
let bridge = async {
while let Ok(b) = rx.recv().await {
if bytes.send(Ok(b)).await.is_err() {
break;
} }
}; }
};
// Copy from pipe to output. tokio::join!(export, bridge);
let copy = copy(path, &mut reader, &mut output); });
tokio::try_join!(export, bridge, copy)?;
}
(None, Some(backup)) => {
let kvs = kvs.clone();
let session = session.clone();
tokio::spawn(async move {
let export = async {
if let Err(error) = export(&kvs, &session, tx, param.ml_config).await {
let _ = backup.send(Err(error)).await;
}
};
let bridge = async {
while let Ok(bytes) = rx.recv().await {
if backup.send(Ok(bytes)).await.is_err() {
break;
}
}
};
tokio::join!(export, bridge);
});
}
_ => unreachable!(),
}
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
Method::Import => { Command::ImportFile {
let path = param.file.expect("file to import from"); path,
} => {
let mut file = match OpenOptions::new().read(true).open(&path).await { let mut file = match OpenOptions::new().read(true).open(&path).await {
Ok(path) => path, Ok(path) => path,
Err(error) => { Err(error) => {
@ -720,76 +863,93 @@ async fn router(
.into()); .into());
} }
}; };
let responses = match param.ml_config { let mut statements = String::new();
#[cfg(feature = "ml")] if let Err(error) = file.read_to_string(&mut statements).await {
Some(MlConfig::Import) => { return Err(Error::FileRead {
// Ensure a NS and DB are set path,
let (nsv, dbv) = check_ns_db(session)?; error,
// Check the permissions level
kvs.check(session, Action::Edit, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Create a new buffer
let mut buffer = Vec::new();
// Load all the uploaded file chunks
if let Err(error) = file.read_to_end(&mut buffer).await {
return Err(Error::FileRead {
path,
error,
}
.into());
}
// Check that the SurrealML file is valid
let file = match SurMlFile::from_bytes(buffer) {
Ok(file) => file,
Err(error) => {
return Err(Error::FileRead {
path,
error: io::Error::new(
io::ErrorKind::InvalidData,
error.message.to_string(),
),
}
.into());
}
};
// Convert the file back in to raw bytes
let data = file.to_bytes();
// Calculate the hash of the model file
let hash = crate::obs::hash(&data);
// Insert the file data in to the store
crate::obs::put(&hash, data).await?;
// Insert the model in to the database
let mut model = DefineModelStatement::default();
model.name = file.header.name.to_string().into();
model.version = file.header.version.to_string();
model.comment = Some(file.header.description.to_string().into());
model.hash = hash;
let query = DefineStatement::Model(model).into();
kvs.process(query, session, Some(vars.clone())).await?
} }
_ => { .into());
let mut statements = String::new(); }
if let Err(error) = file.read_to_string(&mut statements).await {
return Err(Error::FileRead { let responses = kvs.execute(&statements, &*session, Some(vars.clone())).await?;
path,
error,
}
.into());
}
kvs.execute(&statements, &*session, Some(vars.clone())).await?
}
};
for response in responses { for response in responses {
response.result?; response.result?;
} }
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Health => Ok(DbResponse::Other(Value::None)), #[cfg(all(not(target_arch = "wasm32"), feature = "ml"))]
Method::Version => Ok(DbResponse::Other(crate::env::VERSION.into())), Command::ImportMl {
Method::Set => { path,
let (key, value) = match &mut params[..2] { } => {
[Value::Strand(key), value] => (mem::take(&mut key.0), mem::take(value)), let mut file = match OpenOptions::new().read(true).open(&path).await {
_ => unreachable!(), Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
}
}; };
// Ensure a NS and DB are set
let (nsv, dbv) = check_ns_db(session)?;
// Check the permissions level
kvs.check(session, Action::Edit, ResourceKind::Model.on_db(&nsv, &dbv))?;
// Create a new buffer
let mut buffer = Vec::new();
// Load all the uploaded file chunks
if let Err(error) = file.read_to_end(&mut buffer).await {
return Err(Error::FileRead {
path,
error,
}
.into());
}
// Check that the SurrealML file is valid
let file = match SurMlFile::from_bytes(buffer) {
Ok(file) => file,
Err(error) => {
return Err(Error::FileRead {
path,
error: io::Error::new(
io::ErrorKind::InvalidData,
error.message.to_string(),
),
}
.into());
}
};
// Convert the file back in to raw bytes
let data = file.to_bytes();
// Calculate the hash of the model file
let hash = crate::obs::hash(&data);
// Insert the file data in to the store
crate::obs::put(&hash, data).await?;
// Insert the model in to the database
let mut model = DefineModelStatement::default();
model.name = file.header.name.to_string().into();
model.version = file.header.version.to_string();
model.comment = Some(file.header.description.to_string().into());
model.hash = hash;
let query = DefineStatement::Model(model).into();
let responses = kvs.process(query, session, Some(vars.clone())).await?;
for response in responses {
response.result?;
}
Ok(DbResponse::Other(Value::None))
}
Command::Health => Ok(DbResponse::Other(Value::None)),
Command::Version => Ok(DbResponse::Other(crate::env::VERSION.into())),
Command::Set {
key,
value,
} => {
let var = Some(map! { let var = Some(map! {
key.clone() => Value::None, key.clone() => Value::None,
=> vars => vars
@ -800,27 +960,24 @@ async fn router(
}; };
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Unset => { Command::Unset {
if let [Value::Strand(key)] = &params[..1] { key,
vars.remove(&key.0); } => {
} vars.remove(&key);
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Live => { Command::SubscribeLive {
if let Some(sender) = param.notification_sender { uuid,
if let [Value::Uuid(id)] = &params[..1] { notification_sender,
live_queries.insert(*id, sender); } => {
} live_queries.insert(uuid.into(), notification_sender);
}
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Kill => { Command::Kill {
let id = match &params[..] { uuid,
[Value::Uuid(id)] => *id, } => {
_ => unreachable!(), live_queries.remove(&uuid.into());
}; let value = kill_live_query(kvs, uuid.into(), session, vars.clone()).await?;
live_queries.remove(&id);
let value = kill_live_query(kvs, id, session, vars.clone()).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
} }

View file

@ -15,19 +15,9 @@ pub mod remote;
#[doc(hidden)] #[doc(hidden)]
pub mod tasks; pub mod tasks;
use crate::sql::statements::CreateStatement;
use crate::sql::statements::DeleteStatement;
use crate::sql::statements::InsertStatement;
use crate::sql::statements::SelectStatement;
use crate::sql::statements::UpdateStatement;
use crate::sql::statements::UpsertStatement;
use crate::sql::Data;
use crate::sql::Field;
use crate::sql::Output;
use crate::sql::Value; use crate::sql::Value;
use crate::sql::Values; use crate::sql::Values;
use futures::Stream; use futures::Stream;
use std::mem;
use std::pin::Pin; use std::pin::Pin;
use std::task::Context; use std::task::Context;
use std::task::Poll; use std::task::Poll;
@ -40,133 +30,21 @@ use wasmtimer::std::Instant;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use wasmtimer::tokio::Interval; use wasmtimer::tokio::Interval;
#[allow(dead_code)] // used by the the embedded database and `http` // used in http and all local engines.
fn split_params(params: &mut [Value]) -> (bool, Values, Value) { #[allow(dead_code)]
let (what, data) = match params { fn value_to_values(v: Value) -> Values {
[what] => (mem::take(what), Value::None), match v {
[what, data] => (mem::take(what), mem::take(data)), Value::Array(x) => {
_ => unreachable!(),
};
let one = what.is_thing();
let what = match what {
Value::Array(vec) => {
let mut values = Values::default(); let mut values = Values::default();
values.0 = vec.0; values.0 = x.0;
values values
} }
value => { x => {
let mut values = Values::default(); let mut values = Values::default();
values.0 = vec![value]; values.0 = vec![x];
values values
} }
}; }
(one, what, data)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn create_statement(params: &mut [Value]) -> CreateStatement {
let (_, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::ContentExpression(value)),
};
let mut stmt = CreateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
stmt
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn upsert_statement(params: &mut [Value]) -> (bool, UpsertStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::ContentExpression(value)),
};
let mut stmt = UpsertStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn update_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::ContentExpression(value)),
};
let mut stmt = UpdateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn insert_statement(params: &mut [Value]) -> (bool, InsertStatement) {
let (what, data) = match params {
[what, data] => (mem::take(what), mem::take(data)),
_ => unreachable!(),
};
let one = !data.is_array();
let mut stmt = InsertStatement::default();
stmt.into = match what {
Value::None => None,
Value::Null => None,
what => Some(what),
};
stmt.data = Data::SingleExpression(data);
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn patch_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::PatchExpression(value)),
};
let mut stmt = UpdateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn merge_statement(params: &mut [Value]) -> (bool, UpdateStatement) {
let (one, what, data) = split_params(params);
let data = match data {
Value::None | Value::Null => None,
value => Some(Data::MergeExpression(value)),
};
let mut stmt = UpdateStatement::default();
stmt.what = what;
stmt.data = data;
stmt.output = Some(Output::After);
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn select_statement(params: &mut [Value]) -> (bool, SelectStatement) {
let (one, what, _) = split_params(params);
let mut stmt = SelectStatement::default();
stmt.what = what;
stmt.expr.0 = vec![Field::All];
(one, stmt)
}
#[allow(dead_code)] // used by the the embedded database and `http`
fn delete_statement(params: &mut [Value]) -> (bool, DeleteStatement) {
let (one, what, _) = split_params(params);
let mut stmt = DeleteStatement::default();
stmt.what = what;
stmt.output = Some(Output::Before);
(one, stmt)
} }
struct IntervalStream { struct IntervalStream {

View file

@ -5,21 +5,10 @@ pub(crate) mod native;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub(crate) mod wasm; pub(crate) mod wasm;
use crate::api::conn::Command;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Method; use crate::api::conn::RequestData;
#[cfg(feature = "ml")]
#[cfg(not(target_arch = "wasm32"))]
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::engine::create_statement;
use crate::api::engine::delete_statement;
use crate::api::engine::insert_statement;
use crate::api::engine::merge_statement;
use crate::api::engine::patch_statement;
use crate::api::engine::remote::duration_from_str; use crate::api::engine::remote::duration_from_str;
use crate::api::engine::select_statement;
use crate::api::engine::update_statement;
use crate::api::engine::upsert_statement;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::query::QueryResult; use crate::api::method::query::QueryResult;
use crate::api::Connect; use crate::api::Connect;
@ -27,6 +16,7 @@ use crate::api::Response as QueryResponse;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::dbs::Status; use crate::dbs::Status;
use crate::engine::value_to_values;
use crate::headers::AUTH_DB; use crate::headers::AUTH_DB;
use crate::headers::AUTH_NS; use crate::headers::AUTH_NS;
use crate::headers::DB; use crate::headers::DB;
@ -36,19 +26,29 @@ use crate::opt::IntoEndpoint;
use crate::sql::from_value; use crate::sql::from_value;
use crate::sql::serde::deserialize; use crate::sql::serde::deserialize;
use crate::sql::Value; use crate::sql::Value;
#[cfg(not(target_arch = "wasm32"))]
use futures::TryStreamExt; use futures::TryStreamExt;
use indexmap::IndexMap; use indexmap::IndexMap;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue; use reqwest::header::HeaderValue;
use reqwest::header::ACCEPT; use reqwest::header::ACCEPT;
#[cfg(not(target_arch = "wasm32"))]
use reqwest::header::CONTENT_TYPE;
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use serde::Deserialize; use serde::Deserialize;
use serde::Serialize; use serde::Serialize;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
use surrealdb_core::sql::statements::CreateStatement;
use surrealdb_core::sql::statements::DeleteStatement;
use surrealdb_core::sql::statements::InsertStatement;
use surrealdb_core::sql::statements::SelectStatement;
use surrealdb_core::sql::statements::UpdateStatement;
use surrealdb_core::sql::statements::UpsertStatement;
use surrealdb_core::sql::Data;
use surrealdb_core::sql::Field;
use surrealdb_core::sql::Output;
use url::Url;
#[cfg(not(target_arch = "wasm32"))]
use reqwest::header::CONTENT_TYPE;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf; use std::path::PathBuf;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -57,7 +57,8 @@ use tokio::fs::OpenOptions;
use tokio::io; use tokio::io;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use tokio_util::compat::FuturesAsyncReadCompatExt; use tokio_util::compat::FuturesAsyncReadCompatExt;
use url::Url; #[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local;
const SQL_PATH: &str = "sql"; const SQL_PATH: &str = "sql";
@ -238,65 +239,61 @@ async fn take(one: bool, request: RequestBuilder) -> Result<Value> {
} }
} }
#[cfg(not(target_arch = "wasm32"))]
type BackupSender = channel::Sender<Result<Vec<u8>>>; type BackupSender = channel::Sender<Result<Vec<u8>>>;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
async fn export( async fn export_file(request: RequestBuilder, path: PathBuf) -> Result<Value> {
request: RequestBuilder, let mut response = request
(file, sender): (Option<PathBuf>, Option<BackupSender>), .send()
) -> Result<Value> { .await?
match (file, sender) { .error_for_status()?
(Some(path), None) => { .bytes_stream()
let mut response = request .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
.send() .into_async_read()
.await? .compat();
.error_for_status()? let mut file =
.bytes_stream() match OpenOptions::new().write(true).create(true).truncate(true).open(&path).await {
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e)) Ok(path) => path,
.into_async_read() Err(error) => {
.compat(); return Err(Error::FileOpen {
let mut file = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
.await
{
Ok(path) => path,
Err(error) => {
return Err(Error::FileOpen {
path,
error,
}
.into());
}
};
if let Err(error) = io::copy(&mut response, &mut file).await {
return Err(Error::FileRead {
path, path,
error, error,
} }
.into()); .into());
} }
};
if let Err(error) = io::copy(&mut response, &mut file).await {
return Err(Error::FileRead {
path,
error,
} }
(None, Some(tx)) => { .into());
let mut response = request.send().await?.error_for_status()?.bytes_stream();
tokio::spawn(async move {
while let Ok(Some(bytes)) = response.try_next().await {
if tx.send(Ok(bytes.to_vec())).await.is_err() {
break;
}
}
});
}
_ => unreachable!(),
} }
Ok(Value::None) Ok(Value::None)
} }
async fn export_bytes(request: RequestBuilder, bytes: BackupSender) -> Result<Value> {
let response = request.send().await?.error_for_status()?;
let future = async move {
let mut response = response.bytes_stream();
while let Ok(Some(b)) = response.try_next().await {
if bytes.send(Ok(b.to_vec())).await.is_err() {
break;
}
}
};
#[cfg(not(target_arch = "wasm32"))]
tokio::spawn(future);
#[cfg(target_arch = "wasm32")]
spawn_local(future);
Ok(Value::None)
}
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
async fn import(request: RequestBuilder, path: PathBuf) -> Result<Value> { async fn import(request: RequestBuilder, path: PathBuf) -> Result<Value> {
let file = match OpenOptions::new().read(true).open(&path).await { let file = match OpenOptions::new().read(true).open(&path).await {
@ -343,28 +340,24 @@ pub(crate) async fn health(request: RequestBuilder) -> Result<Value> {
} }
async fn router( async fn router(
(_, method, param): (i64, Method, Param), RequestData {
command,
..
}: RequestData,
base_url: &Url, base_url: &Url,
client: &reqwest::Client, client: &reqwest::Client,
headers: &mut HeaderMap, headers: &mut HeaderMap,
vars: &mut IndexMap<String, String>, vars: &mut IndexMap<String, String>,
auth: &mut Option<Auth>, auth: &mut Option<Auth>,
) -> Result<DbResponse> { ) -> Result<DbResponse> {
let mut params = param.other; match command {
Command::Use {
match method { namespace,
Method::Use => { database,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let mut request = client.post(path).headers(headers.clone()); let mut request = client.post(path).headers(headers.clone());
let (ns, db) = match &mut params[..] { let ns = match namespace {
[Value::Strand(ns), Value::Strand(db)] => {
(Some(mem::take(&mut ns.0)), Some(mem::take(&mut db.0)))
}
[Value::Strand(ns), Value::None] => (Some(mem::take(&mut ns.0)), None),
[Value::None, Value::Strand(db)] => (None, Some(mem::take(&mut db.0))),
_ => unreachable!(),
};
let ns = match ns {
Some(ns) => match HeaderValue::try_from(&ns) { Some(ns) => match HeaderValue::try_from(&ns) {
Ok(ns) => { Ok(ns) => {
request = request.header(&NS, &ns); request = request.header(&NS, &ns);
@ -376,7 +369,7 @@ async fn router(
}, },
None => None, None => None,
}; };
let db = match db { let db = match database {
Some(db) => match HeaderValue::try_from(&db) { Some(db) => match HeaderValue::try_from(&db) {
Ok(db) => { Ok(db) => {
request = request.header(&DB, &db); request = request.header(&DB, &db);
@ -398,52 +391,46 @@ async fn router(
} }
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Signin => { Command::Signin {
let path = base_url.join(Method::Signin.as_str())?; credentials,
let credentials = match &mut params[..] { } => {
[credentials] => credentials.to_string(), let path = base_url.join("signin")?;
_ => unreachable!(), let request =
}; client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string());
let request = client.post(path).headers(headers.clone()).auth(auth).body(credentials);
let value = submit_auth(request).await?; let value = submit_auth(request).await?;
if let [credentials] = &mut params[..] { if let Ok(Credentials {
if let Ok(Credentials { user,
pass,
ns,
db,
}) = from_value(credentials.into())
{
*auth = Some(Auth::Basic {
user, user,
pass, pass,
ns, ns,
db, db,
}) = from_value(mem::take(credentials)) });
{ } else {
*auth = Some(Auth::Basic { *auth = Some(Auth::Bearer {
user, token: value.to_raw_string(),
pass, });
ns,
db,
});
} else {
*auth = Some(Auth::Bearer {
token: value.to_raw_string(),
});
}
} }
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Signup => { Command::Signup {
let path = base_url.join(Method::Signup.as_str())?; credentials,
let credentials = match &mut params[..] { } => {
[credentials] => credentials.to_string(), let path = base_url.join("signup")?;
_ => unreachable!(), let request =
}; client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string());
let request = client.post(path).headers(headers.clone()).auth(auth).body(credentials);
let value = submit_auth(request).await?; let value = submit_auth(request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Authenticate => { Command::Authenticate {
token,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let token = match &mut params[..1] {
[Value::Strand(token)] => mem::take(&mut token.0),
_ => unreachable!(),
};
let request = let request =
client.post(path).headers(headers.clone()).bearer_auth(&token).body("RETURN true"); client.post(path).headers(headers.clone()).bearer_auth(&token).body("RETURN true");
take(true, request).await?; take(true, request).await?;
@ -452,142 +439,276 @@ async fn router(
}); });
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Invalidate => { Command::Invalidate => {
*auth = None; *auth = None;
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Create => { Command::Create {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let statement = create_statement(&mut params); let statement = {
let mut stmt = CreateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(true, request).await?; let value = take(true, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Upsert => { Command::Upsert {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = upsert_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpsertStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Update => { Command::Update {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = update_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Insert => { Command::Insert {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = insert_statement(&mut params); let one = !data.is_array();
let statement = {
let mut stmt = InsertStatement::default();
stmt.into = what;
stmt.data = Data::SingleExpression(data);
stmt.output = Some(Output::After);
stmt
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Patch => { Command::Patch {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = patch_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::PatchExpression);
stmt.output = Some(Output::After);
stmt
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Merge => { Command::Merge {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = merge_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::MergeExpression);
stmt.output = Some(Output::After);
stmt
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Select => { Command::Select {
what,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = select_statement(&mut params); let one = what.is_thing();
let statement = {
let mut stmt = SelectStatement::default();
stmt.what = value_to_values(what);
stmt.expr.0 = vec![Field::All];
stmt
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Delete => { Command::Delete {
what,
} => {
let one = what.is_thing();
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (one, statement) = delete_statement(&mut params); let (one, statement) = {
let mut stmt = DeleteStatement::default();
stmt.what = value_to_values(what);
stmt.output = Some(Output::Before);
(one, stmt)
};
let request = let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?; let value = take(one, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Query => { Command::Query {
query: q,
variables,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let mut request = client.post(path).headers(headers.clone()).query(&vars).auth(auth); let mut request = client.post(path).headers(headers.clone()).query(&vars).auth(auth);
match param.query { let bindings: Vec<_> =
Some((query, bindings)) => { variables.iter().map(|(key, value)| (key, value.to_string())).collect();
let bindings: Vec<_> = request = request.query(&bindings).body(q.to_string());
bindings.iter().map(|(key, value)| (key, value.to_string())).collect();
request = request.query(&bindings).body(query.to_string());
}
None => unreachable!(),
}
let values = query(request).await?; let values = query(request).await?;
Ok(DbResponse::Query(values)) Ok(DbResponse::Query(values))
} }
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
Method::Export | Method::Import => unreachable!(), Command::ExportFile {
..
}
| Command::ExportMl {
..
}
| Command::ImportFile {
..
}
| Command::ImportMl {
..
} => {
// TODO: Better error message here, some backups are supported
Err(Error::BackupsNotSupported.into())
}
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
Method::Export => { Command::ExportFile {
let path = match param.ml_config { path,
#[cfg(feature = "ml")] } => {
Some(MlConfig::Export { let req_path = base_url.join("export")?;
name,
version,
}) => base_url.join(&format!("ml/export/{name}/{version}"))?,
_ => base_url.join(Method::Export.as_str())?,
};
let request = client let request = client
.get(path) .get(req_path)
.headers(headers.clone()) .headers(headers.clone())
.auth(auth) .auth(auth)
.header(ACCEPT, "application/octet-stream"); .header(ACCEPT, "application/octet-stream");
let value = export(request, (param.file, param.bytes_sender)).await?; let value = export_file(request, path).await?;
Ok(DbResponse::Other(value))
}
Command::ExportBytes {
bytes,
} => {
let req_path = base_url.join("export")?;
let request = client
.get(req_path)
.headers(headers.clone())
.auth(auth)
.header(ACCEPT, "application/octet-stream");
let value = export_bytes(request, bytes).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
Method::Import => { Command::ExportMl {
let path = match param.ml_config { path,
#[cfg(feature = "ml")] config,
Some(MlConfig::Import) => base_url.join("ml/import")?, } => {
_ => base_url.join(Method::Import.as_str())?, let req_path =
}; base_url.join("ml")?.join("export")?.join(&config.name)?.join(&config.version)?;
let file = param.file.expect("file to import from");
let request = client let request = client
.post(path) .get(req_path)
.headers(headers.clone())
.auth(auth)
.header(ACCEPT, "application/octet-stream");
let value = export_file(request, path).await?;
Ok(DbResponse::Other(value))
}
Command::ExportBytesMl {
bytes,
config,
} => {
let req_path =
base_url.join("ml")?.join("export")?.join(&config.name)?.join(&config.version)?;
let request = client
.get(req_path)
.headers(headers.clone())
.auth(auth)
.header(ACCEPT, "application/octet-stream");
let value = export_bytes(request, bytes).await?;
Ok(DbResponse::Other(value))
}
#[cfg(not(target_arch = "wasm32"))]
Command::ImportFile {
path,
} => {
let req_path = base_url.join("import")?;
let request = client
.post(req_path)
.headers(headers.clone()) .headers(headers.clone())
.auth(auth) .auth(auth)
.header(CONTENT_TYPE, "application/octet-stream"); .header(CONTENT_TYPE, "application/octet-stream");
let value = import(request, file).await?; let value = import(request, path).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Health => { #[cfg(not(target_arch = "wasm32"))]
let path = base_url.join(Method::Health.as_str())?; Command::ImportMl {
path,
} => {
let req_path = base_url.join("ml")?.join("import")?;
let request = client
.post(req_path)
.headers(headers.clone())
.auth(auth)
.header(CONTENT_TYPE, "application/octet-stream");
let value = import(request, path).await?;
Ok(DbResponse::Other(value))
}
Command::Health => {
let path = base_url.join("health")?;
let request = client.get(path); let request = client.get(path);
let value = health(request).await?; let value = health(request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Version => { Command::Version => {
let path = base_url.join(method.as_str())?; let path = base_url.join("version")?;
let request = client.get(path); let request = client.get(path);
let value = version(request).await?; let value = version(request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Method::Set => { Command::Set {
key,
value,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let (key, value) = match &mut params[..2] { let value = value.to_string();
[Value::Strand(key), value] => (mem::take(&mut key.0), value.to_string()),
_ => unreachable!(),
};
let request = client let request = client
.post(path) .post(path)
.headers(headers.clone()) .headers(headers.clone())
@ -598,38 +719,24 @@ async fn router(
vars.insert(key, value); vars.insert(key, value);
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Unset => { Command::Unset {
if let [Value::Strand(key)] = &params[..1] { key,
vars.swap_remove(&key.0); } => {
} vars.shift_remove(&key);
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Method::Live => { Command::SubscribeLive {
..
} => Err(Error::LiveQueriesNotSupported.into()),
Command::Kill {
uuid,
} => {
let path = base_url.join(SQL_PATH)?; let path = base_url.join(SQL_PATH)?;
let table = match &params[..] {
[table] => table.to_string(),
_ => unreachable!(),
};
let request = client let request = client
.post(path) .post(path)
.headers(headers.clone()) .headers(headers.clone())
.auth(auth) .auth(auth)
.query(&[("table", table)]) .query(&[("id", uuid)])
.body("LIVE SELECT * FROM type::table($table)");
let value = take(true, request).await?;
Ok(DbResponse::Other(value))
}
Method::Kill => {
let path = base_url.join(SQL_PATH)?;
let id = match &params[..] {
[id] => id.to_string(),
_ => unreachable!(),
};
let request = client
.post(path)
.headers(headers.clone())
.auth(auth)
.query(&[("id", id)])
.body("KILL type::string($id)"); .body("KILL type::string($id)");
let value = take(true, request).await?; let value = take(true, request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))

View file

@ -1,6 +1,5 @@
use super::Client; use super::Client;
use crate::api::conn::Connection; use crate::api::conn::Connection;
use crate::api::conn::Method;
use crate::api::conn::Route; use crate::api::conn::Route;
use crate::api::conn::Router; use crate::api::conn::Router;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
@ -47,7 +46,7 @@ impl Connection for Client {
let base_url = address.url; let base_url = address.url;
super::health(client.get(base_url.join(Method::Health.as_str())?)).await?; super::health(client.get(base_url.join("health")?)).await?;
let (route_tx, route_rx) = match capacity { let (route_tx, route_rx) = match capacity {
0 => channel::unbounded(), 0 => channel::unbounded(),

View file

@ -1,6 +1,5 @@
use super::Client; use super::Client;
use crate::api::conn::Connection; use crate::api::conn::Connection;
use crate::api::conn::Method;
use crate::api::conn::Route; use crate::api::conn::Route;
use crate::api::conn::Router; use crate::api::conn::Router;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
@ -53,7 +52,7 @@ async fn client(base_url: &Url) -> Result<reqwest::Client> {
let headers = super::default_headers(); let headers = super::default_headers();
let builder = ClientBuilder::new().default_headers(headers); let builder = ClientBuilder::new().default_headers(headers);
let client = builder.build()?; let client = builder.build()?;
let health = base_url.join(Method::Health.as_str())?; let health = base_url.join("health")?;
super::health(client.get(health)).await?; super::health(client.get(health)).await?;
Ok(client) Ok(client)
} }

View file

@ -6,8 +6,8 @@ pub(crate) mod native;
pub(crate) mod wasm; pub(crate) mod wasm;
use crate::api; use crate::api;
use crate::api::conn::Command;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use crate::api::engine::remote::duration_from_str; use crate::api::engine::remote::duration_from_str;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::query::QueryResult; use crate::api::method::query::QueryResult;
@ -20,15 +20,12 @@ use crate::dbs::Status;
use crate::method::Stats; use crate::method::Stats;
use crate::opt::IntoEndpoint; use crate::opt::IntoEndpoint;
use crate::sql::Value; use crate::sql::Value;
use bincode::Options as _;
use channel::Sender; use channel::Sender;
use indexmap::IndexMap; use indexmap::IndexMap;
use revision::revisioned; use revision::revisioned;
use revision::Revisioned; use revision::Revisioned;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::ser::SerializeMap;
use serde::Deserialize; use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Read; use std::io::Read;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -39,127 +36,64 @@ use uuid::Uuid;
pub(crate) const PATH: &str = "rpc"; pub(crate) const PATH: &str = "rpc";
const PING_INTERVAL: Duration = Duration::from_secs(5); const PING_INTERVAL: Duration = Duration::from_secs(5);
const PING_METHOD: &str = "ping";
const REVISION_HEADER: &str = "revision"; const REVISION_HEADER: &str = "revision";
/// A struct which will be serialized as a map to behave like the previously used BTreeMap. enum RequestEffect {
/// /// Completing this request sets a variable to a give value.
/// This struct serializes as if it is a surrealdb_core::sql::Value::Object. Set {
#[derive(Debug)] key: String,
struct RouterRequest { value: Value,
id: Option<Value>, },
method: Value, /// Completing this request sets a variable to a give value.
params: Option<Value>, Clear {
key: String,
},
/// Insert requests repsonses need to be flattened in an array.
Insert,
/// No effect
None,
} }
impl Serialize for RouterRequest { #[derive(Clone, Copy, Eq, PartialEq, Hash)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> enum ReplayMethod {
where Use,
S: serde::Serializer, Signup,
{ Signin,
struct InnerRequest<'a>(&'a RouterRequest); Invalidate,
Authenticate,
impl Serialize for InnerRequest<'_> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let size = 1 + self.0.id.is_some() as usize + self.0.params.is_some() as usize;
let mut map = serializer.serialize_map(Some(size))?;
if let Some(id) = self.0.id.as_ref() {
map.serialize_entry("id", id)?;
}
map.serialize_entry("method", &self.0.method)?;
if let Some(params) = self.0.params.as_ref() {
map.serialize_entry("params", params)?;
}
map.end()
}
}
serializer.serialize_newtype_variant("Value", 9, "Object", &InnerRequest(self))
}
} }
impl Revisioned for RouterRequest { struct PendingRequest {
fn revision() -> u16 { // Does resolving this request has some effects.
1 effect: RequestEffect,
} // The channel to send the result of the request into.
response_channel: Sender<Result<DbResponse>>,
fn serialize_revisioned<W: std::io::Write>(
&self,
w: &mut W,
) -> std::result::Result<(), revision::Error> {
// version
Revisioned::serialize_revisioned(&1u32, w)?;
// object variant
Revisioned::serialize_revisioned(&9u32, w)?;
// object wrapper version
Revisioned::serialize_revisioned(&1u32, w)?;
let size = 1 + self.id.is_some() as usize + self.params.is_some() as usize;
size.serialize_revisioned(w)?;
let serializer = bincode::options()
.with_no_limit()
.with_little_endian()
.with_varint_encoding()
.reject_trailing_bytes();
if let Some(x) = self.id.as_ref() {
serializer
.serialize_into(&mut *w, "id")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
serializer
.serialize_into(&mut *w, "method")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
self.method.serialize_revisioned(w)?;
if let Some(x) = self.params.as_ref() {
serializer
.serialize_into(&mut *w, "params")
.map_err(|err| revision::Error::Serialize(err.to_string()))?;
x.serialize_revisioned(w)?;
}
Ok(())
}
fn deserialize_revisioned<R: Read>(_: &mut R) -> std::result::Result<Self, revision::Error>
where
Self: Sized,
{
panic!("deliberately unimplemented");
}
} }
struct RouterState<Sink, Stream, Msg> { struct RouterState<Sink, Stream> {
var_stash: IndexMap<i64, (String, Value)>,
/// Vars currently set by the set method, /// Vars currently set by the set method,
vars: IndexMap<String, Value>, vars: IndexMap<String, Value>,
/// Messages which aught to be replayed on a reconnect. /// Messages which aught to be replayed on a reconnect.
replay: IndexMap<Method, Msg>, replay: IndexMap<ReplayMethod, Command>,
/// Pending live queries /// Pending live queries
live_queries: HashMap<Uuid, channel::Sender<CoreNotification>>, live_queries: HashMap<Uuid, channel::Sender<CoreNotification>>,
/// Send requests which are still awaiting an awnser.
routes: HashMap<i64, (Method, Sender<Result<DbResponse>>)>, pending_requests: HashMap<i64, PendingRequest>,
/// The last time a message was recieved from the server.
last_activity: Instant, last_activity: Instant,
/// The sink into which messages are send to surrealdb
sink: Sink, sink: Sink,
/// The stream from which messages are recieved from surrealdb
stream: Stream, stream: Stream,
} }
impl<Sink, Stream, Msg> RouterState<Sink, Stream, Msg> { impl<Sink, Stream> RouterState<Sink, Stream> {
pub fn new(sink: Sink, stream: Stream) -> Self { pub fn new(sink: Sink, stream: Stream) -> Self {
RouterState { RouterState {
var_stash: IndexMap::new(),
vars: IndexMap::new(), vars: IndexMap::new(),
replay: IndexMap::new(), replay: IndexMap::new(),
live_queries: HashMap::new(), live_queries: HashMap::new(),
routes: HashMap::new(), pending_requests: HashMap::new(),
last_activity: Instant::now(), last_activity: Instant::now(),
sink, sink,
stream, stream,
@ -317,67 +251,3 @@ where
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?; bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into())) crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
} }
#[cfg(test)]
mod test {
use std::io::Cursor;
use revision::Revisioned;
use surrealdb_core::sql::Value;
use super::RouterRequest;
fn assert_converts<S, D, I>(req: &RouterRequest, s: S, d: D)
where
S: FnOnce(&RouterRequest) -> I,
D: FnOnce(I) -> Value,
{
let ser = s(req);
let val = d(ser);
let Value::Object(obj) = val else {
panic!("not an object");
};
assert_eq!(obj.get("id").cloned(), req.id);
assert_eq!(obj.get("method").unwrap().clone(), req.method);
assert_eq!(obj.get("params").cloned(), req.params);
}
#[test]
fn router_request_value_conversion() {
let request = RouterRequest {
id: Some(Value::from(1234i64)),
method: Value::from("request"),
params: Some(vec![Value::from(1234i64), Value::from("request")].into()),
};
println!("test convert bincode");
assert_converts(
&request,
|i| bincode::serialize(i).unwrap(),
|b| bincode::deserialize(&b).unwrap(),
);
println!("test convert json");
assert_converts(
&request,
|i| serde_json::to_string(i).unwrap(),
|b| serde_json::from_str(&b).unwrap(),
);
println!("test convert revisioned");
assert_converts(
&request,
|i| {
let mut buf = Vec::new();
i.serialize_revisioned(&mut Cursor::new(&mut buf)).unwrap();
buf
},
|b| Value::deserialize_revisioned(&mut Cursor::new(b)).unwrap(),
);
println!("done");
}
}

View file

@ -1,15 +1,13 @@
use super::PATH; use super::{
use super::{deserialize, serialize}; deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH,
use super::{HandleResult, RouterRequest}; };
use crate::api::conn::Connection;
use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use crate::api::conn::Route; use crate::api::conn::Route;
use crate::api::conn::Router; use crate::api::conn::Router;
use crate::api::conn::{Command, DbResponse};
use crate::api::conn::{Connection, RequestData};
use crate::api::engine::remote::ws::Client; use crate::api::engine::remote::ws::Client;
use crate::api::engine::remote::ws::Response; use crate::api::engine::remote::ws::Response;
use crate::api::engine::remote::ws::PING_INTERVAL; use crate::api::engine::remote::ws::PING_INTERVAL;
use crate::api::engine::remote::ws::PING_METHOD;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::Endpoint; use crate::api::opt::Endpoint;
@ -31,7 +29,6 @@ use revision::revisioned;
use serde::Deserialize; use serde::Deserialize;
use std::collections::hash_map::Entry; use std::collections::hash_map::Entry;
use std::collections::HashSet; use std::collections::HashSet;
use std::mem;
use std::sync::atomic::AtomicI64; use std::sync::atomic::AtomicI64;
use std::sync::Arc; use std::sync::Arc;
use std::sync::OnceLock; use std::sync::OnceLock;
@ -58,7 +55,7 @@ pub(crate) const NAGLE_ALG: bool = false;
type MessageSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>; type MessageSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type MessageStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>; type MessageStream = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
type RouterState = super::RouterState<MessageSink, MessageStream, Message>; type RouterState = super::RouterState<MessageSink, MessageStream>;
#[cfg(any(feature = "native-tls", feature = "rustls"))] #[cfg(any(feature = "native-tls", feature = "rustls"))]
impl From<Tls> for Connector { impl From<Tls> for Connector {
@ -153,80 +150,106 @@ async fn router_handle_route(
state: &mut RouterState, state: &mut RouterState,
endpoint: &Endpoint, endpoint: &Endpoint,
) -> HandleResult { ) -> HandleResult {
let (id, method, param) = request; let RequestData {
let params = match param.query { id,
Some((query, bindings)) => { command,
vec![query.into(), bindings.into()] } = request;
// We probably shouldn't be sending duplicate id requests.
let entry = state.pending_requests.entry(id);
let Entry::Vacant(entry) = entry else {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
} }
None => param.other, return HandleResult::Ok;
}; };
match method {
Method::Set => { let mut effect = RequestEffect::None;
if let [Value::Strand(key), value] = &params[..2] {
state.var_stash.insert(id, (key.0.clone(), value.clone())); match command {
} Command::Set {
ref key,
ref value,
} => {
effect = RequestEffect::Set {
key: key.clone(),
value: value.clone(),
};
} }
Method::Unset => { Command::Unset {
if let [Value::Strand(key)] = &params[..1] { ref key,
state.vars.swap_remove(&key.0); } => {
} effect = RequestEffect::Clear {
key: key.clone(),
};
} }
Method::Live => { Command::Insert {
if let Some(sender) = param.notification_sender { ..
if let [Value::Uuid(id)] = &params[..1] { } => {
state.live_queries.insert(id.0, sender); effect = RequestEffect::Insert;
} }
} Command::SubscribeLive {
ref uuid,
ref notification_sender,
} => {
state.live_queries.insert(*uuid, notification_sender.clone());
if response.clone().send(Ok(DbResponse::Other(Value::None))).await.is_err() { if response.clone().send(Ok(DbResponse::Other(Value::None))).await.is_err() {
trace!("Receiver dropped"); trace!("Receiver dropped");
} }
// There is nothing to send to the server here // There is nothing to send to the server here
return HandleResult::Ok;
} }
Method::Kill => { Command::Kill {
if let [Value::Uuid(id)] = &params[..1] { ref uuid,
state.live_queries.remove(id); } => {
} state.live_queries.remove(uuid);
}
Command::Use {
..
} => {
state.replay.insert(ReplayMethod::Use, command.clone());
}
Command::Signup {
..
} => {
state.replay.insert(ReplayMethod::Signup, command.clone());
}
Command::Signin {
..
} => {
state.replay.insert(ReplayMethod::Signin, command.clone());
}
Command::Invalidate {
..
} => {
state.replay.insert(ReplayMethod::Invalidate, command.clone());
}
Command::Authenticate {
..
} => {
state.replay.insert(ReplayMethod::Authenticate, command.clone());
} }
_ => {} _ => {}
} }
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = {
let request = RouterRequest {
id: Some(Value::from(id)),
method: method_str.into(),
params: (!params.is_empty()).then(|| params.into()),
};
let message = {
let Some(request) = command.into_router_request(Some(id)) else {
let _ = response.send(Err(Error::BackupsNotSupported.into())).await;
return HandleResult::Ok;
};
trace!("Request {:?}", request); trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap(); let payload = serialize(&request, endpoint.supports_revision).unwrap();
Message::Binary(payload) Message::Binary(payload)
}; };
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
state.replay.insert(method, message.clone());
}
match state.sink.send(message).await { match state.sink.send(message).await {
Ok(_) => { Ok(_) => {
state.last_activity = Instant::now(); state.last_activity = Instant::now();
match state.routes.entry(id) { entry.insert(PendingRequest {
Entry::Vacant(entry) => { effect,
// Register query route response_channel: response,
entry.insert((method, response)); });
}
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
}
}
} }
Err(error) => { Err(error) => {
let error = Error::Ws(error.to_string()); let error = Error::Ws(error.to_string());
@ -254,23 +277,50 @@ async fn router_handle_response(
Some(id) => { Some(id) => {
if let Ok(id) = id.coerce_to_i64() { if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs // We can only route responses with IDs
if let Some((method, sender)) = state.routes.remove(&id) { if let Some(pending) = state.pending_requests.remove(&id) {
if matches!(method, Method::Set) { match pending.effect {
if let Some((key, value)) = state.var_stash.swap_remove(&id) { RequestEffect::None => {}
state.vars.insert(key, value); RequestEffect::Insert => {
} // For insert, we need to flatten single responses in an array
} if let Ok(Data::Other(Value::Array(value))) =
// Send the response back to the caller response.result
let mut response = response.result; {
if matches!(method, Method::Insert) { if value.len() == 1 {
// For insert, we need to flatten single responses in an array let _ = pending
if let Ok(Data::Other(Value::Array(value))) = &mut response { .response_channel
if let [value] = &mut value.0[..] { .send(DbResponse::from(Ok(Data::Other(
response = Ok(Data::Other(mem::take(value))); value.into_iter().next().unwrap(),
))))
.await;
} else {
let _ = pending
.response_channel
.send(DbResponse::from(Ok(Data::Other(
Value::Array(value),
))))
.await;
}
return HandleResult::Ok;
} }
} }
RequestEffect::Set {
key,
value,
} => {
state.vars.insert(key, value);
}
RequestEffect::Clear {
key,
} => {
state.vars.shift_remove(&key);
}
} }
let _res = sender.send(DbResponse::from(response)).await; let _res = pending
.response_channel
.send(DbResponse::from(response.result))
.await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
} }
} }
} }
@ -285,13 +335,11 @@ async fn router_handle_response(
if sender.send(notification).await.is_err() { if sender.send(notification).await.is_err() {
state.live_queries.remove(&live_query_id); state.live_queries.remove(&live_query_id);
let kill = { let kill = {
let request = RouterRequest { let request = Command::Kill {
id: None, uuid: *live_query_id,
method: Method::Kill.as_str().into(), }
params: Some( .into_router_request(None)
vec![Value::from(live_query_id)].into(), .unwrap();
),
};
let value = let value =
serialize(&request, endpoint.supports_revision) serialize(&request, endpoint.supports_revision)
.unwrap(); .unwrap();
@ -326,8 +374,10 @@ async fn router_handle_response(
{ {
// Return an error if an ID was returned // Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) { if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
if let Some((_method, sender)) = state.routes.remove(&id) { if let Some(pending) = state.pending_requests.remove(&id) {
let _res = sender.send(Err(error)).await; let _res = pending.response_channel.send(Err(error)).await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
} }
} }
} else { } else {
@ -353,19 +403,27 @@ async fn router_reconnect(
let (new_sink, new_stream) = s.split(); let (new_sink, new_stream) = s.split();
state.sink = new_sink; state.sink = new_sink;
state.stream = new_stream; state.stream = new_stream;
for (_, message) in &state.replay { for commands in state.replay.values() {
if let Err(error) = state.sink.send(message.clone()).await { let request = commands
.clone()
.into_router_request(None)
.expect("replay commands should always convert to route requests");
let message = serialize(&request, endpoint.supports_revision).unwrap();
if let Err(error) = state.sink.send(Message::Binary(message)).await {
trace!("{error}"); trace!("{error}");
time::sleep(time::Duration::from_secs(1)).await; time::sleep(time::Duration::from_secs(1)).await;
continue; continue;
} }
} }
for (key, value) in &state.vars { for (key, value) in &state.vars {
let request = RouterRequest { let request = Command::Set {
id: None, key: key.as_str().into(),
method: Method::Set.as_str().into(), value: value.clone(),
params: Some(vec![key.as_str().into(), value.clone()].into()), }
}; .into_router_request(None)
.unwrap();
trace!("Request {:?}", request); trace!("Request {:?}", request);
let payload = serialize(&request, endpoint.supports_revision).unwrap(); let payload = serialize(&request, endpoint.supports_revision).unwrap();
if let Err(error) = state.sink.send(Message::Binary(payload)).await { if let Err(error) = state.sink.send(Message::Binary(payload)).await {
@ -394,11 +452,7 @@ pub(crate) async fn run_router(
route_rx: Receiver<Route>, route_rx: Receiver<Route>,
) { ) {
let ping = { let ping = {
let request = RouterRequest { let request = Command::Health.into_router_request(None).unwrap();
id: None,
method: PING_METHOD.into(),
params: None,
};
let value = serialize(&request, endpoint.supports_revision).unwrap(); let value = serialize(&request, endpoint.supports_revision).unwrap();
Message::Binary(value) Message::Binary(value)
}; };
@ -418,7 +472,7 @@ pub(crate) async fn run_router(
state.last_activity = Instant::now(); state.last_activity = Instant::now();
state.live_queries.clear(); state.live_queries.clear();
state.routes.clear(); state.pending_requests.clear();
loop { loop {
tokio::select! { tokio::select! {

View file

@ -1,14 +1,13 @@
use super::{deserialize, serialize}; use super::{
use super::{HandleResult, PATH}; deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH,
use crate::api::conn::Connection; };
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use crate::api::conn::Route; use crate::api::conn::Route;
use crate::api::conn::Router; use crate::api::conn::Router;
use crate::api::conn::{Command, Connection, RequestData};
use crate::api::engine::remote::ws::Client; use crate::api::engine::remote::ws::Client;
use crate::api::engine::remote::ws::Response; use crate::api::engine::remote::ws::Response;
use crate::api::engine::remote::ws::PING_INTERVAL; use crate::api::engine::remote::ws::PING_INTERVAL;
use crate::api::engine::remote::ws::PING_METHOD;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::Endpoint; use crate::api::opt::Endpoint;
@ -16,7 +15,7 @@ use crate::api::ExtraFeatures;
use crate::api::OnceLockExt; use crate::api::OnceLockExt;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::engine::remote::ws::{Data, RouterRequest}; use crate::engine::remote::ws::Data;
use crate::engine::IntervalStream; use crate::engine::IntervalStream;
use crate::opt::WaitFor; use crate::opt::WaitFor;
use crate::sql::Value; use crate::sql::Value;
@ -34,7 +33,6 @@ use serde::Deserialize;
use std::collections::hash_map::Entry; use std::collections::hash_map::Entry;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::collections::HashSet; use std::collections::HashSet;
use std::mem;
use std::sync::atomic::AtomicI64; use std::sync::atomic::AtomicI64;
use std::sync::Arc; use std::sync::Arc;
use std::sync::OnceLock; use std::sync::OnceLock;
@ -50,7 +48,7 @@ use ws_stream_wasm::{WsEvent, WsStream};
type MessageStream = SplitStream<WsStream>; type MessageStream = SplitStream<WsStream>;
type MessageSink = SplitSink<WsStream, Message>; type MessageSink = SplitSink<WsStream, Message>;
type RouterState = super::RouterState<MessageSink, MessageStream, Message>; type RouterState = super::RouterState<MessageSink, MessageStream>;
impl crate::api::Connection for Client {} impl crate::api::Connection for Client {}
@ -96,79 +94,106 @@ async fn router_handle_request(
state: &mut RouterState, state: &mut RouterState,
endpoint: &Endpoint, endpoint: &Endpoint,
) -> HandleResult { ) -> HandleResult {
let (id, method, param) = request; let RequestData {
let params = match param.query { id,
Some((query, bindings)) => { command,
vec![query.into(), bindings.into()] } = request;
let entry = state.pending_requests.entry(id);
// We probably shouldn't be sending duplicate id requests.
let Entry::Vacant(entry) = entry else {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
} }
None => param.other, return HandleResult::Ok;
}; };
match method {
Method::Set => { let mut effect = RequestEffect::None;
if let [Value::Strand(key), value] = &params[..2] {
state.var_stash.insert(id, (key.0.clone(), value.clone())); match command {
} Command::Set {
ref key,
ref value,
} => {
effect = RequestEffect::Set {
key: key.clone(),
value: value.clone(),
};
} }
Method::Unset => { Command::Unset {
if let [Value::Strand(key)] = &params[..1] { ref key,
state.vars.swap_remove(&key.0); } => {
} effect = RequestEffect::Clear {
key: key.clone(),
};
} }
Method::Live => { Command::Insert {
if let Some(sender) = param.notification_sender { ..
if let [Value::Uuid(id)] = &params[..1] { } => {
state.live_queries.insert(id.0, sender); effect = RequestEffect::Insert;
} }
} Command::SubscribeLive {
ref uuid,
ref notification_sender,
} => {
state.live_queries.insert(*uuid, notification_sender.clone());
if response.send(Ok(DbResponse::Other(Value::None))).await.is_err() { if response.send(Ok(DbResponse::Other(Value::None))).await.is_err() {
trace!("Receiver dropped"); trace!("Receiver dropped");
} }
// There is nothing to send to the server here // There is nothing to send to the server here
return HandleResult::Ok; return HandleResult::Ok;
} }
Method::Kill => { Command::Kill {
if let [Value::Uuid(id)] = &params[..1] { ref uuid,
state.live_queries.remove(id); } => {
} state.live_queries.remove(uuid);
}
Command::Use {
..
} => {
state.replay.insert(ReplayMethod::Use, command.clone());
}
Command::Signup {
..
} => {
state.replay.insert(ReplayMethod::Signup, command.clone());
}
Command::Signin {
..
} => {
state.replay.insert(ReplayMethod::Signin, command.clone());
}
Command::Invalidate {
..
} => {
state.replay.insert(ReplayMethod::Invalidate, command.clone());
}
Command::Authenticate {
..
} => {
state.replay.insert(ReplayMethod::Authenticate, command.clone());
} }
_ => {} _ => {}
} }
let method_str = match method {
Method::Health => PING_METHOD,
_ => method.as_str(),
};
let message = { let message = {
let request = RouterRequest { let Some(req) = command.into_router_request(Some(id)) else {
id: Some(Value::from(id)), let _ = response.send(Err(Error::BackupsNotSupported.into())).await;
method: method_str.into(), return HandleResult::Ok;
params: (!params.is_empty()).then(|| params.into()),
}; };
trace!("Request {:?}", request); trace!("Request {:?}", req);
let payload = serialize(&request, endpoint.supports_revision).unwrap(); let payload = serialize(&req, endpoint.supports_revision).unwrap();
Message::Binary(payload) Message::Binary(payload)
}; };
if let Method::Authenticate
| Method::Invalidate
| Method::Signin
| Method::Signup
| Method::Use = method
{
state.replay.insert(method, message.clone());
}
match state.sink.send(message).await { match state.sink.send(message).await {
Ok(..) => { Ok(..) => {
state.last_activity = Instant::now(); state.last_activity = Instant::now();
match state.routes.entry(id) { entry.insert(PendingRequest {
Entry::Vacant(entry) => { effect,
entry.insert((method, response)); response_channel: response,
} });
Entry::Occupied(..) => {
let error = Error::DuplicateRequestId(id);
if response.send(Err(error.into())).await.is_err() {
trace!("Receiver dropped");
}
}
}
} }
Err(error) => { Err(error) => {
let error = Error::Ws(error.to_string()); let error = Error::Ws(error.to_string());
@ -196,23 +221,50 @@ async fn router_handle_response(
Some(id) => { Some(id) => {
if let Ok(id) = id.coerce_to_i64() { if let Ok(id) = id.coerce_to_i64() {
// We can only route responses with IDs // We can only route responses with IDs
if let Some((method, sender)) = state.routes.remove(&id) { if let Some(pending) = state.pending_requests.remove(&id) {
if matches!(method, Method::Set) { match pending.effect {
if let Some((key, value)) = state.var_stash.swap_remove(&id) { RequestEffect::None => {}
state.vars.insert(key, value); RequestEffect::Insert => {
} // For insert, we need to flatten single responses in an array
} if let Ok(Data::Other(Value::Array(value))) =
// Send the response back to the caller response.result
let mut response = response.result; {
if matches!(method, Method::Insert) { if value.len() == 1 {
// For insert, we need to flatten single responses in an array let _ = pending
if let Ok(Data::Other(Value::Array(value))) = &mut response { .response_channel
if let [value] = &mut value.0[..] { .send(DbResponse::from(Ok(Data::Other(
response = Ok(Data::Other(mem::take(value))); value.into_iter().next().unwrap(),
))))
.await;
} else {
let _ = pending
.response_channel
.send(DbResponse::from(Ok(Data::Other(
Value::Array(value),
))))
.await;
}
return HandleResult::Ok;
} }
} }
RequestEffect::Set {
key,
value,
} => {
state.vars.insert(key, value);
}
RequestEffect::Clear {
key,
} => {
state.vars.shift_remove(&key);
}
} }
let _res = sender.send(DbResponse::from(response)).await; let _res = pending
.response_channel
.send(DbResponse::from(response.result))
.await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
} }
} }
} }
@ -226,11 +278,10 @@ async fn router_handle_response(
if sender.send(notification).await.is_err() { if sender.send(notification).await.is_err() {
state.live_queries.remove(&live_query_id); state.live_queries.remove(&live_query_id);
let kill = { let kill = {
let request = RouterRequest { let request = Command::Kill {
id: None, uuid: live_query_id.0,
method: Method::Kill.as_str().into(), }
params: Some(vec![Value::from(live_query_id)].into()), .into_router_request(None);
};
let value = serialize(&request, endpoint.supports_revision) let value = serialize(&request, endpoint.supports_revision)
.unwrap(); .unwrap();
Message::Binary(value) Message::Binary(value)
@ -265,8 +316,10 @@ async fn router_handle_response(
{ {
// Return an error if an ID was returned // Return an error if an ID was returned
if let Some(Ok(id)) = id.map(Value::coerce_to_i64) { if let Some(Ok(id)) = id.map(Value::coerce_to_i64) {
if let Some((_method, sender)) = state.routes.remove(&id) { if let Some(req) = state.pending_requests.remove(&id) {
let _res = sender.send(Err(error)).await; let _res = req.response_channel.send(Err(error)).await;
} else {
warn!("got response for request with id '{id}', which was not in pending requests")
} }
} }
} else { } else {
@ -311,18 +364,21 @@ async fn router_reconnect(
} }
}; };
for (_, message) in &state.replay { for (_, message) in &state.replay {
if let Err(error) = state.sink.send(message.clone()).await { let message = message.clone().into_router_request(None);
let message = serialize(&message, endpoint.supports_revision).unwrap();
if let Err(error) = state.sink.send(Message::Binary(message)).await {
trace!("{error}"); trace!("{error}");
time::sleep(Duration::from_secs(1)).await; time::sleep(Duration::from_secs(1)).await;
continue; continue;
} }
} }
for (key, value) in &state.vars { for (key, value) in &state.vars {
let request = RouterRequest { let request = Command::Set {
id: None, key: key.as_str().into(),
method: Method::Set.as_str().into(), value: value.clone(),
params: Some(vec![key.as_str().into(), value.clone()].into()), }
}; .into_router_request(None);
trace!("Request {:?}", request); trace!("Request {:?}", request);
let serialize = serialize(&request, false).unwrap(); let serialize = serialize(&request, false).unwrap();
if let Err(error) = state.sink.send(Message::Binary(serialize)).await { if let Err(error) = state.sink.send(Message::Binary(serialize)).await {
@ -378,7 +434,7 @@ pub(crate) async fn run_router(
let ping = { let ping = {
let mut request = BTreeMap::new(); let mut request = BTreeMap::new();
request.insert("method".to_owned(), PING_METHOD.into()); request.insert("method".to_owned(), "ping".into());
let value = Value::from(request); let value = Value::from(request);
let value = serialize(&value, endpoint.supports_revision).unwrap(); let value = serialize(&value, endpoint.supports_revision).unwrap();
Message::Binary(value) Message::Binary(value)
@ -397,7 +453,7 @@ pub(crate) async fn run_router(
state.last_activity = Instant::now(); state.last_activity = Instant::now();
state.live_queries.clear(); state.live_queries.clear();
state.routes.clear(); state.pending_requests.clear();
loop { loop {
futures::select! { futures::select! {

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::method::OnceLockExt; use crate::api::method::OnceLockExt;
use crate::api::opt::auth::Jwt; use crate::api::opt::auth::Jwt;
@ -27,7 +26,11 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
router.execute_unit(Method::Authenticate, Param::new(vec![self.token.0.into()])).await router
.execute_unit(Command::Authenticate {
token: self.token.0,
})
.await
}) })
} }
} }

View file

@ -1,17 +1,11 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::Range;
use crate::api::opt::Resource;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::sql::to_value;
use crate::sql::Id;
use crate::sql::Value; use crate::sql::Value;
use crate::Surreal; use crate::Surreal;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
@ -21,21 +15,29 @@ use std::marker::PhantomData;
/// Content inserts or replaces the contents of a record entirely /// Content inserts or replaces the contents of a record entirely
#[derive(Debug)] #[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Content<'r, C: Connection, D, R> { pub struct Content<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) method: Method, pub(super) command: Result<Command>,
pub(super) resource: Result<Resource>,
pub(super) range: Option<Range<Id>>,
pub(super) content: D,
pub(super) response_type: PhantomData<R>, pub(super) response_type: PhantomData<R>,
} }
impl<C, D, R> Content<'_, C, D, R> impl<'r, C, R> Content<'r, C, R>
where where
C: Connection, C: Connection,
{ {
pub(crate) fn from_closure<F>(client: Cow<'r, Surreal<C>>, f: F) -> Self
where
F: FnOnce() -> Result<Command>,
{
Content {
client,
command: f(),
response_type: PhantomData,
}
}
/// Converts to an owned type which can easily be moved to a different thread /// Converts to an owned type which can easily be moved to a different thread
pub fn into_owned(self) -> Content<'static, C, D, R> { pub fn into_owned(self) -> Content<'static, C, R> {
Content { Content {
client: Cow::Owned(self.client.into_owned()), client: Cow::Owned(self.client.into_owned()),
..self ..self
@ -48,33 +50,20 @@ macro_rules! into_future {
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
let Content { let Content {
client, client,
method, command,
resource,
range,
content,
.. ..
} = self; } = self;
let content = to_value(content);
Box::pin(async move { Box::pin(async move {
let param = match range {
Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(),
};
let params = match content? {
Value::None | Value::Null => vec![param],
content => vec![param, content],
};
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(method, Param::new(params)).await router.$method(command?).await
}) })
} }
}; };
} }
impl<'r, Client, D> IntoFuture for Content<'r, Client, D, Value> impl<'r, Client> IntoFuture for Content<'r, Client, Value>
where where
Client: Connection, Client: Connection,
D: Serialize,
{ {
type Output = Result<Value>; type Output = Result<Value>;
type IntoFuture = BoxFuture<'r, Self::Output>; type IntoFuture = BoxFuture<'r, Self::Output>;
@ -82,10 +71,9 @@ where
into_future! {execute_value} into_future! {execute_value}
} }
impl<'r, Client, D, R> IntoFuture for Content<'r, Client, D, Option<R>> impl<'r, Client, R> IntoFuture for Content<'r, Client, Option<R>>
where where
Client: Connection, Client: Connection,
D: Serialize,
R: DeserializeOwned, R: DeserializeOwned,
{ {
type Output = Result<Option<R>>; type Output = Result<Option<R>>;
@ -94,10 +82,9 @@ where
into_future! {execute_opt} into_future! {execute_opt}
} }
impl<'r, Client, D, R> IntoFuture for Content<'r, Client, D, Vec<R>> impl<'r, Client, R> IntoFuture for Content<'r, Client, Vec<R>>
where where
Client: Connection, Client: Connection,
D: Serialize,
R: DeserializeOwned, R: DeserializeOwned,
{ {
type Output = Result<Vec<R>>; type Output = Result<Vec<R>>;

View file

@ -1,7 +1,5 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::method::Content;
use crate::api::opt::Resource; use crate::api::opt::Resource;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
@ -13,6 +11,9 @@ use serde::Serialize;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
use super::Content;
/// A record create future /// A record create future
#[derive(Debug)] #[derive(Debug)]
@ -46,7 +47,11 @@ macro_rules! into_future {
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Create, Param::new(vec![resource?.into()])).await let cmd = Command::Create {
what: resource?.into(),
data: None,
};
router.$method(cmd).await
}) })
} }
}; };
@ -89,17 +94,22 @@ where
C: Connection, C: Connection,
{ {
/// Sets content of a record /// Sets content of a record
pub fn content<D>(self, data: D) -> Content<'r, C, D, R> pub fn content<D>(self, data: D) -> Content<'r, C, R>
where where
D: Serialize, D: Serialize,
{ {
Content { Content::from_closure(self.client, || {
client: self.client, let content = to_value(data)?;
method: Method::Create,
resource: self.resource, let data = match content {
range: None, Value::None | Value::Null => None,
content: data, content => Some(content),
response_type: PhantomData, };
}
Ok(Command::Create {
what: self.resource?.into(),
data,
})
})
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::Range; use crate::api::opt::Range;
use crate::api::opt::Resource; use crate::api::opt::Resource;
@ -47,12 +46,16 @@ macro_rules! into_future {
.. ..
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let param = match range { let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(), Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(), None => resource?.into(),
}; };
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Delete, Param::new(vec![param])).await router
.$method(Command::Delete {
what: param,
})
.await
}) })
} }
}; };

View file

@ -1,6 +1,5 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::MlConfig; use crate::api::conn::MlExportConfig;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Error; use crate::api::Error;
@ -8,7 +7,6 @@ use crate::api::ExtraFeatures;
use crate::api::Result; use crate::api::Result;
use crate::method::Model; use crate::method::Model;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::opt::ExportDestination;
use crate::Surreal; use crate::Surreal;
use channel::Receiver; use channel::Receiver;
use futures::Stream; use futures::Stream;
@ -27,8 +25,8 @@ use std::task::Poll;
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Export<'r, C: Connection, R, T = ()> { pub struct Export<'r, C: Connection, R, T = ()> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) target: ExportDestination, pub(super) target: R,
pub(super) ml_config: Option<MlConfig>, pub(super) ml_config: Option<MlExportConfig>,
pub(super) response: PhantomData<R>, pub(super) response: PhantomData<R>,
pub(super) export_type: PhantomData<T>, pub(super) export_type: PhantomData<T>,
} }
@ -42,7 +40,7 @@ where
Export { Export {
client: self.client, client: self.client,
target: self.target, target: self.target,
ml_config: Some(MlConfig::Export { ml_config: Some(MlExportConfig {
name: name.to_owned(), name: name.to_owned(),
version: version.to_string(), version: version.to_string(),
}), }),
@ -78,12 +76,21 @@ where
if !router.features.contains(&ExtraFeatures::Backup) { if !router.features.contains(&ExtraFeatures::Backup) {
return Err(Error::BackupsNotSupported.into()); return Err(Error::BackupsNotSupported.into());
} }
let mut param = match self.target {
ExportDestination::File(path) => Param::file(path), if let Some(config) = self.ml_config {
ExportDestination::Memory => unreachable!(), return router
}; .execute_unit(Command::ExportMl {
param.ml_config = self.ml_config; path: self.target,
router.execute_unit(Method::Export, param).await config,
})
.await;
}
router
.execute_unit(Command::ExportFile {
path: self.target,
})
.await
}) })
} }
} }
@ -102,12 +109,25 @@ where
return Err(Error::BackupsNotSupported.into()); return Err(Error::BackupsNotSupported.into());
} }
let (tx, rx) = crate::channel::bounded(1); let (tx, rx) = crate::channel::bounded(1);
let ExportDestination::Memory = self.target else {
unreachable!(); if let Some(config) = self.ml_config {
}; router
let mut param = Param::bytes_sender(tx); .execute_unit(Command::ExportBytesMl {
param.ml_config = self.ml_config; bytes: tx,
router.execute_unit(Method::Export, param).await?; config,
})
.await?;
return Ok(Backup {
rx,
});
}
router
.execute_unit(Command::ExportBytes {
bytes: tx,
})
.await?;
Ok(Backup { Ok(Backup {
rx, rx,
}) })

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
@ -37,7 +36,7 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
router.execute_unit(Method::Health, Param::new(Vec::new())).await router.execute_unit(Command::Health).await
}) })
} }
} }

View file

@ -1,6 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::MlConfig;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Error; use crate::api::Error;
@ -20,7 +18,7 @@ use std::path::PathBuf;
pub struct Import<'r, C: Connection, T = ()> { pub struct Import<'r, C: Connection, T = ()> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) file: PathBuf, pub(super) file: PathBuf,
pub(super) ml_config: Option<MlConfig>, pub(super) is_ml: bool,
pub(super) import_type: PhantomData<T>, pub(super) import_type: PhantomData<T>,
} }
@ -33,7 +31,7 @@ where
Import { Import {
client: self.client, client: self.client,
file: self.file, file: self.file,
ml_config: Some(MlConfig::Import), is_ml: true,
import_type: PhantomData, import_type: PhantomData,
} }
} }
@ -65,9 +63,20 @@ where
if !router.features.contains(&ExtraFeatures::Backup) { if !router.features.contains(&ExtraFeatures::Backup) {
return Err(Error::BackupsNotSupported.into()); return Err(Error::BackupsNotSupported.into());
} }
let mut param = Param::file(self.file);
param.ml_config = self.ml_config; if self.is_ml {
router.execute_unit(Method::Import, param).await return router
.execute_unit(Command::ImportMl {
path: self.file,
})
.await;
}
router
.execute_unit(Command::ImportFile {
path: self.file,
})
.await
}) })
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::method::Content; use crate::api::method::Content;
@ -60,9 +59,13 @@ macro_rules! into_future {
Resource::Array(arr) => return Err(Error::InsertOnArray(arr).into()), Resource::Array(arr) => return Err(Error::InsertOnArray(arr).into()),
Resource::Edges(edges) => return Err(Error::InsertOnEdges(edges).into()), Resource::Edges(edges) => return Err(Error::InsertOnEdges(edges).into()),
}; };
let param = vec![table, data]; let cmd = Command::Insert {
what: Some(table),
data,
};
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Insert, Param::new(param)).await router.$method(cmd).await
}) })
} }
}; };
@ -106,59 +109,41 @@ where
R: DeserializeOwned, R: DeserializeOwned,
{ {
/// Specifies the data to insert into the table /// Specifies the data to insert into the table
pub fn content<D>(self, data: D) -> Content<'r, C, Value, R> pub fn content<D>(self, data: D) -> Content<'r, C, R>
where where
D: Serialize, D: Serialize,
{ {
let mut content = Content { Content::from_closure(self.client, || {
client: self.client, let mut data = crate::sql::to_value(data)?;
method: Method::Insert, match self.resource? {
resource: self.resource, Resource::Table(table) => Ok(Command::Insert {
range: None, what: Some(table.into()),
content: Value::None, data,
response_type: PhantomData, }),
}; Resource::RecordId(thing) => {
match crate::sql::to_value(data) { if data.is_array() {
Ok(mut data) => match content.resource { Err(Error::InvalidParams(
Ok(Resource::Table(table)) => {
content.resource = Ok(table.into());
content.content = data;
}
Ok(Resource::RecordId(record_id)) => match data.is_array() {
true => {
content.resource = Err(Error::InvalidParams(
"Tried to insert multiple records on a record ID".to_owned(), "Tried to insert multiple records on a record ID".to_owned(),
) )
.into()); .into())
} } else {
false => {
let mut table = Table::default(); let mut table = Table::default();
table.0.clone_from(&record_id.tb); table.0.clone_from(&thing.tb);
content.resource = Ok(table.into()); let what = Value::Table(table);
let mut ident = Ident::default(); let mut ident = Ident::default();
"id".clone_into(&mut ident.0); "id".clone_into(&mut ident.0);
let id = Part::Field(ident); let id = Part::Field(ident);
data.put(&[id], record_id.into()); data.put(&[id], thing.into());
content.content = data; Ok(Command::Insert {
what: Some(what),
data,
})
} }
},
Ok(Resource::Object(obj)) => {
content.resource = Err(Error::InsertOnObject(obj).into());
} }
Ok(Resource::Array(arr)) => { Resource::Object(obj) => Err(Error::InsertOnObject(obj).into()),
content.resource = Err(Error::InsertOnArray(arr).into()); Resource::Array(arr) => Err(Error::InsertOnArray(arr).into()),
} Resource::Edges(edges) => Err(Error::InsertOnEdges(edges).into()),
Ok(Resource::Edges(edges)) => {
content.resource = Err(Error::InsertOnEdges(edges).into());
}
Err(error) => {
content.resource = Err(error);
}
},
Err(error) => {
content.resource = Err(error.into());
} }
}; })
content
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
@ -37,7 +36,7 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
router.execute_unit(Method::Invalidate, Param::new(Vec::new())).await router.execute_unit(Command::Invalidate).await
}) })
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::conn::Router; use crate::api::conn::Router;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
@ -33,12 +32,12 @@ use futures::StreamExt;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem;
use std::pin::Pin; use std::pin::Pin;
use std::task::Context; use std::task::Context;
use std::task::Poll; use std::task::Poll;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use tokio::spawn; use tokio::spawn;
use uuid::Uuid;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local as spawn; use wasm_bindgen_futures::spawn_local as spawn;
@ -99,11 +98,16 @@ macro_rules! into_future {
Default::default(), Default::default(),
false, false,
); );
let id: Value = query.await?.take(0)?; let Value::Uuid(id) = query.await?.take(0)? else {
let rx = register(router, id.clone()).await?; return Err(Error::InternalError(
"successufull live query didn't return a uuid".to_string(),
)
.into());
};
let rx = register(router, id.0).await?;
Ok(Stream::new( Ok(Stream::new(
Surreal::new_from_router_waiter(client.router.clone(), client.waiter.clone()), Surreal::new_from_router_waiter(client.router.clone(), client.waiter.clone()),
id, id.0,
Some(rx), Some(rx),
)) ))
}) })
@ -111,11 +115,14 @@ macro_rules! into_future {
}; };
} }
pub(crate) async fn register(router: &Router, id: Value) -> Result<Receiver<dbs::Notification>> { pub(crate) async fn register(router: &Router, id: Uuid) -> Result<Receiver<dbs::Notification>> {
let (tx, rx) = channel::unbounded(); let (tx, rx) = channel::unbounded();
let mut param = Param::notification_sender(tx); router
param.other = vec![id]; .execute_unit(Command::SubscribeLive {
router.execute_unit(Method::Live, param).await?; uuid: id,
notification_sender: tx,
})
.await?;
Ok(rx) Ok(rx)
} }
@ -158,7 +165,7 @@ pub struct Stream<R> {
pub(crate) client: Surreal<Any>, pub(crate) client: Surreal<Any>,
// We no longer need the lifetime and the type parameter // We no longer need the lifetime and the type parameter
// Leaving them in for backwards compatibility // Leaving them in for backwards compatibility
pub(crate) id: Value, pub(crate) id: Uuid,
pub(crate) rx: Option<Receiver<dbs::Notification>>, pub(crate) rx: Option<Receiver<dbs::Notification>>,
pub(crate) response_type: PhantomData<R>, pub(crate) response_type: PhantomData<R>,
} }
@ -166,7 +173,7 @@ pub struct Stream<R> {
impl<R> Stream<R> { impl<R> Stream<R> {
pub(crate) fn new( pub(crate) fn new(
client: Surreal<Any>, client: Surreal<Any>,
id: Value, id: Uuid,
rx: Option<Receiver<dbs::Notification>>, rx: Option<Receiver<dbs::Notification>>,
) -> Self { ) -> Self {
Self { Self {
@ -247,14 +254,19 @@ where
poll_next_and_convert! {} poll_next_and_convert! {}
} }
pub(crate) fn kill<Client>(client: &Surreal<Client>, id: Value) pub(crate) fn kill<Client>(client: &Surreal<Client>, uuid: Uuid)
where where
Client: Connection, Client: Connection,
{ {
let client = client.clone(); let client = client.clone();
spawn(async move { spawn(async move {
if let Ok(router) = client.router.extract() { if let Ok(router) = client.router.extract() {
router.execute_unit(Method::Kill, Param::new(vec![id.clone()])).await.ok(); router
.execute_unit(Command::Kill {
uuid,
})
.await
.ok();
} }
}); });
} }
@ -264,9 +276,8 @@ impl<R> Drop for Stream<R> {
/// ///
/// This kills the live query process responsible for this stream. /// This kills the live query process responsible for this stream.
fn drop(&mut self) { fn drop(&mut self) {
if !self.id.is_none() && self.rx.is_some() { if self.rx.is_some() {
let id = mem::take(&mut self.id); kill(&self.client, self.id);
kill(&self.client, id);
} }
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::Range; use crate::api::opt::Range;
use crate::api::opt::Resource; use crate::api::opt::Resource;
@ -52,12 +51,22 @@ macro_rules! into_future {
} = self; } = self;
let content = to_value(content); let content = to_value(content);
Box::pin(async move { Box::pin(async move {
let param = match range { let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(), Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(), None => resource?.into(),
}; };
let content = match content? {
Value::None | Value::Null => None,
x => Some(x),
};
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Merge, Param::new(vec![param, content?])).await let cmd = Command::Merge {
what: param,
data: content,
};
router.$method(cmd).await
}) })
} }
}; };

View file

@ -1,4 +1,27 @@
//! Methods to use when interacting with a SurrealDB instance //! Methods to use when interacting with a SurrealDB instance
use self::query::ValidQuery;
use crate::api::err::Error;
use crate::api::opt;
use crate::api::opt::auth;
use crate::api::opt::auth::Credentials;
use crate::api::opt::auth::Jwt;
use crate::api::opt::IntoEndpoint;
use crate::api::Connect;
use crate::api::Connection;
use crate::api::OnceLockExt;
use crate::api::Surreal;
use crate::opt::IntoExportDestination;
use crate::opt::WaitFor;
use crate::sql::to_value;
use crate::sql::Value;
use serde::Serialize;
use std::borrow::Cow;
use std::marker::PhantomData;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
pub(crate) mod live; pub(crate) mod live;
pub(crate) mod query; pub(crate) mod query;
@ -43,8 +66,7 @@ pub use commit::Commit;
pub use content::Content; pub use content::Content;
pub use create::Create; pub use create::Create;
pub use delete::Delete; pub use delete::Delete;
pub use export::Backup; pub use export::{Backup, Export};
pub use export::Export;
use futures::Future; use futures::Future;
pub use health::Health; pub use health::Health;
pub use import::Import; pub use import::Import;
@ -67,31 +89,6 @@ pub use use_db::UseDb;
pub use use_ns::UseNs; pub use use_ns::UseNs;
pub use version::Version; pub use version::Version;
use crate::api::conn::Method;
use crate::api::opt;
use crate::api::opt::auth;
use crate::api::opt::auth::Credentials;
use crate::api::opt::auth::Jwt;
use crate::api::opt::IntoEndpoint;
use crate::api::Connect;
use crate::api::Connection;
use crate::api::OnceLockExt;
use crate::api::Surreal;
use crate::opt::IntoExportDestination;
use crate::opt::WaitFor;
use crate::sql::to_value;
use crate::sql::Value;
use serde::Serialize;
use std::borrow::Cow;
use std::marker::PhantomData;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use self::query::ValidQuery;
/// A alias for an often used type of future returned by async methods in this library. /// A alias for an often used type of future returned by async methods in this library.
pub(crate) type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>; pub(crate) type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>;
@ -113,36 +110,6 @@ pub struct Live;
#[derive(Debug)] #[derive(Debug)]
pub struct WithStats<T>(T); pub struct WithStats<T>(T);
impl Method {
#[allow(dead_code)] // used by `ws` and `http`
pub(crate) fn as_str(&self) -> &str {
match self {
Method::Authenticate => "authenticate",
Method::Create => "create",
Method::Delete => "delete",
Method::Export => "export",
Method::Health => "health",
Method::Import => "import",
Method::Invalidate => "invalidate",
Method::Insert => "insert",
Method::Kill => "kill",
Method::Live => "live",
Method::Merge => "merge",
Method::Patch => "patch",
Method::Query => "query",
Method::Select => "select",
Method::Set => "set",
Method::Signin => "signin",
Method::Signup => "signup",
Method::Unset => "unset",
Method::Update => "update",
Method::Upsert => "upsert",
Method::Use => "use",
Method::Version => "version",
}
}
}
impl<C> Surreal<C> impl<C> Surreal<C>
where where
C: Connection, C: Connection,
@ -318,7 +285,7 @@ where
pub fn use_db(&self, db: impl Into<String>) -> UseDb<C> { pub fn use_db(&self, db: impl Into<String>) -> UseDb<C> {
UseDb { UseDb {
client: Cow::Borrowed(self), client: Cow::Borrowed(self),
ns: Value::None, ns: None,
db: db.into(), db: db.into(),
} }
} }
@ -457,7 +424,16 @@ where
pub fn signup<R>(&self, credentials: impl Credentials<auth::Signup, R>) -> Signup<C, R> { pub fn signup<R>(&self, credentials: impl Credentials<auth::Signup, R>) -> Signup<C, R> {
Signup { Signup {
client: Cow::Borrowed(self), client: Cow::Borrowed(self),
credentials: to_value(credentials).map_err(Into::into), credentials: to_value(credentials).map_err(Into::into).and_then(|x| {
if let Value::Object(x) = x {
Ok(x)
} else {
Err(Error::InternalError(
"credentials did not serialize to an object".to_string(),
)
.into())
}
}),
response_type: PhantomData, response_type: PhantomData,
} }
} }
@ -576,7 +552,16 @@ where
pub fn signin<R>(&self, credentials: impl Credentials<auth::Signin, R>) -> Signin<C, R> { pub fn signin<R>(&self, credentials: impl Credentials<auth::Signin, R>) -> Signin<C, R> {
Signin { Signin {
client: Cow::Borrowed(self), client: Cow::Borrowed(self),
credentials: to_value(credentials).map_err(Into::into), credentials: to_value(credentials).map_err(Into::into).and_then(|x| {
if let Value::Object(x) = x {
Ok(x)
} else {
Err(Error::InternalError(
"credentials did not serialize to an object".to_string(),
)
.into())
}
}),
response_type: PhantomData, response_type: PhantomData,
} }
} }
@ -1359,7 +1344,7 @@ where
Import { Import {
client: Cow::Borrowed(self), client: Cow::Borrowed(self),
file: file.as_ref().to_owned(), file: file.as_ref().to_owned(),
ml_config: None, is_ml: false,
import_type: PhantomData, import_type: PhantomData,
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::PatchOp; use crate::api::opt::PatchOp;
use crate::api::opt::Range; use crate::api::opt::Range;
@ -51,7 +50,7 @@ macro_rules! into_future {
.. ..
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let param = match range { let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(), Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(), None => resource?.into(),
}; };
@ -59,9 +58,14 @@ macro_rules! into_future {
for result in patches { for result in patches {
vec.push(result?); vec.push(result?);
} }
let patches = vec.into(); let patches = Value::from(vec);
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Patch, Param::new(vec![param, patches])).await let cmd = Command::Patch {
what: param,
data: Some(patches),
};
router.$method(cmd).await
}) })
} }
}; };

View file

@ -1,8 +1,7 @@
use super::live; use super::live;
use super::Stream; use super::Stream;
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt; use crate::api::opt;
@ -158,8 +157,12 @@ where
let mut query = sql::Query::default(); let mut query = sql::Query::default();
query.0 .0 = query_statements; query.0 .0 = query_statements;
let param = Param::query(query, bindings); let mut response = router
let mut response = router.execute_query(Method::Query, param).await?; .execute_query(Command::Query {
query,
variables: bindings,
})
.await?;
for idx in query_indicies { for idx in query_indicies {
let Some((_, result)) = response.results.get(&idx) else { let Some((_, result)) = response.results.get(&idx) else {
@ -169,16 +172,24 @@ where
// This is a live query. We are using this as a workaround to avoid // This is a live query. We are using this as a workaround to avoid
// creating another public error variant for this internal error. // creating another public error variant for this internal error.
let res = match result { let res = match result {
Ok(id) => live::register(router, id.clone()).await.map(|rx| { Ok(id) => {
Stream::new( let Value::Uuid(uuid) = id else {
Surreal::new_from_router_waiter( return Err(Error::InternalError(
client.router.clone(), "successfull live query did not return a uuid".to_string(),
client.waiter.clone(), )
), .into());
id.clone(), };
Some(rx), live::register(router, uuid.0).await.map(|rx| {
) Stream::new(
}), Surreal::new_from_router_waiter(
client.router.clone(),
client.waiter.clone(),
),
uuid.0,
Some(rx),
)
})
}
Err(_) => Err(crate::Error::from(Error::NotLiveQuery(idx))), Err(_) => Err(crate::Error::from(Error::NotLiveQuery(idx))),
}; };
response.live_queries.insert(idx, res); response.live_queries.insert(idx, res);

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::method::OnceLockExt; use crate::api::method::OnceLockExt;
use crate::api::opt::Range; use crate::api::opt::Range;
@ -49,12 +48,16 @@ macro_rules! into_future {
.. ..
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let param = match range { let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(), Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(), None => resource?.into(),
}; };
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Select, Param::new(vec![param])).await router
.$method(Command::Select {
what: param,
})
.await
}) })
} }
}; };

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
@ -41,7 +40,12 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
router.execute_unit(Method::Set, Param::new(vec![self.key.into(), self.value?])).await router
.execute_unit(Command::Set {
key: self.key,
value: self.value?,
})
.await
}) })
} }
} }

View file

@ -1,22 +1,21 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal; use crate::Surreal;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
use surrealdb_core::sql::Object;
/// A signin future /// A signin future
#[derive(Debug)] #[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Signin<'r, C: Connection, R> { pub struct Signin<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) credentials: Result<Value>, pub(super) credentials: Result<Object>,
pub(super) response_type: PhantomData<R>, pub(super) response_type: PhantomData<R>,
} }
@ -49,7 +48,11 @@ where
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let router = client.router.extract()?; let router = client.router.extract()?;
router.execute(Method::Signin, Param::new(vec![credentials?])).await router
.execute(Command::Signin {
credentials: credentials?,
})
.await
}) })
} }
} }

View file

@ -1,22 +1,21 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal; use crate::Surreal;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
use surrealdb_core::sql::Object;
/// A signup future /// A signup future
#[derive(Debug)] #[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Signup<'r, C: Connection, R> { pub struct Signup<'r, C: Connection, R> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) credentials: Result<Value>, pub(super) credentials: Result<Object>,
pub(super) response_type: PhantomData<R>, pub(super) response_type: PhantomData<R>,
} }
@ -49,7 +48,11 @@ where
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let router = client.router.extract()?; let router = client.router.extract()?;
router.execute(Method::Signup, Param::new(vec![credentials?])).await router
.execute(Command::Signup {
credentials: credentials?,
})
.await
}) })
} }
} }

View file

@ -1,12 +1,11 @@
use channel::Receiver;
use super::types::User; use super::types::User;
use crate::api::conn::Command;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Method;
use crate::api::conn::Route; use crate::api::conn::Route;
use crate::api::Response as QueryResponse; use crate::api::Response as QueryResponse;
use crate::sql::to_value; use crate::sql::to_value;
use crate::sql::Value; use crate::sql::Value;
use channel::Receiver;
pub(super) fn mock(route_rx: Receiver<Route>) { pub(super) fn mock(route_rx: Receiver<Route>) {
tokio::spawn(async move { tokio::spawn(async move {
@ -15,81 +14,111 @@ pub(super) fn mock(route_rx: Receiver<Route>) {
response, response,
}) = route_rx.recv().await }) = route_rx.recv().await
{ {
let (_, method, param) = request; let cmd = request.command;
let mut params = param.other;
let result = match method { let result = match cmd {
Method::Invalidate | Method::Health => match &params[..] { Command::Invalidate | Command::Health => Ok(DbResponse::Other(Value::None)),
[] => Ok(DbResponse::Other(Value::None)), Command::Authenticate {
_ => unreachable!(), ..
},
Method::Authenticate | Method::Kill | Method::Unset => match &params[..] {
[_] => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Method::Live => match &params[..] {
[_] => Ok(DbResponse::Other(
"c6c0e36c-e2cf-42cb-b2d5-75415249b261".to_owned().into(),
)),
_ => unreachable!(),
},
Method::Version => match &params[..] {
[] => Ok(DbResponse::Other("1.0.0".into())),
_ => unreachable!(),
},
Method::Use => match &params[..] {
[_] | [_, _] => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Method::Signup | Method::Signin => match &mut params[..] {
[_] => Ok(DbResponse::Other("jwt".to_owned().into())),
_ => unreachable!(),
},
Method::Set => match &params[..] {
[_, _] => Ok(DbResponse::Other(Value::None)),
_ => unreachable!(),
},
Method::Query => match param.query {
Some(_) => Ok(DbResponse::Query(QueryResponse::new())),
_ => unreachable!(),
},
Method::Create => match &params[..] {
[_] => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
[_, user] => Ok(DbResponse::Other(user.clone())),
_ => unreachable!(),
},
Method::Select | Method::Delete => match &params[..] {
[Value::Thing(..)] => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
[Value::Table(..) | Value::Array(..) | Value::Range(..)] => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
_ => unreachable!(),
},
Method::Upsert | Method::Update | Method::Merge | Method::Patch => {
match &params[..] {
[Value::Thing(..)] | [Value::Thing(..), _] => {
Ok(DbResponse::Other(to_value(User::default()).unwrap()))
}
[Value::Table(..) | Value::Array(..) | Value::Range(..)]
| [Value::Table(..) | Value::Array(..) | Value::Range(..), _] => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
_ => unreachable!(),
}
} }
Method::Insert => match &params[..] { | Command::Kill {
[Value::Table(..), Value::Array(..)] => { ..
}
| Command::Unset {
..
} => Ok(DbResponse::Other(Value::None)),
Command::SubscribeLive {
..
} => Ok(DbResponse::Other("c6c0e36c-e2cf-42cb-b2d5-75415249b261".to_owned().into())),
Command::Version => Ok(DbResponse::Other("1.0.0".into())),
Command::Use {
..
} => Ok(DbResponse::Other(Value::None)),
Command::Signup {
..
}
| Command::Signin {
..
} => Ok(DbResponse::Other("jwt".to_owned().into())),
Command::Set {
..
} => Ok(DbResponse::Other(Value::None)),
Command::Query {
..
} => Ok(DbResponse::Query(QueryResponse::new())),
Command::Create {
data,
..
} => match data {
None => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
Some(user) => Ok(DbResponse::Other(user.clone())),
},
Command::Select {
what,
..
}
| Command::Delete {
what,
..
} => match what {
Value::Thing(..) => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
Value::Table(..) | Value::Array(..) | Value::Range(..) => {
Ok(DbResponse::Other(Value::Array(Default::default()))) Ok(DbResponse::Other(Value::Array(Default::default())))
} }
[Value::Table(..), _] => { _ => unreachable!(),
},
Command::Upsert {
what,
..
}
| Command::Update {
what,
..
}
| Command::Merge {
what,
..
}
| Command::Patch {
what,
..
} => match what {
Value::Thing(..) => Ok(DbResponse::Other(to_value(User::default()).unwrap())),
Value::Table(..) | Value::Array(..) | Value::Range(..) => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
_ => unreachable!(),
},
Command::Insert {
what,
data,
} => match (what, data) {
(Some(Value::Table(..)), Value::Array(..)) => {
Ok(DbResponse::Other(Value::Array(Default::default())))
}
(Some(Value::Table(..)), _) => {
Ok(DbResponse::Other(to_value(User::default()).unwrap())) Ok(DbResponse::Other(to_value(User::default()).unwrap()))
} }
_ => unreachable!(), _ => unreachable!(),
}, },
Method::Export | Method::Import => match param.file { Command::ExportMl {
Some(_) => Ok(DbResponse::Other(Value::None)), ..
_ => unreachable!(), }
}, | Command::ExportBytesMl {
..
}
| Command::ExportFile {
..
}
| Command::ExportBytes {
..
}
| Command::ImportMl {
..
}
| Command::ImportFile {
..
} => Ok(DbResponse::Other(Value::None)),
}; };
if let Err(message) = response.send(result).await { if let Err(message) = response.send(result).await {

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
@ -39,7 +38,11 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
router.execute_unit(Method::Unset, Param::new(vec![self.key.into()])).await router
.execute_unit(Command::Unset {
key: self.key,
})
.await
}) })
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::method::Content; use crate::api::method::Content;
use crate::api::method::Merge; use crate::api::method::Merge;
@ -18,6 +17,7 @@ use serde::Serialize;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
/// An update future /// An update future
#[derive(Debug)] #[derive(Debug)]
@ -52,12 +52,17 @@ macro_rules! into_future {
.. ..
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let param = match range { let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(), Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(), None => resource?.into(),
}; };
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Upsert, Param::new(vec![param])).await router
.$method(Command::Update {
what: param,
data: None,
})
.await
}) })
} }
}; };
@ -123,18 +128,28 @@ where
R: DeserializeOwned, R: DeserializeOwned,
{ {
/// Replaces the current document / record data with the specified data /// Replaces the current document / record data with the specified data
pub fn content<D>(self, data: D) -> Content<'r, C, D, R> pub fn content<D>(self, data: D) -> Content<'r, C, R>
where where
D: Serialize, D: Serialize,
{ {
Content { Content::from_closure(self.client, || {
client: self.client, let data = to_value(data)?;
method: Method::Update,
resource: self.resource, let what: Value = match self.range {
range: self.range, Some(range) => self.resource?.with_range(range)?.into(),
content: data, None => self.resource?.into(),
response_type: PhantomData, };
}
let data = match data {
Value::None | Value::Null => None,
content => Some(content),
};
Ok(Command::Update {
what,
data,
})
})
} }
/// Merges the current document / record data with the specified data /// Merges the current document / record data with the specified data

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::method::Content; use crate::api::method::Content;
use crate::api::method::Merge; use crate::api::method::Merge;
@ -18,6 +17,7 @@ use serde::Serialize;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
use std::marker::PhantomData; use std::marker::PhantomData;
use surrealdb_core::sql::to_value;
/// An upsert future /// An upsert future
#[derive(Debug)] #[derive(Debug)]
@ -52,12 +52,17 @@ macro_rules! into_future {
.. ..
} = self; } = self;
Box::pin(async move { Box::pin(async move {
let param = match range { let param: Value = match range {
Some(range) => resource?.with_range(range)?.into(), Some(range) => resource?.with_range(range)?.into(),
None => resource?.into(), None => resource?.into(),
}; };
let router = client.router.extract()?; let router = client.router.extract()?;
router.$method(Method::Upsert, Param::new(vec![param])).await router
.$method(Command::Upsert {
what: param,
data: None,
})
.await
}) })
} }
}; };
@ -123,18 +128,28 @@ where
R: DeserializeOwned, R: DeserializeOwned,
{ {
/// Replaces the current document / record data with the specified data /// Replaces the current document / record data with the specified data
pub fn content<D>(self, data: D) -> Content<'r, C, D, R> pub fn content<D>(self, data: D) -> Content<'r, C, R>
where where
D: Serialize, D: Serialize,
{ {
Content { Content::from_closure(self.client, || {
client: self.client, let data = to_value(data)?;
method: Method::Upsert,
resource: self.resource, let what: Value = match self.range {
range: self.range, Some(range) => self.resource?.with_range(range)?.into(),
content: data, None => self.resource?.into(),
response_type: PhantomData, };
}
let data = match data {
Value::None | Value::Null => None,
content => Some(content),
};
Ok(Command::Upsert {
what,
data,
})
})
} }
/// Merges the current document / record data with the specified data /// Merges the current document / record data with the specified data

View file

@ -1,11 +1,9 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::opt::WaitFor; use crate::opt::WaitFor;
use crate::sql::Value;
use crate::Surreal; use crate::Surreal;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
@ -14,7 +12,7 @@ use std::future::IntoFuture;
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct UseDb<'r, C: Connection> { pub struct UseDb<'r, C: Connection> {
pub(super) client: Cow<'r, Surreal<C>>, pub(super) client: Cow<'r, Surreal<C>>,
pub(super) ns: Value, pub(super) ns: Option<String>,
pub(super) db: String, pub(super) db: String,
} }
@ -41,7 +39,12 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
router.execute_unit(Method::Use, Param::new(vec![self.ns, self.db.into()])).await?; router
.execute_unit(Command::Use {
namespace: self.ns,
database: Some(self.db),
})
.await?;
self.client.waiter.0.send(Some(WaitFor::Database)).ok(); self.client.waiter.0.send(Some(WaitFor::Database)).ok();
Ok(()) Ok(())
}) })

View file

@ -1,11 +1,9 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::method::UseDb; use crate::api::method::UseDb;
use crate::api::Connection; use crate::api::Connection;
use crate::api::Result; use crate::api::Result;
use crate::method::OnceLockExt; use crate::method::OnceLockExt;
use crate::sql::Value;
use crate::Surreal; use crate::Surreal;
use std::borrow::Cow; use std::borrow::Cow;
use std::future::IntoFuture; use std::future::IntoFuture;
@ -55,7 +53,12 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
router.execute_unit(Method::Use, Param::new(vec![self.ns.into(), Value::None])).await router
.execute_unit(Command::Use {
namespace: Some(self.ns),
database: None,
})
.await
}) })
} }
} }

View file

@ -1,5 +1,4 @@
use crate::api::conn::Method; use crate::api::conn::Command;
use crate::api::conn::Param;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::Connection; use crate::api::Connection;
@ -38,10 +37,7 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { Box::pin(async move {
let router = self.client.router.extract()?; let router = self.client.router.extract()?;
let version = router let version = router.execute_value(Command::Version).await?.convert_to_string()?;
.execute_value(Method::Version, Param::new(Vec::new()))
.await?
.convert_to_string()?;
let semantic = version.trim_start_matches("surrealdb-"); let semantic = version.trim_start_matches("surrealdb-");
semantic.parse().map_err(|_| Error::InvalidSemanticVersion(semantic.to_string()).into()) semantic.parse().map_err(|_| Error::InvalidSemanticVersion(semantic.to_string()).into())
}) })

View file

@ -3,28 +3,28 @@ use std::path::PathBuf;
#[derive(Debug)] #[derive(Debug)]
#[non_exhaustive] #[non_exhaustive]
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub enum ExportDestination { pub enum ExportDestination {
File(PathBuf), File(PathBuf),
Memory, Memory,
} }
/// A trait for converting inputs into database export locations /// A trait for converting inputs into database export locations
#[cfg_attr(docsrs, doc(cfg(not(target_arch = "wasm32"))))]
pub trait IntoExportDestination<R> { pub trait IntoExportDestination<R> {
/// Converts an input into a database export location /// Converts an input into a database export location
fn into_export_destination(self) -> ExportDestination; fn into_export_destination(self) -> R;
} }
impl<T> IntoExportDestination<PathBuf> for T impl<T> IntoExportDestination<PathBuf> for T
where where
T: AsRef<Path>, T: AsRef<Path>,
{ {
fn into_export_destination(self) -> ExportDestination { fn into_export_destination(self) -> PathBuf {
ExportDestination::File(self.as_ref().to_path_buf()) self.as_ref().to_path_buf()
} }
} }
impl IntoExportDestination<()> for () { impl IntoExportDestination<()> for () {
fn into_export_destination(self) -> ExportDestination { fn into_export_destination(self) {}
ExportDestination::Memory
}
} }

View file

@ -1,4 +1,9 @@
//! The different options and types for use in API functions //! The different options and types for use in API functions
use crate::sql::to_value;
use crate::sql::Thing;
use crate::sql::Value;
use dmp::Diff;
use serde::Serialize;
pub mod auth; pub mod auth;
pub mod capabilities; pub mod capabilities;
@ -10,12 +15,6 @@ mod query;
mod resource; mod resource;
mod tls; mod tls;
use crate::sql::to_value;
use crate::sql::Thing;
use crate::sql::Value;
use dmp::Diff;
use serde::Serialize;
pub use config::*; pub use config::*;
pub use endpoint::*; pub use endpoint::*;
pub use export::*; pub use export::*;