Store connection ID for web socket connections (#4536)

Co-authored-by: Gerard Guillemas Martos <gerard.guillemas@surrealdb.com>
This commit is contained in:
Dmitrii Blaginin 2024-08-22 17:25:08 +01:00 committed by GitHub
parent 208d6a897e
commit 883d4f48d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 234 additions and 16 deletions

View file

@ -13,6 +13,7 @@ use surrealdb::{
iam::verify::{basic, token},
};
use tower_http::auth::AsyncAuthorizeRequest;
use uuid::Uuid;
use crate::err::Error;
@ -81,8 +82,21 @@ async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
None
};
// Extract the session id from the headers.
let id = parse_typed_header::<SurrealId>(parts.extract::<TypedHeader<SurrealId>>().await)?;
// Extract the session id from the headers or generate a new one.
let id = match parse_typed_header::<SurrealId>(parts.extract::<TypedHeader<SurrealId>>().await)?
{
Some(id) => {
// Attempt to parse the request id as a UUID.
match Uuid::try_parse(&id) {
// The specified request id was a valid UUID.
Ok(id) => Some(id.to_string()),
// The specified request id was not a valid UUID.
Err(_) => return Err(Error::Request),
}
}
// No request id was specified, create a new id.
None => Some(Uuid::new_v4().to_string()),
};
// Extract the namespace from the headers.
let ns = parse_typed_header::<SurrealNamespace>(

View file

@ -2,6 +2,7 @@ use std::collections::BTreeMap;
use std::ops::Deref;
use std::sync::Arc;
use super::headers::SurrealId;
use crate::cnf;
use crate::err::Error;
use crate::rpc::connection::Connection;
@ -17,8 +18,10 @@ use axum::{
response::IntoResponse,
Extension, Router,
};
use axum_extra::headers::Header;
use axum_extra::TypedHeader;
use bytes::Bytes;
use http::HeaderMap;
use http::HeaderValue;
use surrealdb::dbs::Session;
use surrealdb::kvs::Datastore;
@ -42,26 +45,49 @@ async fn get_handler(
ws: WebSocketUpgrade,
Extension(state): Extension<AppState>,
Extension(id): Extension<RequestId>,
Extension(sess): Extension<Session>,
Extension(mut sess): Extension<Session>,
State(rpc_state): State<Arc<RpcState>>,
headers: HeaderMap,
) -> Result<impl IntoResponse, impl IntoResponse> {
// Check if there is a request id header specified
let id = match id.header_value().is_empty() {
// No request id was specified so create a new id
true => Uuid::new_v4(),
// A request id was specified to try to parse it
false => match id.header_value().to_str() {
// Attempt to parse the request id as a UUID
Ok(id) => match Uuid::try_parse(id) {
// The specified request id was a valid UUID
Ok(id) => id,
// The specified request id was not a UUID
// Check if there is a connection id header specified
let id = match headers.get(SurrealId::name()) {
// Use the specific SurrealDB id header when provided
Some(id) => {
match id.to_str() {
Ok(id) => {
// Attempt to parse the request id as a UUID
match Uuid::try_parse(id) {
// The specified request id was a valid UUID
Ok(id) => id,
// The specified request id was not a UUID
Err(_) => return Err(Error::Request),
}
}
Err(_) => return Err(Error::Request),
}
}
// Otherwise, use the generic WebSocket connection id header
None => match id.header_value().is_empty() {
// No request id was specified so create a new id
true => Uuid::new_v4(),
// A request id was specified to try to parse it
false => match id.header_value().to_str() {
// Attempt to parse the request id as a UUID
Ok(id) => match Uuid::try_parse(id) {
// The specified request id was a valid UUID
Ok(id) => id,
// The specified request id was not a UUID
Err(_) => return Err(Error::Request),
},
// The request id contained invalid characters
Err(_) => return Err(Error::Request),
},
// The request id contained invalid characters
Err(_) => return Err(Error::Request),
},
};
// Store connection id in session
sess.id = Some(id.to_string());
// Check if a connection with this id already exists
if rpc_state.web_sockets.read().await.contains_key(&id) {
return Err(Error::Request);

View file

@ -2,6 +2,7 @@ use super::format::Format;
use crate::common::error::TestError;
use futures::channel::oneshot::channel;
use futures_util::{SinkExt, TryStreamExt};
use http::header::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
@ -104,6 +105,39 @@ impl Socket {
})
}
/// Connect to a WebSocket server using a specific format with custom headers
pub async fn connect_with_headers(
addr: &str,
format: Option<Format>,
msg_format: Format,
headers: HeaderMap<HeaderValue>,
) -> Result<Self> {
let url = format!("ws://{addr}/rpc");
let mut req = url.into_client_request().unwrap();
if let Some(v) = format.map(|v| v.to_string()) {
req.headers_mut().insert("Sec-WebSocket-Protocol", v.parse().unwrap());
}
for (key, value) in headers.into_iter() {
if let Some(key) = key {
req.headers_mut().append(key, value);
}
}
let (stream, _) = connect_async(req).await?;
let (send, recv) = mpsc::channel(16);
let (send_other, recv_other) = mpsc::channel(16);
tokio::spawn(async move {
if let Err(e) = Self::ws_task(recv, stream, send_other, msg_format).await {
eprintln!("error in websocket task: {e}")
}
});
Ok(Self {
sender: send,
other_messages: recv_other,
})
}
fn to_msg(format: Format, message: &serde_json::Value) -> Result<Message> {
match format {
Format::Json => Ok(Message::Text(serde_json::to_string(message)?)),

View file

@ -1,4 +1,5 @@
use super::common::{self, Format, Socket, DB, NS, PASS, USER};
use http::header::{HeaderMap, HeaderValue};
use assert_fs::TempDir;
use serde_json::json;
use std::future::Future;
@ -1783,3 +1784,91 @@ async fn temporary_directory() {
// Cleanup
temp_dir.close().unwrap();
}
#[test(tokio::test)]
async fn session_id_defined() {
// Setup database server
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// We specify a request identifier via a specific SurrealDB header
let mut headers = HeaderMap::new();
headers.insert("surreal-id", HeaderValue::from_static("00000000-0000-0000-0000-000000000000"));
// Connect to WebSocket
let mut socket = Socket::connect_with_headers(&addr, SERVER, FORMAT, headers).await.unwrap();
// Authenticate the connection
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
// Specify a namespace and database
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
let expected = json!(["00000000-0000-0000-0000-000000000000"]);
assert_eq!(res.remove(0)["result"], expected);
// Test passed
server.finish().unwrap();
}
#[test(tokio::test)]
async fn session_id_defined_generic() {
// Setup database server
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// We specify a request identifier via a generic header
let mut headers = HeaderMap::new();
headers.insert("x-request-id", HeaderValue::from_static("00000000-0000-0000-0000-000000000000"));
// Connect to WebSocket
let mut socket = Socket::connect_with_headers(&addr, SERVER, FORMAT, headers).await.unwrap();
// Authenticate the connection
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
// Specify a namespace and database
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
let expected = json!(["00000000-0000-0000-0000-000000000000"]);
assert_eq!(res.remove(0)["result"], expected);
// Test passed
server.finish().unwrap();
}
#[test(tokio::test)]
async fn session_id_defined_both() {
// Setup database server
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// We specify a request identifier via both headers
let mut headers = HeaderMap::new();
headers.insert("surreal-id", HeaderValue::from_static("00000000-0000-0000-0000-000000000000"));
headers.insert("x-request-id", HeaderValue::from_static("aaaaaaaa-aaaa-0000-0000-000000000000"));
// Connect to WebSocket
let mut socket = Socket::connect_with_headers(&addr, SERVER, FORMAT, headers).await.unwrap();
// Authenticate the connection
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
// Specify a namespace and database
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
// The specific header should be used
let expected = json!(["00000000-0000-0000-0000-000000000000"]);
assert_eq!(res.remove(0)["result"], expected);
// Test passed
server.finish().unwrap();
}
#[test(tokio::test)]
async fn session_id_undefined() {
// Setup database server
let (addr, mut server) = common::start_server_with_defaults().await.unwrap();
// Connect to WebSocket
let mut socket = Socket::connect(&addr, SERVER, FORMAT).await.unwrap();
// Authenticate the connection
socket.send_message_signin(USER, PASS, None, None, None).await.unwrap();
// Specify a namespace and database
socket.send_message_use(Some(NS), Some(DB)).await.unwrap();
let mut res = socket.send_message_query("SELECT VALUE id FROM $session").await.unwrap();
// The field is expected to be present even when not provided in the header
let unexpected = json!([null]);
assert_ne!(res.remove(0)["result"], unexpected);
// Test passed
server.finish().unwrap();
}

View file

@ -4,6 +4,7 @@ mod common;
mod http_integration {
use std::time::Duration;
use http::header::HeaderValue;
use http::{header, Method};
use reqwest::Client;
use serde_json::json;
@ -295,6 +296,60 @@ mod http_integration {
Ok(())
}
#[test(tokio::test)]
async fn session_id() {
let (addr, _server) = common::start_server_with_guests().await.unwrap();
let url = &format!("http://{addr}/sql");
// Request without header, gives a randomly generated session identifier
{
// Prepare HTTP client without header
let mut headers = reqwest::header::HeaderMap::new();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
headers.insert("surreal-ns", ns.parse().unwrap());
headers.insert("surreal-db", db.parse().unwrap());
headers.insert(header::ACCEPT, "application/json".parse().unwrap());
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(10))
.default_headers(headers)
.build()
.unwrap();
let res = client.post(url).body("SELECT VALUE id FROM $session").send().await.unwrap();
assert_eq!(res.status(), 200);
let body = res.text().await.unwrap();
// Any randomly generated UUIDv4 will be in the format:
// xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx
assert!(body.contains("-4"), "body: {body}");
}
// Request with header, gives a the session identifier specified in the header
{
// Prepare HTTP client with header
let mut headers = reqwest::header::HeaderMap::new();
let ns = Ulid::new().to_string();
let db = Ulid::new().to_string();
headers.insert("surreal-ns", ns.parse().unwrap());
headers.insert("surreal-db", db.parse().unwrap());
headers.insert(
"surreal-id",
HeaderValue::from_static("00000000-0000-0000-0000-000000000000"),
);
headers.insert(header::ACCEPT, "application/json".parse().unwrap());
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_millis(10))
.default_headers(headers)
.build()
.unwrap();
let res = client.post(url).body("SELECT VALUE id FROM $session").send().await.unwrap();
assert_eq!(res.status(), 200);
let body = res.text().await.unwrap();
assert!(body.contains("00000000-0000-0000-0000-000000000000"), "body: {body}");
}
}
#[test(tokio::test)]
async fn export_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_defaults().await.unwrap();