use /rpc for the rust sdk http connection (#4482)

This commit is contained in:
Raphael Darley 2024-08-20 01:14:32 -07:00 committed by GitHub
parent cb02db1477
commit c7457ffc56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 294 additions and 453 deletions

View file

@ -102,7 +102,7 @@ pub(crate) enum Command {
} }
impl Command { impl Command {
#[cfg(feature = "protocol-ws")] #[cfg(any(feature = "protocol-ws", feature = "protocol-http"))]
pub(crate) fn into_router_request(self, id: Option<i64>) -> Option<RouterRequest> { pub(crate) fn into_router_request(self, id: Option<i64>) -> Option<RouterRequest> {
let id = id.map(Value::from); let id = id.map(Value::from);
let res = match self { let res = match self {
@ -321,6 +321,39 @@ impl Command {
}; };
Some(res) Some(res)
} }
#[cfg(feature = "protocol-http")]
pub(crate) fn needs_one(&self) -> bool {
match self {
Command::Upsert {
what,
..
} => what.is_thing(),
Command::Update {
what,
..
} => what.is_thing(),
Command::Insert {
data,
..
} => !data.is_array(),
Command::Patch {
what,
..
} => what.is_thing(),
Command::Merge {
what,
..
} => what.is_thing(),
Command::Select {
what,
} => what.is_thing(),
Command::Delete {
what,
} => what.is_thing(),
_ => false,
}
}
} }
/// A struct which will be serialized as a map to behave like the previously used BTreeMap. /// A struct which will be serialized as a map to behave like the previously used BTreeMap.

View file

@ -17,6 +17,8 @@ use surrealdb_core::sql::{from_value, Value};
mod cmd; mod cmd;
pub(crate) use cmd::Command; pub(crate) use cmd::Command;
#[cfg(feature = "protocol-http")]
pub(crate) use cmd::RouterRequest;
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] // used by the embedded and remote connections #[allow(dead_code)] // used by the embedded and remote connections

View file

@ -8,47 +8,33 @@ pub(crate) mod wasm;
use crate::api::conn::Command; use crate::api::conn::Command;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::RequestData; use crate::api::conn::RequestData;
use crate::api::engine::remote::duration_from_str; use crate::api::conn::RouterRequest;
use crate::api::engine::remote::{deserialize, serialize};
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::query::QueryResult;
use crate::api::Connect; use crate::api::Connect;
use crate::api::Response as QueryResponse;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::dbs::Status; use crate::engine::remote::Response;
use crate::engine::value_to_values;
use crate::headers::AUTH_DB; use crate::headers::AUTH_DB;
use crate::headers::AUTH_NS; use crate::headers::AUTH_NS;
use crate::headers::DB; use crate::headers::DB;
use crate::headers::NS; use crate::headers::NS;
use crate::method::Stats;
use crate::opt::IntoEndpoint; use crate::opt::IntoEndpoint;
use crate::sql::from_value; use crate::sql::from_value;
use crate::sql::serde::deserialize;
use crate::sql::Value; use crate::sql::Value;
use futures::TryStreamExt; use futures::TryStreamExt;
use indexmap::IndexMap; use indexmap::IndexMap;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue; use reqwest::header::HeaderValue;
use reqwest::header::ACCEPT; use reqwest::header::ACCEPT;
use reqwest::header::CONTENT_TYPE;
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use serde::Deserialize; use serde::Deserialize;
use serde::Serialize; use serde::Serialize;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use surrealdb_core::sql::Query;
use surrealdb_core::sql::statements::CreateStatement;
use surrealdb_core::sql::statements::DeleteStatement;
use surrealdb_core::sql::statements::InsertStatement;
use surrealdb_core::sql::statements::SelectStatement;
use surrealdb_core::sql::statements::UpdateStatement;
use surrealdb_core::sql::statements::UpsertStatement;
use surrealdb_core::sql::Data;
use surrealdb_core::sql::Field;
use surrealdb_core::sql::Output;
use url::Url; use url::Url;
#[cfg(not(target_arch = "wasm32"))]
use reqwest::header::CONTENT_TYPE;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf; use std::path::PathBuf;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -60,9 +46,10 @@ use tokio_util::compat::FuturesAsyncReadCompatExt;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
use wasm_bindgen_futures::spawn_local; use wasm_bindgen_futures::spawn_local;
const SQL_PATH: &str = "sql"; // const SQL_PATH: &str = "sql";
const RPC_PATH: &str = "rpc";
/// The HTTP scheme used to connect to `http://` endpoints // The HTTP scheme used to connect to `http://` endpoints
#[derive(Debug)] #[derive(Debug)]
pub struct Http; pub struct Http;
@ -111,9 +98,11 @@ impl Surreal<Client> {
pub(crate) fn default_headers() -> HeaderMap { pub(crate) fn default_headers() -> HeaderMap {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("application/surrealdb")); headers.insert(ACCEPT, HeaderValue::from_static("application/surrealdb"));
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/surrealdb"));
headers headers
} }
#[allow(dead_code)]
#[derive(Debug)] #[derive(Debug)]
enum Auth { enum Auth {
Basic { Basic {
@ -157,8 +146,6 @@ impl Authenticate for RequestBuilder {
} }
} }
type HttpQueryResponse = (String, Status, Value);
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct Credentials { struct Credentials {
user: String, user: String,
@ -175,70 +162,6 @@ struct AuthResponse {
token: Option<String>, token: Option<String>,
} }
async fn submit_auth(request: RequestBuilder) -> Result<Value> {
let response = request.send().await?.error_for_status()?;
let bytes = response.bytes().await?;
let response: AuthResponse =
deserialize(&bytes).map_err(|error| Error::ResponseFromBinary {
binary: bytes.to_vec(),
error,
})?;
Ok(response.token.into())
}
async fn query(request: RequestBuilder) -> Result<QueryResponse> {
let response = request.send().await?.error_for_status()?;
let bytes = response.bytes().await?;
let responses = deserialize::<Vec<HttpQueryResponse>>(&bytes).map_err(|error| {
Error::ResponseFromBinary {
binary: bytes.to_vec(),
error,
}
})?;
let mut map = IndexMap::<usize, (Stats, QueryResult)>::with_capacity(responses.len());
for (index, (execution_time, status, value)) in responses.into_iter().enumerate() {
let stats = Stats {
execution_time: duration_from_str(&execution_time),
};
match status {
Status::Ok => {
map.insert(index, (stats, Ok(value)));
}
Status::Err => {
map.insert(index, (stats, Err(Error::Query(value.as_raw_string()).into())));
}
_ => unreachable!(),
}
}
Ok(QueryResponse {
results: map,
..QueryResponse::new()
})
}
async fn take(one: bool, request: RequestBuilder) -> Result<Value> {
if let Some((_stats, result)) = query(request).await?.results.swap_remove(&0) {
let value = result?;
match one {
true => match value {
Value::Array(mut vec) => {
if let [value] = &mut vec.0[..] {
return Ok(mem::take(value));
}
}
Value::None | Value::Null => {}
value => return Ok(value),
},
false => return Ok(value),
}
}
match one {
true => Ok(Value::None),
false => Ok(Value::Array(Default::default())),
}
}
type BackupSender = channel::Sender<Result<Vec<u8>>>; type BackupSender = channel::Sender<Result<Vec<u8>>>;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
@ -328,17 +251,40 @@ async fn import(request: RequestBuilder, path: PathBuf) -> Result<Value> {
Ok(Value::None) Ok(Value::None)
} }
async fn version(request: RequestBuilder) -> Result<Value> {
let response = request.send().await?.error_for_status()?;
let version = response.text().await?;
Ok(version.into())
}
pub(crate) async fn health(request: RequestBuilder) -> Result<Value> { pub(crate) async fn health(request: RequestBuilder) -> Result<Value> {
request.send().await?.error_for_status()?; request.send().await?.error_for_status()?;
Ok(Value::None) Ok(Value::None)
} }
async fn process_req(
req: RouterRequest,
base_url: &Url,
client: &reqwest::Client,
headers: &HeaderMap,
auth: &Option<Auth>,
) -> Result<DbResponse> {
let url = base_url.join(RPC_PATH).unwrap();
let http_req =
client.post(url).headers(headers.clone()).auth(auth).body(serialize(&req, false)?);
let response = http_req.send().await?.error_for_status()?;
let bytes = response.bytes().await?;
let response: Response = deserialize(&mut &bytes[..], false)?;
DbResponse::from(response.result)
}
fn try_one(res: DbResponse, needed: bool) -> DbResponse {
if !needed {
return res;
}
match res {
DbResponse::Other(Value::Array(arr)) if arr.len() == 1 => {
DbResponse::Other(arr.into_iter().next().unwrap())
}
r => r,
}
}
async fn router( async fn router(
RequestData { RequestData {
command, command,
@ -347,57 +293,74 @@ async fn router(
base_url: &Url, base_url: &Url,
client: &reqwest::Client, client: &reqwest::Client,
headers: &mut HeaderMap, headers: &mut HeaderMap,
vars: &mut IndexMap<String, String>, vars: &mut IndexMap<String, Value>,
auth: &mut Option<Auth>, auth: &mut Option<Auth>,
) -> Result<DbResponse> { ) -> Result<DbResponse> {
match command { match command {
Command::Use { Command::Query {
namespace, query,
database, mut variables,
} => { } => {
let path = base_url.join(SQL_PATH)?; variables.extend(vars.clone());
let mut request = client.post(path).headers(headers.clone()); let req = Command::Query {
let ns = match namespace { query,
Some(ns) => match HeaderValue::try_from(&ns) { variables,
}
.into_router_request(None)
.expect("query should be valid request");
process_req(req, base_url, client, headers, auth).await
}
ref cmd @ Command::Use {
ref namespace,
ref database,
} => {
let req = cmd
.clone()
.into_router_request(None)
.expect("use should be a valid router request");
// process request to check permissions
let out = process_req(req, base_url, client, headers, auth).await?;
match namespace {
Some(ns) => match HeaderValue::try_from(ns) {
Ok(ns) => { Ok(ns) => {
request = request.header(&NS, &ns);
Some(ns)
}
Err(_) => {
return Err(Error::InvalidNsName(ns).into());
}
},
None => None,
};
let db = match database {
Some(db) => match HeaderValue::try_from(&db) {
Ok(db) => {
request = request.header(&DB, &db);
Some(db)
}
Err(_) => {
return Err(Error::InvalidDbName(db).into());
}
},
None => None,
};
request = request.auth(auth).body("RETURN true");
take(true, request).await?;
if let Some(ns) = ns {
headers.insert(&NS, ns); headers.insert(&NS, ns);
} }
if let Some(db) = db { Err(_) => {
return Err(Error::InvalidNsName(ns.to_owned()).into());
}
},
None => {}
};
match database {
Some(db) => match HeaderValue::try_from(db) {
Ok(db) => {
headers.insert(&DB, db); headers.insert(&DB, db);
} }
Ok(DbResponse::Other(Value::None)) Err(_) => {
return Err(Error::InvalidDbName(db.to_owned()).into());
}
},
None => {}
};
Ok(out)
} }
Command::Signin { Command::Signin {
credentials, credentials,
} => { } => {
let path = base_url.join("signin")?; let req = Command::Signin {
let request = credentials: credentials.clone(),
client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string()); }
let value = submit_auth(request).await?; .into_router_request(None)
.expect("signin should be a valid router request");
let DbResponse::Other(value) =
process_req(req, base_url, client, headers, auth).await?
else {
unreachable!("didn't make query")
};
if let Ok(Credentials { if let Ok(Credentials {
user, user,
pass, pass,
@ -416,24 +379,19 @@ async fn router(
token: value.to_raw_string(), token: value.to_raw_string(),
}); });
} }
Ok(DbResponse::Other(value))
}
Command::Signup {
credentials,
} => {
let path = base_url.join("signup")?;
let request =
client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string());
let value = submit_auth(request).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Command::Authenticate { Command::Authenticate {
token, token,
} => { } => {
let path = base_url.join(SQL_PATH)?; let req = Command::Authenticate {
let request = token: token.clone(),
client.post(path).headers(headers.clone()).bearer_auth(&token).body("RETURN true"); }
take(true, request).await?; .into_router_request(None)
.expect("authenticate should be a valid router request");
process_req(req, base_url, client, headers, auth).await?;
*auth = Some(Auth::Bearer { *auth = Some(Auth::Bearer {
token, token,
}); });
@ -443,156 +401,33 @@ async fn router(
*auth = None; *auth = None;
Ok(DbResponse::Other(Value::None)) Ok(DbResponse::Other(Value::None))
} }
Command::Create { Command::Set {
what, key,
data, value,
} => { } => {
let path = base_url.join(SQL_PATH)?; let query: Query = surrealdb_core::sql::parse(&format!("RETURN ${key};"))?;
let statement = { let req = Command::Query {
let mut stmt = CreateStatement::default(); query,
stmt.what = value_to_values(what); variables: [(key.clone(), value)].into(),
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(true, request).await?;
Ok(DbResponse::Other(value))
} }
Command::Upsert { .into_router_request(None)
what, .expect("query is valid request");
data, let DbResponse::Query(mut res) =
} => { process_req(req, base_url, client, headers, auth).await?
let path = base_url.join(SQL_PATH)?; else {
let one = what.is_thing(); unreachable!("made query request so response must be query")
let statement = {
let mut stmt = UpsertStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
}; };
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); let val: Value = res.take(0)?;
let value = take(one, request).await?;
Ok(DbResponse::Other(value)) vars.insert(key, val);
Ok(DbResponse::Other(Value::None))
} }
Command::Update { Command::Unset {
what, key,
data,
} => { } => {
let path = base_url.join(SQL_PATH)?; vars.shift_remove(&key);
let one = what.is_thing(); Ok(DbResponse::Other(Value::None))
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::ContentExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Command::Insert {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let one = !data.is_array();
let statement = {
let mut stmt = InsertStatement::default();
stmt.into = what;
stmt.data = Data::SingleExpression(data);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Command::Patch {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::PatchExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Command::Merge {
what,
data,
} => {
let path = base_url.join(SQL_PATH)?;
let one = what.is_thing();
let statement = {
let mut stmt = UpdateStatement::default();
stmt.what = value_to_values(what);
stmt.data = data.map(Data::MergeExpression);
stmt.output = Some(Output::After);
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Command::Select {
what,
} => {
let path = base_url.join(SQL_PATH)?;
let one = what.is_thing();
let statement = {
let mut stmt = SelectStatement::default();
stmt.what = value_to_values(what);
stmt.expr.0 = vec![Field::All];
stmt
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Command::Delete {
what,
} => {
let one = what.is_thing();
let path = base_url.join(SQL_PATH)?;
let (one, statement) = {
let mut stmt = DeleteStatement::default();
stmt.what = value_to_values(what);
stmt.output = Some(Output::Before);
(one, stmt)
};
let request =
client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string());
let value = take(one, request).await?;
Ok(DbResponse::Other(value))
}
Command::Query {
query: q,
variables,
} => {
let path = base_url.join(SQL_PATH)?;
let mut request = client.post(path).headers(headers.clone()).query(&vars).auth(auth);
let bindings: Vec<_> =
variables.iter().map(|(key, value)| (key, value.to_string())).collect();
request = request.query(&bindings).body(q.to_string());
let values = query(request).await?;
Ok(DbResponse::Query(values))
} }
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
Command::ExportFile { Command::ExportFile {
@ -691,55 +526,16 @@ async fn router(
let value = import(request, path).await?; let value = import(request, path).await?;
Ok(DbResponse::Other(value)) Ok(DbResponse::Other(value))
} }
Command::Health => {
let path = base_url.join("health")?;
let request = client.get(path);
let value = health(request).await?;
Ok(DbResponse::Other(value))
}
Command::Version => {
let path = base_url.join("version")?;
let request = client.get(path);
let value = version(request).await?;
Ok(DbResponse::Other(value))
}
Command::Set {
key,
value,
} => {
let path = base_url.join(SQL_PATH)?;
let value = value.to_string();
let request = client
.post(path)
.headers(headers.clone())
.auth(auth)
.query(&[(key.as_str(), value.as_str())])
.body(format!("RETURN ${key}"));
take(true, request).await?;
vars.insert(key, value);
Ok(DbResponse::Other(Value::None))
}
Command::Unset {
key,
} => {
vars.shift_remove(&key);
Ok(DbResponse::Other(Value::None))
}
Command::SubscribeLive { Command::SubscribeLive {
.. ..
} => Err(Error::LiveQueriesNotSupported.into()), } => Err(Error::LiveQueriesNotSupported.into()),
Command::Kill {
uuid, cmd => {
} => { let one = cmd.needs_one();
let path = base_url.join(SQL_PATH)?; let req = cmd
let request = client .into_router_request(None)
.post(path) .expect("all invalid variants should have been caught");
.headers(headers.clone()) process_req(req, base_url, client, headers, auth).await.map(|r| try_one(r, one))
.auth(auth)
.query(&[("id", uuid)])
.body("KILL type::string($id)");
let value = take(true, request).await?;
Ok(DbResponse::Other(value))
} }
} }
} }

View file

@ -8,9 +8,25 @@ pub mod http;
#[cfg_attr(docsrs, doc(cfg(feature = "protocol-ws")))] #[cfg_attr(docsrs, doc(cfg(feature = "protocol-ws")))]
pub mod ws; pub mod ws;
use crate::api;
use crate::api::conn::DbResponse;
use crate::api::err::Error;
use crate::api::method::query::QueryResult;
use crate::api::Result;
use crate::dbs::Notification;
use crate::dbs::QueryMethodResponse;
use crate::dbs::Status;
use crate::method::Stats;
use indexmap::IndexMap;
use revision::revisioned;
use revision::Revisioned;
use rust_decimal::prelude::ToPrimitive; use rust_decimal::prelude::ToPrimitive;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use std::io::Read;
use std::time::Duration; use std::time::Duration;
use surrealdb_core::sql::Value;
const NANOS_PER_SEC: i64 = 1_000_000_000; const NANOS_PER_SEC: i64 = 1_000_000_000;
const NANOS_PER_MILLI: i64 = 1_000_000; const NANOS_PER_MILLI: i64 = 1_000_000;
@ -66,3 +82,108 @@ mod tests {
} }
} }
} }
#[revisioned(revision = 1)]
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct Failure {
pub(crate) code: i64,
pub(crate) message: String,
}
#[revisioned(revision = 1)]
#[derive(Debug, Deserialize)]
pub(crate) enum Data {
Other(Value),
Query(Vec<QueryMethodResponse>),
Live(Notification),
}
type ServerResult = std::result::Result<Data, Failure>;
impl From<Failure> for Error {
fn from(failure: Failure) -> Self {
match failure.code {
-32600 => Self::InvalidRequest(failure.message),
-32602 => Self::InvalidParams(failure.message),
-32603 => Self::InternalError(failure.message),
-32700 => Self::ParseError(failure.message),
_ => Self::Query(failure.message),
}
}
}
impl From<Failure> for crate::Error {
fn from(value: Failure) -> Self {
let api_err: Error = value.into();
api_err.into()
}
}
impl DbResponse {
fn from(result: ServerResult) -> Result<Self> {
match result.map_err(Error::from)? {
Data::Other(value) => Ok(DbResponse::Other(value)),
Data::Query(responses) => {
let mut map =
IndexMap::<usize, (Stats, QueryResult)>::with_capacity(responses.len());
for (index, response) in responses.into_iter().enumerate() {
let stats = Stats {
execution_time: duration_from_str(&response.time),
};
match response.status {
Status::Ok => {
map.insert(index, (stats, Ok(response.result)));
}
Status::Err => {
map.insert(
index,
(stats, Err(Error::Query(response.result.as_raw_string()).into())),
);
}
_ => unreachable!(),
}
}
Ok(DbResponse::Query(api::Response {
results: map,
..api::Response::new()
}))
}
// Live notifications don't call this method
Data::Live(..) => unreachable!(),
}
}
}
#[revisioned(revision = 1)]
#[derive(Debug, Deserialize)]
pub(crate) struct Response {
id: Option<Value>,
pub(crate) result: ServerResult,
}
fn serialize<V>(value: &V, revisioned: bool) -> Result<Vec<u8>>
where
V: serde::Serialize + Revisioned,
{
if revisioned {
let mut buf = Vec::new();
value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?;
return Ok(buf);
}
crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into()))
}
fn deserialize<A, T>(bytes: &mut A, revisioned: bool) -> Result<T>
where
A: Read,
T: Revisioned + DeserializeOwned,
{
if revisioned {
return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into()));
}
let mut buf = Vec::new();
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
}

View file

@ -5,29 +5,16 @@ pub(crate) mod native;
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
pub(crate) mod wasm; pub(crate) mod wasm;
use crate::api;
use crate::api::conn::Command; use crate::api::conn::Command;
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::engine::remote::duration_from_str;
use crate::api::err::Error;
use crate::api::method::query::QueryResult;
use crate::api::Connect; use crate::api::Connect;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::dbs::Notification;
use crate::dbs::QueryMethodResponse;
use crate::dbs::Status;
use crate::method::Stats;
use crate::opt::IntoEndpoint; use crate::opt::IntoEndpoint;
use crate::sql::Value; use crate::sql::Value;
use channel::Sender; use channel::Sender;
use indexmap::IndexMap; use indexmap::IndexMap;
use revision::revisioned;
use revision::Revisioned;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Read;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::time::Duration; use std::time::Duration;
use surrealdb_core::dbs::Notification as CoreNotification; use surrealdb_core::dbs::Notification as CoreNotification;
@ -153,101 +140,3 @@ impl Surreal<Client> {
} }
} }
} }
#[revisioned(revision = 1)]
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct Failure {
pub(crate) code: i64,
pub(crate) message: String,
}
#[revisioned(revision = 1)]
#[derive(Debug, Deserialize)]
pub(crate) enum Data {
Other(Value),
Query(Vec<QueryMethodResponse>),
Live(Notification),
}
type ServerResult = std::result::Result<Data, Failure>;
impl From<Failure> for Error {
fn from(failure: Failure) -> Self {
match failure.code {
-32600 => Self::InvalidRequest(failure.message),
-32602 => Self::InvalidParams(failure.message),
-32603 => Self::InternalError(failure.message),
-32700 => Self::ParseError(failure.message),
_ => Self::Query(failure.message),
}
}
}
impl DbResponse {
fn from(result: ServerResult) -> Result<Self> {
match result.map_err(Error::from)? {
Data::Other(value) => Ok(DbResponse::Other(value)),
Data::Query(responses) => {
let mut map =
IndexMap::<usize, (Stats, QueryResult)>::with_capacity(responses.len());
for (index, response) in responses.into_iter().enumerate() {
let stats = Stats {
execution_time: duration_from_str(&response.time),
};
match response.status {
Status::Ok => {
map.insert(index, (stats, Ok(response.result)));
}
Status::Err => {
map.insert(
index,
(stats, Err(Error::Query(response.result.as_raw_string()).into())),
);
}
_ => unreachable!(),
}
}
Ok(DbResponse::Query(api::Response {
results: map,
..api::Response::new()
}))
}
// Live notifications don't call this method
Data::Live(..) => unreachable!(),
}
}
}
#[revisioned(revision = 1)]
#[derive(Debug, Deserialize)]
pub(crate) struct Response {
id: Option<Value>,
pub(crate) result: ServerResult,
}
fn serialize<V>(value: &V, revisioned: bool) -> Result<Vec<u8>>
where
V: serde::Serialize + Revisioned,
{
if revisioned {
let mut buf = Vec::new();
value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?;
return Ok(buf);
}
crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into()))
}
fn deserialize<A, T>(bytes: &mut A, revisioned: bool) -> Result<T>
where
A: Read,
T: Revisioned + DeserializeOwned,
{
if revisioned {
return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into()));
}
let mut buf = Vec::new();
bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?;
crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into()))
}

View file

@ -1,13 +1,12 @@
use super::{ use super::{HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH};
deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH,
};
use crate::api::conn::Route; use crate::api::conn::Route;
use crate::api::conn::Router; use crate::api::conn::Router;
use crate::api::conn::{Command, DbResponse}; use crate::api::conn::{Command, DbResponse};
use crate::api::conn::{Connection, RequestData}; use crate::api::conn::{Connection, RequestData};
use crate::api::engine::remote::ws::Client; use crate::api::engine::remote::ws::Client;
use crate::api::engine::remote::ws::Response;
use crate::api::engine::remote::ws::PING_INTERVAL; use crate::api::engine::remote::ws::PING_INTERVAL;
use crate::api::engine::remote::Response;
use crate::api::engine::remote::{deserialize, serialize};
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::Endpoint; use crate::api::opt::Endpoint;
@ -17,7 +16,7 @@ use crate::api::ExtraFeatures;
use crate::api::OnceLockExt; use crate::api::OnceLockExt;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::engine::remote::ws::Data; use crate::engine::remote::Data;
use crate::engine::IntervalStream; use crate::engine::IntervalStream;
use crate::opt::WaitFor; use crate::opt::WaitFor;
use crate::sql::Value; use crate::sql::Value;
@ -267,6 +266,9 @@ async fn router_handle_response(
state: &mut RouterState, state: &mut RouterState,
endpoint: &Endpoint, endpoint: &Endpoint,
) -> HandleResult { ) -> HandleResult {
if let Message::Binary(b) = &response {
error!(?b);
}
match Response::try_from(&response, endpoint.supports_revision) { match Response::try_from(&response, endpoint.supports_revision) {
Ok(option) => { Ok(option) => {
// We are only interested in responses that are not empty // We are only interested in responses that are not empty

View file

@ -1,13 +1,12 @@
use super::{ use super::{HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH};
deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH,
};
use crate::api::conn::DbResponse; use crate::api::conn::DbResponse;
use crate::api::conn::Route; use crate::api::conn::Route;
use crate::api::conn::Router; use crate::api::conn::Router;
use crate::api::conn::{Command, Connection, RequestData}; use crate::api::conn::{Command, Connection, RequestData};
use crate::api::engine::remote::ws::Client; use crate::api::engine::remote::ws::Client;
use crate::api::engine::remote::ws::Response;
use crate::api::engine::remote::ws::PING_INTERVAL; use crate::api::engine::remote::ws::PING_INTERVAL;
use crate::api::engine::remote::Response;
use crate::api::engine::remote::{deserialize, serialize};
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::method::BoxFuture; use crate::api::method::BoxFuture;
use crate::api::opt::Endpoint; use crate::api::opt::Endpoint;
@ -15,7 +14,7 @@ use crate::api::ExtraFeatures;
use crate::api::OnceLockExt; use crate::api::OnceLockExt;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::engine::remote::ws::Data; use crate::engine::remote::Data;
use crate::engine::IntervalStream; use crate::engine::IntervalStream;
use crate::opt::WaitFor; use crate::opt::WaitFor;
use crate::sql::Value; use crate::sql::Value;

View file

@ -83,8 +83,7 @@ impl HttpFormat for Format {
} }
fn res_http(&self, res: Response) -> Result<AxumResponse, RpcError> { fn res_http(&self, res: Response) -> Result<AxumResponse, RpcError> {
let val = res.into_value(); let res = self.res(res)?;
let res = self.res(val)?;
if matches!(self, Format::Json) { if matches!(self, Format::Json) {
// If this has significant performance overhead it could be replaced with unsafe { String::from_utf8_unchecked(res) } // If this has significant performance overhead it could be replaced with unsafe { String::from_utf8_unchecked(res) }
// This would be safe as in the case of JSON res come from a call to Into::<Vec<u8>> for String // This would be safe as in the case of JSON res come from a call to Into::<Vec<u8>> for String