1use base64::prelude::*;
2use rand::{thread_rng, Rng};
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use url::Url;
6
7use std::fmt::Error as FormatterError;
8use std::fmt::{Debug, Formatter};
9#[cfg(feature = "timing-resistant-secret-traits")]
10use std::hash::{Hash, Hasher};
11use std::ops::Deref;
12
13macro_rules! new_type {
14 (
16 $(#[$attr:meta])*
17 $name:ident(
18 $(#[$type_attr:meta])*
19 $type:ty
20 )
21 ) => {
22 new_type![
23 @new_type $(#[$attr])*,
24 $name(
25 $(#[$type_attr])*
26 $type
27 ),
28 concat!(
29 "Create a new `",
30 stringify!($name),
31 "` to wrap the given `",
32 stringify!($type),
33 "`."
34 ),
35 impl {}
36 ];
37 };
38 (
40 $(#[$attr:meta])*
41 $name:ident(
42 $(#[$type_attr:meta])*
43 $type:ty
44 )
45 impl {
46 $($item:tt)*
47 }
48 ) => {
49 new_type![
50 @new_type $(#[$attr])*,
51 $name(
52 $(#[$type_attr])*
53 $type
54 ),
55 concat!(
56 "Create a new `",
57 stringify!($name),
58 "` to wrap the given `",
59 stringify!($type),
60 "`."
61 ),
62 impl {
63 $($item)*
64 }
65 ];
66 };
67 (
69 @new_type $(#[$attr:meta])*,
70 $name:ident(
71 $(#[$type_attr:meta])*
72 $type:ty
73 ),
74 $new_doc:expr,
75 impl {
76 $($item:tt)*
77 }
78 ) => {
79 $(#[$attr])*
80 #[derive(Clone, Debug, PartialEq)]
81 pub struct $name(
82 $(#[$type_attr])*
83 $type
84 );
85 impl $name {
86 $($item)*
87
88 #[doc = $new_doc]
89 pub const fn new(s: $type) -> Self {
90 $name(s)
91 }
92 }
93 impl Deref for $name {
94 type Target = $type;
95 fn deref(&self) -> &$type {
96 &self.0
97 }
98 }
99 impl From<$name> for $type {
100 fn from(t: $name) -> $type {
101 t.0
102 }
103 }
104 }
105}
106
107macro_rules! new_secret_type {
108 (
109 $(#[$attr:meta])*
110 $name:ident($type:ty)
111 ) => {
112 new_secret_type![
113 $(#[$attr])*
114 $name($type)
115 impl {}
116 ];
117 };
118 (
119 $(#[$attr:meta])*
120 $name:ident($type:ty)
121 impl {
122 $($item:tt)*
123 }
124 ) => {
125 new_secret_type![
126 $(#[$attr])*,
127 $name($type),
128 concat!(
129 "Create a new `",
130 stringify!($name),
131 "` to wrap the given `",
132 stringify!($type),
133 "`."
134 ),
135 concat!("Get the secret contained within this `", stringify!($name), "`."),
136 impl {
137 $($item)*
138 }
139 ];
140 };
141 (
142 $(#[$attr:meta])*,
143 $name:ident($type:ty),
144 $new_doc:expr,
145 $secret_doc:expr,
146 impl {
147 $($item:tt)*
148 }
149 ) => {
150 $(
151 #[$attr]
152 )*
153 #[cfg_attr(feature = "timing-resistant-secret-traits", derive(Eq))]
154 pub struct $name($type);
155 impl $name {
156 $($item)*
157
158 #[doc = $new_doc]
159 pub fn new(s: $type) -> Self {
160 $name(s)
161 }
162
163 #[doc = $secret_doc]
164 pub fn secret(&self) -> &$type { &self.0 }
169
170 #[doc = $secret_doc]
171 pub fn into_secret(self) -> $type { self.0 }
176 }
177 impl Debug for $name {
178 fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
179 write!(f, concat!(stringify!($name), "([redacted])"))
180 }
181 }
182
183 #[cfg(feature = "timing-resistant-secret-traits")]
184 impl PartialEq for $name {
185 fn eq(&self, other: &Self) -> bool {
186 Sha256::digest(&self.0) == Sha256::digest(&other.0)
187 }
188 }
189
190 #[cfg(feature = "timing-resistant-secret-traits")]
191 impl Hash for $name {
192 fn hash<H: Hasher>(&self, state: &mut H) {
193 Sha256::digest(&self.0).hash(state)
194 }
195 }
196
197 };
198}
199
200macro_rules! new_url_type {
211 (
213 $(#[$attr:meta])*
214 $name:ident
215 ) => {
216 new_url_type![
217 @new_type_pub $(#[$attr])*,
218 $name,
219 concat!("Create a new `", stringify!($name), "` from a `String` to wrap a URL."),
220 concat!("Create a new `", stringify!($name), "` from a `Url` to wrap a URL."),
221 concat!("Return this `", stringify!($name), "` as a parsed `Url`."),
222 impl {}
223 ];
224 };
225 (
227 $(#[$attr:meta])*
228 $name:ident
229 impl {
230 $($item:tt)*
231 }
232 ) => {
233 new_url_type![
234 @new_type_pub $(#[$attr])*,
235 $name,
236 concat!("Create a new `", stringify!($name), "` from a `String` to wrap a URL."),
237 concat!("Create a new `", stringify!($name), "` from a `Url` to wrap a URL."),
238 concat!("Return this `", stringify!($name), "` as a parsed `Url`."),
239 impl {
240 $($item)*
241 }
242 ];
243 };
244 (
246 @new_type_pub $(#[$attr:meta])*,
247 $name:ident,
248 $new_doc:expr,
249 $from_url_doc:expr,
250 $url_doc:expr,
251 impl {
252 $($item:tt)*
253 }
254 ) => {
255 $(#[$attr])*
256 #[derive(Clone)]
257 pub struct $name(Url, String);
258 impl $name {
259 #[doc = $new_doc]
260 pub fn new(url: String) -> Result<Self, ::url::ParseError> {
261 Ok($name(Url::parse(&url)?, url))
262 }
263 #[doc = $from_url_doc]
264 pub fn from_url(url: Url) -> Self {
265 let s = url.to_string();
266 Self(url, s)
267 }
268 #[doc = $url_doc]
269 pub fn url(&self) -> &Url {
270 return &self.0;
271 }
272 $($item)*
273 }
274 impl Deref for $name {
275 type Target = String;
276 fn deref(&self) -> &String {
277 &self.1
278 }
279 }
280 impl ::std::fmt::Display for $name {
281 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
282 write!(f, "{}", self.1)
283 }
284 }
285 impl ::std::fmt::Debug for $name {
286 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
287 let mut debug_trait_builder = f.debug_tuple(stringify!($name));
288 debug_trait_builder.field(&self.1);
289 debug_trait_builder.finish()
290 }
291 }
292 impl<'de> ::serde::Deserialize<'de> for $name {
293 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
294 where
295 D: ::serde::de::Deserializer<'de>,
296 {
297 struct UrlVisitor;
298 impl<'de> ::serde::de::Visitor<'de> for UrlVisitor {
299 type Value = $name;
300
301 fn expecting(
302 &self,
303 formatter: &mut ::std::fmt::Formatter
304 ) -> ::std::fmt::Result {
305 formatter.write_str(stringify!($name))
306 }
307
308 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
309 where
310 E: ::serde::de::Error,
311 {
312 $name::new(v.to_string()).map_err(E::custom)
313 }
314 }
315 deserializer.deserialize_str(UrlVisitor {})
316 }
317 }
318 impl ::serde::Serialize for $name {
319 fn serialize<SE>(&self, serializer: SE) -> Result<SE::Ok, SE::Error>
320 where
321 SE: ::serde::Serializer,
322 {
323 serializer.serialize_str(&self.1)
324 }
325 }
326 impl ::std::hash::Hash for $name {
327 fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) -> () {
328 ::std::hash::Hash::hash(&(self.1), state);
329 }
330 }
331 impl Ord for $name {
332 fn cmp(&self, other: &$name) -> ::std::cmp::Ordering {
333 self.1.cmp(&other.1)
334 }
335 }
336 impl PartialOrd for $name {
337 fn partial_cmp(&self, other: &$name) -> Option<::std::cmp::Ordering> {
338 Some(self.cmp(other))
339 }
340 }
341 impl PartialEq for $name {
342 fn eq(&self, other: &$name) -> bool {
343 self.1 == other.1
344 }
345 }
346 impl Eq for $name {}
347 };
348}
349
350new_type![
351 #[derive(Deserialize, Serialize, Eq, Hash)]
354 ClientId(String)
355];
356
357new_url_type![
358 AuthUrl
360];
361new_url_type![
362 TokenUrl
364];
365new_url_type![
366 RedirectUrl
368];
369new_url_type![
370 IntrospectionUrl
372];
373new_url_type![
374 RevocationUrl
376];
377new_url_type![
378 DeviceAuthorizationUrl
380];
381new_url_type![
382 EndUserVerificationUrl
384];
385new_type![
386 #[derive(Deserialize, Serialize, Eq, Hash)]
389 ResponseType(String)
390];
391new_type![
392 #[derive(Deserialize, Serialize, Eq, Hash)]
395 ResourceOwnerUsername(String)
396];
397
398new_type![
399 #[derive(Deserialize, Serialize, Eq, Hash)]
401 Scope(String)
402];
403impl AsRef<str> for Scope {
404 fn as_ref(&self) -> &str {
405 self
406 }
407}
408
409new_type![
410 #[derive(Deserialize, Serialize, Eq, Hash)]
413 PkceCodeChallengeMethod(String)
414];
415new_secret_type![
418 #[derive(Deserialize, Serialize)]
423 PkceCodeVerifier(String)
424];
425
426#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)]
429pub struct PkceCodeChallenge {
430 code_challenge: String,
431 code_challenge_method: PkceCodeChallengeMethod,
432}
433impl PkceCodeChallenge {
434 pub fn new_random_sha256() -> (Self, PkceCodeVerifier) {
436 Self::new_random_sha256_len(32)
437 }
438
439 pub fn new_random_sha256_len(num_bytes: u32) -> (Self, PkceCodeVerifier) {
452 let code_verifier = Self::new_random_len(num_bytes);
453 (
454 Self::from_code_verifier_sha256(&code_verifier),
455 code_verifier,
456 )
457 }
458
459 fn new_random_len(num_bytes: u32) -> PkceCodeVerifier {
472 assert!((32..=96).contains(&num_bytes));
476 let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().gen::<u8>()).collect();
477 PkceCodeVerifier::new(BASE64_URL_SAFE_NO_PAD.encode(random_bytes))
478 }
479
480 pub fn from_code_verifier_sha256(code_verifier: &PkceCodeVerifier) -> Self {
487 assert!(code_verifier.secret().len() >= 43 && code_verifier.secret().len() <= 128);
490
491 let digest = Sha256::digest(code_verifier.secret().as_bytes());
492 let code_challenge = BASE64_URL_SAFE_NO_PAD.encode(digest);
493
494 Self {
495 code_challenge,
496 code_challenge_method: PkceCodeChallengeMethod::new("S256".to_string()),
497 }
498 }
499
500 #[cfg(feature = "pkce-plain")]
508 pub fn new_random_plain() -> (Self, PkceCodeVerifier) {
509 let code_verifier = Self::new_random_len(32);
510 (
511 Self::from_code_verifier_plain(&code_verifier),
512 code_verifier,
513 )
514 }
515
516 #[cfg(feature = "pkce-plain")]
524 pub fn from_code_verifier_plain(code_verifier: &PkceCodeVerifier) -> Self {
525 assert!(code_verifier.secret().len() >= 43 && code_verifier.secret().len() <= 128);
528
529 let code_challenge = code_verifier.secret().clone();
530
531 Self {
532 code_challenge,
533 code_challenge_method: PkceCodeChallengeMethod::new("plain".to_string()),
534 }
535 }
536
537 pub fn as_str(&self) -> &str {
539 &self.code_challenge
540 }
541
542 pub fn method(&self) -> &PkceCodeChallengeMethod {
544 &self.code_challenge_method
545 }
546}
547
548new_secret_type![
549 #[derive(Clone, Deserialize, Serialize)]
552 ClientSecret(String)
553];
554new_secret_type![
555 #[must_use]
558 #[derive(Clone, Deserialize, Serialize)]
559 CsrfToken(String)
560 impl {
561 pub fn new_random() -> Self {
563 CsrfToken::new_random_len(16)
564 }
565 pub fn new_random_len(num_bytes: u32) -> Self {
571 let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().gen::<u8>()).collect();
572 CsrfToken::new(BASE64_URL_SAFE_NO_PAD.encode(random_bytes))
573 }
574 }
575];
576new_secret_type![
577 #[derive(Clone, Deserialize, Serialize)]
579 AuthorizationCode(String)
580];
581new_secret_type![
582 #[derive(Clone, Deserialize, Serialize)]
584 RefreshToken(String)
585];
586new_secret_type![
587 #[derive(Clone, Deserialize, Serialize)]
589 AccessToken(String)
590];
591new_secret_type![
592 #[derive(Clone)]
595 ResourceOwnerPassword(String)
596];
597new_secret_type![
598 #[derive(Clone, Deserialize, Serialize)]
600 DeviceCode(String)
601];
602new_secret_type![
603 #[derive(Clone, Deserialize, Serialize)]
606 VerificationUriComplete(String)
607];
608new_secret_type![
609 #[derive(Clone, Deserialize, Serialize)]
612 UserCode(String)
613];
614
615#[cfg(test)]
616mod tests {
617 use crate::{ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier};
618
619 #[test]
620 fn test_secret_conversion() {
621 let secret = CsrfToken::new("top_secret".into());
622 assert_eq!(secret.into_secret().into_boxed_str(), "top_secret".into());
623 }
624
625 #[test]
626 fn test_secret_redaction() {
627 let secret = ClientSecret::new("top_secret".to_string());
628 assert_eq!("ClientSecret([redacted])", format!("{secret:?}"));
629 }
630
631 #[test]
632 #[should_panic]
633 fn test_code_verifier_too_short() {
634 PkceCodeChallenge::new_random_sha256_len(31);
635 }
636
637 #[test]
638 #[should_panic]
639 fn test_code_verifier_too_long() {
640 PkceCodeChallenge::new_random_sha256_len(97);
641 }
642
643 #[test]
644 fn test_code_verifier_min() {
645 let code = PkceCodeChallenge::new_random_sha256_len(32);
646 assert_eq!(code.1.secret().len(), 43);
647 }
648
649 #[test]
650 fn test_code_verifier_max() {
651 let code = PkceCodeChallenge::new_random_sha256_len(96);
652 assert_eq!(code.1.secret().len(), 128);
653 }
654
655 #[test]
656 fn test_code_verifier_challenge() {
657 let code_verifier =
659 PkceCodeVerifier::new("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string());
660 assert_eq!(
661 PkceCodeChallenge::from_code_verifier_sha256(&code_verifier).as_str(),
662 "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
663 );
664 }
665}