1use actix_web::{
2 Error, HttpResponse,
3 body::{EitherBody, MessageBody},
4 dev::{Service, ServiceRequest, ServiceResponse, Transform},
5 http::{StatusCode, header},
6};
7use futures_util::future::{LocalBoxFuture, Ready, ready};
8use governor::{
9 Quota, RateLimiter,
10 clock::{Clock, DefaultClock},
11 state::keyed::DefaultKeyedStateStore,
12};
13use std::{
14 num::NonZeroU32,
15 sync::{
16 Arc,
17 atomic::{AtomicU64, Ordering},
18 },
19 task::{Context, Poll},
20 time::Duration,
21};
22
23#[derive(Clone, Debug, Default)]
24pub struct RateLimitConfig {
25 pub per_second: Option<u64>,
26 pub per_minute: Option<u64>,
27 pub per_hour: Option<u64>,
28 pub per_day: Option<u64>,
29 pub per_month: Option<u64>,
30}
31
32type Key = String;
33type Limiter = RateLimiter<Key, DefaultKeyedStateStore<Key>, DefaultClock>;
34
35#[derive(Clone)]
36struct EndpointLimiters {
37 month: Option<Arc<Limiter>>,
38 day: Option<Arc<Limiter>>,
39 hour: Option<Arc<Limiter>>,
40 minute: Option<Arc<Limiter>>,
41 second: Option<Arc<Limiter>>,
42}
43
44impl EndpointLimiters {
45 fn from_config(cfg: &RateLimitConfig) -> Self {
46 Self {
47 month: cfg.per_month.and_then(|n| {
48 build_custom_period_limiter(n, Duration::from_secs(30 * 24 * 60 * 60))
49 }),
50 day: cfg
51 .per_day
52 .and_then(|n| build_custom_period_limiter(n, Duration::from_secs(24 * 60 * 60))),
53 hour: cfg.per_hour.and_then(|n| build_limiter(n, Quota::per_hour)),
54 minute: cfg
55 .per_minute
56 .and_then(|n| build_limiter(n, Quota::per_minute)),
57 second: cfg
58 .per_second
59 .and_then(|n| build_limiter(n, Quota::per_second)),
60 }
61 }
62
63 fn iter(&self) -> impl Iterator<Item = &Arc<Limiter>> {
64 self.month
65 .iter()
66 .chain(self.day.iter())
67 .chain(self.hour.iter())
68 .chain(self.minute.iter())
69 .chain(self.second.iter())
70 }
71
72 fn is_empty(&self) -> bool {
73 self.second.is_none()
74 && self.minute.is_none()
75 && self.hour.is_none()
76 && self.day.is_none()
77 && self.month.is_none()
78 }
79}
80
81fn build_limiter<F>(n: u64, quota_fn: F) -> Option<Arc<Limiter>>
82where
83 F: FnOnce(NonZeroU32) -> Quota,
84{
85 let n32 = NonZeroU32::new(u32::try_from(n).ok()?)?;
86 Some(Arc::new(RateLimiter::keyed(quota_fn(n32))))
87}
88
89fn build_custom_period_limiter(n: u64, period: Duration) -> Option<Arc<Limiter>> {
90 let n32 = NonZeroU32::new(u32::try_from(n).ok()?)?;
91 let quota = Quota::with_period(period)?.allow_burst(n32);
92 Some(Arc::new(RateLimiter::keyed(quota)))
93}
94
95#[derive(Clone)]
96pub struct RateLimit {
97 limiters: Arc<EndpointLimiters>,
98 calls: Arc<AtomicU64>,
99}
100
101impl RateLimit {
102 pub fn global_api_rate_limit_config() -> RateLimitConfig {
104 if std::env::var("TEST_MODE").is_ok() {
105 RateLimitConfig {
106 per_second: Some(10000),
107 per_minute: Some(200000),
108 ..Default::default()
109 }
110 } else {
111 RateLimitConfig {
112 per_second: Some(20),
113 per_minute: Some(1000),
114 per_hour: Some(10000),
115 ..Default::default()
116 }
117 }
118 }
119
120 pub fn new(cfg: RateLimitConfig) -> Self {
121 Self {
122 limiters: Arc::new(EndpointLimiters::from_config(&cfg)),
123 calls: Arc::new(AtomicU64::new(0)),
124 }
125 }
126}
127
128impl<S, B> Transform<S, ServiceRequest> for RateLimit
129where
130 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
131 B: MessageBody + 'static,
132{
133 type Response = ServiceResponse<EitherBody<B>>;
134 type Error = Error;
135 type InitError = ();
136 type Transform = RateLimitInner<S>;
137 type Future = Ready<Result<Self::Transform, Self::InitError>>;
138
139 fn new_transform(&self, service: S) -> Self::Future {
140 ready(Ok(RateLimitInner {
141 service,
142 limiters: self.limiters.clone(),
143 calls: self.calls.clone(),
144 }))
145 }
146}
147
148pub struct RateLimitInner<S> {
149 service: S,
150 limiters: Arc<EndpointLimiters>,
151 calls: Arc<AtomicU64>,
152}
153
154impl<S, B> Service<ServiceRequest> for RateLimitInner<S>
155where
156 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
157 B: MessageBody + 'static,
158{
159 type Response = ServiceResponse<EitherBody<B>>;
160 type Error = Error;
161 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
162
163 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
164 self.service.poll_ready(cx)
165 }
166
167 fn call(&self, req: ServiceRequest) -> Self::Future {
168 if self.limiters.is_empty() {
169 let fut = self.service.call(req);
170 return Box::pin(async move { fut.await.map(|r| r.map_into_left_body()) });
171 }
172
173 const RETAIN_EVERY: u64 = 1024;
174 const SHRINK_EVERY: u64 = 65_536;
175
176 let n = self.calls.fetch_add(1, Ordering::Relaxed) + 1;
177 if n.is_multiple_of(RETAIN_EVERY) {
178 for limiter in self.limiters.iter() {
179 limiter.retain_recent();
180 }
181 if n.is_multiple_of(SHRINK_EVERY) {
182 for limiter in self.limiters.iter() {
183 limiter.shrink_to_fit();
184 }
185 }
186 }
187
188 let clock = DefaultClock::default();
189 let key = extract_client_ip_key(&req);
190
191 let mut retry_after: Option<Duration> = None;
192 for limiter in self.limiters.iter() {
193 if let Err(negative) = limiter.check_key(&key) {
194 let wait = negative.wait_time_from(clock.now());
195 retry_after = Some(retry_after.map_or(wait, |cur| cur.max(wait)));
196 }
197 }
198
199 if let Some(wait) = retry_after {
200 let secs = wait.as_secs().max(1);
201 let resp = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
202 .insert_header((header::RETRY_AFTER, secs.to_string()))
203 .content_type("application/json")
204 .body(r#"{"type":"rate_limit","message_key":"rate_limited","message":"Too many requests. Please try again later."}"#.to_string());
205 return Box::pin(async move { Ok(req.into_response(resp).map_into_right_body()) });
206 }
207
208 let fut = self.service.call(req);
209 Box::pin(async move { fut.await.map(|r| r.map_into_left_body()) })
210 }
211}
212
213fn extract_client_ip_key(req: &ServiceRequest) -> String {
214 if let Some(s) = req.connection_info().realip_remote_addr() {
215 let s = s.trim();
216 if !s.is_empty() {
217 return s.to_string();
218 }
219 }
220
221 if let Some(sa) = req.peer_addr() {
222 return sa.ip().to_string();
223 }
224
225 format!("unknown:{}|{}", req.connection_info().host(), req.path())
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use actix_http::Request;
232 use actix_web::{
233 App, HttpResponse,
234 body::{BoxBody, EitherBody},
235 dev::{Service, ServiceResponse},
236 http::header,
237 test, web,
238 };
239 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
240
241 fn mw(
242 per_minute: Option<u64>,
243 per_hour: Option<u64>,
244 per_day: Option<u64>,
245 per_month: Option<u64>,
246 ) -> RateLimit {
247 RateLimit::new(RateLimitConfig {
248 per_minute,
249 per_hour,
250 per_day,
251 per_month,
252 ..Default::default()
253 })
254 }
255
256 async fn call_get<S>(
257 app: &S,
258 uri: &str,
259 xff: Option<&str>,
260 peer: Option<SocketAddr>,
261 ) -> ServiceResponse<EitherBody<BoxBody>>
262 where
263 S: Service<Request, Response = ServiceResponse<EitherBody<BoxBody>>, Error = Error>,
264 {
265 let mut tr = test::TestRequest::get().uri(uri);
266 if let Some(v) = xff {
267 tr = tr.insert_header(("x-forwarded-for", v));
268 }
269 if let Some(p) = peer {
270 tr = tr.peer_addr(p);
271 }
272 test::call_service(app, tr.to_request()).await
273 }
274
275 fn retry_after_secs(resp: &ServiceResponse<EitherBody<BoxBody>>) -> Option<u64> {
276 resp.headers()
277 .get(header::RETRY_AFTER)
278 .and_then(|v| v.to_str().ok())
279 .and_then(|s| s.parse::<u64>().ok())
280 }
281
282 async fn app(
283 mw: RateLimit,
284 ) -> impl Service<Request, Response = ServiceResponse<EitherBody<BoxBody>>, Error = Error> {
285 test::init_service(
286 App::new()
287 .wrap(mw)
288 .route("/", web::get().to(|| async { HttpResponse::Ok().finish() }))
289 .route(
290 "/other",
291 web::get().to(|| async { HttpResponse::Ok().finish() }),
292 ),
293 )
294 .await
295 }
296
297 #[actix_web::test]
298 async fn key_trims_realip_value() {
299 let req = test::TestRequest::get()
300 .uri("/x")
301 .insert_header(("x-forwarded-for", " 9.9.9.9 "))
302 .to_srv_request();
303
304 assert_eq!(super::extract_client_ip_key(&req), "9.9.9.9");
305 }
306
307 #[actix_web::test]
308 async fn key_empty_realip_falls_back_to_peer_ip() {
309 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 5555);
310 let req = test::TestRequest::get()
311 .uri("/x")
312 .insert_header(("x-forwarded-for", " "))
313 .peer_addr(peer)
314 .to_srv_request();
315
316 assert_eq!(super::extract_client_ip_key(&req), "1.2.3.4");
317 }
318
319 #[actix_web::test]
320 async fn key_no_realip_uses_peer_ip() {
321 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)), 1234);
322 let req = test::TestRequest::get()
323 .uri("/x")
324 .peer_addr(peer)
325 .to_srv_request();
326 assert_eq!(super::extract_client_ip_key(&req), "10.0.0.9");
327 }
328
329 #[actix_web::test]
330 async fn key_no_realip_and_no_peer_uses_unknown_host_and_path() {
331 let req = test::TestRequest::get().uri("/path123").to_srv_request();
332 let k = super::extract_client_ip_key(&req);
333 assert!(k.starts_with("unknown:"), "key={k}");
334 assert!(k.contains("|/path123"), "key={k}");
335 }
336
337 #[actix_web::test]
338 async fn passthrough_when_no_limits() {
339 let app = app(mw(None, None, None, None)).await;
340
341 let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
342 let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
343
344 assert_eq!(r1.status(), StatusCode::OK);
345 assert_eq!(r2.status(), StatusCode::OK);
346 assert!(r2.headers().get(header::RETRY_AFTER).is_none());
347 }
348
349 #[actix_web::test]
350 async fn per_minute_zero_disables_window() {
351 let app = app(mw(Some(0), None, None, None)).await;
352
353 let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
354 let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
355
356 assert_eq!(r1.status(), StatusCode::OK);
357 assert_eq!(r2.status(), StatusCode::OK);
358 }
359
360 #[actix_web::test]
361 async fn per_minute_over_u32_disables_window() {
362 let app = app(mw(Some(u64::from(u32::MAX) + 1), None, None, None)).await;
363
364 let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
365 let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
366
367 assert_eq!(r1.status(), StatusCode::OK);
368 assert_eq!(r2.status(), StatusCode::OK);
369 }
370
371 #[actix_web::test]
372 async fn blocks_second_request_same_key() {
373 let app = app(mw(Some(1), None, None, None)).await;
374
375 let ok = call_get(&app, "/", Some("1.2.3.4"), None).await;
376 assert_eq!(ok.status(), StatusCode::OK);
377
378 let blocked = call_get(&app, "/", Some("1.2.3.4"), None).await;
379 assert_eq!(blocked.status(), StatusCode::TOO_MANY_REQUESTS);
380
381 let ra = retry_after_secs(&blocked).expect("missing Retry-After");
382 assert!(ra >= 1);
383 }
384
385 #[actix_web::test]
386 async fn retry_after_is_integer_seconds_and_body_is_json() {
387 let app = app(mw(Some(1), None, None, None)).await;
388
389 let _ = call_get(&app, "/", Some("1.2.3.4"), None).await;
390 let blocked = call_get(&app, "/", Some("1.2.3.4"), None).await;
391
392 assert_eq!(blocked.status(), StatusCode::TOO_MANY_REQUESTS);
393
394 let ra_hdr = blocked.headers().get(header::RETRY_AFTER).unwrap();
395 let ra_str = ra_hdr.to_str().unwrap();
396 assert!(
397 ra_str.parse::<u64>().is_ok(),
398 "Retry-After not int: {ra_str}"
399 );
400
401 let bytes = test::read_body(blocked).await;
402 let body = std::str::from_utf8(&bytes).unwrap();
403 assert!(body.contains(r#""type":"rate_limit""#), "body={body}");
404 let v: serde_json::Value = serde_json::from_str(body).unwrap();
405 assert_eq!(v["type"], "rate_limit");
406 assert_eq!(v["message_key"], "rate_limited");
407 assert!(v.get("retry_after").is_none());
408 }
409
410 #[actix_web::test]
411 async fn different_keys_independent() {
412 let app = app(mw(Some(1), None, None, None)).await;
413
414 let a1 = call_get(&app, "/", Some("10.0.0.1"), None).await;
415 let b1 = call_get(&app, "/", Some("10.0.0.2"), None).await;
416 assert_eq!(a1.status(), StatusCode::OK);
417 assert_eq!(b1.status(), StatusCode::OK);
418
419 let a2 = call_get(&app, "/", Some("10.0.0.1"), None).await;
420 assert_eq!(a2.status(), StatusCode::TOO_MANY_REQUESTS);
421 }
422
423 #[actix_web::test]
424 async fn same_key_shared_across_routes_in_same_app() {
425 let app = app(mw(Some(1), None, None, None)).await;
426
427 let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
428 assert_eq!(r1.status(), StatusCode::OK);
429
430 let r2 = call_get(&app, "/other", Some("1.2.3.4"), None).await;
431 assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
432 }
433
434 #[actix_web::test]
435 async fn max_retry_after_prefers_longer_window() {
436 let app = app(mw(Some(1), Some(1), None, None)).await;
437
438 let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
439 assert_eq!(r1.status(), StatusCode::OK);
440
441 let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
442 assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
443
444 let ra = retry_after_secs(&r2).expect("missing Retry-After");
445 assert!(ra >= 120, "expected hour-dominated Retry-After, got {ra}");
447 }
448
449 #[actix_web::test]
450 async fn peer_addr_used_when_no_forwarded_headers() {
451 let app = app(mw(Some(1), None, None, None)).await;
452
453 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(7, 7, 7, 7)), 9999);
454 let r1 = call_get(&app, "/", None, Some(peer)).await;
455 let r2 = call_get(&app, "/", None, Some(peer)).await;
456
457 assert_eq!(r1.status(), StatusCode::OK);
458 assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
459 }
460
461 #[actix_web::test]
462 async fn empty_forwarded_header_does_not_create_empty_key_bucket() {
463 let app = app(mw(Some(1), None, None, None)).await;
464
465 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 1111);
466 let r1 = call_get(&app, "/", Some(" "), Some(peer)).await;
467 let r2 = call_get(&app, "/", Some(" "), Some(peer)).await;
468
469 assert_eq!(r1.status(), StatusCode::OK);
470 assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
471 }
472
473 #[actix_web::test]
474 async fn unknown_bucket_includes_path_to_reduce_collisions() {
475 let app = app(mw(Some(1), None, None, None)).await;
476
477 let r1 = call_get(&app, "/", None, None).await;
478 let r2 = call_get(&app, "/other", None, None).await;
479
480 assert_eq!(r1.status(), StatusCode::OK);
481 assert_eq!(r2.status(), StatusCode::OK);
482
483 let r1b = call_get(&app, "/", None, None).await;
484 assert_eq!(r1b.status(), StatusCode::TOO_MANY_REQUESTS);
485 }
486
487 #[actix_web::test]
488 async fn housekeeping_retain_recent_path_executes() {
489 let app = app(mw(Some(1_000_000), None, None, None)).await;
491
492 for i in 0..=1024 {
493 let ip = format!("192.0.2.{}", (i % 250) + 1);
494 let resp = call_get(&app, "/", Some(&ip), None).await;
495 assert_eq!(resp.status(), StatusCode::OK, "i={i} ip={ip}");
496 }
497 }
498
499 #[actix_web::test]
500 async fn keys_are_exact_strings_no_normalization_means_ports_are_distinct() {
501 let app = app(mw(Some(1), None, None, None)).await;
502
503 let a = call_get(&app, "/", Some("1.2.3.4"), None).await;
504 let b = call_get(&app, "/", Some("1.2.3.4:12345"), None).await;
505
506 assert_eq!(a.status(), StatusCode::OK);
507 assert_eq!(b.status(), StatusCode::OK);
508
509 let a2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
510 let b2 = call_get(&app, "/", Some("1.2.3.4:12345"), None).await;
511
512 assert_eq!(a2.status(), StatusCode::TOO_MANY_REQUESTS);
513 assert_eq!(b2.status(), StatusCode::TOO_MANY_REQUESTS);
514 }
515}