1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
/*!
Middleware that wraps HTTP requests to tokio tracing spans for debugging and attaches a request id to all log messages.
*/

use super::request_id::RequestId;
use actix_http::{
    header::{HeaderName, HeaderValue},
    HttpMessage,
};
use actix_web::{
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    Error,
};
use futures_util::future::LocalBoxFuture;
use std::future::{ready, Ready};
use tracing::Instrument;
use uuid::Uuid;

pub struct RequestSpan;

impl<S, B> Transform<S, ServiceRequest> for RequestSpan
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = RequestSpanMiddleware<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(RequestSpanMiddleware { service }))
    }
}

/// Wraps HTTP requests into tokio tracing spans, helps with debugging.
pub struct RequestSpanMiddleware<S> {
    service: S,
}

impl<S, B> Service<ServiceRequest> for RequestSpanMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let request_id = Uuid::new_v4();
        let request_id_string = request_id.to_string();
        let request_span = tracing::info_span!("http_request", request_id = &request_id_string,);
        request_span.in_scope(|| {
            // Logging request start along with relevant information so that we can associate later log messages with the request context though the request id.
            info!(
                method =?req.method(),
                path=?req.path(),
                ip=?req.connection_info().realip_remote_addr(),
                "Start handling request.",
            );
        });

        // insert the generated id as an extension so that handlers can extract it if needed
        req.extensions_mut().insert(RequestId(request_id));

        let fut = self.service.call(req).instrument(request_span);

        Box::pin(async move {
            let mut res = fut.await?;
            let headers = res.headers_mut();
            headers.insert(
                HeaderName::from_static("request-id"),
                HeaderValue::from_str(&request_id_string)?,
            );
            Ok(res)
        })
    }
}