headless_lms_server/domain/oauth/
authorize_query.rs

1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use std::collections::HashMap;
5
6#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
7pub struct AuthorizeQuery {
8    pub response_type: Option<String>,
9    pub client_id: Option<String>,
10    pub redirect_uri: Option<String>,
11    pub scope: Option<String>,
12    pub state: Option<String>,
13    pub nonce: Option<String>,
14    pub code_challenge: Option<String>,
15    pub code_challenge_method: Option<String>,
16
17    // OAuth2.0 spec requires that auth does not fail when there are unknown parameters present,
18    // see RFC 6749 3.1
19    #[serde(flatten)]
20    pub _extra: HashMap<String, String>,
21}
22
23#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
24pub struct AuthorizeParams {
25    pub response_type: String,
26    pub client_id: String,
27    pub redirect_uri: String,
28    pub scope: String,
29    pub state: Option<String>,
30    pub nonce: Option<String>,
31    pub code_challenge: Option<String>,
32    pub code_challenge_method: Option<String>,
33}
34
35// We need to make sure we don't return errors directly, instead we need to return
36// error as success request with error parameters to comply with OAuth.
37impl OAuthValidate for AuthorizeQuery {
38    type Output = AuthorizeParams;
39
40    fn validate(&self) -> Result<Self::Output, ControllerError> {
41        let rt_opt = self.response_type.as_deref();
42        let rt = rt_opt.unwrap_or_default();
43
44        let client_id = self.client_id.as_deref().unwrap_or_default();
45        let redirect_uri = self.redirect_uri.as_deref().unwrap_or_default();
46        let scope = self.scope.as_deref().unwrap_or_default();
47
48        // preserve original state (don't stringify empty -> Some(""))
49        let state_opt = self.state.clone();
50
51        // Required params check
52        if client_id.is_empty() || redirect_uri.is_empty() || scope.is_empty() {
53            return Err(ControllerError::new(
54                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
55                    error: OAuthErrorCode::InvalidRequest.as_str().into(),
56                    error_description: "client_id, redirect_uri, and scope are required".into(),
57                    redirect_uri: None,
58                    state: state_opt.clone(),
59                    nonce: None,
60                })),
61                "Missing required OAuth parameters",
62                None,
63            ));
64        }
65
66        // response_type presence check
67        if rt.is_empty() {
68            return Err(ControllerError::new(
69                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
70                    error: OAuthErrorCode::InvalidRequest.as_str().into(),
71                    error_description: "response_type is required".into(),
72                    redirect_uri: None,
73                    state: state_opt.clone(),
74                    nonce: None,
75                })),
76                "Missing response_type",
77                None::<anyhow::Error>,
78            ));
79        }
80
81        // Only "code" is supported
82        if rt != "code" {
83            return Err(ControllerError::new(
84                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
85                    error: OAuthErrorCode::UnsupportedResponseType.as_str().into(),
86                    error_description: "unsupported response_type".into(),
87                    redirect_uri: None, // add later after client+URI validation
88                    state: state_opt,
89                    nonce: None,
90                })),
91                "Unsupported response_type",
92                None::<anyhow::Error>,
93            ));
94        }
95
96        Ok(AuthorizeParams {
97            response_type: rt.to_string(),
98            client_id: client_id.to_string(),
99            redirect_uri: redirect_uri.to_string(),
100            scope: scope.to_string(),
101            state: self.state.clone(),
102            nonce: self.nonce.clone(),
103            code_challenge: self.code_challenge.clone(),
104            code_challenge_method: self.code_challenge_method.clone(),
105        })
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
113    use serde_json::{Value, json};
114
115    fn assert_oauth_error(
116        result: Result<AuthorizeParams, ControllerError>,
117        expected_error: OAuthErrorCode,
118        expected_description: &str,
119    ) {
120        match result {
121            Err(err) => match err.error_type() {
122                ControllerErrorType::OAuthError(data) => {
123                    assert_eq!(data.error, expected_error.as_str());
124                    assert_eq!(data.error_description, expected_description);
125                }
126                other => panic!("Expected OAuthError, got {:?}", other),
127            },
128            Ok(_) => panic!("Expected Err, got Ok(_)"),
129        }
130    }
131
132    #[test]
133    fn authorize_missing_fields() {
134        let q = AuthorizeQuery {
135            response_type: Some("code".into()),
136            client_id: None,
137            redirect_uri: None,
138            scope: None,
139            state: Some("xyz".into()),
140            nonce: None,
141            code_challenge: None,
142            code_challenge_method: None,
143            _extra: Default::default(),
144        };
145        let res = q.validate();
146        assert_oauth_error(
147            res,
148            OAuthErrorCode::InvalidRequest,
149            "client_id, redirect_uri, and scope are required",
150        );
151    }
152
153    #[test]
154    fn authorize_unsupported_response_type() {
155        let q = AuthorizeQuery {
156            response_type: Some("token".into()),
157            client_id: Some("cid".into()),
158            redirect_uri: Some("http://localhost".into()),
159            scope: Some("openid".into()),
160            state: None,
161            nonce: None,
162            code_challenge: None,
163            code_challenge_method: None,
164            _extra: Default::default(),
165        };
166        let res = q.validate();
167        assert_oauth_error(
168            res,
169            OAuthErrorCode::UnsupportedResponseType,
170            "unsupported response_type",
171        );
172    }
173
174    #[test]
175    fn authorize_valid_code_flow_openid_without_nonce_is_ok() {
176        // For pure code flow, nonce is not required by OIDC core.
177        let q = AuthorizeQuery {
178            response_type: Some("code".into()),
179            client_id: Some("cid".into()),
180            redirect_uri: Some("http://localhost".into()),
181            scope: Some("openid profile".into()),
182            state: Some("s".into()),
183            nonce: None,
184            code_challenge: None,
185            code_challenge_method: None,
186            _extra: Default::default(),
187        };
188        assert!(q.validate().is_ok());
189    }
190
191    #[test]
192    fn authorize_unknown_params_are_captured_in_extra() {
193        let v: Value = json!({
194            "response_type": "code",
195            "client_id": "cid",
196            "redirect_uri": "http://localhost",
197            "scope": "openid",
198            "state": "s",
199            "foo": "bar",
200            "x": "y"
201        });
202        let q: AuthorizeQuery = serde_json::from_value(v).unwrap();
203        assert_eq!(q._extra.get("foo").map(String::as_str), Some("bar"));
204        assert_eq!(q._extra.get("x").map(String::as_str), Some("y"));
205        assert!(q.validate().is_ok());
206    }
207
208    #[test]
209    fn authorize_missing_response_type_is_invalid_request() {
210        let q = AuthorizeQuery {
211            response_type: None,
212            client_id: Some("cid".into()),
213            redirect_uri: Some("http://localhost".into()),
214            scope: Some("openid".into()),
215            state: None,
216            nonce: None,
217            code_challenge: None,
218            code_challenge_method: None,
219            _extra: Default::default(),
220        };
221        let res = q.validate();
222        assert_oauth_error(
223            res,
224            OAuthErrorCode::InvalidRequest,
225            "response_type is required",
226        );
227    }
228
229    #[test]
230    fn authorize_pkce_fields_passthrough() {
231        let q = AuthorizeQuery {
232            response_type: Some("code".into()),
233            client_id: Some("cid".into()),
234            redirect_uri: Some("http://localhost".into()),
235            scope: Some("openid profile".into()),
236            state: Some("s".into()),
237            nonce: Some("n".into()),
238            code_challenge: Some("abcDEF123-_".into()),
239            code_challenge_method: Some("S256".into()),
240            _extra: Default::default(),
241        };
242        let p = q.validate().expect("validate should pass");
243        assert_eq!(p.code_challenge.as_deref(), Some("abcDEF123-_"));
244        assert_eq!(p.code_challenge_method.as_deref(), Some("S256"));
245    }
246}