Skip to main content

headless_lms_server/domain/
rate_limit_middleware_builder.rs

1use actix_web::{
2    Error, HttpResponse,
3    body::{EitherBody, MessageBody},
4    dev::{Service, ServiceRequest, ServiceResponse, Transform},
5    http::{StatusCode, header},
6};
7use futures_util::future::{LocalBoxFuture, Ready, ready};
8use governor::{
9    Quota, RateLimiter,
10    clock::{Clock, DefaultClock},
11    state::keyed::DefaultKeyedStateStore,
12};
13use std::{
14    num::NonZeroU32,
15    sync::{
16        Arc,
17        atomic::{AtomicU64, Ordering},
18    },
19    task::{Context, Poll},
20    time::Duration,
21};
22
23#[derive(Clone, Debug, Default)]
24pub struct RateLimitConfig {
25    pub per_second: Option<u64>,
26    pub per_minute: Option<u64>,
27    pub per_hour: Option<u64>,
28    pub per_day: Option<u64>,
29    pub per_month: Option<u64>,
30}
31
32type Key = String;
33type Limiter = RateLimiter<Key, DefaultKeyedStateStore<Key>, DefaultClock>;
34
35#[derive(Clone)]
36struct EndpointLimiters {
37    month: Option<Arc<Limiter>>,
38    day: Option<Arc<Limiter>>,
39    hour: Option<Arc<Limiter>>,
40    minute: Option<Arc<Limiter>>,
41    second: Option<Arc<Limiter>>,
42}
43
44impl EndpointLimiters {
45    fn from_config(cfg: &RateLimitConfig) -> Self {
46        Self {
47            month: cfg.per_month.and_then(|n| {
48                build_custom_period_limiter(n, Duration::from_secs(30 * 24 * 60 * 60))
49            }),
50            day: cfg
51                .per_day
52                .and_then(|n| build_custom_period_limiter(n, Duration::from_secs(24 * 60 * 60))),
53            hour: cfg.per_hour.and_then(|n| build_limiter(n, Quota::per_hour)),
54            minute: cfg
55                .per_minute
56                .and_then(|n| build_limiter(n, Quota::per_minute)),
57            second: cfg
58                .per_second
59                .and_then(|n| build_limiter(n, Quota::per_second)),
60        }
61    }
62
63    fn iter(&self) -> impl Iterator<Item = &Arc<Limiter>> {
64        self.month
65            .iter()
66            .chain(self.day.iter())
67            .chain(self.hour.iter())
68            .chain(self.minute.iter())
69            .chain(self.second.iter())
70    }
71
72    fn is_empty(&self) -> bool {
73        self.second.is_none()
74            && self.minute.is_none()
75            && self.hour.is_none()
76            && self.day.is_none()
77            && self.month.is_none()
78    }
79}
80
81fn build_limiter<F>(n: u64, quota_fn: F) -> Option<Arc<Limiter>>
82where
83    F: FnOnce(NonZeroU32) -> Quota,
84{
85    let n32 = NonZeroU32::new(u32::try_from(n).ok()?)?;
86    Some(Arc::new(RateLimiter::keyed(quota_fn(n32))))
87}
88
89fn build_custom_period_limiter(n: u64, period: Duration) -> Option<Arc<Limiter>> {
90    let n32 = NonZeroU32::new(u32::try_from(n).ok()?)?;
91    let quota = Quota::with_period(period)?.allow_burst(n32);
92    Some(Arc::new(RateLimiter::keyed(quota)))
93}
94
95#[derive(Clone)]
96pub struct RateLimit {
97    limiters: Arc<EndpointLimiters>,
98    calls: Arc<AtomicU64>,
99}
100
101impl RateLimit {
102    /// Global `/api/v0` limits aligned with nginx ingress `limit-rps` and `limit-rpm`; relaxed when `TEST_MODE` is set.
103    pub fn global_api_rate_limit_config(test_mode: bool) -> RateLimitConfig {
104        if test_mode {
105            RateLimitConfig {
106                per_second: Some(10000),
107                per_minute: Some(200000),
108                ..Default::default()
109            }
110        } else {
111            RateLimitConfig {
112                per_second: Some(20),
113                per_minute: Some(1000),
114                per_hour: Some(10000),
115                ..Default::default()
116            }
117        }
118    }
119
120    pub fn new(cfg: RateLimitConfig) -> Self {
121        Self {
122            limiters: Arc::new(EndpointLimiters::from_config(&cfg)),
123            calls: Arc::new(AtomicU64::new(0)),
124        }
125    }
126}
127
128impl<S, B> Transform<S, ServiceRequest> for RateLimit
129where
130    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
131    B: MessageBody + 'static,
132{
133    type Response = ServiceResponse<EitherBody<B>>;
134    type Error = Error;
135    type InitError = ();
136    type Transform = RateLimitInner<S>;
137    type Future = Ready<Result<Self::Transform, Self::InitError>>;
138
139    fn new_transform(&self, service: S) -> Self::Future {
140        ready(Ok(RateLimitInner {
141            service,
142            limiters: self.limiters.clone(),
143            calls: self.calls.clone(),
144        }))
145    }
146}
147
148pub struct RateLimitInner<S> {
149    service: S,
150    limiters: Arc<EndpointLimiters>,
151    calls: Arc<AtomicU64>,
152}
153
154impl<S, B> Service<ServiceRequest> for RateLimitInner<S>
155where
156    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
157    B: MessageBody + 'static,
158{
159    type Response = ServiceResponse<EitherBody<B>>;
160    type Error = Error;
161    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
162
163    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
164        self.service.poll_ready(cx)
165    }
166
167    fn call(&self, req: ServiceRequest) -> Self::Future {
168        if self.limiters.is_empty() {
169            let fut = self.service.call(req);
170            return Box::pin(async move { fut.await.map(|r| r.map_into_left_body()) });
171        }
172
173        const RETAIN_EVERY: u64 = 1024;
174        const SHRINK_EVERY: u64 = 65_536;
175
176        let n = self.calls.fetch_add(1, Ordering::Relaxed) + 1;
177        if n.is_multiple_of(RETAIN_EVERY) {
178            for limiter in self.limiters.iter() {
179                limiter.retain_recent();
180            }
181            if n.is_multiple_of(SHRINK_EVERY) {
182                for limiter in self.limiters.iter() {
183                    limiter.shrink_to_fit();
184                }
185            }
186        }
187
188        let clock = DefaultClock::default();
189        let key = extract_client_ip_key(&req);
190
191        let mut retry_after: Option<Duration> = None;
192        for limiter in self.limiters.iter() {
193            if let Err(negative) = limiter.check_key(&key) {
194                let wait = negative.wait_time_from(clock.now());
195                retry_after = Some(retry_after.map_or(wait, |cur| cur.max(wait)));
196            }
197        }
198
199        if let Some(wait) = retry_after {
200            let secs = wait.as_secs().max(1);
201            let resp = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
202                .insert_header((header::RETRY_AFTER, secs.to_string()))
203                .content_type("application/json")
204                .body(r#"{"type":"rate_limit","message_key":"rate_limited","message":"Too many requests. Please try again later."}"#.to_string());
205            return Box::pin(async move { Ok(req.into_response(resp).map_into_right_body()) });
206        }
207
208        let fut = self.service.call(req);
209        Box::pin(async move { fut.await.map(|r| r.map_into_left_body()) })
210    }
211}
212
213fn extract_client_ip_key(req: &ServiceRequest) -> String {
214    if let Some(s) = req.connection_info().realip_remote_addr() {
215        let s = s.trim();
216        if !s.is_empty() {
217            return s.to_string();
218        }
219    }
220
221    if let Some(sa) = req.peer_addr() {
222        return sa.ip().to_string();
223    }
224
225    format!("unknown:{}|{}", req.connection_info().host(), req.path())
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use actix_http::Request;
232    use actix_web::{
233        App, HttpResponse,
234        body::{BoxBody, EitherBody},
235        dev::{Service, ServiceResponse},
236        http::header,
237        test, web,
238    };
239    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
240
241    fn mw(
242        per_minute: Option<u64>,
243        per_hour: Option<u64>,
244        per_day: Option<u64>,
245        per_month: Option<u64>,
246    ) -> RateLimit {
247        RateLimit::new(RateLimitConfig {
248            per_minute,
249            per_hour,
250            per_day,
251            per_month,
252            ..Default::default()
253        })
254    }
255
256    #[actix_web::test]
257    async fn global_api_rate_limit_config_uses_test_mode_argument() {
258        let test_cfg = RateLimit::global_api_rate_limit_config(true);
259        assert_eq!(test_cfg.per_second, Some(10000));
260        assert_eq!(test_cfg.per_minute, Some(200000));
261        assert_eq!(test_cfg.per_hour, None);
262
263        let production_cfg = RateLimit::global_api_rate_limit_config(false);
264        assert_eq!(production_cfg.per_second, Some(20));
265        assert_eq!(production_cfg.per_minute, Some(1000));
266        assert_eq!(production_cfg.per_hour, Some(10000));
267    }
268
269    async fn call_get<S>(
270        app: &S,
271        uri: &str,
272        xff: Option<&str>,
273        peer: Option<SocketAddr>,
274    ) -> ServiceResponse<EitherBody<BoxBody>>
275    where
276        S: Service<Request, Response = ServiceResponse<EitherBody<BoxBody>>, Error = Error>,
277    {
278        let mut tr = test::TestRequest::get().uri(uri);
279        if let Some(v) = xff {
280            tr = tr.insert_header(("x-forwarded-for", v));
281        }
282        if let Some(p) = peer {
283            tr = tr.peer_addr(p);
284        }
285        test::call_service(app, tr.to_request()).await
286    }
287
288    fn retry_after_secs(resp: &ServiceResponse<EitherBody<BoxBody>>) -> Option<u64> {
289        resp.headers()
290            .get(header::RETRY_AFTER)
291            .and_then(|v| v.to_str().ok())
292            .and_then(|s| s.parse::<u64>().ok())
293    }
294
295    async fn app(
296        mw: RateLimit,
297    ) -> impl Service<Request, Response = ServiceResponse<EitherBody<BoxBody>>, Error = Error> {
298        test::init_service(
299            App::new()
300                .wrap(mw)
301                .route("/", web::get().to(|| async { HttpResponse::Ok().finish() }))
302                .route(
303                    "/other",
304                    web::get().to(|| async { HttpResponse::Ok().finish() }),
305                ),
306        )
307        .await
308    }
309
310    #[actix_web::test]
311    async fn key_trims_realip_value() {
312        let req = test::TestRequest::get()
313            .uri("/x")
314            .insert_header(("x-forwarded-for", " 9.9.9.9 "))
315            .to_srv_request();
316
317        assert_eq!(super::extract_client_ip_key(&req), "9.9.9.9");
318    }
319
320    #[actix_web::test]
321    async fn key_empty_realip_falls_back_to_peer_ip() {
322        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 5555);
323        let req = test::TestRequest::get()
324            .uri("/x")
325            .insert_header(("x-forwarded-for", "     "))
326            .peer_addr(peer)
327            .to_srv_request();
328
329        assert_eq!(super::extract_client_ip_key(&req), "1.2.3.4");
330    }
331
332    #[actix_web::test]
333    async fn key_no_realip_uses_peer_ip() {
334        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)), 1234);
335        let req = test::TestRequest::get()
336            .uri("/x")
337            .peer_addr(peer)
338            .to_srv_request();
339        assert_eq!(super::extract_client_ip_key(&req), "10.0.0.9");
340    }
341
342    #[actix_web::test]
343    async fn key_no_realip_and_no_peer_uses_unknown_host_and_path() {
344        let req = test::TestRequest::get().uri("/path123").to_srv_request();
345        let k = super::extract_client_ip_key(&req);
346        assert!(k.starts_with("unknown:"), "key={k}");
347        assert!(k.contains("|/path123"), "key={k}");
348    }
349
350    #[actix_web::test]
351    async fn passthrough_when_no_limits() {
352        let app = app(mw(None, None, None, None)).await;
353
354        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
355        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
356
357        assert_eq!(r1.status(), StatusCode::OK);
358        assert_eq!(r2.status(), StatusCode::OK);
359        assert!(r2.headers().get(header::RETRY_AFTER).is_none());
360    }
361
362    #[actix_web::test]
363    async fn per_minute_zero_disables_window() {
364        let app = app(mw(Some(0), None, None, None)).await;
365
366        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
367        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
368
369        assert_eq!(r1.status(), StatusCode::OK);
370        assert_eq!(r2.status(), StatusCode::OK);
371    }
372
373    #[actix_web::test]
374    async fn per_minute_over_u32_disables_window() {
375        let app = app(mw(Some(u64::from(u32::MAX) + 1), None, None, None)).await;
376
377        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
378        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
379
380        assert_eq!(r1.status(), StatusCode::OK);
381        assert_eq!(r2.status(), StatusCode::OK);
382    }
383
384    #[actix_web::test]
385    async fn blocks_second_request_same_key() {
386        let app = app(mw(Some(1), None, None, None)).await;
387
388        let ok = call_get(&app, "/", Some("1.2.3.4"), None).await;
389        assert_eq!(ok.status(), StatusCode::OK);
390
391        let blocked = call_get(&app, "/", Some("1.2.3.4"), None).await;
392        assert_eq!(blocked.status(), StatusCode::TOO_MANY_REQUESTS);
393
394        let ra = retry_after_secs(&blocked).expect("missing Retry-After");
395        assert!(ra >= 1);
396    }
397
398    #[actix_web::test]
399    async fn retry_after_is_integer_seconds_and_body_is_json() {
400        let app = app(mw(Some(1), None, None, None)).await;
401
402        let _ = call_get(&app, "/", Some("1.2.3.4"), None).await;
403        let blocked = call_get(&app, "/", Some("1.2.3.4"), None).await;
404
405        assert_eq!(blocked.status(), StatusCode::TOO_MANY_REQUESTS);
406
407        let ra_hdr = blocked.headers().get(header::RETRY_AFTER).unwrap();
408        let ra_str = ra_hdr.to_str().unwrap();
409        assert!(
410            ra_str.parse::<u64>().is_ok(),
411            "Retry-After not int: {ra_str}"
412        );
413
414        let bytes = test::read_body(blocked).await;
415        let body = std::str::from_utf8(&bytes).unwrap();
416        assert!(body.contains(r#""type":"rate_limit""#), "body={body}");
417        let v: serde_json::Value = serde_json::from_str(body).unwrap();
418        assert_eq!(v["type"], "rate_limit");
419        assert_eq!(v["message_key"], "rate_limited");
420        assert!(v.get("retry_after").is_none());
421    }
422
423    #[actix_web::test]
424    async fn different_keys_independent() {
425        let app = app(mw(Some(1), None, None, None)).await;
426
427        let a1 = call_get(&app, "/", Some("10.0.0.1"), None).await;
428        let b1 = call_get(&app, "/", Some("10.0.0.2"), None).await;
429        assert_eq!(a1.status(), StatusCode::OK);
430        assert_eq!(b1.status(), StatusCode::OK);
431
432        let a2 = call_get(&app, "/", Some("10.0.0.1"), None).await;
433        assert_eq!(a2.status(), StatusCode::TOO_MANY_REQUESTS);
434    }
435
436    #[actix_web::test]
437    async fn same_key_shared_across_routes_in_same_app() {
438        let app = app(mw(Some(1), None, None, None)).await;
439
440        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
441        assert_eq!(r1.status(), StatusCode::OK);
442
443        let r2 = call_get(&app, "/other", Some("1.2.3.4"), None).await;
444        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
445    }
446
447    #[actix_web::test]
448    async fn max_retry_after_prefers_longer_window() {
449        let app = app(mw(Some(1), Some(1), None, None)).await;
450
451        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
452        assert_eq!(r1.status(), StatusCode::OK);
453
454        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
455        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
456
457        let ra = retry_after_secs(&r2).expect("missing Retry-After");
458        // hour window should dominate minute window; be tolerant but meaningful
459        assert!(ra >= 120, "expected hour-dominated Retry-After, got {ra}");
460    }
461
462    #[actix_web::test]
463    async fn peer_addr_used_when_no_forwarded_headers() {
464        let app = app(mw(Some(1), None, None, None)).await;
465
466        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(7, 7, 7, 7)), 9999);
467        let r1 = call_get(&app, "/", None, Some(peer)).await;
468        let r2 = call_get(&app, "/", None, Some(peer)).await;
469
470        assert_eq!(r1.status(), StatusCode::OK);
471        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
472    }
473
474    #[actix_web::test]
475    async fn empty_forwarded_header_does_not_create_empty_key_bucket() {
476        let app = app(mw(Some(1), None, None, None)).await;
477
478        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 1111);
479        let r1 = call_get(&app, "/", Some("   "), Some(peer)).await;
480        let r2 = call_get(&app, "/", Some("   "), Some(peer)).await;
481
482        assert_eq!(r1.status(), StatusCode::OK);
483        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
484    }
485
486    #[actix_web::test]
487    async fn unknown_bucket_includes_path_to_reduce_collisions() {
488        let app = app(mw(Some(1), None, None, None)).await;
489
490        let r1 = call_get(&app, "/", None, None).await;
491        let r2 = call_get(&app, "/other", None, None).await;
492
493        assert_eq!(r1.status(), StatusCode::OK);
494        assert_eq!(r2.status(), StatusCode::OK);
495
496        let r1b = call_get(&app, "/", None, None).await;
497        assert_eq!(r1b.status(), StatusCode::TOO_MANY_REQUESTS);
498    }
499
500    #[actix_web::test]
501    async fn housekeeping_retain_recent_path_executes() {
502        // Trigger retain_recent() at 1024 calls; keep limit huge to avoid 429.
503        let app = app(mw(Some(1_000_000), None, None, None)).await;
504
505        for i in 0..=1024 {
506            let ip = format!("192.0.2.{}", (i % 250) + 1);
507            let resp = call_get(&app, "/", Some(&ip), None).await;
508            assert_eq!(resp.status(), StatusCode::OK, "i={i} ip={ip}");
509        }
510    }
511
512    #[actix_web::test]
513    async fn keys_are_exact_strings_no_normalization_means_ports_are_distinct() {
514        let app = app(mw(Some(1), None, None, None)).await;
515
516        let a = call_get(&app, "/", Some("1.2.3.4"), None).await;
517        let b = call_get(&app, "/", Some("1.2.3.4:12345"), None).await;
518
519        assert_eq!(a.status(), StatusCode::OK);
520        assert_eq!(b.status(), StatusCode::OK);
521
522        let a2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
523        let b2 = call_get(&app, "/", Some("1.2.3.4:12345"), None).await;
524
525        assert_eq!(a2.status(), StatusCode::TOO_MANY_REQUESTS);
526        assert_eq!(b2.status(), StatusCode::TOO_MANY_REQUESTS);
527    }
528}