actix_extensible_rate_limit/backend/
memory.rs1use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput};
2use actix_web::rt::task::JoinHandle;
3use actix_web::rt::time::Instant;
4use dashmap::DashMap;
5use std::convert::Infallible;
6use std::sync::Arc;
7use std::time::Duration;
8
9pub const DEFAULT_GC_INTERVAL_SECONDS: u64 = 60 * 10;
10
11#[derive(Clone)]
14pub struct InMemoryBackend {
15 map: Arc<DashMap<String, Value>>,
16 gc_handle: Option<Arc<JoinHandle<()>>>,
17}
18
19struct Value {
20 ttl: Instant,
21 count: u64,
22}
23
24impl InMemoryBackend {
25 pub fn builder() -> Builder {
26 Builder {
27 gc_interval: Some(Duration::from_secs(DEFAULT_GC_INTERVAL_SECONDS)),
28 }
29 }
30
31 fn garbage_collector(map: Arc<DashMap<String, Value>>, interval: Duration) -> JoinHandle<()> {
32 assert!(
33 interval.as_secs_f64() > 0f64,
34 "GC interval must be non-zero"
35 );
36 actix_web::rt::spawn(async move {
37 loop {
38 let now = Instant::now();
39 map.retain(|_k, v| v.ttl > now);
40 actix_web::rt::time::sleep_until(now + interval).await;
41 }
42 })
43 }
44}
45
46pub struct Builder {
47 gc_interval: Option<Duration>,
48}
49
50impl Builder {
51 pub fn with_gc_interval(mut self, interval: Option<Duration>) -> Self {
57 self.gc_interval = interval;
58 self
59 }
60
61 pub fn build(self) -> InMemoryBackend {
62 let map = Arc::new(DashMap::<String, Value>::new());
63 let gc_handle = self.gc_interval.map(|gc_interval| {
64 Arc::new(InMemoryBackend::garbage_collector(map.clone(), gc_interval))
65 });
66 InMemoryBackend { map, gc_handle }
67 }
68}
69
70impl Backend<SimpleInput> for InMemoryBackend {
71 type Output = SimpleOutput;
72 type RollbackToken = String;
73 type Error = Infallible;
74
75 async fn request(
76 &self,
77 input: SimpleInput,
78 ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> {
79 let now = Instant::now();
80 let mut count = 1;
81 let mut expiry = now
82 .checked_add(input.interval)
83 .expect("Interval unexpectedly large");
84 self.map
85 .entry(input.key.clone())
86 .and_modify(|v| {
87 if v.ttl > now {
89 v.count += 1;
90 count = v.count;
91 expiry = v.ttl;
92 } else {
93 v.ttl = expiry;
95 v.count = count;
96 }
97 })
98 .or_insert_with(|| Value {
99 ttl: expiry,
101 count,
102 });
103 let allow = count <= input.max_requests;
104 let output = SimpleOutput {
105 limit: input.max_requests,
106 remaining: input.max_requests.saturating_sub(count),
107 reset: expiry,
108 };
109 Ok((Decision::from_allowed(allow), output, input.key))
110 }
111
112 async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> {
113 self.map.entry(token).and_modify(|v| {
114 v.count = v.count.saturating_sub(1);
115 });
116 Ok(())
117 }
118}
119
120impl SimpleBackend for InMemoryBackend {
121 async fn remove_key(&self, key: &str) -> Result<(), Self::Error> {
122 self.map.remove(key);
123 Ok(())
124 }
125}
126
127impl Drop for InMemoryBackend {
128 fn drop(&mut self) {
129 if let Some(handle) = &self.gc_handle {
130 handle.abort();
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 const MINUTE: Duration = Duration::from_secs(60);
140
141 #[actix_web::test]
142 async fn test_allow_deny() {
143 tokio::time::pause();
144 let backend = InMemoryBackend::builder().build();
145 let input = SimpleInput {
146 interval: MINUTE,
147 max_requests: 5,
148 key: "KEY1".to_string(),
149 };
150 for _ in 0..5 {
151 let (allow, _, _) = backend.request(input.clone()).await.unwrap();
153 assert!(allow.is_allowed());
154 }
155 let (allow, _, _) = backend.request(input.clone()).await.unwrap();
157 assert!(!allow.is_allowed());
158 }
159
160 #[actix_web::test]
161 async fn test_reset() {
162 tokio::time::pause();
163 let backend = InMemoryBackend::builder().with_gc_interval(None).build();
164 let input = SimpleInput {
165 interval: MINUTE,
166 max_requests: 1,
167 key: "KEY1".to_string(),
168 };
169 let (decision, _, _) = backend.request(input.clone()).await.unwrap();
171 assert!(decision.is_allowed());
172 let (decision, _, _) = backend.request(input.clone()).await.unwrap();
174 assert!(decision.is_denied());
175 tokio::time::advance(MINUTE).await;
177 assert!(backend.map.contains_key("KEY1"));
179 let (decision, _, _) = backend.request(input).await.unwrap();
180 assert!(decision.is_allowed());
181 }
182
183 #[actix_web::test]
184 async fn test_garbage_collection() {
185 tokio::time::pause();
186 let backend = InMemoryBackend::builder()
187 .with_gc_interval(Some(MINUTE))
188 .build();
189 backend
190 .request(SimpleInput {
191 interval: MINUTE,
192 max_requests: 1,
193 key: "KEY1".to_string(),
194 })
195 .await
196 .unwrap();
197 backend
198 .request(SimpleInput {
199 interval: MINUTE * 2,
200 max_requests: 1,
201 key: "KEY2".to_string(),
202 })
203 .await
204 .unwrap();
205 assert!(backend.map.contains_key("KEY1"));
206 assert!(backend.map.contains_key("KEY2"));
207 tokio::time::advance(MINUTE).await;
210 assert!(!backend.map.contains_key("KEY1"));
211 assert!(backend.map.contains_key("KEY2"));
212 }
213
214 #[actix_web::test]
215 async fn test_output() {
216 tokio::time::pause();
217 let backend = InMemoryBackend::builder().build();
218 let input = SimpleInput {
219 interval: MINUTE,
220 max_requests: 2,
221 key: "KEY1".to_string(),
222 };
223 let (decision, output, _) = backend.request(input.clone()).await.unwrap();
225 assert!(decision.is_allowed());
226 assert_eq!(output.remaining, 1);
227 assert_eq!(output.limit, 2);
228 assert_eq!(output.reset, Instant::now() + MINUTE);
229 let (decision, output, _) = backend.request(input.clone()).await.unwrap();
231 assert!(decision.is_allowed());
232 assert_eq!(output.remaining, 0);
233 assert_eq!(output.limit, 2);
234 assert_eq!(output.reset, Instant::now() + MINUTE);
235 let (decision, output, _) = backend.request(input).await.unwrap();
237 assert!(decision.is_denied());
238 assert_eq!(output.remaining, 0);
239 assert_eq!(output.limit, 2);
240 assert_eq!(output.reset, Instant::now() + MINUTE);
241 }
242
243 #[actix_web::test]
244 async fn test_rollback() {
245 tokio::time::pause();
246 let backend = InMemoryBackend::builder().build();
247 let input = SimpleInput {
248 interval: MINUTE,
249 max_requests: 5,
250 key: "KEY1".to_string(),
251 };
252 let (_, output, rollback) = backend.request(input.clone()).await.unwrap();
253 assert_eq!(output.remaining, 4);
254 backend.rollback(rollback).await.unwrap();
255 let (_, output, _) = backend.request(input).await.unwrap();
257 assert_eq!(output.remaining, 4);
258 }
259
260 #[actix_web::test]
261 async fn test_remove_key() {
262 tokio::time::pause();
263 let backend = InMemoryBackend::builder().with_gc_interval(None).build();
264 let input = SimpleInput {
265 interval: MINUTE,
266 max_requests: 1,
267 key: "KEY1".to_string(),
268 };
269 let (decision, _, _) = backend.request(input.clone()).await.unwrap();
270 assert!(decision.is_allowed());
271 let (decision, _, _) = backend.request(input.clone()).await.unwrap();
272 assert!(decision.is_denied());
273 backend.remove_key("KEY1").await.unwrap();
274 let (decision, _, _) = backend.request(input).await.unwrap();
276 assert!(decision.is_allowed());
277 }
278}