headless_lms_server/domain/oauth/
token_query.rs

1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use models::library::oauth::GrantTypeName;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
9pub struct TokenQuery {
10    pub client_id: Option<String>,
11    pub client_secret: Option<String>, // optional: public clients won't send this
12    #[serde(flatten)]
13    pub grant: Option<TokenGrant>,
14    // OAuth 2.0 requires unknown params be ignored at /token (RFC 6749 §3.2)
15    #[serde(flatten)]
16    pub _extra: HashMap<String, String>,
17}
18
19#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
20pub struct TokenParams {
21    pub client_id: String,
22    pub client_secret: Option<String>, // carry through; validation for presence is done per-client later
23    pub grant: TokenGrant,
24}
25
26impl OAuthValidate for TokenQuery {
27    type Output = TokenParams;
28
29    fn validate(&self) -> Result<Self::Output, ControllerError> {
30        let client_id = self.client_id.as_deref().unwrap_or_default();
31
32        if client_id.is_empty() {
33            return Err(ControllerError::new(
34                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
35                    error: OAuthErrorCode::InvalidClient.as_str().into(),
36                    error_description: "client_id is required".into(),
37                    redirect_uri: None,
38                    state: None,
39                    nonce: None,
40                })),
41                "Missing client_id",
42                None::<anyhow::Error>,
43            ));
44        }
45
46        // Grant-specific required params
47        let grant = match self.grant.clone() {
48            Some(grant @ TokenGrant::AuthorizationCode { .. }) => {
49                if let TokenGrant::AuthorizationCode {
50                    code, redirect_uri, ..
51                } = &grant
52                {
53                    if code.is_empty() {
54                        return Err(ControllerError::new(
55                            ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
56                                error: OAuthErrorCode::InvalidRequest.as_str().into(),
57                                error_description: "code is required for authorization_code grant"
58                                    .into(),
59                                redirect_uri: None,
60                                state: None,
61                                nonce: None,
62                            })),
63                            "Missing authorization code",
64                            None::<anyhow::Error>,
65                        ));
66                    }
67                    // If redirect_uri is provided, it must not be empty
68                    if matches!(redirect_uri.as_deref(), Some("")) {
69                        return Err(ControllerError::new(
70                            ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
71                                error: OAuthErrorCode::InvalidRequest.as_str().into(),
72                                error_description: "redirect_uri must not be empty when provided"
73                                    .into(),
74                                redirect_uri: None,
75                                state: None,
76                                nonce: None,
77                            })),
78                            "Empty redirect_uri",
79                            None::<anyhow::Error>,
80                        ));
81                    }
82                }
83                // PKCE code_verifier is verified at the token handler (if the code had a challenge)
84                grant
85            }
86            Some(grant @ TokenGrant::RefreshToken { .. }) => {
87                if let TokenGrant::RefreshToken { refresh_token, .. } = &grant
88                    && refresh_token.is_empty()
89                {
90                    return Err(ControllerError::new(
91                        ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
92                            error: OAuthErrorCode::InvalidRequest.as_str().into(),
93                            error_description: "refresh_token is required".into(),
94                            redirect_uri: None,
95                            state: None,
96                            nonce: None,
97                        })),
98                        "Missing refresh token",
99                        None::<anyhow::Error>,
100                    ));
101                }
102                grant
103            }
104            Some(TokenGrant::Unknown) => {
105                return Err(ControllerError::new(
106                    ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
107                        error: OAuthErrorCode::UnsupportedGrantType.as_str().into(),
108                        error_description: "unsupported grant type".into(),
109                        redirect_uri: None,
110                        state: None,
111                        nonce: None,
112                    })),
113                    "Unsupported grant type",
114                    None::<anyhow::Error>,
115                ));
116            }
117            None => {
118                return Err(ControllerError::new(
119                    ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
120                        error: OAuthErrorCode::InvalidRequest.as_str().into(),
121                        error_description: "grant_type is required".into(),
122                        redirect_uri: None,
123                        state: None,
124                        nonce: None,
125                    })),
126                    "Missing grant type",
127                    None::<anyhow::Error>,
128                ));
129            }
130        };
131
132        Ok(TokenParams {
133            client_id: client_id.to_string(),
134            client_secret: self.client_secret.clone(), // may be None for public clients
135            grant,
136        })
137    }
138}
139
140#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
141#[serde(tag = "grant_type", rename_all = "snake_case")]
142pub enum TokenGrant {
143    AuthorizationCode {
144        code: String,
145        /// Optional per RFC 6749 §4.1.3 (required if it was present in the authorization request)
146        redirect_uri: Option<String>,
147        /// Optional; enforced later if the code stored a challenge
148        code_verifier: Option<String>,
149    },
150    RefreshToken {
151        refresh_token: String,
152        /// Optional down-scope
153        #[serde(default)]
154        scope: Option<String>,
155    },
156    #[serde(other)]
157    Unknown,
158}
159
160impl TokenGrant {
161    pub fn kind(&self) -> GrantTypeName {
162        match self {
163            TokenGrant::AuthorizationCode { .. } => GrantTypeName::AuthorizationCode,
164            TokenGrant::RefreshToken { .. } => GrantTypeName::RefreshToken,
165            TokenGrant::Unknown => {
166                unreachable!(
167                    "Unknown grant type should be caught by validate() before kind() is called"
168                )
169            }
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
178    use serde_json::{Value, json};
179
180    fn assert_oauth_error(
181        result: Result<TokenParams, ControllerError>,
182        expected_error: OAuthErrorCode,
183        expected_description: &str,
184    ) {
185        match result {
186            Err(err) => match err.error_type() {
187                ControllerErrorType::OAuthError(data) => {
188                    assert_eq!(data.error, expected_error.as_str());
189                    assert_eq!(data.error_description, expected_description);
190                }
191                other => panic!("Expected OAuthError, got {:?}", other),
192            },
193            Ok(_) => panic!("Expected Err, got Ok(())"),
194        }
195    }
196
197    #[test]
198    fn token_missing_client_id() {
199        let q = TokenQuery {
200            client_id: None,
201            client_secret: None,
202            grant: None,
203            _extra: Default::default(),
204        };
205        let res = q.validate();
206        assert_oauth_error(res, OAuthErrorCode::InvalidClient, "client_id is required");
207    }
208
209    #[test]
210    fn token_public_client_without_secret_is_ok() {
211        let q = TokenQuery {
212            client_id: Some("cid".into()),
213            client_secret: None,
214            grant: Some(TokenGrant::RefreshToken {
215                refresh_token: "rt".into(),
216                scope: None,
217            }),
218            _extra: Default::default(),
219        };
220        assert!(q.validate().is_ok());
221    }
222
223    #[test]
224    fn token_missing_grant_type() {
225        let q = TokenQuery {
226            client_id: Some("cid".into()),
227            client_secret: Some("sec".into()),
228            grant: None,
229            _extra: Default::default(),
230        };
231        let res = q.validate();
232        assert_oauth_error(
233            res,
234            OAuthErrorCode::InvalidRequest,
235            "grant_type is required",
236        );
237    }
238
239    #[test]
240    fn token_auth_code_missing_code() {
241        let q = TokenQuery {
242            client_id: Some("cid".into()),
243            client_secret: Some("sec".into()),
244            grant: Some(TokenGrant::AuthorizationCode {
245                code: "".into(),
246                redirect_uri: Some("http://localhost".into()),
247                code_verifier: None,
248            }),
249            _extra: Default::default(),
250        };
251        let res = q.validate();
252        assert_oauth_error(
253            res,
254            OAuthErrorCode::InvalidRequest,
255            "code is required for authorization_code grant",
256        );
257    }
258
259    #[test]
260    fn token_auth_code_empty_redirect_uri_is_invalid() {
261        let q = TokenQuery {
262            client_id: Some("cid".into()),
263            client_secret: Some("sec".into()),
264            grant: Some(TokenGrant::AuthorizationCode {
265                code: "C".into(),
266                redirect_uri: Some("".into()),
267                code_verifier: None,
268            }),
269            _extra: Default::default(),
270        };
271        let res = q.validate();
272        assert_oauth_error(
273            res,
274            OAuthErrorCode::InvalidRequest,
275            "redirect_uri must not be empty when provided",
276        );
277    }
278
279    #[test]
280    fn token_auth_code_minimal_ok_without_redirect_uri_or_pkce() {
281        // Allowed by validator; actual PKCE/redirect checks happen in handler.
282        let q = TokenQuery {
283            client_id: Some("cid".into()),
284            client_secret: Some("sec".into()),
285            grant: Some(TokenGrant::AuthorizationCode {
286                code: "C".into(),
287                redirect_uri: None,
288                code_verifier: None,
289            }),
290            _extra: Default::default(),
291        };
292        assert!(q.validate().is_ok());
293    }
294
295    #[test]
296    fn token_auth_code_with_pkce_ok() {
297        let q = TokenQuery {
298            client_id: Some("cid".into()),
299            client_secret: Some("sec".into()),
300            grant: Some(TokenGrant::AuthorizationCode {
301                code: "C".into(),
302                redirect_uri: Some("http://localhost".into()),
303                code_verifier: Some("verifier".into()),
304            }),
305            _extra: Default::default(),
306        };
307        assert!(q.validate().is_ok());
308    }
309
310    #[test]
311    fn token_refresh_missing_field() {
312        let q = TokenQuery {
313            client_id: Some("cid".into()),
314            client_secret: Some("sec".into()),
315            grant: Some(TokenGrant::RefreshToken {
316                refresh_token: "".into(),
317                scope: None,
318            }),
319            _extra: Default::default(),
320        };
321        let res = q.validate();
322        assert_oauth_error(
323            res,
324            OAuthErrorCode::InvalidRequest,
325            "refresh_token is required",
326        );
327    }
328
329    #[test]
330    fn token_valid_auth_code() {
331        let q = TokenQuery {
332            client_id: Some("cid".into()),
333            client_secret: Some("sec".into()),
334            grant: Some(TokenGrant::AuthorizationCode {
335                code: "abc".into(),
336                redirect_uri: Some("http://localhost".into()),
337                code_verifier: None,
338            }),
339            _extra: Default::default(),
340        };
341        assert!(q.validate().is_ok());
342    }
343
344    #[test]
345    fn token_valid_refresh_token() {
346        let q = TokenQuery {
347            client_id: Some("cid".into()),
348            client_secret: Some("sec".into()),
349            grant: Some(TokenGrant::RefreshToken {
350                refresh_token: "r1".into(),
351                scope: None,
352            }),
353            _extra: Default::default(),
354        };
355        assert!(q.validate().is_ok());
356    }
357
358    #[test]
359    fn token_unknown_params_are_captured_in_extra() {
360        let v: Value = json!({
361            "client_id": "cid",
362            "client_secret": "sec",
363            "grant_type": "refresh_token",
364            "refresh_token": "rt",
365            "extra_param": "zzz"
366        });
367        let q: TokenQuery = serde_json::from_value(v).unwrap();
368        assert_eq!(q._extra.get("extra_param").map(String::as_str), Some("zzz"));
369        assert!(q.validate().is_ok());
370    }
371
372    #[test]
373    fn token_grant_tagging_deserializes_properly() {
374        // authorization_code branch
375        let ac: TokenQuery = serde_json::from_value(json!({
376            "client_id": "cid",
377            "client_secret": "sec",
378            "grant_type": "authorization_code",
379            "code": "C",
380            "redirect_uri": "http://localhost",
381            "code_verifier": "ver"
382        }))
383        .unwrap();
384        match ac.grant {
385            Some(TokenGrant::AuthorizationCode {
386                code,
387                redirect_uri,
388                code_verifier,
389            }) => {
390                assert_eq!(code, "C");
391                assert_eq!(redirect_uri.as_deref(), Some("http://localhost"));
392                assert_eq!(code_verifier.as_deref(), Some("ver"));
393            }
394            _ => panic!("expected AuthorizationCode"),
395        }
396
397        // refresh_token branch
398        let rt: TokenQuery = serde_json::from_value(json!({
399            "client_id": "cid",
400            "client_secret": "sec",
401            "grant_type": "refresh_token",
402            "refresh_token": "R",
403            "scope": "read write"
404        }))
405        .unwrap();
406        match rt.grant {
407            Some(TokenGrant::RefreshToken {
408                refresh_token,
409                scope,
410            }) => {
411                assert_eq!(refresh_token, "R");
412                assert_eq!(scope.as_deref(), Some("read write"));
413            }
414            _ => panic!("expected RefreshToken"),
415        }
416    }
417}