oauth2/
devicecode.rs

1use crate::basic::BasicErrorResponseType;
2use crate::endpoint::{endpoint_request, endpoint_response};
3use crate::types::VerificationUriComplete;
4use crate::{
5    AsyncHttpClient, AuthType, Client, ClientId, ClientSecret, DeviceAuthorizationUrl, DeviceCode,
6    EndUserVerificationUrl, EndpointState, ErrorResponse, ErrorResponseType, HttpRequest,
7    HttpResponse, RequestTokenError, RevocableToken, Scope, StandardErrorResponse, SyncHttpClient,
8    TokenIntrospectionResponse, TokenResponse, TokenUrl, UserCode,
9};
10
11use chrono::{DateTime, Utc};
12use serde::de::DeserializeOwned;
13use serde::{Deserialize, Serialize};
14
15use std::borrow::Cow;
16use std::error::Error;
17use std::fmt::Error as FormatterError;
18use std::fmt::{Debug, Display, Formatter};
19use std::future::Future;
20use std::marker::PhantomData;
21use std::sync::Arc;
22use std::time::Duration;
23
24impl<
25        TE,
26        TR,
27        TIR,
28        RT,
29        TRE,
30        HasAuthUrl,
31        HasDeviceAuthUrl,
32        HasIntrospectionUrl,
33        HasRevocationUrl,
34        HasTokenUrl,
35    >
36    Client<
37        TE,
38        TR,
39        TIR,
40        RT,
41        TRE,
42        HasAuthUrl,
43        HasDeviceAuthUrl,
44        HasIntrospectionUrl,
45        HasRevocationUrl,
46        HasTokenUrl,
47    >
48where
49    TE: ErrorResponse + 'static,
50    TR: TokenResponse,
51    TIR: TokenIntrospectionResponse,
52    RT: RevocableToken,
53    TRE: ErrorResponse + 'static,
54    HasAuthUrl: EndpointState,
55    HasDeviceAuthUrl: EndpointState,
56    HasIntrospectionUrl: EndpointState,
57    HasRevocationUrl: EndpointState,
58    HasTokenUrl: EndpointState,
59{
60    pub(crate) fn exchange_device_code_impl<'a>(
61        &'a self,
62        device_authorization_url: &'a DeviceAuthorizationUrl,
63    ) -> DeviceAuthorizationRequest<'a, TE> {
64        DeviceAuthorizationRequest {
65            auth_type: &self.auth_type,
66            client_id: &self.client_id,
67            client_secret: self.client_secret.as_ref(),
68            extra_params: Vec::new(),
69            scopes: Vec::new(),
70            device_authorization_url,
71            _phantom: PhantomData,
72        }
73    }
74
75    pub(crate) fn exchange_device_access_token_impl<'a, EF>(
76        &'a self,
77        token_url: &'a TokenUrl,
78        auth_response: &'a DeviceAuthorizationResponse<EF>,
79    ) -> DeviceAccessTokenRequest<'a, 'static, TR, EF>
80    where
81        EF: ExtraDeviceAuthorizationFields,
82    {
83        DeviceAccessTokenRequest {
84            auth_type: &self.auth_type,
85            client_id: &self.client_id,
86            client_secret: self.client_secret.as_ref(),
87            extra_params: Vec::new(),
88            token_url,
89            dev_auth_resp: auth_response,
90            time_fn: Arc::new(Utc::now),
91            max_backoff_interval: None,
92            _phantom: PhantomData,
93        }
94    }
95}
96
97/// The request for a set of verification codes from the authorization server.
98///
99/// See <https://tools.ietf.org/html/rfc8628#section-3.1>.
100#[derive(Debug)]
101pub struct DeviceAuthorizationRequest<'a, TE>
102where
103    TE: ErrorResponse,
104{
105    pub(crate) auth_type: &'a AuthType,
106    pub(crate) client_id: &'a ClientId,
107    pub(crate) client_secret: Option<&'a ClientSecret>,
108    pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>,
109    pub(crate) scopes: Vec<Cow<'a, Scope>>,
110    pub(crate) device_authorization_url: &'a DeviceAuthorizationUrl,
111    pub(crate) _phantom: PhantomData<TE>,
112}
113
114impl<'a, TE> DeviceAuthorizationRequest<'a, TE>
115where
116    TE: ErrorResponse + 'static,
117{
118    /// Appends an extra param to the token request.
119    ///
120    /// This method allows extensions to be used without direct support from
121    /// this crate. If `name` conflicts with a parameter managed by this crate, the
122    /// behavior is undefined. In particular, do not set parameters defined by
123    /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or
124    /// [RFC 7636](https://tools.ietf.org/html/rfc7636).
125    ///
126    /// # Security Warning
127    ///
128    /// Callers should follow the security recommendations for any OAuth2 extensions used with
129    /// this function, which are beyond the scope of
130    /// [RFC 6749](https://tools.ietf.org/html/rfc6749).
131    pub fn add_extra_param<N, V>(mut self, name: N, value: V) -> Self
132    where
133        N: Into<Cow<'a, str>>,
134        V: Into<Cow<'a, str>>,
135    {
136        self.extra_params.push((name.into(), value.into()));
137        self
138    }
139
140    /// Appends a new scope to the token request.
141    pub fn add_scope(mut self, scope: Scope) -> Self {
142        self.scopes.push(Cow::Owned(scope));
143        self
144    }
145
146    /// Appends a collection of scopes to the token request.
147    pub fn add_scopes<I>(mut self, scopes: I) -> Self
148    where
149        I: IntoIterator<Item = Scope>,
150    {
151        self.scopes.extend(scopes.into_iter().map(Cow::Owned));
152        self
153    }
154
155    fn prepare_request<RE>(self) -> Result<HttpRequest, RequestTokenError<RE, TE>>
156    where
157        RE: Error + 'static,
158    {
159        endpoint_request(
160            self.auth_type,
161            self.client_id,
162            self.client_secret,
163            &self.extra_params,
164            None,
165            Some(&self.scopes),
166            self.device_authorization_url.url(),
167            vec![],
168        )
169        .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}")))
170    }
171
172    /// Synchronously sends the request to the authorization server and awaits a response.
173    pub fn request<C, EF>(
174        self,
175        http_client: &C,
176    ) -> Result<DeviceAuthorizationResponse<EF>, RequestTokenError<<C as SyncHttpClient>::Error, TE>>
177    where
178        C: SyncHttpClient,
179        EF: ExtraDeviceAuthorizationFields,
180    {
181        endpoint_response(http_client.call(self.prepare_request()?)?)
182    }
183
184    /// Asynchronously sends the request to the authorization server and returns a Future.
185    pub fn request_async<'c, C, EF>(
186        self,
187        http_client: &'c C,
188    ) -> impl Future<
189        Output = Result<
190            DeviceAuthorizationResponse<EF>,
191            RequestTokenError<<C as AsyncHttpClient<'c>>::Error, TE>,
192        >,
193    > + 'c
194    where
195        Self: 'c,
196        C: AsyncHttpClient<'c>,
197        EF: ExtraDeviceAuthorizationFields,
198    {
199        Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) })
200    }
201}
202
203/// The request for a device access token from the authorization server.
204///
205/// See <https://tools.ietf.org/html/rfc8628#section-3.4>.
206#[derive(Clone)]
207pub struct DeviceAccessTokenRequest<'a, 'b, TR, EF>
208where
209    TR: TokenResponse,
210    EF: ExtraDeviceAuthorizationFields,
211{
212    pub(crate) auth_type: &'a AuthType,
213    pub(crate) client_id: &'a ClientId,
214    pub(crate) client_secret: Option<&'a ClientSecret>,
215    pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>,
216    pub(crate) token_url: &'a TokenUrl,
217    pub(crate) dev_auth_resp: &'a DeviceAuthorizationResponse<EF>,
218    pub(crate) time_fn: Arc<dyn Fn() -> DateTime<Utc> + Send + Sync + 'b>,
219    pub(crate) max_backoff_interval: Option<Duration>,
220    pub(crate) _phantom: PhantomData<(TR, EF)>,
221}
222
223impl<'a, 'b, TR, EF> DeviceAccessTokenRequest<'a, 'b, TR, EF>
224where
225    TR: TokenResponse,
226    EF: ExtraDeviceAuthorizationFields,
227{
228    /// Appends an extra param to the token request.
229    ///
230    /// This method allows extensions to be used without direct support from
231    /// this crate. If `name` conflicts with a parameter managed by this crate, the
232    /// behavior is undefined. In particular, do not set parameters defined by
233    /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or
234    /// [RFC 7636](https://tools.ietf.org/html/rfc7636).
235    ///
236    /// # Security Warning
237    ///
238    /// Callers should follow the security recommendations for any OAuth2 extensions used with
239    /// this function, which are beyond the scope of
240    /// [RFC 6749](https://tools.ietf.org/html/rfc6749).
241    pub fn add_extra_param<N, V>(mut self, name: N, value: V) -> Self
242    where
243        N: Into<Cow<'a, str>>,
244        V: Into<Cow<'a, str>>,
245    {
246        self.extra_params.push((name.into(), value.into()));
247        self
248    }
249
250    /// Specifies a function for returning the current time.
251    ///
252    /// This function is used while polling the authorization server.
253    pub fn set_time_fn<'t, T>(self, time_fn: T) -> DeviceAccessTokenRequest<'a, 't, TR, EF>
254    where
255        T: Fn() -> DateTime<Utc> + Send + Sync + 't,
256    {
257        DeviceAccessTokenRequest {
258            auth_type: self.auth_type,
259            client_id: self.client_id,
260            client_secret: self.client_secret,
261            extra_params: self.extra_params,
262            token_url: self.token_url,
263            dev_auth_resp: self.dev_auth_resp,
264            time_fn: Arc::new(time_fn),
265            max_backoff_interval: self.max_backoff_interval,
266            _phantom: PhantomData,
267        }
268    }
269
270    /// Sets the upper limit of the sleep interval to use for polling the token endpoint when the
271    /// HTTP client returns an error (e.g., in case of connection timeout).
272    pub fn set_max_backoff_interval(mut self, interval: Duration) -> Self {
273        self.max_backoff_interval = Some(interval);
274        self
275    }
276
277    /// Synchronously polls the authorization server for a response, waiting
278    /// using a user defined sleep function.
279    pub fn request<C, S>(
280        self,
281        http_client: &C,
282        sleep_fn: S,
283        timeout: Option<Duration>,
284    ) -> Result<TR, RequestTokenError<<C as SyncHttpClient>::Error, DeviceCodeErrorResponse>>
285    where
286        C: SyncHttpClient,
287        S: Fn(Duration),
288    {
289        // Get the request timeout and starting interval
290        let timeout_dt = self.compute_timeout(timeout)?;
291        let mut interval = self.dev_auth_resp.interval();
292
293        // Loop while requesting a token.
294        loop {
295            let now = (*self.time_fn)();
296            if now > timeout_dt {
297                break Err(RequestTokenError::ServerResponse(
298                    DeviceCodeErrorResponse::new(
299                        DeviceCodeErrorResponseType::ExpiredToken,
300                        Some(String::from("This device code has expired.")),
301                        None,
302                    ),
303                ));
304            }
305
306            match self.process_response(http_client.call(self.prepare_request()?), interval) {
307                DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => {
308                    interval = new_interval
309                }
310                DeviceAccessTokenPollResult::Done(res) => break res,
311            }
312
313            // Sleep here using the provided sleep function.
314            sleep_fn(interval);
315        }
316    }
317
318    /// Asynchronously sends the request to the authorization server and awaits a response.
319    pub fn request_async<'c, C, S, SF>(
320        self,
321        http_client: &'c C,
322        sleep_fn: S,
323        timeout: Option<Duration>,
324    ) -> impl Future<
325        Output = Result<
326            TR,
327            RequestTokenError<<C as AsyncHttpClient<'c>>::Error, DeviceCodeErrorResponse>,
328        >,
329    > + 'c
330    where
331        Self: 'c,
332        C: AsyncHttpClient<'c>,
333        S: Fn(Duration) -> SF + 'c,
334        SF: Future<Output = ()>,
335    {
336        Box::pin(async move {
337            // Get the request timeout and starting interval
338            let timeout_dt = self.compute_timeout(timeout)?;
339            let mut interval = self.dev_auth_resp.interval();
340
341            // Loop while requesting a token.
342            loop {
343                let now = (*self.time_fn)();
344                if now > timeout_dt {
345                    break Err(RequestTokenError::ServerResponse(
346                        DeviceCodeErrorResponse::new(
347                            DeviceCodeErrorResponseType::ExpiredToken,
348                            Some(String::from("This device code has expired.")),
349                            None,
350                        ),
351                    ));
352                }
353
354                match self
355                    .process_response(http_client.call(self.prepare_request()?).await, interval)
356                {
357                    DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => {
358                        interval = new_interval
359                    }
360                    DeviceAccessTokenPollResult::Done(res) => break res,
361                }
362
363                // Sleep here using the provided sleep function.
364                sleep_fn(interval).await;
365            }
366        })
367    }
368
369    fn prepare_request<RE, TE>(&self) -> Result<HttpRequest, RequestTokenError<RE, TE>>
370    where
371        RE: Error + 'static,
372        TE: ErrorResponse + 'static,
373    {
374        endpoint_request(
375            self.auth_type,
376            self.client_id,
377            self.client_secret,
378            &self.extra_params,
379            None,
380            None,
381            self.token_url.url(),
382            vec![
383                ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
384                ("device_code", self.dev_auth_resp.device_code().secret()),
385            ],
386        )
387        .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}")))
388    }
389
390    fn process_response<RE>(
391        &self,
392        res: Result<HttpResponse, RE>,
393        current_interval: Duration,
394    ) -> DeviceAccessTokenPollResult<TR, RE, DeviceCodeErrorResponse>
395    where
396        RE: Error + 'static,
397    {
398        let http_response = match res {
399            Ok(inner) => inner,
400            Err(_) => {
401                // RFC 8628 requires a backoff in cases of connection timeout, but we can't
402                // distinguish between connection timeouts and other HTTP client request errors
403                // here. Set a maximum backoff so that the client doesn't effectively backoff
404                // infinitely when there are network issues unrelated to server load.
405                const DEFAULT_MAX_BACKOFF_INTERVAL: Duration = Duration::from_secs(10);
406                let new_interval = std::cmp::min(
407                    current_interval.checked_mul(2).unwrap_or(current_interval),
408                    self.max_backoff_interval
409                        .unwrap_or(DEFAULT_MAX_BACKOFF_INTERVAL),
410                );
411                return DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval);
412            }
413        };
414
415        // Explicitly process the response with a DeviceCodeErrorResponse
416        let res = endpoint_response::<RE, DeviceCodeErrorResponse, TR>(http_response);
417        match res {
418            // On a ServerResponse error, the error needs inspecting as a DeviceCodeErrorResponse
419            // to work out whether a retry needs to happen.
420            Err(RequestTokenError::ServerResponse(dcer)) => {
421                match dcer.error() {
422                    // On AuthorizationPending, a retry needs to happen with the same poll interval.
423                    DeviceCodeErrorResponseType::AuthorizationPending => {
424                        DeviceAccessTokenPollResult::ContinueWithNewPollInterval(current_interval)
425                    }
426                    // On SlowDown, a retry needs to happen with a larger poll interval.
427                    DeviceCodeErrorResponseType::SlowDown => {
428                        DeviceAccessTokenPollResult::ContinueWithNewPollInterval(
429                            current_interval + Duration::from_secs(5),
430                        )
431                    }
432
433                    // On any other error, just return the error.
434                    _ => DeviceAccessTokenPollResult::Done(Err(RequestTokenError::ServerResponse(
435                        dcer,
436                    ))),
437                }
438            }
439
440            // On any other success or failure, return the failure.
441            res => DeviceAccessTokenPollResult::Done(res),
442        }
443    }
444
445    fn compute_timeout<RE>(
446        &self,
447        timeout: Option<Duration>,
448    ) -> Result<DateTime<Utc>, RequestTokenError<RE, DeviceCodeErrorResponse>>
449    where
450        RE: Error + 'static,
451    {
452        // Calculate the request timeout - if the user specified a timeout,
453        // use that, otherwise use the value given by the device authorization
454        // response.
455        let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in());
456        let chrono_timeout = chrono::Duration::from_std(timeout_dur).map_err(|e| {
457            RequestTokenError::Other(format!(
458                "failed to convert `{timeout_dur:?}` to `chrono::Duration`: {e}"
459            ))
460        })?;
461
462        // Calculate the DateTime at which the request times out.
463        let timeout_dt = (*self.time_fn)()
464            .checked_add_signed(chrono_timeout)
465            .ok_or_else(|| RequestTokenError::Other("failed to calculate timeout".to_string()))?;
466
467        Ok(timeout_dt)
468    }
469}
470
471/// The minimum amount of time in seconds that the client SHOULD wait
472/// between polling requests to the token endpoint.  If no value is
473/// provided, clients MUST use 5 as the default.
474fn default_devicecode_interval() -> u64 {
475    5
476}
477
478fn deserialize_devicecode_interval<'de, D>(deserializer: D) -> Result<u64, D::Error>
479where
480    D: serde::de::Deserializer<'de>,
481{
482    struct NumOrNull;
483
484    impl<'de> serde::de::Visitor<'de> for NumOrNull {
485        type Value = u64;
486
487        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
488            formatter.write_str("non-negative integer or null")
489        }
490
491        fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
492        where
493            E: Error,
494        {
495            Ok(v)
496        }
497
498        fn visit_unit<E>(self) -> Result<Self::Value, E>
499        where
500            E: serde::de::Error,
501        {
502            Ok(default_devicecode_interval())
503        }
504    }
505
506    deserializer.deserialize_any(NumOrNull)
507}
508
509/// Trait for adding extra fields to the `DeviceAuthorizationResponse`.
510pub trait ExtraDeviceAuthorizationFields: DeserializeOwned + Debug + Serialize {}
511
512#[derive(Clone, Debug, Deserialize, Serialize)]
513/// Empty (default) extra token fields.
514pub struct EmptyExtraDeviceAuthorizationFields {}
515impl ExtraDeviceAuthorizationFields for EmptyExtraDeviceAuthorizationFields {}
516
517/// Standard OAuth2 device authorization response.
518#[derive(Clone, Debug, Deserialize, Serialize)]
519pub struct DeviceAuthorizationResponse<EF>
520where
521    EF: ExtraDeviceAuthorizationFields,
522{
523    /// The device verification code.
524    device_code: DeviceCode,
525
526    /// The end-user verification code.
527    user_code: UserCode,
528
529    /// The end-user verification URI on the authorization The URI should be
530    /// short and easy to remember as end users will be asked to manually type
531    /// it into their user agent.
532    ///
533    /// The `verification_url` alias here is a deviation from the RFC, as
534    /// implementations of device authorization flow predate RFC 8628.
535    #[serde(alias = "verification_url")]
536    verification_uri: EndUserVerificationUrl,
537
538    /// A verification URI that includes the "user_code" (or other information
539    /// with the same function as the "user_code"), which is designed for
540    /// non-textual transmission.
541    #[serde(skip_serializing_if = "Option::is_none")]
542    verification_uri_complete: Option<VerificationUriComplete>,
543
544    /// The lifetime in seconds of the "device_code" and "user_code".
545    expires_in: u64,
546
547    /// The minimum amount of time in seconds that the client SHOULD wait
548    /// between polling requests to the token endpoint.  If no value is
549    /// provided, clients MUST use 5 as the default.
550    #[serde(
551        default = "default_devicecode_interval",
552        deserialize_with = "deserialize_devicecode_interval"
553    )]
554    interval: u64,
555
556    #[serde(bound = "EF: ExtraDeviceAuthorizationFields", flatten)]
557    extra_fields: EF,
558}
559
560impl<EF> DeviceAuthorizationResponse<EF>
561where
562    EF: ExtraDeviceAuthorizationFields,
563{
564    /// The device verification code.
565    pub fn device_code(&self) -> &DeviceCode {
566        &self.device_code
567    }
568
569    /// The end-user verification code.
570    pub fn user_code(&self) -> &UserCode {
571        &self.user_code
572    }
573
574    /// The end-user verification URI on the authorization The URI should be
575    /// short and easy to remember as end users will be asked to manually type
576    /// it into their user agent.
577    pub fn verification_uri(&self) -> &EndUserVerificationUrl {
578        &self.verification_uri
579    }
580
581    /// A verification URI that includes the "user_code" (or other information
582    /// with the same function as the "user_code"), which is designed for
583    /// non-textual transmission.
584    pub fn verification_uri_complete(&self) -> Option<&VerificationUriComplete> {
585        self.verification_uri_complete.as_ref()
586    }
587
588    /// The lifetime in seconds of the "device_code" and "user_code".
589    pub fn expires_in(&self) -> Duration {
590        Duration::from_secs(self.expires_in)
591    }
592
593    /// The minimum amount of time in seconds that the client SHOULD wait
594    /// between polling requests to the token endpoint.  If no value is
595    /// provided, clients MUST use 5 as the default.
596    pub fn interval(&self) -> Duration {
597        Duration::from_secs(self.interval)
598    }
599
600    /// Any extra fields returned on the response.
601    pub fn extra_fields(&self) -> &EF {
602        &self.extra_fields
603    }
604}
605
606/// Standard implementation of DeviceAuthorizationResponse which throws away
607/// extra received response fields.
608pub type StandardDeviceAuthorizationResponse =
609    DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields>;
610
611/// Basic access token error types.
612///
613/// These error types are defined in
614/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2) and
615/// [Section 3.5 of RFC 6749](https://tools.ietf.org/html/rfc8628#section-3.5)
616#[derive(Clone, PartialEq, Eq)]
617pub enum DeviceCodeErrorResponseType {
618    /// The authorization request is still pending as the end user hasn't
619    /// yet completed the user-interaction steps.  The client SHOULD repeat the
620    /// access token request to the token endpoint.  Before each new request,
621    /// the client MUST wait at least the number of seconds specified by the
622    /// "interval" parameter of the device authorization response, or 5 seconds
623    /// if none was provided, and respect any increase in the polling interval
624    /// required by the "slow_down" error.
625    AuthorizationPending,
626    /// A variant of "authorization_pending", the authorization request is
627    /// still pending and polling should continue, but the interval MUST be
628    /// increased by 5 seconds for this and all subsequent requests.
629    SlowDown,
630    /// The authorization request was denied.
631    AccessDenied,
632    /// The "device_code" has expired, and the device authorization session has
633    /// concluded.  The client MAY commence a new device authorization request
634    /// but SHOULD wait for user interaction before restarting to avoid
635    /// unnecessary polling.
636    ExpiredToken,
637    /// A Basic response type
638    Basic(BasicErrorResponseType),
639}
640impl DeviceCodeErrorResponseType {
641    fn from_str(s: &str) -> Self {
642        match BasicErrorResponseType::from_str(s) {
643            BasicErrorResponseType::Extension(ext) => match ext.as_str() {
644                "authorization_pending" => DeviceCodeErrorResponseType::AuthorizationPending,
645                "slow_down" => DeviceCodeErrorResponseType::SlowDown,
646                "access_denied" => DeviceCodeErrorResponseType::AccessDenied,
647                "expired_token" => DeviceCodeErrorResponseType::ExpiredToken,
648                _ => DeviceCodeErrorResponseType::Basic(BasicErrorResponseType::Extension(ext)),
649            },
650            basic => DeviceCodeErrorResponseType::Basic(basic),
651        }
652    }
653}
654impl AsRef<str> for DeviceCodeErrorResponseType {
655    fn as_ref(&self) -> &str {
656        match self {
657            DeviceCodeErrorResponseType::AuthorizationPending => "authorization_pending",
658            DeviceCodeErrorResponseType::SlowDown => "slow_down",
659            DeviceCodeErrorResponseType::AccessDenied => "access_denied",
660            DeviceCodeErrorResponseType::ExpiredToken => "expired_token",
661            DeviceCodeErrorResponseType::Basic(basic) => basic.as_ref(),
662        }
663    }
664}
665impl<'de> serde::Deserialize<'de> for DeviceCodeErrorResponseType {
666    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
667    where
668        D: serde::de::Deserializer<'de>,
669    {
670        let variant_str = String::deserialize(deserializer)?;
671        Ok(Self::from_str(&variant_str))
672    }
673}
674impl serde::ser::Serialize for DeviceCodeErrorResponseType {
675    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
676    where
677        S: serde::ser::Serializer,
678    {
679        serializer.serialize_str(self.as_ref())
680    }
681}
682impl ErrorResponseType for DeviceCodeErrorResponseType {}
683impl Debug for DeviceCodeErrorResponseType {
684    fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
685        Display::fmt(self, f)
686    }
687}
688
689impl Display for DeviceCodeErrorResponseType {
690    fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
691        write!(f, "{}", self.as_ref())
692    }
693}
694
695/// Error response specialization for device code OAuth2 implementation.
696pub type DeviceCodeErrorResponse = StandardErrorResponse<DeviceCodeErrorResponseType>;
697
698pub(crate) enum DeviceAccessTokenPollResult<TR, RE, TE>
699where
700    TE: ErrorResponse + 'static,
701    TR: TokenResponse,
702    RE: Error + 'static,
703{
704    ContinueWithNewPollInterval(Duration),
705    Done(Result<TR, RequestTokenError<RE, TE>>),
706}
707
708#[cfg(test)]
709mod tests {
710    use crate::basic::BasicTokenType;
711    use crate::devicecode::default_devicecode_interval;
712    use crate::tests::{mock_http_client, mock_http_client_success_fail, new_client};
713    use crate::{
714        DeviceAuthorizationResponse, DeviceAuthorizationUrl, DeviceCodeErrorResponse,
715        DeviceCodeErrorResponseType, EmptyExtraDeviceAuthorizationFields, RequestTokenError, Scope,
716        StandardDeviceAuthorizationResponse, TokenResponse,
717    };
718
719    use chrono::{DateTime, Utc};
720    use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
721    use http::{HeaderValue, Response, StatusCode};
722
723    use std::time::Duration;
724
725    fn new_device_auth_details(expires_in: u32) -> StandardDeviceAuthorizationResponse {
726        let body = format!(
727            "{{\
728        \"device_code\": \"12345\", \
729        \"verification_uri\": \"https://verify/here\", \
730        \"user_code\": \"abcde\", \
731        \"verification_uri_complete\": \"https://verify/here?abcde\", \
732        \"expires_in\": {expires_in}, \
733        \"interval\": 1 \
734        }}"
735        );
736
737        let device_auth_url =
738            DeviceAuthorizationUrl::new("https://deviceauth/here".to_string()).unwrap();
739
740        let client = new_client().set_device_authorization_url(device_auth_url.clone());
741        client
742            .exchange_device_code()
743            .add_extra_param("foo", "bar")
744            .add_scope(Scope::new("openid".to_string()))
745            .request(&mock_http_client(
746                vec![
747                    (ACCEPT, "application/json"),
748                    (CONTENT_TYPE, "application/x-www-form-urlencoded"),
749                    (AUTHORIZATION, "Basic YWFhOmJiYg=="),
750                ],
751                "scope=openid&foo=bar",
752                Some(device_auth_url.url().to_owned()),
753                Response::builder()
754                    .status(StatusCode::OK)
755                    .header(
756                        CONTENT_TYPE,
757                        HeaderValue::from_str("application/json").unwrap(),
758                    )
759                    .body(body.into_bytes())
760                    .unwrap(),
761            ))
762            .unwrap()
763    }
764
765    #[test]
766    fn test_device_token_pending_then_success() {
767        let details = new_device_auth_details(20);
768        assert_eq!("12345", details.device_code().secret());
769        assert_eq!("https://verify/here", details.verification_uri().as_str());
770        assert_eq!("abcde", details.user_code().secret().as_str());
771        assert_eq!(
772            "https://verify/here?abcde",
773            details
774                .verification_uri_complete()
775                .unwrap()
776                .secret()
777                .as_str()
778        );
779        assert_eq!(Duration::from_secs(20), details.expires_in());
780        assert_eq!(Duration::from_secs(1), details.interval());
781
782        let token = new_client()
783          .exchange_device_access_token(&details)
784          .set_time_fn(mock_time_fn())
785          .request(
786              &mock_http_client_success_fail(
787                  None,
788                  vec![
789                      (ACCEPT, "application/json"),
790                      (CONTENT_TYPE, "application/x-www-form-urlencoded"),
791                      (AUTHORIZATION, "Basic YWFhOmJiYg=="),
792                  ],
793                  "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345",
794                  Response::builder()
795                    .status(StatusCode::BAD_REQUEST)
796                    .header(
797                        CONTENT_TYPE,
798                        HeaderValue::from_str("application/json").unwrap(),
799                    )
800                    .body("{\
801                    \"error\": \"authorization_pending\", \
802                    \"error_description\": \"Still waiting for user\"\
803                    }"
804                      .to_string()
805                      .into_bytes())
806                    .unwrap(),
807                  5,
808                  Response::builder()
809                    .status(StatusCode::OK)
810                    .header(
811                        CONTENT_TYPE,
812                        HeaderValue::from_str("application/json").unwrap(),
813                    )
814                    .body("{\
815                    \"access_token\": \"12/34\", \
816                    \"token_type\": \"bearer\", \
817                    \"scope\": \"openid\"\
818                    }"
819                      .to_string()
820                      .into_bytes())
821                    .unwrap(),
822              ),
823              mock_sleep_fn,
824              None)
825          .unwrap();
826
827        assert_eq!("12/34", token.access_token().secret());
828        assert_eq!(BasicTokenType::Bearer, *token.token_type());
829        assert_eq!(
830            Some(&vec![Scope::new("openid".to_string()),]),
831            token.scopes()
832        );
833        assert_eq!(None, token.expires_in());
834        assert!(token.refresh_token().is_none());
835    }
836
837    #[test]
838    fn test_device_token_slowdown_then_success() {
839        let details = new_device_auth_details(3600);
840        assert_eq!("12345", details.device_code().secret());
841        assert_eq!("https://verify/here", details.verification_uri().as_str());
842        assert_eq!("abcde", details.user_code().secret().as_str());
843        assert_eq!(
844            "https://verify/here?abcde",
845            details
846                .verification_uri_complete()
847                .unwrap()
848                .secret()
849                .as_str()
850        );
851        assert_eq!(Duration::from_secs(3600), details.expires_in());
852        assert_eq!(Duration::from_secs(1), details.interval());
853
854        let token = new_client()
855          .exchange_device_access_token(&details)
856          .set_time_fn(mock_time_fn())
857          .request(
858              &mock_http_client_success_fail(
859                  None,
860                  vec![
861                      (ACCEPT, "application/json"),
862                      (CONTENT_TYPE, "application/x-www-form-urlencoded"),
863                      (AUTHORIZATION, "Basic YWFhOmJiYg=="),
864                  ],
865                  "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345",
866                  Response::builder()
867                    .status(StatusCode::BAD_REQUEST)
868                    .header(
869                        CONTENT_TYPE,
870                        HeaderValue::from_str("application/json").unwrap(),
871                    )
872                    .body("{\
873                    \"error\": \"slow_down\", \
874                    \"error_description\": \"Woah there partner\"\
875                    }"
876                      .to_string()
877                      .into_bytes())
878                    .unwrap(),
879                  5,
880                  Response::builder()
881                    .status(StatusCode::OK)
882                    .header(
883                        CONTENT_TYPE,
884                        HeaderValue::from_str("application/json").unwrap(),
885                    )
886                    .body("{\
887                    \"access_token\": \"12/34\", \
888                    \"token_type\": \"bearer\", \
889                    \"scope\": \"openid\"\
890                    }"
891                      .to_string()
892                      .into_bytes())
893                    .unwrap(),
894              ),
895              mock_sleep_fn,
896              None)
897          .unwrap();
898
899        assert_eq!("12/34", token.access_token().secret());
900        assert_eq!(BasicTokenType::Bearer, *token.token_type());
901        assert_eq!(
902            Some(&vec![Scope::new("openid".to_string()),]),
903            token.scopes()
904        );
905        assert_eq!(None, token.expires_in());
906        assert!(token.refresh_token().is_none());
907    }
908
909    struct IncreasingTime {
910        times: std::ops::RangeFrom<i64>,
911    }
912
913    impl IncreasingTime {
914        fn new() -> Self {
915            Self { times: (0..) }
916        }
917        fn next(&mut self) -> DateTime<Utc> {
918            let next_value = self.times.next().unwrap();
919            DateTime::from_timestamp(next_value, 0).unwrap()
920        }
921    }
922
923    /// Creates a time function that increments by one second each time.
924    fn mock_time_fn() -> impl Fn() -> DateTime<Utc> + Send + Sync {
925        let timer = std::sync::Mutex::new(IncreasingTime::new());
926        move || timer.lock().unwrap().next()
927    }
928
929    /// Mock sleep function that doesn't actually sleep.
930    fn mock_sleep_fn(_: Duration) {}
931
932    #[test]
933    fn test_exchange_device_code_and_token() {
934        let details = new_device_auth_details(3600);
935        assert_eq!("12345", details.device_code().secret());
936        assert_eq!("https://verify/here", details.verification_uri().as_str());
937        assert_eq!("abcde", details.user_code().secret().as_str());
938        assert_eq!(
939            "https://verify/here?abcde",
940            details
941                .verification_uri_complete()
942                .unwrap()
943                .secret()
944                .as_str()
945        );
946        assert_eq!(Duration::from_secs(3600), details.expires_in());
947        assert_eq!(Duration::from_secs(1), details.interval());
948
949        let token = new_client()
950          .exchange_device_access_token(&details)
951          .set_time_fn(mock_time_fn())
952          .request(
953              &mock_http_client(
954                  vec![
955                      (ACCEPT, "application/json"),
956                      (CONTENT_TYPE, "application/x-www-form-urlencoded"),
957                      (AUTHORIZATION, "Basic YWFhOmJiYg=="),
958                  ],
959                  "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345",
960                  None,
961                  Response::builder()
962                    .status(StatusCode::OK)
963                    .header(
964                        CONTENT_TYPE,
965                        HeaderValue::from_str("application/json").unwrap(),
966                    )
967                    .body("{\
968                    \"access_token\": \"12/34\", \
969                    \"token_type\": \"bearer\", \
970                    \"scope\": \"openid\"\
971                    }"
972                      .to_string()
973                      .into_bytes())
974                    .unwrap(),
975              ),
976              mock_sleep_fn,
977              None)
978          .unwrap();
979
980        assert_eq!("12/34", token.access_token().secret());
981        assert_eq!(BasicTokenType::Bearer, *token.token_type());
982        assert_eq!(
983            Some(&vec![Scope::new("openid".to_string()),]),
984            token.scopes()
985        );
986        assert_eq!(None, token.expires_in());
987        assert!(token.refresh_token().is_none());
988    }
989
990    #[test]
991    fn test_device_token_authorization_timeout() {
992        let details = new_device_auth_details(2);
993        assert_eq!("12345", details.device_code().secret());
994        assert_eq!("https://verify/here", details.verification_uri().as_str());
995        assert_eq!("abcde", details.user_code().secret().as_str());
996        assert_eq!(
997            "https://verify/here?abcde",
998            details
999                .verification_uri_complete()
1000                .unwrap()
1001                .secret()
1002                .as_str()
1003        );
1004        assert_eq!(Duration::from_secs(2), details.expires_in());
1005        assert_eq!(Duration::from_secs(1), details.interval());
1006
1007        let token = new_client()
1008          .exchange_device_access_token(&details)
1009          .set_time_fn(mock_time_fn())
1010          .request(
1011              &mock_http_client(
1012                  vec![
1013                      (ACCEPT, "application/json"),
1014                      (CONTENT_TYPE, "application/x-www-form-urlencoded"),
1015                      (AUTHORIZATION, "Basic YWFhOmJiYg=="),
1016                  ],
1017                  "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345",
1018                  None,
1019                  Response::builder()
1020                    .status(StatusCode::BAD_REQUEST)
1021                    .header(
1022                        CONTENT_TYPE,
1023                        HeaderValue::from_str("application/json").unwrap(),
1024                    )
1025                    .body("{\
1026                    \"error\": \"authorization_pending\", \
1027                    \"error_description\": \"Still waiting for user\"\
1028                    }"
1029                      .to_string()
1030                      .into_bytes())
1031                    .unwrap(),
1032              ),
1033              mock_sleep_fn,
1034              None)
1035          .err()
1036          .unwrap();
1037        match token {
1038            RequestTokenError::ServerResponse(msg) => assert_eq!(
1039                msg,
1040                DeviceCodeErrorResponse::new(
1041                    DeviceCodeErrorResponseType::ExpiredToken,
1042                    Some(String::from("This device code has expired.")),
1043                    None,
1044                )
1045            ),
1046            _ => unreachable!("Error should be an expiry"),
1047        }
1048    }
1049
1050    #[test]
1051    fn test_device_token_access_denied() {
1052        let details = new_device_auth_details(2);
1053        assert_eq!("12345", details.device_code().secret());
1054        assert_eq!("https://verify/here", details.verification_uri().as_str());
1055        assert_eq!("abcde", details.user_code().secret().as_str());
1056        assert_eq!(
1057            "https://verify/here?abcde",
1058            details
1059                .verification_uri_complete()
1060                .unwrap()
1061                .secret()
1062                .as_str()
1063        );
1064        assert_eq!(Duration::from_secs(2), details.expires_in());
1065        assert_eq!(Duration::from_secs(1), details.interval());
1066
1067        let token = new_client()
1068          .exchange_device_access_token(&details)
1069          .set_time_fn(mock_time_fn())
1070          .request(
1071              &mock_http_client(
1072                  vec![
1073                      (ACCEPT, "application/json"),
1074                      (CONTENT_TYPE, "application/x-www-form-urlencoded"),
1075                      (AUTHORIZATION, "Basic YWFhOmJiYg=="),
1076                  ],
1077                  "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345",
1078                  None,
1079                  Response::builder()
1080                    .status(StatusCode::BAD_REQUEST)
1081                    .header(
1082                        CONTENT_TYPE,
1083                        HeaderValue::from_str("application/json").unwrap(),
1084                    )
1085                    .body("{\
1086                    \"error\": \"access_denied\", \
1087                    \"error_description\": \"Access Denied\"\
1088                    }"
1089                      .to_string()
1090                      .into_bytes())
1091                    .unwrap(),
1092              ),
1093              mock_sleep_fn,
1094              None)
1095          .err()
1096          .unwrap();
1097        match token {
1098            RequestTokenError::ServerResponse(msg) => {
1099                assert_eq!(msg.error(), &DeviceCodeErrorResponseType::AccessDenied)
1100            }
1101            _ => unreachable!("Error should be Access Denied"),
1102        }
1103    }
1104
1105    #[test]
1106    fn test_device_token_expired() {
1107        let details = new_device_auth_details(2);
1108        assert_eq!("12345", details.device_code().secret());
1109        assert_eq!("https://verify/here", details.verification_uri().as_str());
1110        assert_eq!("abcde", details.user_code().secret().as_str());
1111        assert_eq!(
1112            "https://verify/here?abcde",
1113            details
1114                .verification_uri_complete()
1115                .unwrap()
1116                .secret()
1117                .as_str()
1118        );
1119        assert_eq!(Duration::from_secs(2), details.expires_in());
1120        assert_eq!(Duration::from_secs(1), details.interval());
1121
1122        let token = new_client()
1123          .exchange_device_access_token(&details)
1124          .set_time_fn(mock_time_fn())
1125          .request(
1126              &mock_http_client(
1127                  vec![
1128                      (ACCEPT, "application/json"),
1129                      (CONTENT_TYPE, "application/x-www-form-urlencoded"),
1130                      (AUTHORIZATION, "Basic YWFhOmJiYg=="),
1131                  ],
1132                  "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345",
1133                  None,
1134                  Response::builder()
1135                    .status(StatusCode::BAD_REQUEST)
1136                    .header(
1137                        CONTENT_TYPE,
1138                        HeaderValue::from_str("application/json").unwrap(),
1139                    )
1140                    .body("{\
1141                    \"error\": \"expired_token\", \
1142                    \"error_description\": \"Token has expired\"\
1143                    }"
1144                      .to_string()
1145                      .into_bytes())
1146                    .unwrap(),
1147              ),
1148              mock_sleep_fn,
1149              None)
1150          .err()
1151          .unwrap();
1152        match token {
1153            RequestTokenError::ServerResponse(msg) => {
1154                assert_eq!(msg.error(), &DeviceCodeErrorResponseType::ExpiredToken)
1155            }
1156            _ => unreachable!("Error should be ExpiredToken"),
1157        }
1158    }
1159
1160    #[test]
1161    fn test_device_auth_response_default_interval() {
1162        let response: DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields> =
1163            serde_json::from_str(
1164                r#"{
1165                    "device_code": "12345",
1166                    "verification_uri": "https://verify/here",
1167                    "user_code": "abcde",
1168                    "expires_in": 300
1169                }"#,
1170            )
1171            .unwrap();
1172
1173        assert_eq!(response.interval, default_devicecode_interval());
1174    }
1175
1176    #[test]
1177    fn test_device_auth_response_non_default_interval() {
1178        let response: DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields> =
1179            serde_json::from_str(
1180                r#"{
1181                    "device_code": "12345",
1182                    "verification_uri": "https://verify/here",
1183                    "user_code": "abcde",
1184                    "expires_in": 300,
1185                    "interval": 10
1186                }"#,
1187            )
1188            .unwrap();
1189
1190        assert_eq!(response.interval, 10);
1191    }
1192
1193    #[test]
1194    fn test_device_auth_response_null_interval() {
1195        let response: DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields> =
1196            serde_json::from_str(
1197                r#"{
1198                    "device_code": "12345",
1199                    "verification_uri": "https://verify/here",
1200                    "user_code": "abcde",
1201                    "expires_in": 300,
1202                    "interval": null
1203                }"#,
1204            )
1205            .unwrap();
1206
1207        assert_eq!(response.interval, default_devicecode_interval());
1208    }
1209}