Switch the HTTP engine to a binary protocol (#1751)

This commit is contained in:
Rushmore Mushambi 2023-03-31 19:15:15 +02:00 committed by GitHub
parent 725b03729b
commit 6e6621565d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 43 deletions

View file

@ -6,6 +6,8 @@ use crate::api::conn::Router;
#[allow(unused_imports)] // used by the DB engines #[allow(unused_imports)] // used by the DB engines
use crate::api::engine; use crate::api::engine;
use crate::api::engine::any::Any; use crate::api::engine::any::Any;
#[cfg(feature = "protocol-http")]
use crate::api::engine::remote::http;
use crate::api::err::Error; use crate::api::err::Error;
use crate::api::opt::from_value; use crate::api::opt::from_value;
use crate::api::opt::Endpoint; use crate::api::opt::Endpoint;
@ -21,12 +23,6 @@ use crate::api::Surreal;
use flume::Receiver; use flume::Receiver;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
#[cfg(feature = "protocol-http")] #[cfg(feature = "protocol-http")]
use reqwest::header::HeaderMap;
#[cfg(feature = "protocol-http")]
use reqwest::header::HeaderValue;
#[cfg(feature = "protocol-http")]
use reqwest::header::ACCEPT;
#[cfg(feature = "protocol-http")]
use reqwest::ClientBuilder; use reqwest::ClientBuilder;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::collections::HashSet; use std::collections::HashSet;
@ -105,8 +101,7 @@ impl Connection for Any {
"http" | "https" => { "http" | "https" => {
features.insert(ExtraFeatures::Auth); features.insert(ExtraFeatures::Auth);
features.insert(ExtraFeatures::Backup); features.insert(ExtraFeatures::Backup);
let mut headers = HeaderMap::new(); let headers = http::default_headers();
headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
#[allow(unused_mut)] #[allow(unused_mut)]
let mut builder = ClientBuilder::new().default_headers(headers); let mut builder = ClientBuilder::new().default_headers(headers);
#[cfg(any(feature = "native-tls", feature = "rustls"))] #[cfg(any(feature = "native-tls", feature = "rustls"))]

View file

@ -23,8 +23,6 @@ use crate::api::Response as QueryResponse;
use crate::api::Result; use crate::api::Result;
use crate::api::Surreal; use crate::api::Surreal;
use crate::opt::IntoEndpoint; use crate::opt::IntoEndpoint;
use crate::sql;
use crate::sql::to_value;
use crate::sql::Array; use crate::sql::Array;
use crate::sql::Strand; use crate::sql::Strand;
use crate::sql::Value; use crate::sql::Value;
@ -33,7 +31,6 @@ use futures::TryStreamExt;
use indexmap::IndexMap; use indexmap::IndexMap;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue; use reqwest::header::HeaderValue;
#[cfg(not(target_arch = "wasm32"))]
use reqwest::header::ACCEPT; use reqwest::header::ACCEPT;
#[cfg(not(target_arch = "wasm32"))] #[cfg(not(target_arch = "wasm32"))]
use reqwest::header::CONTENT_TYPE; use reqwest::header::CONTENT_TYPE;
@ -105,6 +102,12 @@ impl Surreal<Client> {
} }
} }
pub(crate) fn default_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static("application/cork"));
headers
}
#[derive(Debug)] #[derive(Debug)]
enum Auth { enum Auth {
Basic { Basic {
@ -136,10 +139,14 @@ impl Authenticate for RequestBuilder {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct HttpQueryResponse { struct HttpQueryResponse {
time: String,
status: Status, status: Status,
result: Option<serde_json::Value>, #[serde(default)]
detail: Option<String>, result: Value,
#[serde(default)]
detail: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -149,42 +156,46 @@ struct Root {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct AuthResponse { struct AuthResponse {
code: u16,
details: String,
#[serde(default)]
token: Option<String>, token: Option<String>,
} }
async fn submit_auth(request: RequestBuilder) -> Result<Value> { async fn submit_auth(request: RequestBuilder) -> Result<Value> {
let response = request.send().await?.error_for_status()?; let response = request.send().await?.error_for_status()?;
let text = response.text().await?; let bytes = response.bytes().await?;
info!(target: LOG, "Response {text}"); let response: AuthResponse =
let value = sql::json(&text)?; msgpack::from_slice(&bytes).map_err(|error| Error::ResponseFromBinary {
let response: AuthResponse = from_value(value)?; binary: bytes.to_vec(),
error,
})?;
Ok(response.token.into()) Ok(response.token.into())
} }
async fn query(request: RequestBuilder) -> Result<QueryResponse> { async fn query(request: RequestBuilder) -> Result<QueryResponse> {
info!(target: LOG, "{request:?}"); info!(target: LOG, "{request:?}");
let response = request.send().await?.error_for_status()?; let response = request.send().await?.error_for_status()?;
let text = response.text().await?; let bytes = response.bytes().await?;
info!(target: LOG, "Response {text}"); let responses: Vec<HttpQueryResponse> =
let value = sql::json(&text)?; msgpack::from_slice(&bytes).map_err(|error| Error::ResponseFromBinary {
let responses: Vec<HttpQueryResponse> = from_value(value)?; binary: bytes.to_vec(),
error,
})?;
let mut map = IndexMap::<usize, QueryResult>::with_capacity(responses.len()); let mut map = IndexMap::<usize, QueryResult>::with_capacity(responses.len());
for (index, response) in responses.into_iter().enumerate() { for (index, response) in responses.into_iter().enumerate() {
match response.status { match response.status {
Status::Ok => { Status::Ok => {
if let Some(value) = response.result { match response.result {
match to_value(value)? { Value::Array(Array(array)) => map.insert(index, Ok(array)),
Value::Array(Array(array)) => map.insert(index, Ok(array)), Value::None | Value::Null => map.insert(index, Ok(vec![])),
Value::None | Value::Null => map.insert(index, Ok(vec![])), value => map.insert(index, Ok(vec![value])),
value => map.insert(index, Ok(vec![value])), };
};
}
} }
Status::Err => { Status::Err => {
if let Some(error) = response.detail { map.insert(index, Err(Error::Query(response.detail).into()));
map.insert(index, Err(Error::Query(error).into()));
}
} }
} }
} }
@ -281,10 +292,15 @@ async fn import(request: RequestBuilder, path: PathBuf) -> Result<Value> {
} }
.into()); .into());
} }
// ideally we should pass `file` directly into the body request
// but currently that results in .header(ACCEPT, "application/octet-stream")
// "HTTP status client error (405 Method Not Allowed) for url" // ideally we should pass `file` directly into the body
request.body(contents).send().await?.error_for_status()?; // but currently that results in
// "HTTP status client error (405 Method Not Allowed) for url"
.body(contents)
.send()
.await?
.error_for_status()?;
Ok(Value::None) Ok(Value::None)
} }

View file

@ -19,8 +19,6 @@ use futures::StreamExt;
use indexmap::IndexMap; use indexmap::IndexMap;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest::header::ACCEPT;
use reqwest::ClientBuilder; use reqwest::ClientBuilder;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::collections::HashSet; use std::collections::HashSet;
@ -45,8 +43,7 @@ impl Connection for Client {
capacity: usize, capacity: usize,
) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Surreal<Self>>> + Send + Sync + 'static>> {
Box::pin(async move { Box::pin(async move {
let mut headers = HeaderMap::new(); let headers = super::default_headers();
headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
#[allow(unused_mut)] #[allow(unused_mut)]
let mut builder = ClientBuilder::new().default_headers(headers); let mut builder = ClientBuilder::new().default_headers(headers);

View file

@ -18,8 +18,6 @@ use futures::StreamExt;
use indexmap::IndexMap; use indexmap::IndexMap;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use reqwest::header::HeaderMap; use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest::header::ACCEPT;
use reqwest::ClientBuilder; use reqwest::ClientBuilder;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::collections::HashSet; use std::collections::HashSet;
@ -122,8 +120,7 @@ impl Connection for Client {
} }
async fn client(base_url: &Url) -> Result<reqwest::Client> { async fn client(base_url: &Url) -> Result<reqwest::Client> {
let mut headers = HeaderMap::new(); let headers = super::default_headers();
headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
let builder = ClientBuilder::new().default_headers(headers); let builder = ClientBuilder::new().default_headers(headers);
let client = builder.build()?; let client = builder.build()?;
let health = base_url.join(Method::Health.as_str())?; let health = base_url.join(Method::Health.as_str())?;