Store connection ID for web socket connections (#4536)
Co-authored-by: Gerard Guillemas Martos <gerard.guillemas@surrealdb.com>
This commit is contained in:
parent
208d6a897e
commit
883d4f48d9
5 changed files with 234 additions and 16 deletions
|
@ -13,6 +13,7 @@ use surrealdb::{
|
|||
iam::verify::{basic, token},
|
||||
};
|
||||
use tower_http::auth::AsyncAuthorizeRequest;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::err::Error;
|
||||
|
||||
|
@ -81,8 +82,21 @@ async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
|
|||
None
|
||||
};
|
||||
|
||||
// Extract the session id from the headers.
|
||||
let id = parse_typed_header::<SurrealId>(parts.extract::<TypedHeader<SurrealId>>().await)?;
|
||||
// Extract the session id from the headers or generate a new one.
|
||||
let id = match parse_typed_header::<SurrealId>(parts.extract::<TypedHeader<SurrealId>>().await)?
|
||||
{
|
||||
Some(id) => {
|
||||
// Attempt to parse the request id as a UUID.
|
||||
match Uuid::try_parse(&id) {
|
||||
// The specified request id was a valid UUID.
|
||||
Ok(id) => Some(id.to_string()),
|
||||
// The specified request id was not a valid UUID.
|
||||
Err(_) => return Err(Error::Request),
|
||||
}
|
||||
}
|
||||
// No request id was specified, create a new id.
|
||||
None => Some(Uuid::new_v4().to_string()),
|
||||
};
|
||||
|
||||
// Extract the namespace from the headers.
|
||||
let ns = parse_typed_header::<SurrealNamespace>(
|
||||
|
|
|
@ -2,6 +2,7 @@ use std::collections::BTreeMap;
|
|||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::headers::SurrealId;
|
||||
use crate::cnf;
|
||||
use crate::err::Error;
|
||||
use crate::rpc::connection::Connection;
|
||||
|
@ -17,8 +18,10 @@ use axum::{
|
|||
response::IntoResponse,
|
||||
Extension, Router,
|
||||
};
|
||||
use axum_extra::headers::Header;
|
||||
use axum_extra::TypedHeader;
|
||||
use bytes::Bytes;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use surrealdb::dbs::Session;
|
||||
use surrealdb::kvs::Datastore;
|
||||
|
@ -42,26 +45,49 @@ async fn get_handler(
|
|||
ws: WebSocketUpgrade,
|
||||
Extension(state): Extension<AppState>,
|
||||
Extension(id): Extension<RequestId>,
|
||||
Extension(sess): Extension<Session>,
|
||||
Extension(mut sess): Extension<Session>,
|
||||
State(rpc_state): State<Arc<RpcState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<impl IntoResponse, impl IntoResponse> {
|
||||
// Check if there is a request id header specified
|
||||
let id = match id.header_value().is_empty() {
|
||||
// No request id was specified so create a new id
|
||||
true => Uuid::new_v4(),
|
||||
// A request id was specified to try to parse it
|
||||
false => match id.header_value().to_str() {
|
||||
// Attempt to parse the request id as a UUID
|
||||
Ok(id) => match Uuid::try_parse(id) {
|
||||
// The specified request id was a valid UUID
|
||||
Ok(id) => id,
|
||||
// The specified request id was not a UUID
|
||||
// Check if there is a connection id header specified
|
||||
let id = match headers.get(SurrealId::name()) {
|
||||
// Use the specific SurrealDB id header when provided
|
||||
Some(id) => {
|
||||
match id.to_str() {
|
||||
Ok(id) => {
|
||||
// Attempt to parse the request id as a UUID
|
||||
match Uuid::try_parse(id) {
|
||||
// The specified request id was a valid UUID
|
||||
Ok(id) => id,
|
||||
// The specified request id was not a UUID
|
||||
Err(_) => return Err(Error::Request),
|
||||
}
|
||||
}
|
||||
Err(_) => return Err(Error::Request),
|
||||
}
|
||||
}
|
||||
// Otherwise, use the generic WebSocket connection id header
|
||||
None => match id.header_value().is_empty() {
|
||||
// No request id was specified so create a new id
|
||||
true => Uuid::new_v4(),
|
||||
// A request id was specified to try to parse it
|
||||
false => match id.header_value().to_str() {
|
||||
// Attempt to parse the request id as a UUID
|
||||
Ok(id) => match Uuid::try_parse(id) {
|
||||
// The specified request id was a valid UUID
|
||||
Ok(id) => id,
|
||||
// The specified request id was not a UUID
|
||||
Err(_) => return Err(Error::Request),
|
||||
},
|
||||
// The request id contained invalid characters
|
||||
Err(_) => return Err(Error::Request),
|
||||
},
|
||||
// The request id contained invalid characters
|
||||
Err(_) => return Err(Error::Request),
|
||||
},
|
||||
};
|
||||
|
||||
// Store connection id in session
|
||||
sess.id = Some(id.to_string());
|
||||
|
||||
// Check if a connection with this id already exists
|
||||
if rpc_state.web_sockets.read().await.contains_key(&id) {
|
||||
return Err(Error::Request);
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::format::Format;
|
|||
use crate::common::error::TestError;
|
||||
use futures::channel::oneshot::channel;
|
||||
use futures_util::{SinkExt, TryStreamExt};
|
||||
use http::header::{HeaderMap, HeaderValue};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
@ -104,6 +105,39 @@ impl Socket {
|
|||
})
|
||||
}
|
||||
|
||||
/// Connect to a WebSocket server using a specific format with custom headers
|
||||
pub async fn connect_with_headers(
|
||||
addr: &str,
|
||||
format: Option<Format>,
|
||||
msg_format: Format,
|
||||
headers: HeaderMap<HeaderValue>,
|
||||
) -> Result<Self> {
|
||||
let url = format!("ws://{addr}/rpc");
|
||||
let mut req = url.into_client_request().unwrap();
|
||||
if let Some(v) = format.map(|v| v.to_string()) {
|
||||
req.headers_mut().insert("Sec-WebSocket-Protocol", v.parse().unwrap());
|
||||
}
|
||||
for (key, value) in headers.into_iter() {
|
||||
if let Some(key) = key {
|
||||
req.headers_mut().append(key, value);
|
||||
}
|
||||
}
|
||||
let (stream, _) = connect_async(req).await?;
|
||||
let (send, recv) = mpsc::channel(16);
|
||||
let (send_other, recv_other) = mpsc::channel(16);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = Self::ws_task(recv, stream, send_other, msg_format).await {
|
||||
eprintln!("error in websocket task: {e}")
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
sender: send,
|
||||
other_messages: recv_other,
|
||||
})
|
||||
}
|
||||
|
||||
fn to_msg(format: Format, message: &serde_json::Value) -> Result<Message> {
|
||||
match format {
|
||||
Format::Json => Ok(Message::Text(serde_json::to_string(message)?)),
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::common::{self, Format, Socket, DB, NS, PASS, USER};
|
||||
use http::header::{HeaderMap, HeaderValue};
|
||||
use assert_fs::TempDir;
|
||||
use serde_json::json;
|
||||
use std::future::Future;
|
||||
|
@ -1783,3 +1784,91 @@ async fn temporary_directory() {
|
|||
// Cleanup
|
||||
temp_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn session_id_defined() {
|
||||
// Setup database server
|
||||
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
|
||||
// We specify a request identifier via a specific SurrealDB header
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("surreal-id", HeaderValue::from_static("00000000-0000-0000-0000-000000000000"));
|
||||
// Connect to WebSocket
|
||||
let mut socket = Socket::connect_with_headers(&addr, SERVER, FORMAT, headers).await.unwrap();
|
||||
// Authenticate the connection
|
||||
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
|
||||
// Specify a namespace and database
|
||||
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
|
||||
|
||||
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
|
||||
let expected = json!(["00000000-0000-0000-0000-000000000000"]);
|
||||
assert_eq!(res.remove(0)["result"], expected);
|
||||
|
||||
// Test passed
|
||||
server.finish().unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn session_id_defined_generic() {
|
||||
// Setup database server
|
||||
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
|
||||
// We specify a request identifier via a generic header
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-request-id", HeaderValue::from_static("00000000-0000-0000-0000-000000000000"));
|
||||
// Connect to WebSocket
|
||||
let mut socket = Socket::connect_with_headers(&addr, SERVER, FORMAT, headers).await.unwrap();
|
||||
// Authenticate the connection
|
||||
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
|
||||
// Specify a namespace and database
|
||||
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
|
||||
|
||||
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
|
||||
let expected = json!(["00000000-0000-0000-0000-000000000000"]);
|
||||
assert_eq!(res.remove(0)["result"], expected);
|
||||
|
||||
// Test passed
|
||||
server.finish().unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn session_id_defined_both() {
|
||||
// Setup database server
|
||||
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
|
||||
// We specify a request identifier via both headers
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("surreal-id", HeaderValue::from_static("00000000-0000-0000-0000-000000000000"));
|
||||
headers.insert("x-request-id", HeaderValue::from_static("aaaaaaaa-aaaa-0000-0000-000000000000"));
|
||||
// Connect to WebSocket
|
||||
let mut socket = Socket::connect_with_headers(&addr, SERVER, FORMAT, headers).await.unwrap();
|
||||
// Authenticate the connection
|
||||
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
|
||||
// Specify a namespace and database
|
||||
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
|
||||
|
||||
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
|
||||
// The specific header should be used
|
||||
let expected = json!(["00000000-0000-0000-0000-000000000000"]);
|
||||
assert_eq!(res.remove(0)["result"], expected);
|
||||
|
||||
// Test passed
|
||||
server.finish().unwrap();
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn session_id_undefined() {
|
||||
// Setup database server
|
||||
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
|
||||
// Connect to WebSocket
|
||||
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
|
||||
// Authenticate the connection
|
||||
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
|
||||
// Specify a namespace and database
|
||||
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
|
||||
|
||||
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
|
||||
// The field is expected to be present even when not provided in the header
|
||||
let unexpected = json!([null]);
|
||||
assert_ne!(res.remove(0)["result"], unexpected);
|
||||
|
||||
// Test passed
|
||||
server.finish().unwrap();
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ mod common;
|
|||
mod http_integration {
|
||||
use std::time::Duration;
|
||||
|
||||
use http::header::HeaderValue;
|
||||
use http::{header, Method};
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
|
@ -295,6 +296,60 @@ mod http_integration {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn session_id() {
|
||||
let (addr, _server) = common::start_server_with_guests().await.unwrap();
|
||||
let url = &format!("http://{addr}/sql");
|
||||
|
||||
// Request without header, gives a randomly generated session identifier
|
||||
{
|
||||
// Prepare HTTP client without header
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let ns = Ulid::new().to_string();
|
||||
let db = Ulid::new().to_string();
|
||||
headers.insert("surreal-ns", ns.parse().unwrap());
|
||||
headers.insert("surreal-db", db.parse().unwrap());
|
||||
headers.insert(header::ACCEPT, "application/json".parse().unwrap());
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_millis(10))
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let res = client.post(url).body("SELECT VALUE id FROM $session").send().await.unwrap();
|
||||
assert_eq!(res.status(), 200);
|
||||
let body = res.text().await.unwrap();
|
||||
// Any randomly generated UUIDv4 will be in the format:
|
||||
// xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx
|
||||
assert!(body.contains("-4"), "body: {body}");
|
||||
}
|
||||
|
||||
// Request with header, gives a the session identifier specified in the header
|
||||
{
|
||||
// Prepare HTTP client with header
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let ns = Ulid::new().to_string();
|
||||
let db = Ulid::new().to_string();
|
||||
headers.insert("surreal-ns", ns.parse().unwrap());
|
||||
headers.insert("surreal-db", db.parse().unwrap());
|
||||
headers.insert(
|
||||
"surreal-id",
|
||||
HeaderValue::from_static("00000000-0000-0000-0000-000000000000"),
|
||||
);
|
||||
headers.insert(header::ACCEPT, "application/json".parse().unwrap());
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_millis(10))
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let res = client.post(url).body("SELECT VALUE id FROM $session").send().await.unwrap();
|
||||
assert_eq!(res.status(), 200);
|
||||
let body = res.text().await.unwrap();
|
||||
assert!(body.contains("00000000-0000-0000-0000-000000000000"), "body: {body}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test(tokio::test)]
|
||||
async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
|
||||
|
|
Loading…
Reference in a new issue