headless_lms_server/domain/
request_span_middleware.rs

1/*!
2Middleware that wraps HTTP requests to tokio tracing spans for debugging and attaches a request id to all log messages.
3*/
4
5use super::request_id::RequestId;
6use actix_http::{
7    HttpMessage,
8    header::{HeaderName, HeaderValue},
9};
10use actix_web::{
11    Error,
12    dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
13};
14use futures_util::future::LocalBoxFuture;
15use std::future::{Ready, ready};
16use tracing::Instrument;
17use uuid::Uuid;
18
19pub struct RequestSpan;
20
21impl<S, B> Transform<S, ServiceRequest> for RequestSpan
22where
23    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
24    S::Future: 'static,
25    B: 'static,
26{
27    type Response = ServiceResponse<B>;
28    type Error = Error;
29    type InitError = ();
30    type Transform = RequestSpanMiddleware<S>;
31    type Future = Ready<Result<Self::Transform, Self::InitError>>;
32
33    fn new_transform(&self, service: S) -> Self::Future {
34        ready(Ok(RequestSpanMiddleware { service }))
35    }
36}
37
38/// Wraps HTTP requests into tokio tracing spans, helps with debugging.
39pub struct RequestSpanMiddleware<S> {
40    service: S,
41}
42
43impl<S, B> Service<ServiceRequest> for RequestSpanMiddleware<S>
44where
45    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
46    S::Future: 'static,
47    B: 'static,
48{
49    type Response = ServiceResponse<B>;
50    type Error = Error;
51    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
52
53    forward_ready!(service);
54
55    fn call(&self, req: ServiceRequest) -> Self::Future {
56        let request_id = Uuid::new_v4();
57        let request_id_string = request_id.to_string();
58        let request_span = tracing::info_span!("http_request", request_id = &request_id_string,);
59        request_span.in_scope(|| {
60            // Logging request start along with relevant information so that we can associate later log messages with the request context though the request id.
61            info!(
62                method =?req.method(),
63                path=?req.path(),
64                ip=?req.connection_info().realip_remote_addr(),
65                "Start handling request.",
66            );
67        });
68
69        // insert the generated id as an extension so that handlers can extract it if needed
70        req.extensions_mut().insert(RequestId(request_id));
71
72        let fut = self.service.call(req).instrument(request_span);
73
74        Box::pin(async move {
75            let mut res = fut.await?;
76            let headers = res.headers_mut();
77            headers.insert(
78                HeaderName::from_static("request-id"),
79                HeaderValue::from_str(&request_id_string)?,
80            );
81            Ok(res)
82        })
83    }
84}