oauth2/
endpoint.rs

1use crate::{
2    AuthType, ClientId, ClientSecret, ErrorResponse, RedirectUrl, RequestTokenError, Scope,
3    CONTENT_TYPE_FORMENCODED, CONTENT_TYPE_JSON,
4};
5
6use base64::prelude::*;
7use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
8use http::{HeaderValue, StatusCode};
9use serde::de::DeserializeOwned;
10use url::{form_urlencoded, Url};
11
12use std::borrow::Cow;
13use std::error::Error;
14use std::future::Future;
15
16/// An HTTP request.
17pub type HttpRequest = http::Request<Vec<u8>>;
18
19/// An HTTP response.
20pub type HttpResponse = http::Response<Vec<u8>>;
21
22/// An asynchronous (future-based) HTTP client.
23pub trait AsyncHttpClient<'c> {
24    /// Error type returned by HTTP client.
25    type Error: Error + 'static;
26
27    /// Future type returned by HTTP client.
28    type Future: Future<Output = Result<HttpResponse, Self::Error>> + 'c;
29
30    /// Perform a single HTTP request.
31    fn call(&'c self, request: HttpRequest) -> Self::Future;
32}
33impl<'c, E, F, T> AsyncHttpClient<'c> for T
34where
35    E: Error + 'static,
36    F: Future<Output = Result<HttpResponse, E>> + 'c,
37    // We can't implement this for FnOnce because the device authorization flow requires clients to
38    // supportmultiple calls.
39    T: Fn(HttpRequest) -> F,
40{
41    type Error = E;
42    type Future = F;
43
44    fn call(&'c self, request: HttpRequest) -> Self::Future {
45        self(request)
46    }
47}
48
49/// A synchronous (blocking) HTTP client.
50pub trait SyncHttpClient {
51    /// Error type returned by HTTP client.
52    type Error: Error + 'static;
53
54    /// Perform a single HTTP request.
55    fn call(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error>;
56}
57impl<E, T> SyncHttpClient for T
58where
59    E: Error + 'static,
60    // We can't implement this for FnOnce because the device authorization flow requires clients to
61    // support multiple calls.
62    T: Fn(HttpRequest) -> Result<HttpResponse, E>,
63{
64    type Error = E;
65
66    fn call(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error> {
67        self(request)
68    }
69}
70
71#[allow(clippy::too_many_arguments)]
72pub(crate) fn endpoint_request<'a>(
73    auth_type: &'a AuthType,
74    client_id: &'a ClientId,
75    client_secret: Option<&'a ClientSecret>,
76    extra_params: &'a [(Cow<'a, str>, Cow<'a, str>)],
77    redirect_url: Option<Cow<'a, RedirectUrl>>,
78    scopes: Option<&'a Vec<Cow<'a, Scope>>>,
79    url: &'a Url,
80    params: Vec<(&'a str, &'a str)>,
81) -> Result<HttpRequest, http::Error> {
82    let mut builder = http::Request::builder()
83        .uri(url.to_string())
84        .method(http::Method::POST)
85        .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
86        .header(
87            CONTENT_TYPE,
88            HeaderValue::from_static(CONTENT_TYPE_FORMENCODED),
89        );
90
91    let scopes_opt = scopes.and_then(|scopes| {
92        if !scopes.is_empty() {
93            Some(
94                scopes
95                    .iter()
96                    .map(|s| s.to_string())
97                    .collect::<Vec<_>>()
98                    .join(" "),
99            )
100        } else {
101            None
102        }
103    });
104
105    let mut params: Vec<(&str, &str)> = params;
106    if let Some(ref scopes) = scopes_opt {
107        params.push(("scope", scopes));
108    }
109
110    // FIXME: add support for auth extensions? e.g., client_secret_jwt and private_key_jwt
111    match (auth_type, client_secret) {
112        // Basic auth only makes sense when a client secret is provided. Otherwise, always pass the
113        // client ID in the request body.
114        (AuthType::BasicAuth, Some(secret)) => {
115            // Section 2.3.1 of RFC 6749 requires separately url-encoding the id and secret
116            // before using them as HTTP Basic auth username and password. Note that this is
117            // not standard for ordinary Basic auth, so curl won't do it for us.
118            let urlencoded_id: String =
119                form_urlencoded::byte_serialize(client_id.as_bytes()).collect();
120            let urlencoded_secret: String =
121                form_urlencoded::byte_serialize(secret.secret().as_bytes()).collect();
122            let b64_credential =
123                BASE64_STANDARD.encode(format!("{}:{}", &urlencoded_id, urlencoded_secret));
124            builder = builder.header(
125                AUTHORIZATION,
126                HeaderValue::from_str(&format!("Basic {}", &b64_credential)).unwrap(),
127            );
128        }
129        (AuthType::RequestBody, _) | (AuthType::BasicAuth, None) => {
130            params.push(("client_id", client_id));
131            if let Some(client_secret) = client_secret {
132                params.push(("client_secret", client_secret.secret()));
133            }
134        }
135    }
136
137    if let Some(ref redirect_url) = redirect_url {
138        params.push(("redirect_uri", redirect_url.as_str()));
139    }
140
141    params.extend_from_slice(
142        extra_params
143            .iter()
144            .map(|(k, v)| (k.as_ref(), v.as_ref()))
145            .collect::<Vec<_>>()
146            .as_slice(),
147    );
148
149    let body = form_urlencoded::Serializer::new(String::new())
150        .extend_pairs(params)
151        .finish()
152        .into_bytes();
153
154    builder.body(body)
155}
156
157pub(crate) fn endpoint_response<RE, TE, DO>(
158    http_response: HttpResponse,
159) -> Result<DO, RequestTokenError<RE, TE>>
160where
161    RE: Error,
162    TE: ErrorResponse,
163    DO: DeserializeOwned,
164{
165    check_response_status(&http_response)?;
166
167    check_response_body(&http_response)?;
168
169    let response_body = http_response.body().as_slice();
170    serde_path_to_error::deserialize(&mut serde_json::Deserializer::from_slice(response_body))
171        .map_err(|e| RequestTokenError::Parse(e, response_body.to_vec()))
172}
173
174pub(crate) fn endpoint_response_status_only<RE, TE>(
175    http_response: HttpResponse,
176) -> Result<(), RequestTokenError<RE, TE>>
177where
178    RE: Error + 'static,
179    TE: ErrorResponse,
180{
181    check_response_status(&http_response)
182}
183
184fn check_response_status<RE, TE>(
185    http_response: &HttpResponse,
186) -> Result<(), RequestTokenError<RE, TE>>
187where
188    RE: Error + 'static,
189    TE: ErrorResponse,
190{
191    if http_response.status() != StatusCode::OK {
192        let reason = http_response.body().as_slice();
193        if reason.is_empty() {
194            Err(RequestTokenError::Other(
195                "server returned empty error response".to_string(),
196            ))
197        } else {
198            let error = match serde_path_to_error::deserialize::<_, TE>(
199                &mut serde_json::Deserializer::from_slice(reason),
200            ) {
201                Ok(error) => RequestTokenError::ServerResponse(error),
202                Err(error) => RequestTokenError::Parse(error, reason.to_vec()),
203            };
204            Err(error)
205        }
206    } else {
207        Ok(())
208    }
209}
210
211fn check_response_body<RE, TE>(
212    http_response: &HttpResponse,
213) -> Result<(), RequestTokenError<RE, TE>>
214where
215    RE: Error + 'static,
216    TE: ErrorResponse,
217{
218    // Validate that the response Content-Type is JSON.
219    http_response
220    .headers()
221    .get(CONTENT_TYPE)
222    .map_or(Ok(()), |content_type|
223      // Section 3.1.1.1 of RFC 7231 indicates that media types are case-insensitive and
224      // may be followed by optional whitespace and/or a parameter (e.g., charset).
225      // See https://tools.ietf.org/html/rfc7231#section-3.1.1.1.
226      if content_type.to_str().ok().filter(|ct| ct.to_lowercase().starts_with(CONTENT_TYPE_JSON)).is_none() {
227        Err(
228          RequestTokenError::Other(
229            format!(
230              "unexpected response Content-Type: {content_type:?}, should be `{CONTENT_TYPE_JSON}`",
231            )
232          )
233        )
234      } else {
235        Ok(())
236      }
237    )?;
238
239    if http_response.body().is_empty() {
240        return Err(RequestTokenError::Other(
241            "server returned empty response body".to_string(),
242        ));
243    }
244
245    Ok(())
246}
247
248#[cfg(test)]
249mod tests {
250    use crate::tests::{new_client, FakeError};
251    use crate::{AuthorizationCode, TokenResponse};
252
253    use http::{Response, StatusCode};
254
255    #[tokio::test]
256    async fn test_async_client_closure() {
257        let client = new_client();
258
259        let http_response = Response::builder()
260            .status(StatusCode::OK)
261            .body(
262                "{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}"
263                    .to_string()
264                    .into_bytes(),
265            )
266            .unwrap();
267
268        let token = client
269            .exchange_code(AuthorizationCode::new("ccc".to_string()))
270            // NB: This tests that the closure doesn't require a static lifetime.
271            .request_async(&|_| async { Ok(http_response.clone()) as Result<_, FakeError> })
272            .await
273            .unwrap();
274
275        assert_eq!("12/34", token.access_token().secret());
276    }
277}