1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use models::library::oauth::GrantTypeName;
5use secrecy::{ExposeSecret, SecretString};
6use serde::Deserialize;
7use std::collections::HashMap;
8
9#[derive(Debug, Deserialize, Clone, Default)]
10pub struct TokenQuery {
11 pub client_id: Option<String>,
12 pub client_secret: Option<SecretString>, #[serde(flatten)]
14 pub grant: Option<TokenGrant>,
15 #[serde(flatten)]
17 pub _extra: HashMap<String, String>,
18}
19
20#[derive(Debug, Clone)]
21pub struct TokenParams {
22 pub client_id: String,
23 pub client_secret: Option<SecretString>, pub grant: TokenGrant,
25}
26
27impl OAuthValidate for TokenQuery {
28 type Output = TokenParams;
29
30 fn validate(&self) -> Result<Self::Output, ControllerError> {
31 let client_id = self.client_id.as_deref().unwrap_or_default();
32
33 if client_id.is_empty() {
34 return Err(ControllerError::new(
35 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
36 error: OAuthErrorCode::InvalidClient.as_str().into(),
37 error_description: "client_id is required".into(),
38 redirect_uri: None,
39 state: None,
40 nonce: None,
41 })),
42 "Missing client_id",
43 None::<anyhow::Error>,
44 ));
45 }
46
47 let grant = match self.grant.clone() {
49 Some(grant @ TokenGrant::AuthorizationCode { .. }) => {
50 if let TokenGrant::AuthorizationCode {
51 code, redirect_uri, ..
52 } = &grant
53 {
54 if code.expose_secret().is_empty() {
55 return Err(ControllerError::new(
56 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
57 error: OAuthErrorCode::InvalidRequest.as_str().into(),
58 error_description: "code is required for authorization_code grant"
59 .into(),
60 redirect_uri: None,
61 state: None,
62 nonce: None,
63 })),
64 "Missing authorization code",
65 None::<anyhow::Error>,
66 ));
67 }
68 if matches!(redirect_uri.as_deref(), Some("")) {
70 return Err(ControllerError::new(
71 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
72 error: OAuthErrorCode::InvalidRequest.as_str().into(),
73 error_description: "redirect_uri must not be empty when provided"
74 .into(),
75 redirect_uri: None,
76 state: None,
77 nonce: None,
78 })),
79 "Empty redirect_uri",
80 None::<anyhow::Error>,
81 ));
82 }
83 }
84 grant
86 }
87 Some(grant @ TokenGrant::RefreshToken { .. }) => {
88 if let TokenGrant::RefreshToken { refresh_token, .. } = &grant
89 && refresh_token.expose_secret().is_empty()
90 {
91 return Err(ControllerError::new(
92 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
93 error: OAuthErrorCode::InvalidRequest.as_str().into(),
94 error_description: "refresh_token is required".into(),
95 redirect_uri: None,
96 state: None,
97 nonce: None,
98 })),
99 "Missing refresh token",
100 None::<anyhow::Error>,
101 ));
102 }
103 grant
104 }
105 Some(TokenGrant::Unknown) => {
106 return Err(ControllerError::new(
107 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
108 error: OAuthErrorCode::UnsupportedGrantType.as_str().into(),
109 error_description: "unsupported grant type".into(),
110 redirect_uri: None,
111 state: None,
112 nonce: None,
113 })),
114 "Unsupported grant type",
115 None::<anyhow::Error>,
116 ));
117 }
118 None => {
119 return Err(ControllerError::new(
120 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
121 error: OAuthErrorCode::InvalidRequest.as_str().into(),
122 error_description: "grant_type is required".into(),
123 redirect_uri: None,
124 state: None,
125 nonce: None,
126 })),
127 "Missing grant type",
128 None::<anyhow::Error>,
129 ));
130 }
131 };
132
133 Ok(TokenParams {
134 client_id: client_id.to_string(),
135 client_secret: self.client_secret.clone(), grant,
137 })
138 }
139}
140
141#[derive(Debug, Deserialize, Clone)]
142#[serde(tag = "grant_type", rename_all = "snake_case")]
143pub enum TokenGrant {
144 AuthorizationCode {
145 code: SecretString,
146 redirect_uri: Option<String>,
148 code_verifier: Option<SecretString>,
150 },
151 RefreshToken {
152 refresh_token: SecretString,
153 #[serde(default)]
155 scope: Option<String>,
156 },
157 #[serde(other)]
158 Unknown,
159}
160
161impl TokenGrant {
162 pub fn kind(&self) -> GrantTypeName {
163 match self {
164 TokenGrant::AuthorizationCode { .. } => GrantTypeName::AuthorizationCode,
165 TokenGrant::RefreshToken { .. } => GrantTypeName::RefreshToken,
166 TokenGrant::Unknown => {
167 unreachable!(
168 "Unknown grant type should be caught by validate() before kind() is called"
169 )
170 }
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
179 use serde_json::{Value, json};
180
181 fn assert_oauth_error(
182 result: Result<TokenParams, ControllerError>,
183 expected_error: OAuthErrorCode,
184 expected_description: &str,
185 ) {
186 match result {
187 Err(err) => match err.error_type() {
188 ControllerErrorType::OAuthError(data) => {
189 assert_eq!(data.error, expected_error.as_str());
190 assert_eq!(data.error_description, expected_description);
191 }
192 other => panic!("Expected OAuthError, got {:?}", other),
193 },
194 Ok(_) => panic!("Expected Err, got Ok(())"),
195 }
196 }
197
198 #[test]
199 fn token_missing_client_id() {
200 let q = TokenQuery {
201 client_id: None,
202 client_secret: None,
203 grant: None,
204 _extra: Default::default(),
205 };
206 let res = q.validate();
207 assert_oauth_error(res, OAuthErrorCode::InvalidClient, "client_id is required");
208 }
209
210 #[test]
211 fn token_public_client_without_secret_is_ok() {
212 let q = TokenQuery {
213 client_id: Some("cid".into()),
214 client_secret: None,
215 grant: Some(TokenGrant::RefreshToken {
216 refresh_token: "rt".into(),
217 scope: None,
218 }),
219 _extra: Default::default(),
220 };
221 assert!(q.validate().is_ok());
222 }
223
224 #[test]
225 fn token_missing_grant_type() {
226 let q = TokenQuery {
227 client_id: Some("cid".into()),
228 client_secret: Some("sec".into()),
229 grant: None,
230 _extra: Default::default(),
231 };
232 let res = q.validate();
233 assert_oauth_error(
234 res,
235 OAuthErrorCode::InvalidRequest,
236 "grant_type is required",
237 );
238 }
239
240 #[test]
241 fn token_auth_code_missing_code() {
242 let q = TokenQuery {
243 client_id: Some("cid".into()),
244 client_secret: Some("sec".into()),
245 grant: Some(TokenGrant::AuthorizationCode {
246 code: "".into(),
247 redirect_uri: Some("http://localhost".into()),
248 code_verifier: None,
249 }),
250 _extra: Default::default(),
251 };
252 let res = q.validate();
253 assert_oauth_error(
254 res,
255 OAuthErrorCode::InvalidRequest,
256 "code is required for authorization_code grant",
257 );
258 }
259
260 #[test]
261 fn token_auth_code_empty_redirect_uri_is_invalid() {
262 let q = TokenQuery {
263 client_id: Some("cid".into()),
264 client_secret: Some("sec".into()),
265 grant: Some(TokenGrant::AuthorizationCode {
266 code: "C".into(),
267 redirect_uri: Some("".into()),
268 code_verifier: None,
269 }),
270 _extra: Default::default(),
271 };
272 let res = q.validate();
273 assert_oauth_error(
274 res,
275 OAuthErrorCode::InvalidRequest,
276 "redirect_uri must not be empty when provided",
277 );
278 }
279
280 #[test]
281 fn token_auth_code_minimal_ok_without_redirect_uri_or_pkce() {
282 let q = TokenQuery {
284 client_id: Some("cid".into()),
285 client_secret: Some("sec".into()),
286 grant: Some(TokenGrant::AuthorizationCode {
287 code: "C".into(),
288 redirect_uri: None,
289 code_verifier: None,
290 }),
291 _extra: Default::default(),
292 };
293 assert!(q.validate().is_ok());
294 }
295
296 #[test]
297 fn token_auth_code_with_pkce_ok() {
298 let q = TokenQuery {
299 client_id: Some("cid".into()),
300 client_secret: Some("sec".into()),
301 grant: Some(TokenGrant::AuthorizationCode {
302 code: "C".into(),
303 redirect_uri: Some("http://localhost".into()),
304 code_verifier: Some("verifier".into()),
305 }),
306 _extra: Default::default(),
307 };
308 assert!(q.validate().is_ok());
309 }
310
311 #[test]
312 fn token_refresh_missing_field() {
313 let q = TokenQuery {
314 client_id: Some("cid".into()),
315 client_secret: Some("sec".into()),
316 grant: Some(TokenGrant::RefreshToken {
317 refresh_token: "".into(),
318 scope: None,
319 }),
320 _extra: Default::default(),
321 };
322 let res = q.validate();
323 assert_oauth_error(
324 res,
325 OAuthErrorCode::InvalidRequest,
326 "refresh_token is required",
327 );
328 }
329
330 #[test]
331 fn token_valid_auth_code() {
332 let q = TokenQuery {
333 client_id: Some("cid".into()),
334 client_secret: Some("sec".into()),
335 grant: Some(TokenGrant::AuthorizationCode {
336 code: "abc".into(),
337 redirect_uri: Some("http://localhost".into()),
338 code_verifier: None,
339 }),
340 _extra: Default::default(),
341 };
342 assert!(q.validate().is_ok());
343 }
344
345 #[test]
346 fn token_valid_refresh_token() {
347 let q = TokenQuery {
348 client_id: Some("cid".into()),
349 client_secret: Some("sec".into()),
350 grant: Some(TokenGrant::RefreshToken {
351 refresh_token: "r1".into(),
352 scope: None,
353 }),
354 _extra: Default::default(),
355 };
356 assert!(q.validate().is_ok());
357 }
358
359 #[test]
360 fn token_unknown_params_are_captured_in_extra() {
361 let v: Value = json!({
362 "client_id": "cid",
363 "client_secret": "sec",
364 "grant_type": "refresh_token",
365 "refresh_token": "rt",
366 "extra_param": "zzz"
367 });
368 let q: TokenQuery = serde_json::from_value(v).unwrap();
369 assert_eq!(q._extra.get("extra_param").map(String::as_str), Some("zzz"));
370 assert!(q.validate().is_ok());
371 }
372
373 #[test]
374 fn token_grant_tagging_deserializes_properly() {
375 let ac: TokenQuery = serde_json::from_value(json!({
377 "client_id": "cid",
378 "client_secret": "sec",
379 "grant_type": "authorization_code",
380 "code": "C",
381 "redirect_uri": "http://localhost",
382 "code_verifier": "ver"
383 }))
384 .unwrap();
385 match ac.grant {
386 Some(TokenGrant::AuthorizationCode {
387 code,
388 redirect_uri,
389 code_verifier,
390 }) => {
391 assert_eq!(code.expose_secret(), "C");
392 assert_eq!(redirect_uri.as_deref(), Some("http://localhost"));
393 assert_eq!(
394 code_verifier.as_ref().map(|v| v.expose_secret()),
395 Some("ver")
396 );
397 }
398 _ => panic!("expected AuthorizationCode"),
399 }
400
401 let rt: TokenQuery = serde_json::from_value(json!({
403 "client_id": "cid",
404 "client_secret": "sec",
405 "grant_type": "refresh_token",
406 "refresh_token": "R",
407 "scope": "read write"
408 }))
409 .unwrap();
410 match rt.grant {
411 Some(TokenGrant::RefreshToken {
412 refresh_token,
413 scope,
414 }) => {
415 assert_eq!(refresh_token.expose_secret(), "R");
416 assert_eq!(scope.as_deref(), Some("read write"));
417 }
418 _ => panic!("expected RefreshToken"),
419 }
420 }
421}