actix_extensible_rate_limit/middleware/
mod.rs

1pub mod builder;
2#[cfg(test)]
3mod tests;
4
5use crate::backend::Backend;
6use actix_web::body::EitherBody;
7use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
8use actix_web::http::header::HeaderMap;
9use actix_web::http::StatusCode;
10use actix_web::HttpResponse;
11use builder::RateLimiterBuilder;
12use futures::future::{ok, LocalBoxFuture, Ready};
13use std::cell::RefCell;
14use std::{future::Future, rc::Rc};
15
16type AllowedTransformation<BO> = dyn Fn(&mut HeaderMap, Option<&BO>, bool);
17type DeniedResponse<BO> = dyn Fn(&BO) -> HttpResponse;
18type RollbackCondition = dyn Fn(StatusCode) -> bool;
19
20/// Rate limit middleware.
21pub struct RateLimiter<BA, BO, F> {
22    backend: BA,
23    input_fn: Rc<F>,
24    fail_open: bool,
25    allowed_mutation: Option<Rc<AllowedTransformation<BO>>>,
26    denied_response: Rc<DeniedResponse<BO>>,
27    rollback_condition: Option<Rc<RollbackCondition>>,
28}
29
30impl<BA, BI, BO, F, O> Clone for RateLimiter<BA, BO, F>
31where
32    BA: Backend<BI> + 'static,
33    BI: 'static,
34    F: Fn(&ServiceRequest) -> O + 'static,
35    O: Future<Output = Result<BI, actix_web::Error>>,
36{
37    fn clone(&self) -> Self {
38        Self {
39            backend: self.backend.clone(),
40            input_fn: self.input_fn.clone(),
41            fail_open: self.fail_open,
42            allowed_mutation: self.allowed_mutation.clone(),
43            denied_response: self.denied_response.clone(),
44            rollback_condition: self.rollback_condition.clone(),
45        }
46    }
47}
48
49impl<BA, BI, BO, F, O> RateLimiter<BA, BO, F>
50where
51    BA: Backend<BI, Output = BO> + 'static,
52    BI: 'static,
53    F: Fn(&ServiceRequest) -> O + 'static,
54    O: Future<Output = Result<BI, actix_web::Error>>,
55{
56    /// # Arguments
57    ///
58    /// * `backend`: A rate limiting algorithm and store implementation.
59    /// * `input_fn`: A future that produces input to the backend based on the incoming request.
60    pub fn builder(backend: BA, input_fn: F) -> RateLimiterBuilder<BA, BO, F> {
61        RateLimiterBuilder::new(backend, input_fn)
62    }
63}
64
65impl<S, B, BA, BI, BO, BE, F, O> Transform<S, ServiceRequest> for RateLimiter<BA, BO, F>
66where
67    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
68    S::Future: 'static,
69    B: 'static,
70    BA: Backend<BI, Output = BO, Error = BE> + 'static,
71    BI: 'static,
72    BO: 'static,
73    BE: Into<actix_web::Error> + std::fmt::Display + 'static,
74    F: Fn(&ServiceRequest) -> O + 'static,
75    O: Future<Output = Result<BI, actix_web::Error>>,
76{
77    type Response = ServiceResponse<EitherBody<B>>;
78    type Error = actix_web::Error;
79    type Transform = RateLimiterMiddleware<S, BA, BO, F>;
80    type InitError = ();
81    type Future = Ready<Result<Self::Transform, Self::InitError>>;
82
83    fn new_transform(&self, service: S) -> Self::Future {
84        ok(RateLimiterMiddleware {
85            service: Rc::new(RefCell::new(service)),
86            backend: self.backend.clone(),
87            input_fn: Rc::clone(&self.input_fn),
88            fail_open: self.fail_open,
89            allowed_transformation: self.allowed_mutation.clone(),
90            denied_response: self.denied_response.clone(),
91            rollback_condition: self.rollback_condition.clone(),
92        })
93    }
94}
95
96pub struct RateLimiterMiddleware<S, BE, BO, F> {
97    service: Rc<RefCell<S>>,
98    backend: BE,
99    input_fn: Rc<F>,
100    fail_open: bool,
101    allowed_transformation: Option<Rc<AllowedTransformation<BO>>>,
102    denied_response: Rc<DeniedResponse<BO>>,
103    rollback_condition: Option<Rc<RollbackCondition>>,
104}
105
106impl<S, B, BA, BI, BO, BE, F, O> Service<ServiceRequest> for RateLimiterMiddleware<S, BA, BO, F>
107where
108    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
109    S::Future: 'static,
110    B: 'static,
111    BA: Backend<BI, Output = BO, Error = BE> + 'static,
112    BI: 'static,
113    BO: 'static,
114    BE: Into<actix_web::Error> + std::fmt::Display + 'static,
115    F: Fn(&ServiceRequest) -> O + 'static,
116    O: Future<Output = Result<BI, actix_web::Error>>,
117{
118    type Response = ServiceResponse<EitherBody<B>>;
119    type Error = actix_web::Error;
120    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
121
122    forward_ready!(service);
123
124    fn call(&self, req: ServiceRequest) -> Self::Future {
125        let service = self.service.clone();
126        let backend = self.backend.clone();
127        let input_fn = self.input_fn.clone();
128        let fail_open = self.fail_open;
129        let allowed_transformation = self.allowed_transformation.clone();
130        let denied_response = self.denied_response.clone();
131        let rollback_condition = self.rollback_condition.clone();
132
133        Box::pin(async move {
134            let input = match input_fn(&req).await {
135                Ok(input) => input,
136                Err(e) => {
137                    log::error!("Rate limiter input function failed: {e}");
138                    return Ok(req.into_response(e.error_response()).map_into_right_body());
139                }
140            };
141
142            let (output, rollback) = match backend.request(input).await {
143                // Able to successfully query rate limiter backend
144                Ok((decision, output, rollback)) => {
145                    if decision.is_denied() {
146                        let response: HttpResponse = denied_response(&output);
147                        return Ok(req.into_response(response).map_into_right_body());
148                    }
149                    (Some(output), Some(rollback))
150                }
151                // Unable to query rate limiter backend
152                Err(e) => {
153                    if fail_open {
154                        log::warn!("Rate limiter failed: {}, allowing the request anyway", e);
155                        (None, None)
156                    } else {
157                        log::error!("Rate limiter failed: {}", e);
158                        return Ok(req
159                            .into_response(e.into().error_response())
160                            .map_into_right_body());
161                    }
162                }
163            };
164
165            let mut service_response = service.call(req).await?;
166
167            let mut rolled_back = false;
168            if let Some(token) = rollback {
169                if let Some(rollback_condition) = rollback_condition {
170                    let status = service_response.status();
171                    if rollback_condition(status) {
172                        if let Err(e) = backend.rollback(token).await {
173                            log::error!("Unable to rollback rate-limit count for response: {:?}, error: {e}", status);
174                        } else {
175                            rolled_back = true;
176                        };
177                    }
178                }
179            }
180
181            if let Some(transformation) = allowed_transformation {
182                transformation(service_response.headers_mut(), output.as_ref(), rolled_back);
183            }
184
185            Ok(service_response.map_into_left_body())
186        })
187    }
188}