1use crate::{
2 AuthType, ClientId, ClientSecret, ErrorResponse, RedirectUrl, RequestTokenError, Scope,
3 CONTENT_TYPE_FORMENCODED, CONTENT_TYPE_JSON,
4};
5
6use base64::prelude::*;
7use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
8use http::{HeaderValue, StatusCode};
9use serde::de::DeserializeOwned;
10use url::{form_urlencoded, Url};
11
12use std::borrow::Cow;
13use std::error::Error;
14use std::future::Future;
15
16pub type HttpRequest = http::Request<Vec<u8>>;
18
19pub type HttpResponse = http::Response<Vec<u8>>;
21
22pub trait AsyncHttpClient<'c> {
24 type Error: Error + 'static;
26
27 type Future: Future<Output = Result<HttpResponse, Self::Error>> + 'c;
29
30 fn call(&'c self, request: HttpRequest) -> Self::Future;
32}
33impl<'c, E, F, T> AsyncHttpClient<'c> for T
34where
35 E: Error + 'static,
36 F: Future<Output = Result<HttpResponse, E>> + 'c,
37 T: Fn(HttpRequest) -> F,
40{
41 type Error = E;
42 type Future = F;
43
44 fn call(&'c self, request: HttpRequest) -> Self::Future {
45 self(request)
46 }
47}
48
49pub trait SyncHttpClient {
51 type Error: Error + 'static;
53
54 fn call(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error>;
56}
57impl<E, T> SyncHttpClient for T
58where
59 E: Error + 'static,
60 T: Fn(HttpRequest) -> Result<HttpResponse, E>,
63{
64 type Error = E;
65
66 fn call(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error> {
67 self(request)
68 }
69}
70
71#[allow(clippy::too_many_arguments)]
72pub(crate) fn endpoint_request<'a>(
73 auth_type: &'a AuthType,
74 client_id: &'a ClientId,
75 client_secret: Option<&'a ClientSecret>,
76 extra_params: &'a [(Cow<'a, str>, Cow<'a, str>)],
77 redirect_url: Option<Cow<'a, RedirectUrl>>,
78 scopes: Option<&'a Vec<Cow<'a, Scope>>>,
79 url: &'a Url,
80 params: Vec<(&'a str, &'a str)>,
81) -> Result<HttpRequest, http::Error> {
82 let mut builder = http::Request::builder()
83 .uri(url.to_string())
84 .method(http::Method::POST)
85 .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
86 .header(
87 CONTENT_TYPE,
88 HeaderValue::from_static(CONTENT_TYPE_FORMENCODED),
89 );
90
91 let scopes_opt = scopes.and_then(|scopes| {
92 if !scopes.is_empty() {
93 Some(
94 scopes
95 .iter()
96 .map(|s| s.to_string())
97 .collect::<Vec<_>>()
98 .join(" "),
99 )
100 } else {
101 None
102 }
103 });
104
105 let mut params: Vec<(&str, &str)> = params;
106 if let Some(ref scopes) = scopes_opt {
107 params.push(("scope", scopes));
108 }
109
110 match (auth_type, client_secret) {
112 (AuthType::BasicAuth, Some(secret)) => {
115 let urlencoded_id: String =
119 form_urlencoded::byte_serialize(client_id.as_bytes()).collect();
120 let urlencoded_secret: String =
121 form_urlencoded::byte_serialize(secret.secret().as_bytes()).collect();
122 let b64_credential =
123 BASE64_STANDARD.encode(format!("{}:{}", &urlencoded_id, urlencoded_secret));
124 builder = builder.header(
125 AUTHORIZATION,
126 HeaderValue::from_str(&format!("Basic {}", &b64_credential)).unwrap(),
127 );
128 }
129 (AuthType::RequestBody, _) | (AuthType::BasicAuth, None) => {
130 params.push(("client_id", client_id));
131 if let Some(client_secret) = client_secret {
132 params.push(("client_secret", client_secret.secret()));
133 }
134 }
135 }
136
137 if let Some(ref redirect_url) = redirect_url {
138 params.push(("redirect_uri", redirect_url.as_str()));
139 }
140
141 params.extend_from_slice(
142 extra_params
143 .iter()
144 .map(|(k, v)| (k.as_ref(), v.as_ref()))
145 .collect::<Vec<_>>()
146 .as_slice(),
147 );
148
149 let body = form_urlencoded::Serializer::new(String::new())
150 .extend_pairs(params)
151 .finish()
152 .into_bytes();
153
154 builder.body(body)
155}
156
157pub(crate) fn endpoint_response<RE, TE, DO>(
158 http_response: HttpResponse,
159) -> Result<DO, RequestTokenError<RE, TE>>
160where
161 RE: Error,
162 TE: ErrorResponse,
163 DO: DeserializeOwned,
164{
165 check_response_status(&http_response)?;
166
167 check_response_body(&http_response)?;
168
169 let response_body = http_response.body().as_slice();
170 serde_path_to_error::deserialize(&mut serde_json::Deserializer::from_slice(response_body))
171 .map_err(|e| RequestTokenError::Parse(e, response_body.to_vec()))
172}
173
174pub(crate) fn endpoint_response_status_only<RE, TE>(
175 http_response: HttpResponse,
176) -> Result<(), RequestTokenError<RE, TE>>
177where
178 RE: Error + 'static,
179 TE: ErrorResponse,
180{
181 check_response_status(&http_response)
182}
183
184fn check_response_status<RE, TE>(
185 http_response: &HttpResponse,
186) -> Result<(), RequestTokenError<RE, TE>>
187where
188 RE: Error + 'static,
189 TE: ErrorResponse,
190{
191 if http_response.status() != StatusCode::OK {
192 let reason = http_response.body().as_slice();
193 if reason.is_empty() {
194 Err(RequestTokenError::Other(
195 "server returned empty error response".to_string(),
196 ))
197 } else {
198 let error = match serde_path_to_error::deserialize::<_, TE>(
199 &mut serde_json::Deserializer::from_slice(reason),
200 ) {
201 Ok(error) => RequestTokenError::ServerResponse(error),
202 Err(error) => RequestTokenError::Parse(error, reason.to_vec()),
203 };
204 Err(error)
205 }
206 } else {
207 Ok(())
208 }
209}
210
211fn check_response_body<RE, TE>(
212 http_response: &HttpResponse,
213) -> Result<(), RequestTokenError<RE, TE>>
214where
215 RE: Error + 'static,
216 TE: ErrorResponse,
217{
218 http_response
220 .headers()
221 .get(CONTENT_TYPE)
222 .map_or(Ok(()), |content_type|
223 if content_type.to_str().ok().filter(|ct| ct.to_lowercase().starts_with(CONTENT_TYPE_JSON)).is_none() {
227 Err(
228 RequestTokenError::Other(
229 format!(
230 "unexpected response Content-Type: {content_type:?}, should be `{CONTENT_TYPE_JSON}`",
231 )
232 )
233 )
234 } else {
235 Ok(())
236 }
237 )?;
238
239 if http_response.body().is_empty() {
240 return Err(RequestTokenError::Other(
241 "server returned empty response body".to_string(),
242 ));
243 }
244
245 Ok(())
246}
247
248#[cfg(test)]
249mod tests {
250 use crate::tests::{new_client, FakeError};
251 use crate::{AuthorizationCode, TokenResponse};
252
253 use http::{Response, StatusCode};
254
255 #[tokio::test]
256 async fn test_async_client_closure() {
257 let client = new_client();
258
259 let http_response = Response::builder()
260 .status(StatusCode::OK)
261 .body(
262 "{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}"
263 .to_string()
264 .into_bytes(),
265 )
266 .unwrap();
267
268 let token = client
269 .exchange_code(AuthorizationCode::new("ccc".to_string()))
270 .request_async(&|_| async { Ok(http_response.clone()) as Result<_, FakeError> })
272 .await
273 .unwrap();
274
275 assert_eq!("12/34", token.access_token().secret());
276 }
277}