actix_extensible_rate_limit/backend/
input_builder.rs1use crate::backend::SimpleInput;
2use actix_web::dev::ServiceRequest;
3use actix_web::ResponseError;
4use std::future::{ready, Ready};
5use std::net::{AddrParseError, IpAddr, Ipv6Addr};
6use std::time::Duration;
7use thiserror::Error;
8
9type CustomFn = Box<dyn Fn(&ServiceRequest) -> Result<String, actix_web::Error>>;
10
11pub type SimpleInputFuture = Ready<Result<SimpleInput, actix_web::Error>>;
12
13pub struct SimpleInputFunctionBuilder {
20 interval: Duration,
21 max_requests: u64,
22 real_ip_key: bool,
23 peer_ip_key: bool,
24 path_key: bool,
25 custom_key: Option<String>,
26 custom_fn: Option<CustomFn>,
27}
28
29impl SimpleInputFunctionBuilder {
30 pub fn new(interval: Duration, max_requests: u64) -> Self {
31 Self {
32 interval,
33 max_requests,
34 real_ip_key: false,
35 peer_ip_key: false,
36 path_key: false,
37 custom_key: None,
38 custom_fn: None,
39 }
40 }
41
42 pub fn real_ip_key(mut self) -> Self {
55 self.real_ip_key = true;
56 self
57 }
58
59 pub fn peer_ip_key(mut self) -> Self {
67 self.peer_ip_key = true;
68 self
69 }
70
71 pub fn path_key(mut self) -> Self {
73 self.path_key = true;
74 self
75 }
76
77 pub fn custom_key(mut self, key: &str) -> Self {
79 self.custom_key = Some(key.to_owned());
80 self
81 }
82
83 pub fn custom_fn<F>(mut self, f: F) -> Self
85 where
86 F: Fn(&ServiceRequest) -> Result<String, actix_web::Error> + 'static,
87 {
88 self.custom_fn = Some(Box::new(f));
89 self
90 }
91
92 pub fn build(self) -> impl Fn(&ServiceRequest) -> SimpleInputFuture + 'static {
93 move |req| {
94 ready((|| {
95 let mut components = Vec::new();
96 let info = req.connection_info();
97 if let Some(custom) = &self.custom_key {
98 components.push(custom.clone());
99 }
100 if self.real_ip_key {
101 components.push(ip_key(info.realip_remote_addr().unwrap())?)
102 }
103 if self.peer_ip_key {
104 components.push(ip_key(info.peer_addr().unwrap())?)
105 }
106 if self.path_key {
107 components.push(req.path().to_owned());
108 }
109 if let Some(f) = &self.custom_fn {
110 components.push(f(req)?)
111 }
112 let key = components.join("-");
113
114 Ok(SimpleInput {
115 interval: self.interval,
116 max_requests: self.max_requests,
117 key,
118 })
119 })())
120 }
121 }
122}
123
124#[derive(Debug, Error)]
125enum Error {
126 #[error("Unable to parse remote IP address: {0}")]
127 InvalidIpError(
128 #[source]
129 #[from]
130 AddrParseError,
131 ),
132}
133
134impl ResponseError for Error {}
135
136fn ip_key(ip_str: &str) -> Result<String, Error> {
140 let ip = ip_str.parse::<IpAddr>()?;
141 Ok(match ip {
142 IpAddr::V4(v4) => v4.to_string(),
143 IpAddr::V6(v6) => {
144 if let Some(v4) = v6.to_ipv4() {
145 return Ok(v4.to_string());
146 }
147 let zeroes = [0u16; 4];
148 let concat = [&v6.segments()[0..4], &zeroes].concat();
149 let concat: [u16; 8] = concat.try_into().unwrap();
150 let subnet = Ipv6Addr::from(concat);
151 format!("{}/64", subnet)
152 }
153 })
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_ip_key() {
162 assert_eq!(ip_key("142.250.187.206").unwrap(), "142.250.187.206");
164 assert_eq!(ip_key("::FFFF:142.250.187.206").unwrap(), "142.250.187.206");
166 assert_eq!(
168 ip_key("2a00:1450:4009:81f::200e").unwrap(),
169 "2a00:1450:4009:81f::/64"
170 );
171 }
172}