Skip to main content

headless_lms_server/domain/oauth/
token_query.rs

1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use models::library::oauth::GrantTypeName;
5use secrecy::{ExposeSecret, SecretString};
6use serde::Deserialize;
7use std::collections::HashMap;
8
9#[derive(Debug, Deserialize, Clone, Default)]
10pub struct TokenQuery {
11    pub client_id: Option<String>,
12    pub client_secret: Option<SecretString>, // optional: public clients won't send this
13    #[serde(flatten)]
14    pub grant: Option<TokenGrant>,
15    // OAuth 2.0 requires unknown params be ignored at /token (RFC 6749 §3.2)
16    #[serde(flatten)]
17    pub _extra: HashMap<String, String>,
18}
19
20#[derive(Debug, Clone)]
21pub struct TokenParams {
22    pub client_id: String,
23    pub client_secret: Option<SecretString>, // carry through; validation for presence is done per-client later
24    pub grant: TokenGrant,
25}
26
27impl OAuthValidate for TokenQuery {
28    type Output = TokenParams;
29
30    fn validate(&self) -> Result<Self::Output, ControllerError> {
31        let client_id = self.client_id.as_deref().unwrap_or_default();
32
33        if client_id.is_empty() {
34            return Err(ControllerError::new(
35                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
36                    error: OAuthErrorCode::InvalidClient.as_str().into(),
37                    error_description: "client_id is required".into(),
38                    redirect_uri: None,
39                    state: None,
40                    nonce: None,
41                })),
42                "Missing client_id",
43                None::<anyhow::Error>,
44            ));
45        }
46
47        // Grant-specific required params
48        let grant = match self.grant.clone() {
49            Some(grant @ TokenGrant::AuthorizationCode { .. }) => {
50                if let TokenGrant::AuthorizationCode {
51                    code, redirect_uri, ..
52                } = &grant
53                {
54                    if code.expose_secret().is_empty() {
55                        return Err(ControllerError::new(
56                            ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
57                                error: OAuthErrorCode::InvalidRequest.as_str().into(),
58                                error_description: "code is required for authorization_code grant"
59                                    .into(),
60                                redirect_uri: None,
61                                state: None,
62                                nonce: None,
63                            })),
64                            "Missing authorization code",
65                            None::<anyhow::Error>,
66                        ));
67                    }
68                    // If redirect_uri is provided, it must not be empty
69                    if matches!(redirect_uri.as_deref(), Some("")) {
70                        return Err(ControllerError::new(
71                            ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
72                                error: OAuthErrorCode::InvalidRequest.as_str().into(),
73                                error_description: "redirect_uri must not be empty when provided"
74                                    .into(),
75                                redirect_uri: None,
76                                state: None,
77                                nonce: None,
78                            })),
79                            "Empty redirect_uri",
80                            None::<anyhow::Error>,
81                        ));
82                    }
83                }
84                // PKCE code_verifier is verified at the token handler (if the code had a challenge)
85                grant
86            }
87            Some(grant @ TokenGrant::RefreshToken { .. }) => {
88                if let TokenGrant::RefreshToken { refresh_token, .. } = &grant
89                    && refresh_token.expose_secret().is_empty()
90                {
91                    return Err(ControllerError::new(
92                        ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
93                            error: OAuthErrorCode::InvalidRequest.as_str().into(),
94                            error_description: "refresh_token is required".into(),
95                            redirect_uri: None,
96                            state: None,
97                            nonce: None,
98                        })),
99                        "Missing refresh token",
100                        None::<anyhow::Error>,
101                    ));
102                }
103                grant
104            }
105            Some(TokenGrant::Unknown) => {
106                return Err(ControllerError::new(
107                    ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
108                        error: OAuthErrorCode::UnsupportedGrantType.as_str().into(),
109                        error_description: "unsupported grant type".into(),
110                        redirect_uri: None,
111                        state: None,
112                        nonce: None,
113                    })),
114                    "Unsupported grant type",
115                    None::<anyhow::Error>,
116                ));
117            }
118            None => {
119                return Err(ControllerError::new(
120                    ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
121                        error: OAuthErrorCode::InvalidRequest.as_str().into(),
122                        error_description: "grant_type is required".into(),
123                        redirect_uri: None,
124                        state: None,
125                        nonce: None,
126                    })),
127                    "Missing grant type",
128                    None::<anyhow::Error>,
129                ));
130            }
131        };
132
133        Ok(TokenParams {
134            client_id: client_id.to_string(),
135            client_secret: self.client_secret.clone(), // may be None for public clients
136            grant,
137        })
138    }
139}
140
141#[derive(Debug, Deserialize, Clone)]
142#[serde(tag = "grant_type", rename_all = "snake_case")]
143pub enum TokenGrant {
144    AuthorizationCode {
145        code: SecretString,
146        /// Optional per RFC 6749 §4.1.3 (required if it was present in the authorization request)
147        redirect_uri: Option<String>,
148        /// Optional; enforced later if the code stored a challenge
149        code_verifier: Option<SecretString>,
150    },
151    RefreshToken {
152        refresh_token: SecretString,
153        /// Optional down-scope
154        #[serde(default)]
155        scope: Option<String>,
156    },
157    #[serde(other)]
158    Unknown,
159}
160
161impl TokenGrant {
162    pub fn kind(&self) -> GrantTypeName {
163        match self {
164            TokenGrant::AuthorizationCode { .. } => GrantTypeName::AuthorizationCode,
165            TokenGrant::RefreshToken { .. } => GrantTypeName::RefreshToken,
166            TokenGrant::Unknown => {
167                unreachable!(
168                    "Unknown grant type should be caught by validate() before kind() is called"
169                )
170            }
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
179    use serde_json::{Value, json};
180
181    fn assert_oauth_error(
182        result: Result<TokenParams, ControllerError>,
183        expected_error: OAuthErrorCode,
184        expected_description: &str,
185    ) {
186        match result {
187            Err(err) => match err.error_type() {
188                ControllerErrorType::OAuthError(data) => {
189                    assert_eq!(data.error, expected_error.as_str());
190                    assert_eq!(data.error_description, expected_description);
191                }
192                other => panic!("Expected OAuthError, got {:?}", other),
193            },
194            Ok(_) => panic!("Expected Err, got Ok(())"),
195        }
196    }
197
198    #[test]
199    fn token_missing_client_id() {
200        let q = TokenQuery {
201            client_id: None,
202            client_secret: None,
203            grant: None,
204            _extra: Default::default(),
205        };
206        let res = q.validate();
207        assert_oauth_error(res, OAuthErrorCode::InvalidClient, "client_id is required");
208    }
209
210    #[test]
211    fn token_public_client_without_secret_is_ok() {
212        let q = TokenQuery {
213            client_id: Some("cid".into()),
214            client_secret: None,
215            grant: Some(TokenGrant::RefreshToken {
216                refresh_token: "rt".into(),
217                scope: None,
218            }),
219            _extra: Default::default(),
220        };
221        assert!(q.validate().is_ok());
222    }
223
224    #[test]
225    fn token_missing_grant_type() {
226        let q = TokenQuery {
227            client_id: Some("cid".into()),
228            client_secret: Some("sec".into()),
229            grant: None,
230            _extra: Default::default(),
231        };
232        let res = q.validate();
233        assert_oauth_error(
234            res,
235            OAuthErrorCode::InvalidRequest,
236            "grant_type is required",
237        );
238    }
239
240    #[test]
241    fn token_auth_code_missing_code() {
242        let q = TokenQuery {
243            client_id: Some("cid".into()),
244            client_secret: Some("sec".into()),
245            grant: Some(TokenGrant::AuthorizationCode {
246                code: "".into(),
247                redirect_uri: Some("http://localhost".into()),
248                code_verifier: None,
249            }),
250            _extra: Default::default(),
251        };
252        let res = q.validate();
253        assert_oauth_error(
254            res,
255            OAuthErrorCode::InvalidRequest,
256            "code is required for authorization_code grant",
257        );
258    }
259
260    #[test]
261    fn token_auth_code_empty_redirect_uri_is_invalid() {
262        let q = TokenQuery {
263            client_id: Some("cid".into()),
264            client_secret: Some("sec".into()),
265            grant: Some(TokenGrant::AuthorizationCode {
266                code: "C".into(),
267                redirect_uri: Some("".into()),
268                code_verifier: None,
269            }),
270            _extra: Default::default(),
271        };
272        let res = q.validate();
273        assert_oauth_error(
274            res,
275            OAuthErrorCode::InvalidRequest,
276            "redirect_uri must not be empty when provided",
277        );
278    }
279
280    #[test]
281    fn token_auth_code_minimal_ok_without_redirect_uri_or_pkce() {
282        // Allowed by validator; actual PKCE/redirect checks happen in handler.
283        let q = TokenQuery {
284            client_id: Some("cid".into()),
285            client_secret: Some("sec".into()),
286            grant: Some(TokenGrant::AuthorizationCode {
287                code: "C".into(),
288                redirect_uri: None,
289                code_verifier: None,
290            }),
291            _extra: Default::default(),
292        };
293        assert!(q.validate().is_ok());
294    }
295
296    #[test]
297    fn token_auth_code_with_pkce_ok() {
298        let q = TokenQuery {
299            client_id: Some("cid".into()),
300            client_secret: Some("sec".into()),
301            grant: Some(TokenGrant::AuthorizationCode {
302                code: "C".into(),
303                redirect_uri: Some("http://localhost".into()),
304                code_verifier: Some("verifier".into()),
305            }),
306            _extra: Default::default(),
307        };
308        assert!(q.validate().is_ok());
309    }
310
311    #[test]
312    fn token_refresh_missing_field() {
313        let q = TokenQuery {
314            client_id: Some("cid".into()),
315            client_secret: Some("sec".into()),
316            grant: Some(TokenGrant::RefreshToken {
317                refresh_token: "".into(),
318                scope: None,
319            }),
320            _extra: Default::default(),
321        };
322        let res = q.validate();
323        assert_oauth_error(
324            res,
325            OAuthErrorCode::InvalidRequest,
326            "refresh_token is required",
327        );
328    }
329
330    #[test]
331    fn token_valid_auth_code() {
332        let q = TokenQuery {
333            client_id: Some("cid".into()),
334            client_secret: Some("sec".into()),
335            grant: Some(TokenGrant::AuthorizationCode {
336                code: "abc".into(),
337                redirect_uri: Some("http://localhost".into()),
338                code_verifier: None,
339            }),
340            _extra: Default::default(),
341        };
342        assert!(q.validate().is_ok());
343    }
344
345    #[test]
346    fn token_valid_refresh_token() {
347        let q = TokenQuery {
348            client_id: Some("cid".into()),
349            client_secret: Some("sec".into()),
350            grant: Some(TokenGrant::RefreshToken {
351                refresh_token: "r1".into(),
352                scope: None,
353            }),
354            _extra: Default::default(),
355        };
356        assert!(q.validate().is_ok());
357    }
358
359    #[test]
360    fn token_unknown_params_are_captured_in_extra() {
361        let v: Value = json!({
362            "client_id": "cid",
363            "client_secret": "sec",
364            "grant_type": "refresh_token",
365            "refresh_token": "rt",
366            "extra_param": "zzz"
367        });
368        let q: TokenQuery = serde_json::from_value(v).unwrap();
369        assert_eq!(q._extra.get("extra_param").map(String::as_str), Some("zzz"));
370        assert!(q.validate().is_ok());
371    }
372
373    #[test]
374    fn token_grant_tagging_deserializes_properly() {
375        // authorization_code branch
376        let ac: TokenQuery = serde_json::from_value(json!({
377            "client_id": "cid",
378            "client_secret": "sec",
379            "grant_type": "authorization_code",
380            "code": "C",
381            "redirect_uri": "http://localhost",
382            "code_verifier": "ver"
383        }))
384        .unwrap();
385        match ac.grant {
386            Some(TokenGrant::AuthorizationCode {
387                code,
388                redirect_uri,
389                code_verifier,
390            }) => {
391                assert_eq!(code.expose_secret(), "C");
392                assert_eq!(redirect_uri.as_deref(), Some("http://localhost"));
393                assert_eq!(
394                    code_verifier.as_ref().map(|v| v.expose_secret()),
395                    Some("ver")
396                );
397            }
398            _ => panic!("expected AuthorizationCode"),
399        }
400
401        // refresh_token branch
402        let rt: TokenQuery = serde_json::from_value(json!({
403            "client_id": "cid",
404            "client_secret": "sec",
405            "grant_type": "refresh_token",
406            "refresh_token": "R",
407            "scope": "read write"
408        }))
409        .unwrap();
410        match rt.grant {
411            Some(TokenGrant::RefreshToken {
412                refresh_token,
413                scope,
414            }) => {
415                assert_eq!(refresh_token.expose_secret(), "R");
416                assert_eq!(scope.as_deref(), Some("read write"));
417            }
418            _ => panic!("expected RefreshToken"),
419        }
420    }
421}