1use std::collections::HashSet;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use serde_json::map::Map;
5use serde_json::{from_value, Value};
6
7use crate::algorithms::Algorithm;
8use crate::errors::{new_error, ErrorKind, Result};
9
10#[derive(Debug, Clone, PartialEq)]
29pub struct Validation {
30 pub leeway: u64,
35 pub validate_exp: bool,
41 pub validate_nbf: bool,
47 pub aud: Option<HashSet<String>>,
52 pub iss: Option<String>,
57 pub sub: Option<String>,
62 pub algorithms: Vec<Algorithm>,
67}
68
69impl Validation {
70 pub fn new(alg: Algorithm) -> Validation {
72 let mut validation = Validation::default();
73 validation.algorithms = vec![alg];
74 validation
75 }
76
77 pub fn set_audience<T: ToString>(&mut self, items: &[T]) {
79 self.aud = Some(items.iter().map(|x| x.to_string()).collect())
80 }
81}
82
83impl Default for Validation {
84 fn default() -> Validation {
85 Validation {
86 leeway: 0,
87
88 validate_exp: true,
89 validate_nbf: false,
90
91 iss: None,
92 sub: None,
93 aud: None,
94
95 algorithms: vec![Algorithm::HS256],
96 }
97 }
98}
99
100fn get_current_timestamp() -> u64 {
101 let start = SystemTime::now();
102 start.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs()
103}
104
105pub fn validate(claims: &Map<String, Value>, options: &Validation) -> Result<()> {
106 let now = get_current_timestamp();
107
108 if options.validate_exp {
109 if let Some(exp) = claims.get("exp") {
110 if from_value::<u64>(exp.clone())? < now - options.leeway {
111 return Err(new_error(ErrorKind::ExpiredSignature));
112 }
113 } else {
114 return Err(new_error(ErrorKind::ExpiredSignature));
115 }
116 }
117
118 if options.validate_nbf {
119 if let Some(nbf) = claims.get("nbf") {
120 if from_value::<u64>(nbf.clone())? > now + options.leeway {
121 return Err(new_error(ErrorKind::ImmatureSignature));
122 }
123 } else {
124 return Err(new_error(ErrorKind::ImmatureSignature));
125 }
126 }
127
128 if let Some(ref correct_iss) = options.iss {
129 if let Some(iss) = claims.get("iss") {
130 if from_value::<String>(iss.clone())? != *correct_iss {
131 return Err(new_error(ErrorKind::InvalidIssuer));
132 }
133 } else {
134 return Err(new_error(ErrorKind::InvalidIssuer));
135 }
136 }
137
138 if let Some(ref correct_sub) = options.sub {
139 if let Some(sub) = claims.get("sub") {
140 if from_value::<String>(sub.clone())? != *correct_sub {
141 return Err(new_error(ErrorKind::InvalidSubject));
142 }
143 } else {
144 return Err(new_error(ErrorKind::InvalidSubject));
145 }
146 }
147
148 if let Some(ref correct_aud) = options.aud {
149 if let Some(aud) = claims.get("aud") {
150 match aud {
151 Value::String(aud_found) => {
152 if !correct_aud.contains(aud_found) {
153 return Err(new_error(ErrorKind::InvalidAudience));
154 }
155 }
156 Value::Array(_) => {
157 let provided_aud: HashSet<String> = from_value(aud.clone())?;
158 if provided_aud.intersection(correct_aud).count() == 0 {
159 return Err(new_error(ErrorKind::InvalidAudience));
160 }
161 }
162 _ => return Err(new_error(ErrorKind::InvalidAudience)),
163 };
164 } else {
165 return Err(new_error(ErrorKind::InvalidAudience));
166 }
167 }
168
169 Ok(())
170}
171
172#[cfg(test)]
173mod tests {
174 use serde_json::map::Map;
175 use serde_json::to_value;
176
177 use super::{get_current_timestamp, validate, Validation};
178
179 use crate::errors::ErrorKind;
180
181 #[test]
182 fn exp_in_future_ok() {
183 let mut claims = Map::new();
184 claims.insert("exp".to_string(), to_value(get_current_timestamp() + 10000).unwrap());
185 let res = validate(&claims, &Validation::default());
186 assert!(res.is_ok());
187 }
188
189 #[test]
190 fn exp_in_past_fails() {
191 let mut claims = Map::new();
192 claims.insert("exp".to_string(), to_value(get_current_timestamp() - 100000).unwrap());
193 let res = validate(&claims, &Validation::default());
194 assert!(res.is_err());
195
196 match res.unwrap_err().kind() {
197 &ErrorKind::ExpiredSignature => (),
198 _ => assert!(false),
199 };
200 }
201
202 #[test]
203 fn exp_in_past_but_in_leeway_ok() {
204 let mut claims = Map::new();
205 claims.insert("exp".to_string(), to_value(get_current_timestamp() - 500).unwrap());
206 let validation = Validation { leeway: 1000 * 60, ..Default::default() };
207 let res = validate(&claims, &validation);
208 assert!(res.is_ok());
209 }
210
211 #[test]
213 fn validation_called_even_if_field_is_empty() {
214 let claims = Map::new();
215 let res = validate(&claims, &Validation::default());
216 assert!(res.is_err());
217 match res.unwrap_err().kind() {
218 &ErrorKind::ExpiredSignature => (),
219 _ => assert!(false),
220 };
221 }
222
223 #[test]
224 fn nbf_in_past_ok() {
225 let mut claims = Map::new();
226 claims.insert("nbf".to_string(), to_value(get_current_timestamp() - 10000).unwrap());
227 let validation =
228 Validation { validate_exp: false, validate_nbf: true, ..Validation::default() };
229 let res = validate(&claims, &validation);
230 assert!(res.is_ok());
231 }
232
233 #[test]
234 fn nbf_in_future_fails() {
235 let mut claims = Map::new();
236 claims.insert("nbf".to_string(), to_value(get_current_timestamp() + 100000).unwrap());
237 let validation =
238 Validation { validate_exp: false, validate_nbf: true, ..Validation::default() };
239 let res = validate(&claims, &validation);
240 assert!(res.is_err());
241
242 match res.unwrap_err().kind() {
243 &ErrorKind::ImmatureSignature => (),
244 _ => assert!(false),
245 };
246 }
247
248 #[test]
249 fn nbf_in_future_but_in_leeway_ok() {
250 let mut claims = Map::new();
251 claims.insert("nbf".to_string(), to_value(get_current_timestamp() + 500).unwrap());
252 let validation = Validation {
253 leeway: 1000 * 60,
254 validate_nbf: true,
255 validate_exp: false,
256 ..Default::default()
257 };
258 let res = validate(&claims, &validation);
259 assert!(res.is_ok());
260 }
261
262 #[test]
263 fn iss_ok() {
264 let mut claims = Map::new();
265 claims.insert("iss".to_string(), to_value("Keats").unwrap());
266 let validation = Validation {
267 validate_exp: false,
268 iss: Some("Keats".to_string()),
269 ..Default::default()
270 };
271 let res = validate(&claims, &validation);
272 assert!(res.is_ok());
273 }
274
275 #[test]
276 fn iss_not_matching_fails() {
277 let mut claims = Map::new();
278 claims.insert("iss".to_string(), to_value("Hacked").unwrap());
279 let validation = Validation {
280 validate_exp: false,
281 iss: Some("Keats".to_string()),
282 ..Default::default()
283 };
284 let res = validate(&claims, &validation);
285 assert!(res.is_err());
286
287 match res.unwrap_err().kind() {
288 &ErrorKind::InvalidIssuer => (),
289 _ => assert!(false),
290 };
291 }
292
293 #[test]
294 fn iss_missing_fails() {
295 let claims = Map::new();
296 let validation = Validation {
297 validate_exp: false,
298 iss: Some("Keats".to_string()),
299 ..Default::default()
300 };
301 let res = validate(&claims, &validation);
302 assert!(res.is_err());
303
304 match res.unwrap_err().kind() {
305 &ErrorKind::InvalidIssuer => (),
306 _ => assert!(false),
307 };
308 }
309
310 #[test]
311 fn sub_ok() {
312 let mut claims = Map::new();
313 claims.insert("sub".to_string(), to_value("Keats").unwrap());
314 let validation = Validation {
315 validate_exp: false,
316 sub: Some("Keats".to_string()),
317 ..Default::default()
318 };
319 let res = validate(&claims, &validation);
320 assert!(res.is_ok());
321 }
322
323 #[test]
324 fn sub_not_matching_fails() {
325 let mut claims = Map::new();
326 claims.insert("sub".to_string(), to_value("Hacked").unwrap());
327 let validation = Validation {
328 validate_exp: false,
329 sub: Some("Keats".to_string()),
330 ..Default::default()
331 };
332 let res = validate(&claims, &validation);
333 assert!(res.is_err());
334
335 match res.unwrap_err().kind() {
336 &ErrorKind::InvalidSubject => (),
337 _ => assert!(false),
338 };
339 }
340
341 #[test]
342 fn sub_missing_fails() {
343 let claims = Map::new();
344 let validation = Validation {
345 validate_exp: false,
346 sub: Some("Keats".to_string()),
347 ..Default::default()
348 };
349 let res = validate(&claims, &validation);
350 assert!(res.is_err());
351
352 match res.unwrap_err().kind() {
353 &ErrorKind::InvalidSubject => (),
354 _ => assert!(false),
355 };
356 }
357
358 #[test]
359 fn aud_string_ok() {
360 let mut claims = Map::new();
361 claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
362 let mut validation = Validation { validate_exp: false, ..Validation::default() };
363 validation.set_audience(&["Everyone"]);
364 let res = validate(&claims, &validation);
365 assert!(res.is_ok());
366 }
367
368 #[test]
369 fn aud_array_of_string_ok() {
370 let mut claims = Map::new();
371 claims.insert("aud".to_string(), to_value(["UserA", "UserB"]).unwrap());
372 let mut validation = Validation { validate_exp: false, ..Validation::default() };
373 validation.set_audience(&["UserA", "UserB"]);
374 let res = validate(&claims, &validation);
375 assert!(res.is_ok());
376 }
377
378 #[test]
379 fn aud_type_mismatch_fails() {
380 let mut claims = Map::new();
381 claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
382 let mut validation = Validation { validate_exp: false, ..Validation::default() };
383 validation.set_audience(&["UserA", "UserB"]);
384 let res = validate(&claims, &validation);
385 assert!(res.is_err());
386
387 match res.unwrap_err().kind() {
388 &ErrorKind::InvalidAudience => (),
389 _ => assert!(false),
390 };
391 }
392
393 #[test]
394 fn aud_correct_type_not_matching_fails() {
395 let mut claims = Map::new();
396 claims.insert("aud".to_string(), to_value(["Everyone"]).unwrap());
397 let mut validation = Validation { validate_exp: false, ..Validation::default() };
398 validation.set_audience(&["None"]);
399 let res = validate(&claims, &validation);
400 assert!(res.is_err());
401
402 match res.unwrap_err().kind() {
403 &ErrorKind::InvalidAudience => (),
404 _ => assert!(false),
405 };
406 }
407
408 #[test]
409 fn aud_missing_fails() {
410 let claims = Map::new();
411 let mut validation = Validation { validate_exp: false, ..Validation::default() };
412 validation.set_audience(&["None"]);
413 let res = validate(&claims, &validation);
414 assert!(res.is_err());
415
416 match res.unwrap_err().kind() {
417 &ErrorKind::InvalidAudience => (),
418 _ => assert!(false),
419 };
420 }
421
422 #[test]
424 fn does_validation_in_right_order() {
425 let mut claims = Map::new();
426 claims.insert("exp".to_string(), to_value(get_current_timestamp() + 10000).unwrap());
427 let v = Validation {
428 leeway: 5,
429 validate_exp: true,
430 iss: Some("iss no check".to_string()),
431 sub: Some("sub no check".to_string()),
432 ..Validation::default()
433 };
434 let res = validate(&claims, &v);
435 assert!(res.is_err());
437 match res.unwrap_err().kind() {
438 &ErrorKind::InvalidIssuer => (),
439 t @ _ => {
440 println!("{:?}", t);
441 assert!(false)
442 }
443 };
444 }
445
446 #[test]
448 fn aud_use_validation_struct() {
449 let mut claims = Map::new();
450 claims.insert(
451 "aud".to_string(),
452 to_value("my-googleclientid1234.apps.googleusercontent.com").unwrap(),
453 );
454
455 let aud = "my-googleclientid1234.apps.googleusercontent.com".to_string();
456 let mut aud_hashset = std::collections::HashSet::new();
457 aud_hashset.insert(aud);
458
459 let validation =
460 Validation { aud: Some(aud_hashset), validate_exp: false, ..Validation::default() };
461 let res = validate(&claims, &validation);
462 println!("{:?}", res);
463 assert!(res.is_ok());
464 }
465}