headless_lms_server/domain/oauth/
oauth_validated.rs

1use actix_web::{
2    Error, FromRequest, HttpRequest,
3    dev::Payload,
4    http::{Method, header},
5    web,
6};
7use futures_util::future::LocalBoxFuture;
8use serde::de::DeserializeOwned;
9
10use super::oauth_validate::OAuthValidate;
11
12/// Wrapper for OAuth related requests.
13pub struct OAuthValidated<Raw: OAuthValidate>(pub <Raw as OAuthValidate>::Output);
14
15impl<Raw> FromRequest for OAuthValidated<Raw>
16where
17    Raw: DeserializeOwned + OAuthValidate + 'static,
18{
19    type Error = Error;
20    type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
21
22    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
23        let req = req.clone();
24        let mut payload = payload.take();
25
26        Box::pin(async move {
27            let raw: Raw = match *req.method() {
28                Method::GET | Method::DELETE => {
29                    web::Query::<Raw>::from_query(req.query_string()).map(|q| q.into_inner())?
30                }
31                _ => {
32                    let ct = req
33                        .headers()
34                        .get(header::CONTENT_TYPE)
35                        .and_then(|v| v.to_str().ok())
36                        .unwrap_or("");
37
38                    if ct.starts_with("application/json") {
39                        web::Json::<Raw>::from_request(&req, &mut payload)
40                            .await
41                            .map(|j| j.into_inner())?
42                    } else {
43                        web::Form::<Raw>::from_request(&req, &mut payload)
44                            .await
45                            .map(|f| f.into_inner())?
46                    }
47                }
48            };
49
50            let out = <Raw as OAuthValidate>::validate(&raw)?;
51
52            Ok(OAuthValidated(out))
53        })
54    }
55}