actix_extensible_rate_limit/middleware/
mod.rs1pub mod builder;
2#[cfg(test)]
3mod tests;
4
5use crate::backend::Backend;
6use actix_web::body::EitherBody;
7use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
8use actix_web::http::header::HeaderMap;
9use actix_web::http::StatusCode;
10use actix_web::HttpResponse;
11use builder::RateLimiterBuilder;
12use futures::future::{ok, LocalBoxFuture, Ready};
13use std::cell::RefCell;
14use std::{future::Future, rc::Rc};
15
16type AllowedTransformation<BO> = dyn Fn(&mut HeaderMap, Option<&BO>, bool);
17type DeniedResponse<BO> = dyn Fn(&BO) -> HttpResponse;
18type RollbackCondition = dyn Fn(StatusCode) -> bool;
19
20pub struct RateLimiter<BA, BO, F> {
22 backend: BA,
23 input_fn: Rc<F>,
24 fail_open: bool,
25 allowed_mutation: Option<Rc<AllowedTransformation<BO>>>,
26 denied_response: Rc<DeniedResponse<BO>>,
27 rollback_condition: Option<Rc<RollbackCondition>>,
28}
29
30impl<BA, BI, BO, F, O> Clone for RateLimiter<BA, BO, F>
31where
32 BA: Backend<BI> + 'static,
33 BI: 'static,
34 F: Fn(&ServiceRequest) -> O + 'static,
35 O: Future<Output = Result<BI, actix_web::Error>>,
36{
37 fn clone(&self) -> Self {
38 Self {
39 backend: self.backend.clone(),
40 input_fn: self.input_fn.clone(),
41 fail_open: self.fail_open,
42 allowed_mutation: self.allowed_mutation.clone(),
43 denied_response: self.denied_response.clone(),
44 rollback_condition: self.rollback_condition.clone(),
45 }
46 }
47}
48
49impl<BA, BI, BO, F, O> RateLimiter<BA, BO, F>
50where
51 BA: Backend<BI, Output = BO> + 'static,
52 BI: 'static,
53 F: Fn(&ServiceRequest) -> O + 'static,
54 O: Future<Output = Result<BI, actix_web::Error>>,
55{
56 pub fn builder(backend: BA, input_fn: F) -> RateLimiterBuilder<BA, BO, F> {
61 RateLimiterBuilder::new(backend, input_fn)
62 }
63}
64
65impl<S, B, BA, BI, BO, BE, F, O> Transform<S, ServiceRequest> for RateLimiter<BA, BO, F>
66where
67 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
68 S::Future: 'static,
69 B: 'static,
70 BA: Backend<BI, Output = BO, Error = BE> + 'static,
71 BI: 'static,
72 BO: 'static,
73 BE: Into<actix_web::Error> + std::fmt::Display + 'static,
74 F: Fn(&ServiceRequest) -> O + 'static,
75 O: Future<Output = Result<BI, actix_web::Error>>,
76{
77 type Response = ServiceResponse<EitherBody<B>>;
78 type Error = actix_web::Error;
79 type Transform = RateLimiterMiddleware<S, BA, BO, F>;
80 type InitError = ();
81 type Future = Ready<Result<Self::Transform, Self::InitError>>;
82
83 fn new_transform(&self, service: S) -> Self::Future {
84 ok(RateLimiterMiddleware {
85 service: Rc::new(RefCell::new(service)),
86 backend: self.backend.clone(),
87 input_fn: Rc::clone(&self.input_fn),
88 fail_open: self.fail_open,
89 allowed_transformation: self.allowed_mutation.clone(),
90 denied_response: self.denied_response.clone(),
91 rollback_condition: self.rollback_condition.clone(),
92 })
93 }
94}
95
96pub struct RateLimiterMiddleware<S, BE, BO, F> {
97 service: Rc<RefCell<S>>,
98 backend: BE,
99 input_fn: Rc<F>,
100 fail_open: bool,
101 allowed_transformation: Option<Rc<AllowedTransformation<BO>>>,
102 denied_response: Rc<DeniedResponse<BO>>,
103 rollback_condition: Option<Rc<RollbackCondition>>,
104}
105
106impl<S, B, BA, BI, BO, BE, F, O> Service<ServiceRequest> for RateLimiterMiddleware<S, BA, BO, F>
107where
108 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
109 S::Future: 'static,
110 B: 'static,
111 BA: Backend<BI, Output = BO, Error = BE> + 'static,
112 BI: 'static,
113 BO: 'static,
114 BE: Into<actix_web::Error> + std::fmt::Display + 'static,
115 F: Fn(&ServiceRequest) -> O + 'static,
116 O: Future<Output = Result<BI, actix_web::Error>>,
117{
118 type Response = ServiceResponse<EitherBody<B>>;
119 type Error = actix_web::Error;
120 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
121
122 forward_ready!(service);
123
124 fn call(&self, req: ServiceRequest) -> Self::Future {
125 let service = self.service.clone();
126 let backend = self.backend.clone();
127 let input_fn = self.input_fn.clone();
128 let fail_open = self.fail_open;
129 let allowed_transformation = self.allowed_transformation.clone();
130 let denied_response = self.denied_response.clone();
131 let rollback_condition = self.rollback_condition.clone();
132
133 Box::pin(async move {
134 let input = match input_fn(&req).await {
135 Ok(input) => input,
136 Err(e) => {
137 log::error!("Rate limiter input function failed: {e}");
138 return Ok(req.into_response(e.error_response()).map_into_right_body());
139 }
140 };
141
142 let (output, rollback) = match backend.request(input).await {
143 Ok((decision, output, rollback)) => {
145 if decision.is_denied() {
146 let response: HttpResponse = denied_response(&output);
147 return Ok(req.into_response(response).map_into_right_body());
148 }
149 (Some(output), Some(rollback))
150 }
151 Err(e) => {
153 if fail_open {
154 log::warn!("Rate limiter failed: {}, allowing the request anyway", e);
155 (None, None)
156 } else {
157 log::error!("Rate limiter failed: {}", e);
158 return Ok(req
159 .into_response(e.into().error_response())
160 .map_into_right_body());
161 }
162 }
163 };
164
165 let mut service_response = service.call(req).await?;
166
167 let mut rolled_back = false;
168 if let Some(token) = rollback {
169 if let Some(rollback_condition) = rollback_condition {
170 let status = service_response.status();
171 if rollback_condition(status) {
172 if let Err(e) = backend.rollback(token).await {
173 log::error!("Unable to rollback rate-limit count for response: {:?}, error: {e}", status);
174 } else {
175 rolled_back = true;
176 };
177 }
178 }
179 }
180
181 if let Some(transformation) = allowed_transformation {
182 transformation(service_response.headers_mut(), output.as_ref(), rolled_back);
183 }
184
185 Ok(service_response.map_into_left_body())
186 })
187 }
188}