surrealpatch/src/net/tracer.rs
2023-07-19 14:35:56 +00:00

235 lines
7.3 KiB
Rust

use std::{fmt, time::Duration};
use axum::{
body::{boxed, Body, BoxBody},
extract::MatchedPath,
headers::{
authorization::{Basic, Bearer},
Authorization, Origin,
},
Extension, RequestPartsExt, TypedHeader,
};
use futures_util::future::BoxFuture;
use http::{header, request::Parts, StatusCode};
use hyper::{Request, Response};
use surrealdb::{dbs::Session, iam::verify::token};
use tower_http::{
auth::AsyncAuthorizeRequest,
request_id::RequestId,
trace::{MakeSpan, OnFailure, OnRequest, OnResponse},
};
use tracing::{field, Level, Span};
use crate::{dbs::DB, err::Error, iam::verify::basic};
use super::{client_ip::ExtractClientIP, AppState};
///
/// SurrealAuth is a tower layer that implements the AsyncAuthorizeRequest trait.
/// It is used to authorize requests to SurrealDB using Basic or Token authentication.
///
/// It has to be used in conjunction with the tower_http::auth::RequireAuthorizationLayer layer:
///
/// ```rust
/// use tower_http::auth::RequireAuthorizationLayer;
/// use surrealdb::net::SurrealAuth;
/// use axum::Router;
///
/// let auth = RequireAuthorizationLayer::new(SurrealAuth);
///
/// let app = Router::new()
/// .route("/version", get(|| async { "0.1.0" }))
/// .layer(auth);
/// ```
#[derive(Clone, Copy)]
pub(super) struct SurrealAuth;
impl<B> AsyncAuthorizeRequest<B> for SurrealAuth
where
B: Send + Sync + 'static,
{
type RequestBody = B;
type ResponseBody = BoxBody;
type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, request: Request<B>) -> Self::Future {
Box::pin(async {
let (mut parts, body) = request.into_parts();
match check_auth(&mut parts).await {
Ok(sess) => {
parts.extensions.insert(sess);
Ok(Request::from_parts(parts, body))
}
Err(err) => {
let unauthorized_response = Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(boxed(Body::from(err.to_string())))
.unwrap();
Err(unauthorized_response)
}
}
})
}
}
async fn check_auth(parts: &mut Parts) -> Result<Session, Error> {
let kvs = DB.get().unwrap();
let or = if let Ok(or) = parts.extract::<TypedHeader<Origin>>().await {
if !or.is_null() {
Some(or.to_string())
} else {
None
}
} else {
None
};
let id = parts.headers.get("id").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let ns = parts.headers.get("ns").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let db = parts.headers.get("db").map(|v| v.to_str().unwrap().to_string()); // TODO: Use a TypedHeader
let Extension(state) = parts.extract::<Extension<AppState>>().await.map_err(|err| {
tracing::error!("Error extracting the app state: {:?}", err);
Error::InvalidAuth
})?;
let ExtractClientIP(ip) =
parts.extract_with_state(&state).await.unwrap_or(ExtractClientIP(None));
// Create session
#[rustfmt::skip]
let mut session = Session { ip, or, id, ns, db, ..Default::default() };
// If Basic authentication data was supplied
if let Ok(au) = parts.extract::<TypedHeader<Authorization<Basic>>>().await {
basic(&mut session, au.username(), au.password()).await
} else if let Ok(au) = parts.extract::<TypedHeader<Authorization<Bearer>>>().await {
token(kvs, &mut session, au.token().into()).await.map_err(|e| e.into())
} else {
Err(Error::InvalidAuth)
}?;
Ok(session)
}
///
/// HttpTraceLayerHooks implements custom hooks for the tower_http::trace::TraceLayer layer.
///
/// Example:
///
/// ```rust
/// use tower_http::trace::TraceLayer;
/// use surrealdb::net::HttpTraceLayerHooks;
/// use axum::Router;
///
/// let trace = TraceLayer::new_for_http().on_request(HttpTraceLayerHooks::default());
///
/// let app = Router::new()
/// .route("/version", get(|| async { "0.1.0" }))
/// .layer(trace);
/// ```
#[derive(Default, Clone)]
pub(crate) struct HttpTraceLayerHooks;
impl<B> MakeSpan<B> for HttpTraceLayerHooks {
fn make_span(&mut self, req: &Request<B>) -> Span {
// The fields follow the OTEL semantic conventions: https://github.com/open-telemetry/opentelemetry-specification/blob/v1.23.0/specification/trace/semantic_conventions/http.md
let span = tracing::info_span!(
target: "surreal::http",
"request",
otel.name = field::Empty,
otel.kind = "server",
http.route = field::Empty,
http.request.method = req.method().as_str(),
http.request.body.size = field::Empty,
url.path = req.uri().path(),
url.query = field::Empty,
url.scheme = field::Empty,
http.request.id = field::Empty,
user_agent.original = field::Empty,
network.protocol.name = "http",
network.protocol.version = format!("{:?}", req.version()).strip_prefix("HTTP/"),
client.address = field::Empty,
client.port = field::Empty,
client.socket.address = field::Empty,
server.address = field::Empty,
server.port = field::Empty,
// set on the response hook
http.latency.ms = field::Empty,
http.response.status_code = field::Empty,
http.response.body.size = field::Empty,
// set on the failure hook
error = field::Empty,
error_message = field::Empty,
);
req.uri().query().map(|v| span.record("url.query", v));
req.uri().scheme().map(|v| span.record("url.scheme", v.as_str()));
req.uri().host().map(|v| span.record("server.address", v));
req.uri().port_u16().map(|v| span.record("server.port", v));
req.headers()
.get(header::CONTENT_LENGTH)
.map(|v| v.to_str().map(|v| span.record("http.request.body.size", v)));
req.headers()
.get(header::USER_AGENT)
.map(|v| v.to_str().map(|v| span.record("user_agent.original", v)));
if let Some(path) = req.extensions().get::<MatchedPath>() {
span.record("otel.name", format!("{} {}", req.method(), path.as_str()));
span.record("http.route", path.as_str());
} else {
span.record("otel.name", format!("{} -", req.method()));
};
if let Some(req_id) = req.extensions().get::<RequestId>() {
match req_id.header_value().to_str() {
Err(err) => tracing::error!(error = %err, "failed to parse request id"),
Ok(request_id) => {
span.record("http.request.id", request_id);
}
}
}
if let Some(client_ip) = req.extensions().get::<ExtractClientIP>() {
if let Some(ref client_ip) = client_ip.0 {
span.record("client.address", client_ip);
}
}
span
}
}
impl<B> OnRequest<B> for HttpTraceLayerHooks {
fn on_request(&mut self, _: &Request<B>, _: &Span) {
tracing::event!(Level::INFO, "started processing request");
}
}
impl<B> OnResponse<B> for HttpTraceLayerHooks {
fn on_response(self, response: &Response<B>, latency: Duration, span: &Span) {
if let Some(size) = response.headers().get(header::CONTENT_LENGTH) {
span.record("http.response.body.size", size.to_str().unwrap());
}
span.record("http.response.status_code", response.status().as_u16());
// Server errors are handled by the OnFailure hook
if !response.status().is_server_error() {
span.record("http.latency.ms", latency.as_millis());
tracing::event!(Level::INFO, "finished processing request");
}
}
}
impl<FailureClass> OnFailure<FailureClass> for HttpTraceLayerHooks
where
FailureClass: fmt::Display,
{
fn on_failure(&mut self, error: FailureClass, latency: Duration, span: &Span) {
span.record("error_message", &error.to_string());
span.record("http.latency.ms", latency.as_millis());
tracing::event!(Level::ERROR, error = error.to_string(), "response failed");
}
}