actix_extensible_rate_limit/backend/
memory.rs

1use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput};
2use actix_web::rt::task::JoinHandle;
3use actix_web::rt::time::Instant;
4use dashmap::DashMap;
5use std::convert::Infallible;
6use std::sync::Arc;
7use std::time::Duration;
8
9pub const DEFAULT_GC_INTERVAL_SECONDS: u64 = 60 * 10;
10
11/// A Fixed Window rate limiter [Backend] that uses [Dashmap](dashmap::DashMap) to store keys
12/// in memory.
13#[derive(Clone)]
14pub struct InMemoryBackend {
15    map: Arc<DashMap<String, Value>>,
16    gc_handle: Option<Arc<JoinHandle<()>>>,
17}
18
19struct Value {
20    ttl: Instant,
21    count: u64,
22}
23
24impl InMemoryBackend {
25    pub fn builder() -> Builder {
26        Builder {
27            gc_interval: Some(Duration::from_secs(DEFAULT_GC_INTERVAL_SECONDS)),
28        }
29    }
30
31    fn garbage_collector(map: Arc<DashMap<String, Value>>, interval: Duration) -> JoinHandle<()> {
32        assert!(
33            interval.as_secs_f64() > 0f64,
34            "GC interval must be non-zero"
35        );
36        actix_web::rt::spawn(async move {
37            loop {
38                let now = Instant::now();
39                map.retain(|_k, v| v.ttl > now);
40                actix_web::rt::time::sleep_until(now + interval).await;
41            }
42        })
43    }
44}
45
46pub struct Builder {
47    gc_interval: Option<Duration>,
48}
49
50impl Builder {
51    /// Override the default garbage collector interval.
52    ///
53    /// Set to None to disable garbage collection.
54    ///
55    /// The garbage collector periodically scans the internal map, removing expired buckets.
56    pub fn with_gc_interval(mut self, interval: Option<Duration>) -> Self {
57        self.gc_interval = interval;
58        self
59    }
60
61    pub fn build(self) -> InMemoryBackend {
62        let map = Arc::new(DashMap::<String, Value>::new());
63        let gc_handle = self.gc_interval.map(|gc_interval| {
64            Arc::new(InMemoryBackend::garbage_collector(map.clone(), gc_interval))
65        });
66        InMemoryBackend { map, gc_handle }
67    }
68}
69
70impl Backend<SimpleInput> for InMemoryBackend {
71    type Output = SimpleOutput;
72    type RollbackToken = String;
73    type Error = Infallible;
74
75    async fn request(
76        &self,
77        input: SimpleInput,
78    ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> {
79        let now = Instant::now();
80        let mut count = 1;
81        let mut expiry = now
82            .checked_add(input.interval)
83            .expect("Interval unexpectedly large");
84        self.map
85            .entry(input.key.clone())
86            .and_modify(|v| {
87                // If this bucket hasn't yet expired, increment and extract the count/expiry
88                if v.ttl > now {
89                    v.count += 1;
90                    count = v.count;
91                    expiry = v.ttl;
92                } else {
93                    // If this bucket has expired we will reset the count to 1 and set a new TTL.
94                    v.ttl = expiry;
95                    v.count = count;
96                }
97            })
98            .or_insert_with(|| Value {
99                // If the bucket doesn't exist, create it with a count of 1, and set the TTL.
100                ttl: expiry,
101                count,
102            });
103        let allow = count <= input.max_requests;
104        let output = SimpleOutput {
105            limit: input.max_requests,
106            remaining: input.max_requests.saturating_sub(count),
107            reset: expiry,
108        };
109        Ok((Decision::from_allowed(allow), output, input.key))
110    }
111
112    async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> {
113        self.map.entry(token).and_modify(|v| {
114            v.count = v.count.saturating_sub(1);
115        });
116        Ok(())
117    }
118}
119
120impl SimpleBackend for InMemoryBackend {
121    async fn remove_key(&self, key: &str) -> Result<(), Self::Error> {
122        self.map.remove(key);
123        Ok(())
124    }
125}
126
127impl Drop for InMemoryBackend {
128    fn drop(&mut self) {
129        if let Some(handle) = &self.gc_handle {
130            handle.abort();
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    const MINUTE: Duration = Duration::from_secs(60);
140
141    #[actix_web::test]
142    async fn test_allow_deny() {
143        tokio::time::pause();
144        let backend = InMemoryBackend::builder().build();
145        let input = SimpleInput {
146            interval: MINUTE,
147            max_requests: 5,
148            key: "KEY1".to_string(),
149        };
150        for _ in 0..5 {
151            // First 5 should be allowed
152            let (allow, _, _) = backend.request(input.clone()).await.unwrap();
153            assert!(allow.is_allowed());
154        }
155        // Sixth should be denied
156        let (allow, _, _) = backend.request(input.clone()).await.unwrap();
157        assert!(!allow.is_allowed());
158    }
159
160    #[actix_web::test]
161    async fn test_reset() {
162        tokio::time::pause();
163        let backend = InMemoryBackend::builder().with_gc_interval(None).build();
164        let input = SimpleInput {
165            interval: MINUTE,
166            max_requests: 1,
167            key: "KEY1".to_string(),
168        };
169        // Make first request, should be allowed
170        let (decision, _, _) = backend.request(input.clone()).await.unwrap();
171        assert!(decision.is_allowed());
172        // Request again, should be denied
173        let (decision, _, _) = backend.request(input.clone()).await.unwrap();
174        assert!(decision.is_denied());
175        // Advance time and try again, should now be allowed
176        tokio::time::advance(MINUTE).await;
177        // We want to be sure the key hasn't been garbage collected, and we are testing the expiry logic
178        assert!(backend.map.contains_key("KEY1"));
179        let (decision, _, _) = backend.request(input).await.unwrap();
180        assert!(decision.is_allowed());
181    }
182
183    #[actix_web::test]
184    async fn test_garbage_collection() {
185        tokio::time::pause();
186        let backend = InMemoryBackend::builder()
187            .with_gc_interval(Some(MINUTE))
188            .build();
189        backend
190            .request(SimpleInput {
191                interval: MINUTE,
192                max_requests: 1,
193                key: "KEY1".to_string(),
194            })
195            .await
196            .unwrap();
197        backend
198            .request(SimpleInput {
199                interval: MINUTE * 2,
200                max_requests: 1,
201                key: "KEY2".to_string(),
202            })
203            .await
204            .unwrap();
205        assert!(backend.map.contains_key("KEY1"));
206        assert!(backend.map.contains_key("KEY2"));
207        // Advance time such that the garbage collector runs,
208        // expired KEY1 should be cleaned, but KEY2 should remain.
209        tokio::time::advance(MINUTE).await;
210        assert!(!backend.map.contains_key("KEY1"));
211        assert!(backend.map.contains_key("KEY2"));
212    }
213
214    #[actix_web::test]
215    async fn test_output() {
216        tokio::time::pause();
217        let backend = InMemoryBackend::builder().build();
218        let input = SimpleInput {
219            interval: MINUTE,
220            max_requests: 2,
221            key: "KEY1".to_string(),
222        };
223        // First of 2 should be allowed.
224        let (decision, output, _) = backend.request(input.clone()).await.unwrap();
225        assert!(decision.is_allowed());
226        assert_eq!(output.remaining, 1);
227        assert_eq!(output.limit, 2);
228        assert_eq!(output.reset, Instant::now() + MINUTE);
229        // Second of 2 should be allowed.
230        let (decision, output, _) = backend.request(input.clone()).await.unwrap();
231        assert!(decision.is_allowed());
232        assert_eq!(output.remaining, 0);
233        assert_eq!(output.limit, 2);
234        assert_eq!(output.reset, Instant::now() + MINUTE);
235        // Should be denied
236        let (decision, output, _) = backend.request(input).await.unwrap();
237        assert!(decision.is_denied());
238        assert_eq!(output.remaining, 0);
239        assert_eq!(output.limit, 2);
240        assert_eq!(output.reset, Instant::now() + MINUTE);
241    }
242
243    #[actix_web::test]
244    async fn test_rollback() {
245        tokio::time::pause();
246        let backend = InMemoryBackend::builder().build();
247        let input = SimpleInput {
248            interval: MINUTE,
249            max_requests: 5,
250            key: "KEY1".to_string(),
251        };
252        let (_, output, rollback) = backend.request(input.clone()).await.unwrap();
253        assert_eq!(output.remaining, 4);
254        backend.rollback(rollback).await.unwrap();
255        // Remaining requests should still be the same, since the previous call was excluded
256        let (_, output, _) = backend.request(input).await.unwrap();
257        assert_eq!(output.remaining, 4);
258    }
259
260    #[actix_web::test]
261    async fn test_remove_key() {
262        tokio::time::pause();
263        let backend = InMemoryBackend::builder().with_gc_interval(None).build();
264        let input = SimpleInput {
265            interval: MINUTE,
266            max_requests: 1,
267            key: "KEY1".to_string(),
268        };
269        let (decision, _, _) = backend.request(input.clone()).await.unwrap();
270        assert!(decision.is_allowed());
271        let (decision, _, _) = backend.request(input.clone()).await.unwrap();
272        assert!(decision.is_denied());
273        backend.remove_key("KEY1").await.unwrap();
274        // Counter should have been reset
275        let (decision, _, _) = backend.request(input).await.unwrap();
276        assert!(decision.is_allowed());
277    }
278}