headless_lms_server/domain/oauth/
authorize_query.rs

1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use std::collections::HashMap;
5
6#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
7pub struct AuthorizeQuery {
8    pub response_type: Option<String>,
9    pub client_id: Option<String>,
10    pub redirect_uri: Option<String>,
11    pub scope: Option<String>,
12    pub state: Option<String>,
13    pub nonce: Option<String>,
14    pub code_challenge: Option<String>,
15    pub code_challenge_method: Option<String>,
16
17    pub prompt: Option<String>,
18    // Request is not supported, but we capture it so we can return "not supported" error
19    pub request: Option<String>,
20
21    // OAuth2.0 spec requires that auth does not fail when there are unknown parameters present,
22    // see RFC 6749 3.1
23    #[serde(flatten)]
24    pub _extra: HashMap<String, String>,
25}
26
27#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
28pub struct AuthorizeParams {
29    pub response_type: String,
30    pub client_id: String,
31    pub redirect_uri: String,
32    pub scope: String,
33    pub state: Option<String>,
34    pub nonce: Option<String>,
35    pub code_challenge: Option<String>,
36    pub code_challenge_method: Option<String>,
37    pub prompt: Option<String>,
38    pub request: Option<String>,
39}
40
41// We need to make sure we don't return errors directly, instead we need to return
42// error as success request with error parameters to comply with OAuth.
43impl OAuthValidate for AuthorizeQuery {
44    type Output = AuthorizeParams;
45
46    fn validate(&self) -> Result<Self::Output, ControllerError> {
47        let rt_opt = self.response_type.as_deref();
48        let rt = rt_opt.unwrap_or_default();
49
50        let client_id = self.client_id.as_deref().unwrap_or_default();
51        let redirect_uri = self.redirect_uri.as_deref().unwrap_or_default();
52        let scope = self.scope.as_deref().unwrap_or_default();
53
54        // preserve original state (don't stringify empty -> Some(""))
55        let state_opt = self.state.clone();
56
57        // Required params check
58        if client_id.is_empty() || redirect_uri.is_empty() || scope.is_empty() {
59            return Err(ControllerError::new(
60                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
61                    error: OAuthErrorCode::InvalidRequest.as_str().into(),
62                    error_description: "client_id, redirect_uri, and scope are required".into(),
63                    redirect_uri: None,
64                    state: state_opt.clone(),
65                    nonce: None,
66                })),
67                "Missing required OAuth parameters",
68                None,
69            ));
70        }
71
72        // response_type presence check
73        if rt.is_empty() {
74            return Err(ControllerError::new(
75                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
76                    error: OAuthErrorCode::InvalidRequest.as_str().into(),
77                    error_description: "response_type is required".into(),
78                    redirect_uri: None,
79                    state: state_opt.clone(),
80                    nonce: None,
81                })),
82                "Missing response_type",
83                None::<anyhow::Error>,
84            ));
85        }
86
87        // Only "code" is supported
88        if rt != "code" {
89            return Err(ControllerError::new(
90                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
91                    error: OAuthErrorCode::UnsupportedResponseType.as_str().into(),
92                    error_description: "unsupported response_type".into(),
93                    redirect_uri: None, // add later after client+URI validation
94                    state: state_opt,
95                    nonce: None,
96                })),
97                "Unsupported response_type",
98                None::<anyhow::Error>,
99            ));
100        }
101
102        Ok(AuthorizeParams {
103            response_type: rt.to_string(),
104            client_id: client_id.to_string(),
105            redirect_uri: redirect_uri.to_string(),
106            scope: scope.to_string(),
107            state: self.state.clone(),
108            nonce: self.nonce.clone(),
109            code_challenge: self.code_challenge.clone(),
110            code_challenge_method: self.code_challenge_method.clone(),
111            prompt: self.prompt.clone(),
112            request: self.request.clone(),
113        })
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
121    use serde_json::{Value, json};
122
123    fn assert_oauth_error(
124        result: Result<AuthorizeParams, ControllerError>,
125        expected_error: OAuthErrorCode,
126        expected_description: &str,
127    ) {
128        match result {
129            Err(err) => match err.error_type() {
130                ControllerErrorType::OAuthError(data) => {
131                    assert_eq!(data.error, expected_error.as_str());
132                    assert_eq!(data.error_description, expected_description);
133                }
134                other => panic!("Expected OAuthError, got {:?}", other),
135            },
136            Ok(_) => panic!("Expected Err, got Ok(_)"),
137        }
138    }
139
140    #[test]
141    fn authorize_missing_fields() {
142        let q = AuthorizeQuery {
143            response_type: Some("code".into()),
144            client_id: None,
145            redirect_uri: None,
146            scope: None,
147            state: Some("xyz".into()),
148            nonce: None,
149            code_challenge: None,
150            code_challenge_method: None,
151            prompt: None,
152            request: None,
153            _extra: Default::default(),
154        };
155        let res = q.validate();
156        assert_oauth_error(
157            res,
158            OAuthErrorCode::InvalidRequest,
159            "client_id, redirect_uri, and scope are required",
160        );
161    }
162
163    #[test]
164    fn authorize_unsupported_response_type() {
165        let q = AuthorizeQuery {
166            response_type: Some("token".into()),
167            client_id: Some("cid".into()),
168            redirect_uri: Some("http://localhost".into()),
169            scope: Some("openid".into()),
170            state: None,
171            nonce: None,
172            code_challenge: None,
173            code_challenge_method: None,
174            prompt: None,
175            request: None,
176            _extra: Default::default(),
177        };
178        let res = q.validate();
179        assert_oauth_error(
180            res,
181            OAuthErrorCode::UnsupportedResponseType,
182            "unsupported response_type",
183        );
184    }
185
186    #[test]
187    fn authorize_valid_code_flow_openid_without_nonce_is_ok() {
188        // For pure code flow, nonce is not required by OIDC core.
189        let q = AuthorizeQuery {
190            response_type: Some("code".into()),
191            client_id: Some("cid".into()),
192            redirect_uri: Some("http://localhost".into()),
193            scope: Some("openid profile".into()),
194            state: Some("s".into()),
195            nonce: None,
196            code_challenge: None,
197            code_challenge_method: None,
198            prompt: None,
199            request: None,
200            _extra: Default::default(),
201        };
202        assert!(q.validate().is_ok());
203    }
204
205    #[test]
206    fn authorize_unknown_params_are_captured_in_extra() {
207        let v: Value = json!({
208            "response_type": "code",
209            "client_id": "cid",
210            "redirect_uri": "http://localhost",
211            "scope": "openid",
212            "state": "s",
213            "foo": "bar",
214            "x": "y"
215        });
216        let q: AuthorizeQuery = serde_json::from_value(v).unwrap();
217        assert_eq!(q._extra.get("foo").map(String::as_str), Some("bar"));
218        assert_eq!(q._extra.get("x").map(String::as_str), Some("y"));
219        assert!(q.validate().is_ok());
220    }
221
222    #[test]
223    fn authorize_missing_response_type_is_invalid_request() {
224        let q = AuthorizeQuery {
225            response_type: None,
226            client_id: Some("cid".into()),
227            redirect_uri: Some("http://localhost".into()),
228            scope: Some("openid".into()),
229            state: None,
230            nonce: None,
231            code_challenge: None,
232            code_challenge_method: None,
233            prompt: None,
234            request: None,
235            _extra: Default::default(),
236        };
237        let res = q.validate();
238        assert_oauth_error(
239            res,
240            OAuthErrorCode::InvalidRequest,
241            "response_type is required",
242        );
243    }
244
245    #[test]
246    fn authorize_pkce_fields_passthrough() {
247        let q = AuthorizeQuery {
248            response_type: Some("code".into()),
249            client_id: Some("cid".into()),
250            redirect_uri: Some("http://localhost".into()),
251            scope: Some("openid profile".into()),
252            state: Some("s".into()),
253            nonce: Some("n".into()),
254            code_challenge: Some("abcDEF123-_".into()),
255            code_challenge_method: Some("S256".into()),
256            prompt: None,
257            request: None,
258            _extra: Default::default(),
259        };
260        let p = q.validate().expect("validate should pass");
261        assert_eq!(p.code_challenge.as_deref(), Some("abcDEF123-_"));
262        assert_eq!(p.code_challenge_method.as_deref(), Some("S256"));
263    }
264}