headless_lms_server/domain/
rate_limit_middleware_builder.rs

1use actix_web::{
2    Error, HttpResponse,
3    body::{EitherBody, MessageBody},
4    dev::{Service, ServiceRequest, ServiceResponse, Transform},
5    http::{StatusCode, header},
6};
7use futures_util::future::{LocalBoxFuture, Ready, ready};
8use governor::{
9    Quota, RateLimiter,
10    clock::{Clock, DefaultClock},
11    state::keyed::DefaultKeyedStateStore,
12};
13use std::{
14    num::NonZeroU32,
15    sync::{
16        Arc,
17        atomic::{AtomicU64, Ordering},
18    },
19    task::{Context, Poll},
20    time::Duration,
21};
22
23#[derive(Clone, Debug, Default)]
24pub struct RateLimitConfig {
25    pub per_minute: Option<u64>,
26    pub per_hour: Option<u64>,
27    pub per_day: Option<u64>,
28    pub per_month: Option<u64>,
29}
30
31type Key = String;
32type Limiter = RateLimiter<Key, DefaultKeyedStateStore<Key>, DefaultClock>;
33
34#[derive(Clone)]
35struct EndpointLimiters {
36    month: Option<Arc<Limiter>>,
37    day: Option<Arc<Limiter>>,
38    hour: Option<Arc<Limiter>>,
39    minute: Option<Arc<Limiter>>,
40}
41
42impl EndpointLimiters {
43    fn from_config(cfg: &RateLimitConfig) -> Self {
44        Self {
45            month: cfg.per_month.and_then(|n| {
46                build_custom_period_limiter(n, Duration::from_secs(30 * 24 * 60 * 60))
47            }),
48            day: cfg
49                .per_day
50                .and_then(|n| build_custom_period_limiter(n, Duration::from_secs(24 * 60 * 60))),
51            hour: cfg.per_hour.and_then(|n| build_limiter(n, Quota::per_hour)),
52            minute: cfg
53                .per_minute
54                .and_then(|n| build_limiter(n, Quota::per_minute)),
55        }
56    }
57
58    fn iter(&self) -> impl Iterator<Item = &Arc<Limiter>> {
59        self.month
60            .iter()
61            .chain(self.day.iter())
62            .chain(self.hour.iter())
63            .chain(self.minute.iter())
64    }
65
66    fn is_empty(&self) -> bool {
67        self.minute.is_none() && self.hour.is_none() && self.day.is_none() && self.month.is_none()
68    }
69}
70
71fn build_limiter<F>(n: u64, quota_fn: F) -> Option<Arc<Limiter>>
72where
73    F: FnOnce(NonZeroU32) -> Quota,
74{
75    let n32 = NonZeroU32::new(u32::try_from(n).ok()?)?;
76    Some(Arc::new(RateLimiter::keyed(quota_fn(n32))))
77}
78
79fn build_custom_period_limiter(n: u64, period: Duration) -> Option<Arc<Limiter>> {
80    let n32 = NonZeroU32::new(u32::try_from(n).ok()?)?;
81    let quota = Quota::with_period(period)?.allow_burst(n32);
82    Some(Arc::new(RateLimiter::keyed(quota)))
83}
84
85#[derive(Clone)]
86pub struct RateLimit {
87    limiters: Arc<EndpointLimiters>,
88    calls: Arc<AtomicU64>,
89}
90
91impl RateLimit {
92    pub fn new(cfg: RateLimitConfig) -> Self {
93        Self {
94            limiters: Arc::new(EndpointLimiters::from_config(&cfg)),
95            calls: Arc::new(AtomicU64::new(0)),
96        }
97    }
98}
99
100impl<S, B> Transform<S, ServiceRequest> for RateLimit
101where
102    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
103    B: MessageBody + 'static,
104{
105    type Response = ServiceResponse<EitherBody<B>>;
106    type Error = Error;
107    type InitError = ();
108    type Transform = RateLimitInner<S>;
109    type Future = Ready<Result<Self::Transform, Self::InitError>>;
110
111    fn new_transform(&self, service: S) -> Self::Future {
112        ready(Ok(RateLimitInner {
113            service,
114            limiters: self.limiters.clone(),
115            calls: self.calls.clone(),
116        }))
117    }
118}
119
120pub struct RateLimitInner<S> {
121    service: S,
122    limiters: Arc<EndpointLimiters>,
123    calls: Arc<AtomicU64>,
124}
125
126impl<S, B> Service<ServiceRequest> for RateLimitInner<S>
127where
128    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
129    B: MessageBody + 'static,
130{
131    type Response = ServiceResponse<EitherBody<B>>;
132    type Error = Error;
133    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
134
135    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136        self.service.poll_ready(cx)
137    }
138
139    fn call(&self, req: ServiceRequest) -> Self::Future {
140        if self.limiters.is_empty() {
141            let fut = self.service.call(req);
142            return Box::pin(async move { fut.await.map(|r| r.map_into_left_body()) });
143        }
144
145        const RETAIN_EVERY: u64 = 1024;
146        const SHRINK_EVERY: u64 = 65_536;
147
148        let n = self.calls.fetch_add(1, Ordering::Relaxed) + 1;
149        if n.is_multiple_of(RETAIN_EVERY) {
150            for limiter in self.limiters.iter() {
151                limiter.retain_recent();
152            }
153            if n.is_multiple_of(SHRINK_EVERY) {
154                for limiter in self.limiters.iter() {
155                    limiter.shrink_to_fit();
156                }
157            }
158        }
159
160        let clock = DefaultClock::default();
161        let key = extract_client_ip_key(&req);
162
163        let mut retry_after: Option<Duration> = None;
164        for limiter in self.limiters.iter() {
165            if let Err(negative) = limiter.check_key(&key) {
166                let wait = negative.wait_time_from(clock.now());
167                retry_after = Some(retry_after.map_or(wait, |cur| cur.max(wait)));
168            }
169        }
170
171        if let Some(wait) = retry_after {
172            let secs = wait.as_secs().max(1);
173            let resp = HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
174                .insert_header((header::RETRY_AFTER, secs.to_string()))
175                .content_type("application/json")
176                .body(format!(
177                    r#"{{"error":"too_many_requests","retry_after_seconds":{secs}}}"#
178                ));
179            return Box::pin(async move { Ok(req.into_response(resp).map_into_right_body()) });
180        }
181
182        let fut = self.service.call(req);
183        Box::pin(async move { fut.await.map(|r| r.map_into_left_body()) })
184    }
185}
186
187fn extract_client_ip_key(req: &ServiceRequest) -> String {
188    if let Some(s) = req.connection_info().realip_remote_addr() {
189        let s = s.trim();
190        if !s.is_empty() {
191            return s.to_string();
192        }
193    }
194
195    if let Some(sa) = req.peer_addr() {
196        return sa.ip().to_string();
197    }
198
199    format!("unknown:{}|{}", req.connection_info().host(), req.path())
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use actix_http::Request;
206    use actix_web::{
207        App, HttpResponse,
208        body::{BoxBody, EitherBody},
209        dev::{Service, ServiceResponse},
210        http::header,
211        test, web,
212    };
213    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
214
215    fn mw(
216        per_minute: Option<u64>,
217        per_hour: Option<u64>,
218        per_day: Option<u64>,
219        per_month: Option<u64>,
220    ) -> RateLimit {
221        RateLimit::new(RateLimitConfig {
222            per_minute,
223            per_hour,
224            per_day,
225            per_month,
226        })
227    }
228
229    async fn call_get<S>(
230        app: &S,
231        uri: &str,
232        xff: Option<&str>,
233        peer: Option<SocketAddr>,
234    ) -> ServiceResponse<EitherBody<BoxBody>>
235    where
236        S: Service<Request, Response = ServiceResponse<EitherBody<BoxBody>>, Error = Error>,
237    {
238        let mut tr = test::TestRequest::get().uri(uri);
239        if let Some(v) = xff {
240            tr = tr.insert_header(("x-forwarded-for", v));
241        }
242        if let Some(p) = peer {
243            tr = tr.peer_addr(p);
244        }
245        test::call_service(app, tr.to_request()).await
246    }
247
248    fn retry_after_secs(resp: &ServiceResponse<EitherBody<BoxBody>>) -> Option<u64> {
249        resp.headers()
250            .get(header::RETRY_AFTER)
251            .and_then(|v| v.to_str().ok())
252            .and_then(|s| s.parse::<u64>().ok())
253    }
254
255    async fn app(
256        mw: RateLimit,
257    ) -> impl Service<Request, Response = ServiceResponse<EitherBody<BoxBody>>, Error = Error> {
258        test::init_service(
259            App::new()
260                .wrap(mw)
261                .route("/", web::get().to(|| async { HttpResponse::Ok().finish() }))
262                .route(
263                    "/other",
264                    web::get().to(|| async { HttpResponse::Ok().finish() }),
265                ),
266        )
267        .await
268    }
269
270    #[actix_web::test]
271    async fn key_trims_realip_value() {
272        let req = test::TestRequest::get()
273            .uri("/x")
274            .insert_header(("x-forwarded-for", " 9.9.9.9 "))
275            .to_srv_request();
276
277        assert_eq!(super::extract_client_ip_key(&req), "9.9.9.9");
278    }
279
280    #[actix_web::test]
281    async fn key_empty_realip_falls_back_to_peer_ip() {
282        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 5555);
283        let req = test::TestRequest::get()
284            .uri("/x")
285            .insert_header(("x-forwarded-for", "     "))
286            .peer_addr(peer)
287            .to_srv_request();
288
289        assert_eq!(super::extract_client_ip_key(&req), "1.2.3.4");
290    }
291
292    #[actix_web::test]
293    async fn key_no_realip_uses_peer_ip() {
294        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 9)), 1234);
295        let req = test::TestRequest::get()
296            .uri("/x")
297            .peer_addr(peer)
298            .to_srv_request();
299        assert_eq!(super::extract_client_ip_key(&req), "10.0.0.9");
300    }
301
302    #[actix_web::test]
303    async fn key_no_realip_and_no_peer_uses_unknown_host_and_path() {
304        let req = test::TestRequest::get().uri("/path123").to_srv_request();
305        let k = super::extract_client_ip_key(&req);
306        assert!(k.starts_with("unknown:"), "key={k}");
307        assert!(k.contains("|/path123"), "key={k}");
308    }
309
310    #[actix_web::test]
311    async fn passthrough_when_no_limits() {
312        let app = app(mw(None, None, None, None)).await;
313
314        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
315        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
316
317        assert_eq!(r1.status(), StatusCode::OK);
318        assert_eq!(r2.status(), StatusCode::OK);
319        assert!(r2.headers().get(header::RETRY_AFTER).is_none());
320    }
321
322    #[actix_web::test]
323    async fn per_minute_zero_disables_window() {
324        let app = app(mw(Some(0), None, None, None)).await;
325
326        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
327        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
328
329        assert_eq!(r1.status(), StatusCode::OK);
330        assert_eq!(r2.status(), StatusCode::OK);
331    }
332
333    #[actix_web::test]
334    async fn per_minute_over_u32_disables_window() {
335        let app = app(mw(Some(u64::from(u32::MAX) + 1), None, None, None)).await;
336
337        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
338        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
339
340        assert_eq!(r1.status(), StatusCode::OK);
341        assert_eq!(r2.status(), StatusCode::OK);
342    }
343
344    #[actix_web::test]
345    async fn blocks_second_request_same_key() {
346        let app = app(mw(Some(1), None, None, None)).await;
347
348        let ok = call_get(&app, "/", Some("1.2.3.4"), None).await;
349        assert_eq!(ok.status(), StatusCode::OK);
350
351        let blocked = call_get(&app, "/", Some("1.2.3.4"), None).await;
352        assert_eq!(blocked.status(), StatusCode::TOO_MANY_REQUESTS);
353
354        let ra = retry_after_secs(&blocked).expect("missing Retry-After");
355        assert!(ra >= 1);
356    }
357
358    #[actix_web::test]
359    async fn retry_after_is_integer_seconds_and_body_is_json() {
360        let app = app(mw(Some(1), None, None, None)).await;
361
362        let _ = call_get(&app, "/", Some("1.2.3.4"), None).await;
363        let blocked = call_get(&app, "/", Some("1.2.3.4"), None).await;
364
365        assert_eq!(blocked.status(), StatusCode::TOO_MANY_REQUESTS);
366
367        let ra_hdr = blocked.headers().get(header::RETRY_AFTER).unwrap();
368        let ra_str = ra_hdr.to_str().unwrap();
369        assert!(
370            ra_str.parse::<u64>().is_ok(),
371            "Retry-After not int: {ra_str}"
372        );
373
374        let bytes = test::read_body(blocked).await;
375        let body = std::str::from_utf8(&bytes).unwrap();
376        assert!(
377            body.contains(r#""error":"too_many_requests""#),
378            "body={body}"
379        );
380        let v: serde_json::Value = serde_json::from_str(body).unwrap();
381        assert_eq!(v["error"], "too_many_requests");
382        assert!(v["retry_after_seconds"].as_u64().is_some());
383    }
384
385    #[actix_web::test]
386    async fn different_keys_independent() {
387        let app = app(mw(Some(1), None, None, None)).await;
388
389        let a1 = call_get(&app, "/", Some("10.0.0.1"), None).await;
390        let b1 = call_get(&app, "/", Some("10.0.0.2"), None).await;
391        assert_eq!(a1.status(), StatusCode::OK);
392        assert_eq!(b1.status(), StatusCode::OK);
393
394        let a2 = call_get(&app, "/", Some("10.0.0.1"), None).await;
395        assert_eq!(a2.status(), StatusCode::TOO_MANY_REQUESTS);
396    }
397
398    #[actix_web::test]
399    async fn same_key_shared_across_routes_in_same_app() {
400        let app = app(mw(Some(1), None, None, None)).await;
401
402        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
403        assert_eq!(r1.status(), StatusCode::OK);
404
405        let r2 = call_get(&app, "/other", Some("1.2.3.4"), None).await;
406        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
407    }
408
409    #[actix_web::test]
410    async fn max_retry_after_prefers_longer_window() {
411        let app = app(mw(Some(1), Some(1), None, None)).await;
412
413        let r1 = call_get(&app, "/", Some("1.2.3.4"), None).await;
414        assert_eq!(r1.status(), StatusCode::OK);
415
416        let r2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
417        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
418
419        let ra = retry_after_secs(&r2).expect("missing Retry-After");
420        // hour window should dominate minute window; be tolerant but meaningful
421        assert!(ra >= 120, "expected hour-dominated Retry-After, got {ra}");
422    }
423
424    #[actix_web::test]
425    async fn peer_addr_used_when_no_forwarded_headers() {
426        let app = app(mw(Some(1), None, None, None)).await;
427
428        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(7, 7, 7, 7)), 9999);
429        let r1 = call_get(&app, "/", None, Some(peer)).await;
430        let r2 = call_get(&app, "/", None, Some(peer)).await;
431
432        assert_eq!(r1.status(), StatusCode::OK);
433        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
434    }
435
436    #[actix_web::test]
437    async fn empty_forwarded_header_does_not_create_empty_key_bucket() {
438        let app = app(mw(Some(1), None, None, None)).await;
439
440        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 1111);
441        let r1 = call_get(&app, "/", Some("   "), Some(peer)).await;
442        let r2 = call_get(&app, "/", Some("   "), Some(peer)).await;
443
444        assert_eq!(r1.status(), StatusCode::OK);
445        assert_eq!(r2.status(), StatusCode::TOO_MANY_REQUESTS);
446    }
447
448    #[actix_web::test]
449    async fn unknown_bucket_includes_path_to_reduce_collisions() {
450        let app = app(mw(Some(1), None, None, None)).await;
451
452        let r1 = call_get(&app, "/", None, None).await;
453        let r2 = call_get(&app, "/other", None, None).await;
454
455        assert_eq!(r1.status(), StatusCode::OK);
456        assert_eq!(r2.status(), StatusCode::OK);
457
458        let r1b = call_get(&app, "/", None, None).await;
459        assert_eq!(r1b.status(), StatusCode::TOO_MANY_REQUESTS);
460    }
461
462    #[actix_web::test]
463    async fn housekeeping_retain_recent_path_executes() {
464        // Trigger retain_recent() at 1024 calls; keep limit huge to avoid 429.
465        let app = app(mw(Some(1_000_000), None, None, None)).await;
466
467        for i in 0..=1024 {
468            let ip = format!("192.0.2.{}", (i % 250) + 1);
469            let resp = call_get(&app, "/", Some(&ip), None).await;
470            assert_eq!(resp.status(), StatusCode::OK, "i={i} ip={ip}");
471        }
472    }
473
474    #[actix_web::test]
475    async fn keys_are_exact_strings_no_normalization_means_ports_are_distinct() {
476        let app = app(mw(Some(1), None, None, None)).await;
477
478        let a = call_get(&app, "/", Some("1.2.3.4"), None).await;
479        let b = call_get(&app, "/", Some("1.2.3.4:12345"), None).await;
480
481        assert_eq!(a.status(), StatusCode::OK);
482        assert_eq!(b.status(), StatusCode::OK);
483
484        let a2 = call_get(&app, "/", Some("1.2.3.4"), None).await;
485        let b2 = call_get(&app, "/", Some("1.2.3.4:12345"), None).await;
486
487        assert_eq!(a2.status(), StatusCode::TOO_MANY_REQUESTS);
488        assert_eq!(b2.status(), StatusCode::TOO_MANY_REQUESTS);
489    }
490}