feature-2924: Add an option to suppress server identification headers (#3770)

Co-authored-by: Gerard Guillemas Martos <gguillemas@users.noreply.github.com>
Co-authored-by: Micha de Vries <micha@devrie.sh>
This commit is contained in:
Brian Yarr 2024-06-10 16:39:38 +01:00 committed by GitHub
parent 2913917284
commit a11f1bc82f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 54 additions and 9 deletions

View file

@ -18,4 +18,5 @@ pub struct Config {
pub key: Option<PathBuf>,
pub tick_interval: Duration,
pub engine: Option<EngineOptions>,
pub no_identification_headers: bool,
}

View file

@ -96,7 +96,10 @@ pub struct StartCommandArguments {
#[arg(env = "SURREAL_BIND", short = 'b', long = "bind")]
#[arg(default_value = "127.0.0.1:8000")]
listen_addresses: Vec<SocketAddr>,
#[arg(help = "Whether to suppress the server name and version headers")]
#[arg(env = "SURREAL_NO_IDENTIFICATION_HEADERS", long)]
#[arg(default_value_t = false)]
no_identification_headers: bool,
//
// Database options
//
@ -142,6 +145,7 @@ pub async fn init(
log,
tick_interval,
no_banner,
no_identification_headers,
..
}: StartCommandArguments,
) -> Result<(), Error> {
@ -171,6 +175,7 @@ pub async fn init(
user,
pass,
tick_interval,
no_identification_headers,
crt: web.as_ref().and_then(|x| x.web_crt.clone()),
key: web.as_ref().and_then(|x| x.web_key.clone()),
engine: None,

View file

@ -27,13 +27,25 @@ pub use db::SurrealDatabase;
pub use id::SurrealId;
pub use ns::SurrealNamespace;
pub fn add_version_header() -> SetResponseHeaderLayer<HeaderValue> {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
SetResponseHeaderLayer::if_not_present(VERSION.to_owned(), HeaderValue::try_from(val).unwrap())
pub fn add_version_header(enabled: bool) -> SetResponseHeaderLayer<Option<HeaderValue>> {
let header_value = if enabled {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
Some(HeaderValue::try_from(val).unwrap())
} else {
None
};
SetResponseHeaderLayer::if_not_present(VERSION.to_owned(), header_value)
}
pub fn add_server_header() -> SetResponseHeaderLayer<HeaderValue> {
SetResponseHeaderLayer::if_not_present(SERVER, HeaderValue::try_from(SERVER_NAME).unwrap())
pub fn add_server_header(enabled: bool) -> SetResponseHeaderLayer<Option<HeaderValue>> {
let header_value = if enabled {
Some(HeaderValue::try_from(SERVER_NAME).unwrap())
} else {
None
};
SetResponseHeaderLayer::if_not_present(SERVER, header_value)
}
// Parse a TypedHeader, returning None if the header is missing and an error if the header is invalid.

View file

@ -137,8 +137,8 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
.layer(HttpMetricsLayer)
.layer(SetSensitiveResponseHeadersLayer::from_shared(headers))
.layer(AsyncRequireAuthorizationLayer::new(auth::SurrealAuth))
.layer(headers::add_server_header())
.layer(headers::add_version_header())
.layer(headers::add_server_header(!opt.no_identification_headers))
.layer(headers::add_version_header(!opt.no_identification_headers))
.layer(
CorsLayer::new()
.allow_methods([

View file

@ -11,7 +11,7 @@ mod http_integration {
use test_log::test;
use ulid::Ulid;
use super::common::{self, PASS, USER};
use super::common::{self, StartServerArguments, PASS, USER};
#[test(tokio::test)]
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
@ -352,6 +352,33 @@ mod http_integration {
Ok(())
}
#[test(tokio::test)]
async fn no_server_id_headers() -> Result<(), Box<dyn std::error::Error>> {
// default server has the id headers
{
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let url = &format!("http://{addr}/health");
let res = Client::default().get(url).send().await?;
assert!(res.headers().contains_key("server"));
assert!(res.headers().contains_key("surreal-version"));
}
// turn on the no-identification-headers option to suppress headers
{
let mut start_server_arguments = StartServerArguments::default();
start_server_arguments.args.push_str(" --no-identification-headers");
let (addr, _server) = common::start_server(start_server_arguments).await.unwrap();
let url = &format!("http://{addr}/health");
let res = Client::default().get(url).send().await?;
assert!(!res.headers().contains_key("server"));
assert!(!res.headers().contains_key("surreal-version"));
}
Ok(())
}
#[test(tokio::test)]
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_defaults().await.unwrap();