use crate::{
AuthType, ClientId, ClientSecret, ErrorResponse, RedirectUrl, RequestTokenError, Scope,
CONTENT_TYPE_FORMENCODED, CONTENT_TYPE_JSON,
};
use base64::prelude::*;
use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use http::{HeaderValue, StatusCode};
use serde::de::DeserializeOwned;
use url::{form_urlencoded, Url};
use std::borrow::Cow;
use std::error::Error;
use std::future::Future;
pub type HttpRequest = http::Request<Vec<u8>>;
pub type HttpResponse = http::Response<Vec<u8>>;
pub trait AsyncHttpClient<'c> {
type Error: Error + 'static;
type Future: Future<Output = Result<HttpResponse, Self::Error>> + 'c;
fn call(&'c self, request: HttpRequest) -> Self::Future;
}
impl<'c, E, F, T> AsyncHttpClient<'c> for T
where
E: Error + 'static,
F: Future<Output = Result<HttpResponse, E>> + 'c,
T: Fn(HttpRequest) -> F,
{
type Error = E;
type Future = F;
fn call(&'c self, request: HttpRequest) -> Self::Future {
self(request)
}
}
pub trait SyncHttpClient {
type Error: Error + 'static;
fn call(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error>;
}
impl<E, T> SyncHttpClient for T
where
E: Error + 'static,
T: Fn(HttpRequest) -> Result<HttpResponse, E>,
{
type Error = E;
fn call(&self, request: HttpRequest) -> Result<HttpResponse, Self::Error> {
self(request)
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn endpoint_request<'a>(
auth_type: &'a AuthType,
client_id: &'a ClientId,
client_secret: Option<&'a ClientSecret>,
extra_params: &'a [(Cow<'a, str>, Cow<'a, str>)],
redirect_url: Option<Cow<'a, RedirectUrl>>,
scopes: Option<&'a Vec<Cow<'a, Scope>>>,
url: &'a Url,
params: Vec<(&'a str, &'a str)>,
) -> Result<HttpRequest, http::Error> {
let mut builder = http::Request::builder()
.uri(url.to_string())
.method(http::Method::POST)
.header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
.header(
CONTENT_TYPE,
HeaderValue::from_static(CONTENT_TYPE_FORMENCODED),
);
let scopes_opt = scopes.and_then(|scopes| {
if !scopes.is_empty() {
Some(
scopes
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>()
.join(" "),
)
} else {
None
}
});
let mut params: Vec<(&str, &str)> = params;
if let Some(ref scopes) = scopes_opt {
params.push(("scope", scopes));
}
match (auth_type, client_secret) {
(AuthType::BasicAuth, Some(secret)) => {
let urlencoded_id: String =
form_urlencoded::byte_serialize(client_id.as_bytes()).collect();
let urlencoded_secret: String =
form_urlencoded::byte_serialize(secret.secret().as_bytes()).collect();
let b64_credential =
BASE64_STANDARD.encode(format!("{}:{}", &urlencoded_id, urlencoded_secret));
builder = builder.header(
AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {}", &b64_credential)).unwrap(),
);
}
(AuthType::RequestBody, _) | (AuthType::BasicAuth, None) => {
params.push(("client_id", client_id));
if let Some(client_secret) = client_secret {
params.push(("client_secret", client_secret.secret()));
}
}
}
if let Some(ref redirect_url) = redirect_url {
params.push(("redirect_uri", redirect_url.as_str()));
}
params.extend_from_slice(
extra_params
.iter()
.map(|(k, v)| (k.as_ref(), v.as_ref()))
.collect::<Vec<_>>()
.as_slice(),
);
let body = form_urlencoded::Serializer::new(String::new())
.extend_pairs(params)
.finish()
.into_bytes();
builder.body(body)
}
pub(crate) fn endpoint_response<RE, TE, DO>(
http_response: HttpResponse,
) -> Result<DO, RequestTokenError<RE, TE>>
where
RE: Error,
TE: ErrorResponse,
DO: DeserializeOwned,
{
check_response_status(&http_response)?;
check_response_body(&http_response)?;
let response_body = http_response.body().as_slice();
serde_path_to_error::deserialize(&mut serde_json::Deserializer::from_slice(response_body))
.map_err(|e| RequestTokenError::Parse(e, response_body.to_vec()))
}
pub(crate) fn endpoint_response_status_only<RE, TE>(
http_response: HttpResponse,
) -> Result<(), RequestTokenError<RE, TE>>
where
RE: Error + 'static,
TE: ErrorResponse,
{
check_response_status(&http_response)
}
fn check_response_status<RE, TE>(
http_response: &HttpResponse,
) -> Result<(), RequestTokenError<RE, TE>>
where
RE: Error + 'static,
TE: ErrorResponse,
{
if http_response.status() != StatusCode::OK {
let reason = http_response.body().as_slice();
if reason.is_empty() {
Err(RequestTokenError::Other(
"server returned empty error response".to_string(),
))
} else {
let error = match serde_path_to_error::deserialize::<_, TE>(
&mut serde_json::Deserializer::from_slice(reason),
) {
Ok(error) => RequestTokenError::ServerResponse(error),
Err(error) => RequestTokenError::Parse(error, reason.to_vec()),
};
Err(error)
}
} else {
Ok(())
}
}
fn check_response_body<RE, TE>(
http_response: &HttpResponse,
) -> Result<(), RequestTokenError<RE, TE>>
where
RE: Error + 'static,
TE: ErrorResponse,
{
http_response
.headers()
.get(CONTENT_TYPE)
.map_or(Ok(()), |content_type|
if content_type.to_str().ok().filter(|ct| ct.to_lowercase().starts_with(CONTENT_TYPE_JSON)).is_none() {
Err(
RequestTokenError::Other(
format!(
"unexpected response Content-Type: {content_type:?}, should be `{CONTENT_TYPE_JSON}`",
)
)
)
} else {
Ok(())
}
)?;
if http_response.body().is_empty() {
return Err(RequestTokenError::Other(
"server returned empty response body".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use crate::tests::{new_client, FakeError};
use crate::{AuthorizationCode, TokenResponse};
use http::{Response, StatusCode};
#[tokio::test]
async fn test_async_client_closure() {
let client = new_client();
let http_response = Response::builder()
.status(StatusCode::OK)
.body(
"{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}"
.to_string()
.into_bytes(),
)
.unwrap();
let token = client
.exchange_code(AuthorizationCode::new("ccc".to_string()))
.request_async(&|_| async { Ok(http_response.clone()) as Result<_, FakeError> })
.await
.unwrap();
assert_eq!("12/34", token.access_token().secret());
}
}