oauth2/
types.rs

1use base64::prelude::*;
2use rand::{thread_rng, Rng};
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use url::Url;
6
7use std::fmt::Error as FormatterError;
8use std::fmt::{Debug, Formatter};
9#[cfg(feature = "timing-resistant-secret-traits")]
10use std::hash::{Hash, Hasher};
11use std::ops::Deref;
12
13macro_rules! new_type {
14    // Convenience pattern without an impl.
15    (
16        $(#[$attr:meta])*
17        $name:ident(
18            $(#[$type_attr:meta])*
19            $type:ty
20        )
21    ) => {
22        new_type![
23            @new_type $(#[$attr])*,
24            $name(
25                $(#[$type_attr])*
26                $type
27            ),
28            concat!(
29                "Create a new `",
30                stringify!($name),
31                "` to wrap the given `",
32                stringify!($type),
33                "`."
34            ),
35            impl {}
36        ];
37    };
38    // Main entry point with an impl.
39    (
40        $(#[$attr:meta])*
41        $name:ident(
42            $(#[$type_attr:meta])*
43            $type:ty
44        )
45        impl {
46            $($item:tt)*
47        }
48    ) => {
49        new_type![
50            @new_type $(#[$attr])*,
51            $name(
52                $(#[$type_attr])*
53                $type
54            ),
55            concat!(
56                "Create a new `",
57                stringify!($name),
58                "` to wrap the given `",
59                stringify!($type),
60                "`."
61            ),
62            impl {
63                $($item)*
64            }
65        ];
66    };
67    // Actual implementation, after stringifying the #[doc] attr.
68    (
69        @new_type $(#[$attr:meta])*,
70        $name:ident(
71            $(#[$type_attr:meta])*
72            $type:ty
73        ),
74        $new_doc:expr,
75        impl {
76            $($item:tt)*
77        }
78    ) => {
79        $(#[$attr])*
80        #[derive(Clone, Debug, PartialEq)]
81        pub struct $name(
82            $(#[$type_attr])*
83            $type
84        );
85        impl $name {
86            $($item)*
87
88            #[doc = $new_doc]
89            pub const fn new(s: $type) -> Self {
90                $name(s)
91            }
92        }
93        impl Deref for $name {
94            type Target = $type;
95            fn deref(&self) -> &$type {
96                &self.0
97            }
98        }
99        impl From<$name> for $type {
100            fn from(t: $name) -> $type {
101                t.0
102            }
103        }
104    }
105}
106
107macro_rules! new_secret_type {
108    (
109        $(#[$attr:meta])*
110        $name:ident($type:ty)
111    ) => {
112        new_secret_type![
113            $(#[$attr])*
114            $name($type)
115            impl {}
116        ];
117    };
118    (
119        $(#[$attr:meta])*
120        $name:ident($type:ty)
121        impl {
122            $($item:tt)*
123        }
124    ) => {
125        new_secret_type![
126            $(#[$attr])*,
127            $name($type),
128            concat!(
129                "Create a new `",
130                stringify!($name),
131                "` to wrap the given `",
132                stringify!($type),
133                "`."
134            ),
135            concat!("Get the secret contained within this `", stringify!($name), "`."),
136            impl {
137                $($item)*
138            }
139        ];
140    };
141    (
142        $(#[$attr:meta])*,
143        $name:ident($type:ty),
144        $new_doc:expr,
145        $secret_doc:expr,
146        impl {
147            $($item:tt)*
148        }
149    ) => {
150        $(
151            #[$attr]
152        )*
153        #[cfg_attr(feature = "timing-resistant-secret-traits", derive(Eq))]
154        pub struct $name($type);
155        impl $name {
156            $($item)*
157
158            #[doc = $new_doc]
159            pub fn new(s: $type) -> Self {
160                $name(s)
161            }
162
163            #[doc = $secret_doc]
164            ///
165            /// # Security Warning
166            ///
167            /// Leaking this value may compromise the security of the OAuth2 flow.
168            pub fn secret(&self) -> &$type { &self.0 }
169
170            #[doc = $secret_doc]
171            ///
172            /// # Security Warning
173            ///
174            /// Leaking this value may compromise the security of the OAuth2 flow.
175            pub fn into_secret(self) -> $type { self.0 }
176        }
177        impl Debug for $name {
178            fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
179                write!(f, concat!(stringify!($name), "([redacted])"))
180            }
181        }
182
183        #[cfg(feature = "timing-resistant-secret-traits")]
184        impl PartialEq for $name {
185            fn eq(&self, other: &Self) -> bool {
186                Sha256::digest(&self.0) == Sha256::digest(&other.0)
187            }
188        }
189
190        #[cfg(feature = "timing-resistant-secret-traits")]
191        impl Hash for $name {
192            fn hash<H: Hasher>(&self, state: &mut H) {
193                Sha256::digest(&self.0).hash(state)
194            }
195        }
196
197    };
198}
199
200/// Creates a URL-specific new type
201///
202/// Types created by this macro enforce during construction that the contained value represents a
203/// syntactically valid URL. However, comparisons and hashes of these types are based on the string
204/// representation given during construction, disregarding any canonicalization performed by the
205/// underlying `Url` struct. OpenID Connect requires certain URLs (e.g., ID token issuers) to be
206/// compared exactly, without canonicalization.
207///
208/// In addition to the raw string representation, these types include a `url` method to retrieve a
209/// parsed `Url` struct.
210macro_rules! new_url_type {
211    // Convenience pattern without an impl.
212    (
213        $(#[$attr:meta])*
214        $name:ident
215    ) => {
216        new_url_type![
217            @new_type_pub $(#[$attr])*,
218            $name,
219            concat!("Create a new `", stringify!($name), "` from a `String` to wrap a URL."),
220            concat!("Create a new `", stringify!($name), "` from a `Url` to wrap a URL."),
221            concat!("Return this `", stringify!($name), "` as a parsed `Url`."),
222            impl {}
223        ];
224    };
225    // Main entry point with an impl.
226    (
227        $(#[$attr:meta])*
228        $name:ident
229        impl {
230            $($item:tt)*
231        }
232    ) => {
233        new_url_type![
234            @new_type_pub $(#[$attr])*,
235            $name,
236            concat!("Create a new `", stringify!($name), "` from a `String` to wrap a URL."),
237            concat!("Create a new `", stringify!($name), "` from a `Url` to wrap a URL."),
238            concat!("Return this `", stringify!($name), "` as a parsed `Url`."),
239            impl {
240                $($item)*
241            }
242        ];
243    };
244    // Actual implementation, after stringifying the #[doc] attr.
245    (
246        @new_type_pub $(#[$attr:meta])*,
247        $name:ident,
248        $new_doc:expr,
249        $from_url_doc:expr,
250        $url_doc:expr,
251        impl {
252            $($item:tt)*
253        }
254    ) => {
255        $(#[$attr])*
256        #[derive(Clone)]
257        pub struct $name(Url, String);
258        impl $name {
259            #[doc = $new_doc]
260            pub fn new(url: String) -> Result<Self, ::url::ParseError> {
261                Ok($name(Url::parse(&url)?, url))
262            }
263            #[doc = $from_url_doc]
264            pub fn from_url(url: Url) -> Self {
265                let s = url.to_string();
266                Self(url, s)
267            }
268            #[doc = $url_doc]
269            pub fn url(&self) -> &Url {
270                return &self.0;
271            }
272            $($item)*
273        }
274        impl Deref for $name {
275            type Target = String;
276            fn deref(&self) -> &String {
277                &self.1
278            }
279        }
280        impl ::std::fmt::Display for $name {
281            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
282                write!(f, "{}", self.1)
283            }
284        }
285        impl ::std::fmt::Debug for $name {
286            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
287                let mut debug_trait_builder = f.debug_tuple(stringify!($name));
288                debug_trait_builder.field(&self.1);
289                debug_trait_builder.finish()
290            }
291        }
292        impl<'de> ::serde::Deserialize<'de> for $name {
293            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
294            where
295                D: ::serde::de::Deserializer<'de>,
296            {
297                struct UrlVisitor;
298                impl<'de> ::serde::de::Visitor<'de> for UrlVisitor {
299                    type Value = $name;
300
301                    fn expecting(
302                        &self,
303                        formatter: &mut ::std::fmt::Formatter
304                    ) -> ::std::fmt::Result {
305                        formatter.write_str(stringify!($name))
306                    }
307
308                    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
309                    where
310                        E: ::serde::de::Error,
311                    {
312                        $name::new(v.to_string()).map_err(E::custom)
313                    }
314                }
315                deserializer.deserialize_str(UrlVisitor {})
316            }
317        }
318        impl ::serde::Serialize for $name {
319            fn serialize<SE>(&self, serializer: SE) -> Result<SE::Ok, SE::Error>
320            where
321                SE: ::serde::Serializer,
322            {
323                serializer.serialize_str(&self.1)
324            }
325        }
326        impl ::std::hash::Hash for $name {
327            fn hash<H: ::std::hash::Hasher>(&self, state: &mut H) -> () {
328                ::std::hash::Hash::hash(&(self.1), state);
329            }
330        }
331        impl Ord for $name {
332            fn cmp(&self, other: &$name) -> ::std::cmp::Ordering {
333                self.1.cmp(&other.1)
334            }
335        }
336        impl PartialOrd for $name {
337            fn partial_cmp(&self, other: &$name) -> Option<::std::cmp::Ordering> {
338                Some(self.cmp(other))
339            }
340        }
341        impl PartialEq for $name {
342            fn eq(&self, other: &$name) -> bool {
343                self.1 == other.1
344            }
345        }
346        impl Eq for $name {}
347    };
348}
349
350new_type![
351    /// Client identifier issued to the client during the registration process described by
352    /// [Section 2.2](https://tools.ietf.org/html/rfc6749#section-2.2).
353    #[derive(Deserialize, Serialize, Eq, Hash)]
354    ClientId(String)
355];
356
357new_url_type![
358    /// URL of the authorization server's authorization endpoint.
359    AuthUrl
360];
361new_url_type![
362    /// URL of the authorization server's token endpoint.
363    TokenUrl
364];
365new_url_type![
366    /// URL of the client's redirection endpoint.
367    RedirectUrl
368];
369new_url_type![
370    /// URL of the client's [RFC 7662 OAuth 2.0 Token Introspection](https://tools.ietf.org/html/rfc7662) endpoint.
371    IntrospectionUrl
372];
373new_url_type![
374    /// URL of the authorization server's RFC 7009 token revocation endpoint.
375    RevocationUrl
376];
377new_url_type![
378    /// URL of the client's device authorization endpoint.
379    DeviceAuthorizationUrl
380];
381new_url_type![
382    /// URL of the end-user verification URI on the authorization server.
383    EndUserVerificationUrl
384];
385new_type![
386    /// Authorization endpoint response (grant) type defined in
387    /// [Section 3.1.1](https://tools.ietf.org/html/rfc6749#section-3.1.1).
388    #[derive(Deserialize, Serialize, Eq, Hash)]
389    ResponseType(String)
390];
391new_type![
392    /// Resource owner's username used directly as an authorization grant to obtain an access
393    /// token.
394    #[derive(Deserialize, Serialize, Eq, Hash)]
395    ResourceOwnerUsername(String)
396];
397
398new_type![
399    /// Access token scope, as defined by the authorization server.
400    #[derive(Deserialize, Serialize, Eq, Hash)]
401    Scope(String)
402];
403impl AsRef<str> for Scope {
404    fn as_ref(&self) -> &str {
405        self
406    }
407}
408
409new_type![
410    /// Code Challenge Method used for [PKCE](https://tools.ietf.org/html/rfc7636) protection
411    /// via the `code_challenge_method` parameter.
412    #[derive(Deserialize, Serialize, Eq, Hash)]
413    PkceCodeChallengeMethod(String)
414];
415// This type intentionally does not implement Clone in order to make it difficult to reuse PKCE
416// challenges across multiple requests.
417new_secret_type![
418    /// Code Verifier used for [PKCE](https://tools.ietf.org/html/rfc7636) protection via the
419    /// `code_verifier` parameter. The value must have a minimum length of 43 characters and a
420    /// maximum length of 128 characters.  Each character must be ASCII alphanumeric or one of
421    /// the characters "-" / "." / "_" / "~".
422    #[derive(Deserialize, Serialize)]
423    PkceCodeVerifier(String)
424];
425
426/// Code Challenge used for [PKCE](https://tools.ietf.org/html/rfc7636) protection via the
427/// `code_challenge` parameter.
428#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)]
429pub struct PkceCodeChallenge {
430    code_challenge: String,
431    code_challenge_method: PkceCodeChallengeMethod,
432}
433impl PkceCodeChallenge {
434    /// Generate a new random, base64-encoded SHA-256 PKCE code.
435    pub fn new_random_sha256() -> (Self, PkceCodeVerifier) {
436        Self::new_random_sha256_len(32)
437    }
438
439    /// Generate a new random, base64-encoded SHA-256 PKCE challenge code and verifier.
440    ///
441    /// # Arguments
442    ///
443    /// * `num_bytes` - Number of random bytes to generate, prior to base64-encoding.
444    ///   The value must be in the range 32 to 96 inclusive in order to generate a verifier
445    ///   with a suitable length.
446    ///
447    /// # Panics
448    ///
449    /// This method panics if the resulting PKCE code verifier is not of a suitable length
450    /// to comply with [RFC 7636](https://tools.ietf.org/html/rfc7636).
451    pub fn new_random_sha256_len(num_bytes: u32) -> (Self, PkceCodeVerifier) {
452        let code_verifier = Self::new_random_len(num_bytes);
453        (
454            Self::from_code_verifier_sha256(&code_verifier),
455            code_verifier,
456        )
457    }
458
459    /// Generate a new random, base64-encoded PKCE code verifier.
460    ///
461    /// # Arguments
462    ///
463    /// * `num_bytes` - Number of random bytes to generate, prior to base64-encoding.
464    ///   The value must be in the range 32 to 96 inclusive in order to generate a verifier
465    ///   with a suitable length.
466    ///
467    /// # Panics
468    ///
469    /// This method panics if the resulting PKCE code verifier is not of a suitable length
470    /// to comply with [RFC 7636](https://tools.ietf.org/html/rfc7636).
471    fn new_random_len(num_bytes: u32) -> PkceCodeVerifier {
472        // The RFC specifies that the code verifier must have "a minimum length of 43
473        // characters and a maximum length of 128 characters".
474        // This implies 32-96 octets of random data to be base64 encoded.
475        assert!((32..=96).contains(&num_bytes));
476        let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().gen::<u8>()).collect();
477        PkceCodeVerifier::new(BASE64_URL_SAFE_NO_PAD.encode(random_bytes))
478    }
479
480    /// Generate a SHA-256 PKCE code challenge from the supplied PKCE code verifier.
481    ///
482    /// # Panics
483    ///
484    /// This method panics if the supplied PKCE code verifier is not of a suitable length
485    /// to comply with [RFC 7636](https://tools.ietf.org/html/rfc7636).
486    pub fn from_code_verifier_sha256(code_verifier: &PkceCodeVerifier) -> Self {
487        // The RFC specifies that the code verifier must have "a minimum length of 43
488        // characters and a maximum length of 128 characters".
489        assert!(code_verifier.secret().len() >= 43 && code_verifier.secret().len() <= 128);
490
491        let digest = Sha256::digest(code_verifier.secret().as_bytes());
492        let code_challenge = BASE64_URL_SAFE_NO_PAD.encode(digest);
493
494        Self {
495            code_challenge,
496            code_challenge_method: PkceCodeChallengeMethod::new("S256".to_string()),
497        }
498    }
499
500    /// Generate a new random, base64-encoded PKCE code.
501    /// Use is discouraged unless the endpoint does not support SHA-256.
502    ///
503    /// # Panics
504    ///
505    /// This method panics if the supplied PKCE code verifier is not of a suitable length
506    /// to comply with [RFC 7636](https://tools.ietf.org/html/rfc7636).
507    #[cfg(feature = "pkce-plain")]
508    pub fn new_random_plain() -> (Self, PkceCodeVerifier) {
509        let code_verifier = Self::new_random_len(32);
510        (
511            Self::from_code_verifier_plain(&code_verifier),
512            code_verifier,
513        )
514    }
515
516    /// Generate a plain PKCE code challenge from the supplied PKCE code verifier.
517    /// Use is discouraged unless the endpoint does not support SHA-256.
518    ///
519    /// # Panics
520    ///
521    /// This method panics if the supplied PKCE code verifier is not of a suitable length
522    /// to comply with [RFC 7636](https://tools.ietf.org/html/rfc7636).
523    #[cfg(feature = "pkce-plain")]
524    pub fn from_code_verifier_plain(code_verifier: &PkceCodeVerifier) -> Self {
525        // The RFC specifies that the code verifier must have "a minimum length of 43
526        // characters and a maximum length of 128 characters".
527        assert!(code_verifier.secret().len() >= 43 && code_verifier.secret().len() <= 128);
528
529        let code_challenge = code_verifier.secret().clone();
530
531        Self {
532            code_challenge,
533            code_challenge_method: PkceCodeChallengeMethod::new("plain".to_string()),
534        }
535    }
536
537    /// Returns the PKCE code challenge as a string.
538    pub fn as_str(&self) -> &str {
539        &self.code_challenge
540    }
541
542    /// Returns the PKCE code challenge method as a string.
543    pub fn method(&self) -> &PkceCodeChallengeMethod {
544        &self.code_challenge_method
545    }
546}
547
548new_secret_type![
549    /// Client password issued to the client during the registration process described by
550    /// [Section 2.2](https://tools.ietf.org/html/rfc6749#section-2.2).
551    #[derive(Clone, Deserialize, Serialize)]
552    ClientSecret(String)
553];
554new_secret_type![
555    /// Value used for [CSRF](https://tools.ietf.org/html/rfc6749#section-10.12) protection
556    /// via the `state` parameter.
557    #[must_use]
558    #[derive(Clone, Deserialize, Serialize)]
559    CsrfToken(String)
560    impl {
561        /// Generate a new random, base64-encoded 128-bit CSRF token.
562        pub fn new_random() -> Self {
563            CsrfToken::new_random_len(16)
564        }
565        /// Generate a new random, base64-encoded CSRF token of the specified length.
566        ///
567        /// # Arguments
568        ///
569        /// * `num_bytes` - Number of random bytes to generate, prior to base64-encoding.
570        pub fn new_random_len(num_bytes: u32) -> Self {
571            let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().gen::<u8>()).collect();
572            CsrfToken::new(BASE64_URL_SAFE_NO_PAD.encode(random_bytes))
573        }
574    }
575];
576new_secret_type![
577    /// Authorization code returned from the authorization endpoint.
578    #[derive(Clone, Deserialize, Serialize)]
579    AuthorizationCode(String)
580];
581new_secret_type![
582    /// Refresh token used to obtain a new access token (if supported by the authorization server).
583    #[derive(Clone, Deserialize, Serialize)]
584    RefreshToken(String)
585];
586new_secret_type![
587    /// Access token returned by the token endpoint and used to access protected resources.
588    #[derive(Clone, Deserialize, Serialize)]
589    AccessToken(String)
590];
591new_secret_type![
592    /// Resource owner's password used directly as an authorization grant to obtain an access
593    /// token.
594    #[derive(Clone)]
595    ResourceOwnerPassword(String)
596];
597new_secret_type![
598    /// Device code returned by the device authorization endpoint and used to query the token endpoint.
599    #[derive(Clone, Deserialize, Serialize)]
600    DeviceCode(String)
601];
602new_secret_type![
603    /// Verification URI returned by the device authorization endpoint and visited by the user
604    /// to authorize.  Contains the user code.
605    #[derive(Clone, Deserialize, Serialize)]
606    VerificationUriComplete(String)
607];
608new_secret_type![
609    /// User code returned by the device authorization endpoint and used by the user to authorize at
610    /// the verification URI.
611    #[derive(Clone, Deserialize, Serialize)]
612    UserCode(String)
613];
614
615#[cfg(test)]
616mod tests {
617    use crate::{ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier};
618
619    #[test]
620    fn test_secret_conversion() {
621        let secret = CsrfToken::new("top_secret".into());
622        assert_eq!(secret.into_secret().into_boxed_str(), "top_secret".into());
623    }
624
625    #[test]
626    fn test_secret_redaction() {
627        let secret = ClientSecret::new("top_secret".to_string());
628        assert_eq!("ClientSecret([redacted])", format!("{secret:?}"));
629    }
630
631    #[test]
632    #[should_panic]
633    fn test_code_verifier_too_short() {
634        PkceCodeChallenge::new_random_sha256_len(31);
635    }
636
637    #[test]
638    #[should_panic]
639    fn test_code_verifier_too_long() {
640        PkceCodeChallenge::new_random_sha256_len(97);
641    }
642
643    #[test]
644    fn test_code_verifier_min() {
645        let code = PkceCodeChallenge::new_random_sha256_len(32);
646        assert_eq!(code.1.secret().len(), 43);
647    }
648
649    #[test]
650    fn test_code_verifier_max() {
651        let code = PkceCodeChallenge::new_random_sha256_len(96);
652        assert_eq!(code.1.secret().len(), 128);
653    }
654
655    #[test]
656    fn test_code_verifier_challenge() {
657        // Example from https://tools.ietf.org/html/rfc7636#appendix-B
658        let code_verifier =
659            PkceCodeVerifier::new("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string());
660        assert_eq!(
661            PkceCodeChallenge::from_code_verifier_sha256(&code_verifier).as_str(),
662            "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM",
663        );
664    }
665}