From c7457ffc5600ea30c122c874689cb457a6b763b0 Mon Sep 17 00:00:00 2001 From: Raphael Darley Date: Tue, 20 Aug 2024 01:14:32 -0700 Subject: [PATCH] use /rpc for the rust sdk http connection (#4482) --- lib/src/api/conn/cmd.rs | 35 +- lib/src/api/conn/mod.rs | 2 + lib/src/api/engine/remote/http/mod.rs | 454 +++++++------------------ lib/src/api/engine/remote/mod.rs | 121 +++++++ lib/src/api/engine/remote/ws/mod.rs | 111 ------ lib/src/api/engine/remote/ws/native.rs | 12 +- lib/src/api/engine/remote/ws/wasm.rs | 9 +- src/rpc/format.rs | 3 +- 8 files changed, 294 insertions(+), 453 deletions(-) diff --git a/lib/src/api/conn/cmd.rs b/lib/src/api/conn/cmd.rs index ba264bfc..8ae63d9d 100644 --- a/lib/src/api/conn/cmd.rs +++ b/lib/src/api/conn/cmd.rs @@ -102,7 +102,7 @@ pub(crate) enum Command { } impl Command { - #[cfg(feature = "protocol-ws")] + #[cfg(any(feature = "protocol-ws", feature = "protocol-http"))] pub(crate) fn into_router_request(self, id: Option) -> Option { let id = id.map(Value::from); let res = match self { @@ -321,6 +321,39 @@ impl Command { }; Some(res) } + + #[cfg(feature = "protocol-http")] + pub(crate) fn needs_one(&self) -> bool { + match self { + Command::Upsert { + what, + .. + } => what.is_thing(), + Command::Update { + what, + .. + } => what.is_thing(), + Command::Insert { + data, + .. + } => !data.is_array(), + Command::Patch { + what, + .. + } => what.is_thing(), + Command::Merge { + what, + .. + } => what.is_thing(), + Command::Select { + what, + } => what.is_thing(), + Command::Delete { + what, + } => what.is_thing(), + _ => false, + } + } } /// A struct which will be serialized as a map to behave like the previously used BTreeMap. diff --git a/lib/src/api/conn/mod.rs b/lib/src/api/conn/mod.rs index 13b65326..4ee675f5 100644 --- a/lib/src/api/conn/mod.rs +++ b/lib/src/api/conn/mod.rs @@ -17,6 +17,8 @@ use surrealdb_core::sql::{from_value, Value}; mod cmd; pub(crate) use cmd::Command; +#[cfg(feature = "protocol-http")] +pub(crate) use cmd::RouterRequest; #[derive(Debug)] #[allow(dead_code)] // used by the embedded and remote connections diff --git a/lib/src/api/engine/remote/http/mod.rs b/lib/src/api/engine/remote/http/mod.rs index c3595214..f88a94fd 100644 --- a/lib/src/api/engine/remote/http/mod.rs +++ b/lib/src/api/engine/remote/http/mod.rs @@ -8,47 +8,33 @@ pub(crate) mod wasm; use crate::api::conn::Command; use crate::api::conn::DbResponse; use crate::api::conn::RequestData; -use crate::api::engine::remote::duration_from_str; +use crate::api::conn::RouterRequest; +use crate::api::engine::remote::{deserialize, serialize}; use crate::api::err::Error; -use crate::api::method::query::QueryResult; use crate::api::Connect; -use crate::api::Response as QueryResponse; use crate::api::Result; use crate::api::Surreal; -use crate::dbs::Status; -use crate::engine::value_to_values; +use crate::engine::remote::Response; use crate::headers::AUTH_DB; use crate::headers::AUTH_NS; use crate::headers::DB; use crate::headers::NS; -use crate::method::Stats; use crate::opt::IntoEndpoint; use crate::sql::from_value; -use crate::sql::serde::deserialize; use crate::sql::Value; use futures::TryStreamExt; use indexmap::IndexMap; use reqwest::header::HeaderMap; use reqwest::header::HeaderValue; use reqwest::header::ACCEPT; +use reqwest::header::CONTENT_TYPE; use reqwest::RequestBuilder; use serde::Deserialize; use serde::Serialize; use std::marker::PhantomData; -use std::mem; -use surrealdb_core::sql::statements::CreateStatement; -use surrealdb_core::sql::statements::DeleteStatement; -use surrealdb_core::sql::statements::InsertStatement; -use surrealdb_core::sql::statements::SelectStatement; -use surrealdb_core::sql::statements::UpdateStatement; -use surrealdb_core::sql::statements::UpsertStatement; -use surrealdb_core::sql::Data; -use surrealdb_core::sql::Field; -use surrealdb_core::sql::Output; +use surrealdb_core::sql::Query; use url::Url; -#[cfg(not(target_arch = "wasm32"))] -use reqwest::header::CONTENT_TYPE; #[cfg(not(target_arch = "wasm32"))] use std::path::PathBuf; #[cfg(not(target_arch = "wasm32"))] @@ -60,9 +46,10 @@ use tokio_util::compat::FuturesAsyncReadCompatExt; #[cfg(target_arch = "wasm32")] use wasm_bindgen_futures::spawn_local; -const SQL_PATH: &str = "sql"; +// const SQL_PATH: &str = "sql"; +const RPC_PATH: &str = "rpc"; -/// The HTTP scheme used to connect to `http://` endpoints +// The HTTP scheme used to connect to `http://` endpoints #[derive(Debug)] pub struct Http; @@ -111,9 +98,11 @@ impl Surreal { pub(crate) fn default_headers() -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert(ACCEPT, HeaderValue::from_static("application/surrealdb")); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/surrealdb")); headers } +#[allow(dead_code)] #[derive(Debug)] enum Auth { Basic { @@ -157,8 +146,6 @@ impl Authenticate for RequestBuilder { } } -type HttpQueryResponse = (String, Status, Value); - #[derive(Debug, Serialize, Deserialize)] struct Credentials { user: String, @@ -175,70 +162,6 @@ struct AuthResponse { token: Option, } -async fn submit_auth(request: RequestBuilder) -> Result { - let response = request.send().await?.error_for_status()?; - let bytes = response.bytes().await?; - let response: AuthResponse = - deserialize(&bytes).map_err(|error| Error::ResponseFromBinary { - binary: bytes.to_vec(), - error, - })?; - Ok(response.token.into()) -} - -async fn query(request: RequestBuilder) -> Result { - let response = request.send().await?.error_for_status()?; - let bytes = response.bytes().await?; - let responses = deserialize::>(&bytes).map_err(|error| { - Error::ResponseFromBinary { - binary: bytes.to_vec(), - error, - } - })?; - let mut map = IndexMap::::with_capacity(responses.len()); - for (index, (execution_time, status, value)) in responses.into_iter().enumerate() { - let stats = Stats { - execution_time: duration_from_str(&execution_time), - }; - match status { - Status::Ok => { - map.insert(index, (stats, Ok(value))); - } - Status::Err => { - map.insert(index, (stats, Err(Error::Query(value.as_raw_string()).into()))); - } - _ => unreachable!(), - } - } - - Ok(QueryResponse { - results: map, - ..QueryResponse::new() - }) -} - -async fn take(one: bool, request: RequestBuilder) -> Result { - if let Some((_stats, result)) = query(request).await?.results.swap_remove(&0) { - let value = result?; - match one { - true => match value { - Value::Array(mut vec) => { - if let [value] = &mut vec.0[..] { - return Ok(mem::take(value)); - } - } - Value::None | Value::Null => {} - value => return Ok(value), - }, - false => return Ok(value), - } - } - match one { - true => Ok(Value::None), - false => Ok(Value::Array(Default::default())), - } -} - type BackupSender = channel::Sender>>; #[cfg(not(target_arch = "wasm32"))] @@ -328,17 +251,40 @@ async fn import(request: RequestBuilder, path: PathBuf) -> Result { Ok(Value::None) } -async fn version(request: RequestBuilder) -> Result { - let response = request.send().await?.error_for_status()?; - let version = response.text().await?; - Ok(version.into()) -} - pub(crate) async fn health(request: RequestBuilder) -> Result { request.send().await?.error_for_status()?; Ok(Value::None) } +async fn process_req( + req: RouterRequest, + base_url: &Url, + client: &reqwest::Client, + headers: &HeaderMap, + auth: &Option, +) -> Result { + let url = base_url.join(RPC_PATH).unwrap(); + let http_req = + client.post(url).headers(headers.clone()).auth(auth).body(serialize(&req, false)?); + let response = http_req.send().await?.error_for_status()?; + let bytes = response.bytes().await?; + + let response: Response = deserialize(&mut &bytes[..], false)?; + DbResponse::from(response.result) +} + +fn try_one(res: DbResponse, needed: bool) -> DbResponse { + if !needed { + return res; + } + match res { + DbResponse::Other(Value::Array(arr)) if arr.len() == 1 => { + DbResponse::Other(arr.into_iter().next().unwrap()) + } + r => r, + } +} + async fn router( RequestData { command, @@ -347,57 +293,74 @@ async fn router( base_url: &Url, client: &reqwest::Client, headers: &mut HeaderMap, - vars: &mut IndexMap, + vars: &mut IndexMap, auth: &mut Option, ) -> Result { match command { - Command::Use { - namespace, - database, + Command::Query { + query, + mut variables, } => { - let path = base_url.join(SQL_PATH)?; - let mut request = client.post(path).headers(headers.clone()); - let ns = match namespace { - Some(ns) => match HeaderValue::try_from(&ns) { + variables.extend(vars.clone()); + let req = Command::Query { + query, + variables, + } + .into_router_request(None) + .expect("query should be valid request"); + process_req(req, base_url, client, headers, auth).await + } + ref cmd @ Command::Use { + ref namespace, + ref database, + } => { + let req = cmd + .clone() + .into_router_request(None) + .expect("use should be a valid router request"); + // process request to check permissions + let out = process_req(req, base_url, client, headers, auth).await?; + match namespace { + Some(ns) => match HeaderValue::try_from(ns) { Ok(ns) => { - request = request.header(&NS, &ns); - Some(ns) + headers.insert(&NS, ns); } Err(_) => { - return Err(Error::InvalidNsName(ns).into()); + return Err(Error::InvalidNsName(ns.to_owned()).into()); } }, - None => None, + None => {} }; - let db = match database { - Some(db) => match HeaderValue::try_from(&db) { + + match database { + Some(db) => match HeaderValue::try_from(db) { Ok(db) => { - request = request.header(&DB, &db); - Some(db) + headers.insert(&DB, db); } Err(_) => { - return Err(Error::InvalidDbName(db).into()); + return Err(Error::InvalidDbName(db.to_owned()).into()); } }, - None => None, + None => {} }; - request = request.auth(auth).body("RETURN true"); - take(true, request).await?; - if let Some(ns) = ns { - headers.insert(&NS, ns); - } - if let Some(db) = db { - headers.insert(&DB, db); - } - Ok(DbResponse::Other(Value::None)) + + Ok(out) } Command::Signin { credentials, } => { - let path = base_url.join("signin")?; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string()); - let value = submit_auth(request).await?; + let req = Command::Signin { + credentials: credentials.clone(), + } + .into_router_request(None) + .expect("signin should be a valid router request"); + + let DbResponse::Other(value) = + process_req(req, base_url, client, headers, auth).await? + else { + unreachable!("didn't make query") + }; + if let Ok(Credentials { user, pass, @@ -416,24 +379,19 @@ async fn router( token: value.to_raw_string(), }); } - Ok(DbResponse::Other(value)) - } - Command::Signup { - credentials, - } => { - let path = base_url.join("signup")?; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(credentials.to_string()); - let value = submit_auth(request).await?; + Ok(DbResponse::Other(value)) } Command::Authenticate { token, } => { - let path = base_url.join(SQL_PATH)?; - let request = - client.post(path).headers(headers.clone()).bearer_auth(&token).body("RETURN true"); - take(true, request).await?; + let req = Command::Authenticate { + token: token.clone(), + } + .into_router_request(None) + .expect("authenticate should be a valid router request"); + process_req(req, base_url, client, headers, auth).await?; + *auth = Some(Auth::Bearer { token, }); @@ -443,156 +401,33 @@ async fn router( *auth = None; Ok(DbResponse::Other(Value::None)) } - Command::Create { - what, - data, + Command::Set { + key, + value, } => { - let path = base_url.join(SQL_PATH)?; - let statement = { - let mut stmt = CreateStatement::default(); - stmt.what = value_to_values(what); - stmt.data = data.map(Data::ContentExpression); - stmt.output = Some(Output::After); - stmt + let query: Query = surrealdb_core::sql::parse(&format!("RETURN ${key};"))?; + let req = Command::Query { + query, + variables: [(key.clone(), value)].into(), + } + .into_router_request(None) + .expect("query is valid request"); + let DbResponse::Query(mut res) = + process_req(req, base_url, client, headers, auth).await? + else { + unreachable!("made query request so response must be query") }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(true, request).await?; - Ok(DbResponse::Other(value)) + + let val: Value = res.take(0)?; + + vars.insert(key, val); + Ok(DbResponse::Other(Value::None)) } - Command::Upsert { - what, - data, + Command::Unset { + key, } => { - let path = base_url.join(SQL_PATH)?; - let one = what.is_thing(); - let statement = { - let mut stmt = UpsertStatement::default(); - stmt.what = value_to_values(what); - stmt.data = data.map(Data::ContentExpression); - stmt.output = Some(Output::After); - stmt - }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(one, request).await?; - Ok(DbResponse::Other(value)) - } - Command::Update { - what, - data, - } => { - let path = base_url.join(SQL_PATH)?; - let one = what.is_thing(); - let statement = { - let mut stmt = UpdateStatement::default(); - stmt.what = value_to_values(what); - stmt.data = data.map(Data::ContentExpression); - stmt.output = Some(Output::After); - stmt - }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(one, request).await?; - Ok(DbResponse::Other(value)) - } - Command::Insert { - what, - data, - } => { - let path = base_url.join(SQL_PATH)?; - let one = !data.is_array(); - let statement = { - let mut stmt = InsertStatement::default(); - stmt.into = what; - stmt.data = Data::SingleExpression(data); - stmt.output = Some(Output::After); - stmt - }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(one, request).await?; - Ok(DbResponse::Other(value)) - } - Command::Patch { - what, - data, - } => { - let path = base_url.join(SQL_PATH)?; - let one = what.is_thing(); - let statement = { - let mut stmt = UpdateStatement::default(); - stmt.what = value_to_values(what); - stmt.data = data.map(Data::PatchExpression); - stmt.output = Some(Output::After); - stmt - }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(one, request).await?; - Ok(DbResponse::Other(value)) - } - Command::Merge { - what, - data, - } => { - let path = base_url.join(SQL_PATH)?; - let one = what.is_thing(); - let statement = { - let mut stmt = UpdateStatement::default(); - stmt.what = value_to_values(what); - stmt.data = data.map(Data::MergeExpression); - stmt.output = Some(Output::After); - stmt - }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(one, request).await?; - Ok(DbResponse::Other(value)) - } - Command::Select { - what, - } => { - let path = base_url.join(SQL_PATH)?; - let one = what.is_thing(); - let statement = { - let mut stmt = SelectStatement::default(); - stmt.what = value_to_values(what); - stmt.expr.0 = vec![Field::All]; - stmt - }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(one, request).await?; - Ok(DbResponse::Other(value)) - } - Command::Delete { - what, - } => { - let one = what.is_thing(); - let path = base_url.join(SQL_PATH)?; - let (one, statement) = { - let mut stmt = DeleteStatement::default(); - stmt.what = value_to_values(what); - stmt.output = Some(Output::Before); - (one, stmt) - }; - let request = - client.post(path).headers(headers.clone()).auth(auth).body(statement.to_string()); - let value = take(one, request).await?; - Ok(DbResponse::Other(value)) - } - Command::Query { - query: q, - variables, - } => { - let path = base_url.join(SQL_PATH)?; - let mut request = client.post(path).headers(headers.clone()).query(&vars).auth(auth); - let bindings: Vec<_> = - variables.iter().map(|(key, value)| (key, value.to_string())).collect(); - request = request.query(&bindings).body(q.to_string()); - let values = query(request).await?; - Ok(DbResponse::Query(values)) + vars.shift_remove(&key); + Ok(DbResponse::Other(Value::None)) } #[cfg(target_arch = "wasm32")] Command::ExportFile { @@ -691,55 +526,16 @@ async fn router( let value = import(request, path).await?; Ok(DbResponse::Other(value)) } - Command::Health => { - let path = base_url.join("health")?; - let request = client.get(path); - let value = health(request).await?; - Ok(DbResponse::Other(value)) - } - Command::Version => { - let path = base_url.join("version")?; - let request = client.get(path); - let value = version(request).await?; - Ok(DbResponse::Other(value)) - } - Command::Set { - key, - value, - } => { - let path = base_url.join(SQL_PATH)?; - let value = value.to_string(); - let request = client - .post(path) - .headers(headers.clone()) - .auth(auth) - .query(&[(key.as_str(), value.as_str())]) - .body(format!("RETURN ${key}")); - take(true, request).await?; - vars.insert(key, value); - Ok(DbResponse::Other(Value::None)) - } - Command::Unset { - key, - } => { - vars.shift_remove(&key); - Ok(DbResponse::Other(Value::None)) - } Command::SubscribeLive { .. } => Err(Error::LiveQueriesNotSupported.into()), - Command::Kill { - uuid, - } => { - let path = base_url.join(SQL_PATH)?; - let request = client - .post(path) - .headers(headers.clone()) - .auth(auth) - .query(&[("id", uuid)]) - .body("KILL type::string($id)"); - let value = take(true, request).await?; - Ok(DbResponse::Other(value)) + + cmd => { + let one = cmd.needs_one(); + let req = cmd + .into_router_request(None) + .expect("all invalid variants should have been caught"); + process_req(req, base_url, client, headers, auth).await.map(|r| try_one(r, one)) } } } diff --git a/lib/src/api/engine/remote/mod.rs b/lib/src/api/engine/remote/mod.rs index d9c25cbc..e026366a 100644 --- a/lib/src/api/engine/remote/mod.rs +++ b/lib/src/api/engine/remote/mod.rs @@ -8,9 +8,25 @@ pub mod http; #[cfg_attr(docsrs, doc(cfg(feature = "protocol-ws")))] pub mod ws; +use crate::api; +use crate::api::conn::DbResponse; +use crate::api::err::Error; +use crate::api::method::query::QueryResult; +use crate::api::Result; +use crate::dbs::Notification; +use crate::dbs::QueryMethodResponse; +use crate::dbs::Status; +use crate::method::Stats; +use indexmap::IndexMap; +use revision::revisioned; +use revision::Revisioned; use rust_decimal::prelude::ToPrimitive; use rust_decimal::Decimal; +use serde::de::DeserializeOwned; +use serde::Deserialize; +use std::io::Read; use std::time::Duration; +use surrealdb_core::sql::Value; const NANOS_PER_SEC: i64 = 1_000_000_000; const NANOS_PER_MILLI: i64 = 1_000_000; @@ -66,3 +82,108 @@ mod tests { } } } + +#[revisioned(revision = 1)] +#[derive(Clone, Debug, Deserialize)] +pub(crate) struct Failure { + pub(crate) code: i64, + pub(crate) message: String, +} + +#[revisioned(revision = 1)] +#[derive(Debug, Deserialize)] +pub(crate) enum Data { + Other(Value), + Query(Vec), + Live(Notification), +} + +type ServerResult = std::result::Result; + +impl From for Error { + fn from(failure: Failure) -> Self { + match failure.code { + -32600 => Self::InvalidRequest(failure.message), + -32602 => Self::InvalidParams(failure.message), + -32603 => Self::InternalError(failure.message), + -32700 => Self::ParseError(failure.message), + _ => Self::Query(failure.message), + } + } +} + +impl From for crate::Error { + fn from(value: Failure) -> Self { + let api_err: Error = value.into(); + api_err.into() + } +} + +impl DbResponse { + fn from(result: ServerResult) -> Result { + match result.map_err(Error::from)? { + Data::Other(value) => Ok(DbResponse::Other(value)), + Data::Query(responses) => { + let mut map = + IndexMap::::with_capacity(responses.len()); + + for (index, response) in responses.into_iter().enumerate() { + let stats = Stats { + execution_time: duration_from_str(&response.time), + }; + match response.status { + Status::Ok => { + map.insert(index, (stats, Ok(response.result))); + } + Status::Err => { + map.insert( + index, + (stats, Err(Error::Query(response.result.as_raw_string()).into())), + ); + } + _ => unreachable!(), + } + } + + Ok(DbResponse::Query(api::Response { + results: map, + ..api::Response::new() + })) + } + // Live notifications don't call this method + Data::Live(..) => unreachable!(), + } + } +} + +#[revisioned(revision = 1)] +#[derive(Debug, Deserialize)] +pub(crate) struct Response { + id: Option, + pub(crate) result: ServerResult, +} + +fn serialize(value: &V, revisioned: bool) -> Result> +where + V: serde::Serialize + Revisioned, +{ + if revisioned { + let mut buf = Vec::new(); + value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?; + return Ok(buf); + } + crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into())) +} + +fn deserialize(bytes: &mut A, revisioned: bool) -> Result +where + A: Read, + T: Revisioned + DeserializeOwned, +{ + if revisioned { + return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into())); + } + let mut buf = Vec::new(); + bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?; + crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into())) +} diff --git a/lib/src/api/engine/remote/ws/mod.rs b/lib/src/api/engine/remote/ws/mod.rs index e8890139..487306bd 100644 --- a/lib/src/api/engine/remote/ws/mod.rs +++ b/lib/src/api/engine/remote/ws/mod.rs @@ -5,29 +5,16 @@ pub(crate) mod native; #[cfg(target_arch = "wasm32")] pub(crate) mod wasm; -use crate::api; use crate::api::conn::Command; use crate::api::conn::DbResponse; -use crate::api::engine::remote::duration_from_str; -use crate::api::err::Error; -use crate::api::method::query::QueryResult; use crate::api::Connect; use crate::api::Result; use crate::api::Surreal; -use crate::dbs::Notification; -use crate::dbs::QueryMethodResponse; -use crate::dbs::Status; -use crate::method::Stats; use crate::opt::IntoEndpoint; use crate::sql::Value; use channel::Sender; use indexmap::IndexMap; -use revision::revisioned; -use revision::Revisioned; -use serde::de::DeserializeOwned; -use serde::Deserialize; use std::collections::HashMap; -use std::io::Read; use std::marker::PhantomData; use std::time::Duration; use surrealdb_core::dbs::Notification as CoreNotification; @@ -153,101 +140,3 @@ impl Surreal { } } } - -#[revisioned(revision = 1)] -#[derive(Clone, Debug, Deserialize)] -pub(crate) struct Failure { - pub(crate) code: i64, - pub(crate) message: String, -} - -#[revisioned(revision = 1)] -#[derive(Debug, Deserialize)] -pub(crate) enum Data { - Other(Value), - Query(Vec), - Live(Notification), -} - -type ServerResult = std::result::Result; - -impl From for Error { - fn from(failure: Failure) -> Self { - match failure.code { - -32600 => Self::InvalidRequest(failure.message), - -32602 => Self::InvalidParams(failure.message), - -32603 => Self::InternalError(failure.message), - -32700 => Self::ParseError(failure.message), - _ => Self::Query(failure.message), - } - } -} - -impl DbResponse { - fn from(result: ServerResult) -> Result { - match result.map_err(Error::from)? { - Data::Other(value) => Ok(DbResponse::Other(value)), - Data::Query(responses) => { - let mut map = - IndexMap::::with_capacity(responses.len()); - - for (index, response) in responses.into_iter().enumerate() { - let stats = Stats { - execution_time: duration_from_str(&response.time), - }; - match response.status { - Status::Ok => { - map.insert(index, (stats, Ok(response.result))); - } - Status::Err => { - map.insert( - index, - (stats, Err(Error::Query(response.result.as_raw_string()).into())), - ); - } - _ => unreachable!(), - } - } - - Ok(DbResponse::Query(api::Response { - results: map, - ..api::Response::new() - })) - } - // Live notifications don't call this method - Data::Live(..) => unreachable!(), - } - } -} - -#[revisioned(revision = 1)] -#[derive(Debug, Deserialize)] -pub(crate) struct Response { - id: Option, - pub(crate) result: ServerResult, -} - -fn serialize(value: &V, revisioned: bool) -> Result> -where - V: serde::Serialize + Revisioned, -{ - if revisioned { - let mut buf = Vec::new(); - value.serialize_revisioned(&mut buf).map_err(|error| crate::Error::Db(error.into()))?; - return Ok(buf); - } - crate::sql::serde::serialize(value).map_err(|error| crate::Error::Db(error.into())) -} - -fn deserialize(bytes: &mut A, revisioned: bool) -> Result -where - A: Read, - T: Revisioned + DeserializeOwned, -{ - if revisioned { - return T::deserialize_revisioned(bytes).map_err(|x| crate::Error::Db(x.into())); - } - let mut buf = Vec::new(); - bytes.read_to_end(&mut buf).map_err(crate::err::Error::Io)?; - crate::sql::serde::deserialize(&buf).map_err(|error| crate::Error::Db(error.into())) -} diff --git a/lib/src/api/engine/remote/ws/native.rs b/lib/src/api/engine/remote/ws/native.rs index ad0daff4..bc99fa82 100644 --- a/lib/src/api/engine/remote/ws/native.rs +++ b/lib/src/api/engine/remote/ws/native.rs @@ -1,13 +1,12 @@ -use super::{ - deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH, -}; +use super::{HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH}; use crate::api::conn::Route; use crate::api::conn::Router; use crate::api::conn::{Command, DbResponse}; use crate::api::conn::{Connection, RequestData}; use crate::api::engine::remote::ws::Client; -use crate::api::engine::remote::ws::Response; use crate::api::engine::remote::ws::PING_INTERVAL; +use crate::api::engine::remote::Response; +use crate::api::engine::remote::{deserialize, serialize}; use crate::api::err::Error; use crate::api::method::BoxFuture; use crate::api::opt::Endpoint; @@ -17,7 +16,7 @@ use crate::api::ExtraFeatures; use crate::api::OnceLockExt; use crate::api::Result; use crate::api::Surreal; -use crate::engine::remote::ws::Data; +use crate::engine::remote::Data; use crate::engine::IntervalStream; use crate::opt::WaitFor; use crate::sql::Value; @@ -267,6 +266,9 @@ async fn router_handle_response( state: &mut RouterState, endpoint: &Endpoint, ) -> HandleResult { + if let Message::Binary(b) = &response { + error!(?b); + } match Response::try_from(&response, endpoint.supports_revision) { Ok(option) => { // We are only interested in responses that are not empty diff --git a/lib/src/api/engine/remote/ws/wasm.rs b/lib/src/api/engine/remote/ws/wasm.rs index 198fa874..3ad23ba8 100644 --- a/lib/src/api/engine/remote/ws/wasm.rs +++ b/lib/src/api/engine/remote/ws/wasm.rs @@ -1,13 +1,12 @@ -use super::{ - deserialize, serialize, HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH, -}; +use super::{HandleResult, PendingRequest, ReplayMethod, RequestEffect, PATH}; use crate::api::conn::DbResponse; use crate::api::conn::Route; use crate::api::conn::Router; use crate::api::conn::{Command, Connection, RequestData}; use crate::api::engine::remote::ws::Client; -use crate::api::engine::remote::ws::Response; use crate::api::engine::remote::ws::PING_INTERVAL; +use crate::api::engine::remote::Response; +use crate::api::engine::remote::{deserialize, serialize}; use crate::api::err::Error; use crate::api::method::BoxFuture; use crate::api::opt::Endpoint; @@ -15,7 +14,7 @@ use crate::api::ExtraFeatures; use crate::api::OnceLockExt; use crate::api::Result; use crate::api::Surreal; -use crate::engine::remote::ws::Data; +use crate::engine::remote::Data; use crate::engine::IntervalStream; use crate::opt::WaitFor; use crate::sql::Value; diff --git a/src/rpc/format.rs b/src/rpc/format.rs index 075acec4..c5150a91 100644 --- a/src/rpc/format.rs +++ b/src/rpc/format.rs @@ -83,8 +83,7 @@ impl HttpFormat for Format { } fn res_http(&self, res: Response) -> Result { - let val = res.into_value(); - let res = self.res(val)?; + let res = self.res(res)?; if matches!(self, Format::Json) { // If this has significant performance overhead it could be replaced with unsafe { String::from_utf8_unchecked(res) } // This would be safe as in the case of JSON res come from a call to Into::> for String