headless_lms_server/domain/oauth/
token_query.rs

1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use models::library::oauth::GrantTypeName;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
9pub struct TokenQuery {
10    pub client_id: Option<String>,
11    pub client_secret: Option<String>, // optional: public clients won't send this
12    #[serde(flatten)]
13    pub grant: Option<TokenGrant>,
14    // OAuth 2.0 requires unknown params be ignored at /token (RFC 6749 §3.2)
15    #[serde(flatten)]
16    pub _extra: HashMap<String, String>,
17}
18
19#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
20pub struct TokenParams {
21    pub client_id: String,
22    pub client_secret: Option<String>, // carry through; validation for presence is done per-client later
23    pub grant: TokenGrant,
24}
25
26impl OAuthValidate for TokenQuery {
27    type Output = TokenParams;
28
29    fn validate(&self) -> Result<Self::Output, ControllerError> {
30        let client_id = self.client_id.as_deref().unwrap_or_default();
31
32        if client_id.is_empty() {
33            return Err(ControllerError::new(
34                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
35                    error: OAuthErrorCode::InvalidClient.as_str().into(),
36                    error_description: "client_id is required".into(),
37                    redirect_uri: None,
38                    state: None,
39                    nonce: None,
40                })),
41                "Missing client_id",
42                None::<anyhow::Error>,
43            ));
44        }
45
46        // Grant-specific required params
47        match &self.grant {
48            Some(TokenGrant::AuthorizationCode {
49                code,
50                redirect_uri,
51                code_verifier: _,
52            }) => {
53                if code.is_empty() {
54                    return Err(ControllerError::new(
55                        ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
56                            error: OAuthErrorCode::InvalidRequest.as_str().into(),
57                            error_description: "code is required for authorization_code grant"
58                                .into(),
59                            redirect_uri: None,
60                            state: None,
61                            nonce: None,
62                        })),
63                        "Missing authorization code",
64                        None::<anyhow::Error>,
65                    ));
66                }
67                // If redirect_uri is provided, it must not be empty
68                if matches!(redirect_uri.as_deref(), Some("")) {
69                    return Err(ControllerError::new(
70                        ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
71                            error: OAuthErrorCode::InvalidRequest.as_str().into(),
72                            error_description: "redirect_uri must not be empty when provided"
73                                .into(),
74                            redirect_uri: None,
75                            state: None,
76                            nonce: None,
77                        })),
78                        "Empty redirect_uri",
79                        None::<anyhow::Error>,
80                    ));
81                }
82                // PKCE code_verifier is verified at the token handler (if the code had a challenge)
83            }
84            Some(TokenGrant::RefreshToken { refresh_token, .. }) => {
85                if refresh_token.is_empty() {
86                    return Err(ControllerError::new(
87                        ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
88                            error: OAuthErrorCode::InvalidRequest.as_str().into(),
89                            error_description: "refresh_token is required".into(),
90                            redirect_uri: None,
91                            state: None,
92                            nonce: None,
93                        })),
94                        "Missing refresh token",
95                        None::<anyhow::Error>,
96                    ));
97                }
98            }
99            Some(TokenGrant::Unknown) => {
100                return Err(ControllerError::new(
101                    ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
102                        error: OAuthErrorCode::UnsupportedGrantType.as_str().into(),
103                        error_description: "unsupported grant type".into(),
104                        redirect_uri: None,
105                        state: None,
106                        nonce: None,
107                    })),
108                    "Unsupported grant type",
109                    None::<anyhow::Error>,
110                ));
111            }
112            None => {
113                return Err(ControllerError::new(
114                    ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
115                        error: OAuthErrorCode::InvalidRequest.as_str().into(),
116                        error_description: "grant_type is required".into(),
117                        redirect_uri: None,
118                        state: None,
119                        nonce: None,
120                    })),
121                    "Missing grant type",
122                    None::<anyhow::Error>,
123                ));
124            }
125        }
126
127        Ok(TokenParams {
128            client_id: client_id.to_string(),
129            client_secret: self.client_secret.clone(), // may be None for public clients
130            grant: self.grant.clone().unwrap(),        // safe due to the match above
131        })
132    }
133}
134
135#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
136#[serde(tag = "grant_type", rename_all = "snake_case")]
137pub enum TokenGrant {
138    AuthorizationCode {
139        code: String,
140        /// Optional per RFC 6749 §4.1.3 (required if it was present in the authorization request)
141        redirect_uri: Option<String>,
142        /// Optional; enforced later if the code stored a challenge
143        code_verifier: Option<String>,
144    },
145    RefreshToken {
146        refresh_token: String,
147        /// Optional down-scope
148        #[serde(default)]
149        scope: Option<String>,
150    },
151    #[serde(other)]
152    Unknown,
153}
154
155impl TokenGrant {
156    pub fn kind(&self) -> GrantTypeName {
157        match self {
158            TokenGrant::AuthorizationCode { .. } => GrantTypeName::AuthorizationCode,
159            TokenGrant::RefreshToken { .. } => GrantTypeName::RefreshToken,
160            TokenGrant::Unknown => {
161                unreachable!(
162                    "Unknown grant type should be caught by validate() before kind() is called"
163                )
164            }
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
173    use serde_json::{Value, json};
174
175    fn assert_oauth_error(
176        result: Result<TokenParams, ControllerError>,
177        expected_error: OAuthErrorCode,
178        expected_description: &str,
179    ) {
180        match result {
181            Err(err) => match err.error_type() {
182                ControllerErrorType::OAuthError(data) => {
183                    assert_eq!(data.error, expected_error.as_str());
184                    assert_eq!(data.error_description, expected_description);
185                }
186                other => panic!("Expected OAuthError, got {:?}", other),
187            },
188            Ok(_) => panic!("Expected Err, got Ok(())"),
189        }
190    }
191
192    #[test]
193    fn token_missing_client_id() {
194        let q = TokenQuery {
195            client_id: None,
196            client_secret: None,
197            grant: None,
198            _extra: Default::default(),
199        };
200        let res = q.validate();
201        assert_oauth_error(res, OAuthErrorCode::InvalidClient, "client_id is required");
202    }
203
204    #[test]
205    fn token_public_client_without_secret_is_ok() {
206        let q = TokenQuery {
207            client_id: Some("cid".into()),
208            client_secret: None,
209            grant: Some(TokenGrant::RefreshToken {
210                refresh_token: "rt".into(),
211                scope: None,
212            }),
213            _extra: Default::default(),
214        };
215        assert!(q.validate().is_ok());
216    }
217
218    #[test]
219    fn token_missing_grant_type() {
220        let q = TokenQuery {
221            client_id: Some("cid".into()),
222            client_secret: Some("sec".into()),
223            grant: None,
224            _extra: Default::default(),
225        };
226        let res = q.validate();
227        assert_oauth_error(
228            res,
229            OAuthErrorCode::InvalidRequest,
230            "grant_type is required",
231        );
232    }
233
234    #[test]
235    fn token_auth_code_missing_code() {
236        let q = TokenQuery {
237            client_id: Some("cid".into()),
238            client_secret: Some("sec".into()),
239            grant: Some(TokenGrant::AuthorizationCode {
240                code: "".into(),
241                redirect_uri: Some("http://localhost".into()),
242                code_verifier: None,
243            }),
244            _extra: Default::default(),
245        };
246        let res = q.validate();
247        assert_oauth_error(
248            res,
249            OAuthErrorCode::InvalidRequest,
250            "code is required for authorization_code grant",
251        );
252    }
253
254    #[test]
255    fn token_auth_code_empty_redirect_uri_is_invalid() {
256        let q = TokenQuery {
257            client_id: Some("cid".into()),
258            client_secret: Some("sec".into()),
259            grant: Some(TokenGrant::AuthorizationCode {
260                code: "C".into(),
261                redirect_uri: Some("".into()),
262                code_verifier: None,
263            }),
264            _extra: Default::default(),
265        };
266        let res = q.validate();
267        assert_oauth_error(
268            res,
269            OAuthErrorCode::InvalidRequest,
270            "redirect_uri must not be empty when provided",
271        );
272    }
273
274    #[test]
275    fn token_auth_code_minimal_ok_without_redirect_uri_or_pkce() {
276        // Allowed by validator; actual PKCE/redirect checks happen in handler.
277        let q = TokenQuery {
278            client_id: Some("cid".into()),
279            client_secret: Some("sec".into()),
280            grant: Some(TokenGrant::AuthorizationCode {
281                code: "C".into(),
282                redirect_uri: None,
283                code_verifier: None,
284            }),
285            _extra: Default::default(),
286        };
287        assert!(q.validate().is_ok());
288    }
289
290    #[test]
291    fn token_auth_code_with_pkce_ok() {
292        let q = TokenQuery {
293            client_id: Some("cid".into()),
294            client_secret: Some("sec".into()),
295            grant: Some(TokenGrant::AuthorizationCode {
296                code: "C".into(),
297                redirect_uri: Some("http://localhost".into()),
298                code_verifier: Some("verifier".into()),
299            }),
300            _extra: Default::default(),
301        };
302        assert!(q.validate().is_ok());
303    }
304
305    #[test]
306    fn token_refresh_missing_field() {
307        let q = TokenQuery {
308            client_id: Some("cid".into()),
309            client_secret: Some("sec".into()),
310            grant: Some(TokenGrant::RefreshToken {
311                refresh_token: "".into(),
312                scope: None,
313            }),
314            _extra: Default::default(),
315        };
316        let res = q.validate();
317        assert_oauth_error(
318            res,
319            OAuthErrorCode::InvalidRequest,
320            "refresh_token is required",
321        );
322    }
323
324    #[test]
325    fn token_valid_auth_code() {
326        let q = TokenQuery {
327            client_id: Some("cid".into()),
328            client_secret: Some("sec".into()),
329            grant: Some(TokenGrant::AuthorizationCode {
330                code: "abc".into(),
331                redirect_uri: Some("http://localhost".into()),
332                code_verifier: None,
333            }),
334            _extra: Default::default(),
335        };
336        assert!(q.validate().is_ok());
337    }
338
339    #[test]
340    fn token_valid_refresh_token() {
341        let q = TokenQuery {
342            client_id: Some("cid".into()),
343            client_secret: Some("sec".into()),
344            grant: Some(TokenGrant::RefreshToken {
345                refresh_token: "r1".into(),
346                scope: None,
347            }),
348            _extra: Default::default(),
349        };
350        assert!(q.validate().is_ok());
351    }
352
353    #[test]
354    fn token_unknown_params_are_captured_in_extra() {
355        let v: Value = json!({
356            "client_id": "cid",
357            "client_secret": "sec",
358            "grant_type": "refresh_token",
359            "refresh_token": "rt",
360            "extra_param": "zzz"
361        });
362        let q: TokenQuery = serde_json::from_value(v).unwrap();
363        assert_eq!(q._extra.get("extra_param").map(String::as_str), Some("zzz"));
364        assert!(q.validate().is_ok());
365    }
366
367    #[test]
368    fn token_grant_tagging_deserializes_properly() {
369        // authorization_code branch
370        let ac: TokenQuery = serde_json::from_value(json!({
371            "client_id": "cid",
372            "client_secret": "sec",
373            "grant_type": "authorization_code",
374            "code": "C",
375            "redirect_uri": "http://localhost",
376            "code_verifier": "ver"
377        }))
378        .unwrap();
379        match ac.grant {
380            Some(TokenGrant::AuthorizationCode {
381                code,
382                redirect_uri,
383                code_verifier,
384            }) => {
385                assert_eq!(code, "C");
386                assert_eq!(redirect_uri.as_deref(), Some("http://localhost"));
387                assert_eq!(code_verifier.as_deref(), Some("ver"));
388            }
389            _ => panic!("expected AuthorizationCode"),
390        }
391
392        // refresh_token branch
393        let rt: TokenQuery = serde_json::from_value(json!({
394            "client_id": "cid",
395            "client_secret": "sec",
396            "grant_type": "refresh_token",
397            "refresh_token": "R",
398            "scope": "read write"
399        }))
400        .unwrap();
401        match rt.grant {
402            Some(TokenGrant::RefreshToken {
403                refresh_token,
404                scope,
405            }) => {
406                assert_eq!(refresh_token, "R");
407                assert_eq!(scope.as_deref(), Some("read write"));
408            }
409            _ => panic!("expected RefreshToken"),
410        }
411    }
412}