1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use models::library::oauth::GrantTypeName;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
9pub struct TokenQuery {
10 pub client_id: Option<String>,
11 pub client_secret: Option<String>, #[serde(flatten)]
13 pub grant: Option<TokenGrant>,
14 #[serde(flatten)]
16 pub _extra: HashMap<String, String>,
17}
18
19#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
20pub struct TokenParams {
21 pub client_id: String,
22 pub client_secret: Option<String>, pub grant: TokenGrant,
24}
25
26impl OAuthValidate for TokenQuery {
27 type Output = TokenParams;
28
29 fn validate(&self) -> Result<Self::Output, ControllerError> {
30 let client_id = self.client_id.as_deref().unwrap_or_default();
31
32 if client_id.is_empty() {
33 return Err(ControllerError::new(
34 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
35 error: OAuthErrorCode::InvalidClient.as_str().into(),
36 error_description: "client_id is required".into(),
37 redirect_uri: None,
38 state: None,
39 nonce: None,
40 })),
41 "Missing client_id",
42 None::<anyhow::Error>,
43 ));
44 }
45
46 let grant = match self.grant.clone() {
48 Some(grant @ TokenGrant::AuthorizationCode { .. }) => {
49 if let TokenGrant::AuthorizationCode {
50 code, redirect_uri, ..
51 } = &grant
52 {
53 if code.is_empty() {
54 return Err(ControllerError::new(
55 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
56 error: OAuthErrorCode::InvalidRequest.as_str().into(),
57 error_description: "code is required for authorization_code grant"
58 .into(),
59 redirect_uri: None,
60 state: None,
61 nonce: None,
62 })),
63 "Missing authorization code",
64 None::<anyhow::Error>,
65 ));
66 }
67 if matches!(redirect_uri.as_deref(), Some("")) {
69 return Err(ControllerError::new(
70 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
71 error: OAuthErrorCode::InvalidRequest.as_str().into(),
72 error_description: "redirect_uri must not be empty when provided"
73 .into(),
74 redirect_uri: None,
75 state: None,
76 nonce: None,
77 })),
78 "Empty redirect_uri",
79 None::<anyhow::Error>,
80 ));
81 }
82 }
83 grant
85 }
86 Some(grant @ TokenGrant::RefreshToken { .. }) => {
87 if let TokenGrant::RefreshToken { refresh_token, .. } = &grant
88 && refresh_token.is_empty()
89 {
90 return Err(ControllerError::new(
91 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
92 error: OAuthErrorCode::InvalidRequest.as_str().into(),
93 error_description: "refresh_token is required".into(),
94 redirect_uri: None,
95 state: None,
96 nonce: None,
97 })),
98 "Missing refresh token",
99 None::<anyhow::Error>,
100 ));
101 }
102 grant
103 }
104 Some(TokenGrant::Unknown) => {
105 return Err(ControllerError::new(
106 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
107 error: OAuthErrorCode::UnsupportedGrantType.as_str().into(),
108 error_description: "unsupported grant type".into(),
109 redirect_uri: None,
110 state: None,
111 nonce: None,
112 })),
113 "Unsupported grant type",
114 None::<anyhow::Error>,
115 ));
116 }
117 None => {
118 return Err(ControllerError::new(
119 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
120 error: OAuthErrorCode::InvalidRequest.as_str().into(),
121 error_description: "grant_type is required".into(),
122 redirect_uri: None,
123 state: None,
124 nonce: None,
125 })),
126 "Missing grant type",
127 None::<anyhow::Error>,
128 ));
129 }
130 };
131
132 Ok(TokenParams {
133 client_id: client_id.to_string(),
134 client_secret: self.client_secret.clone(), grant,
136 })
137 }
138}
139
140#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
141#[serde(tag = "grant_type", rename_all = "snake_case")]
142pub enum TokenGrant {
143 AuthorizationCode {
144 code: String,
145 redirect_uri: Option<String>,
147 code_verifier: Option<String>,
149 },
150 RefreshToken {
151 refresh_token: String,
152 #[serde(default)]
154 scope: Option<String>,
155 },
156 #[serde(other)]
157 Unknown,
158}
159
160impl TokenGrant {
161 pub fn kind(&self) -> GrantTypeName {
162 match self {
163 TokenGrant::AuthorizationCode { .. } => GrantTypeName::AuthorizationCode,
164 TokenGrant::RefreshToken { .. } => GrantTypeName::RefreshToken,
165 TokenGrant::Unknown => {
166 unreachable!(
167 "Unknown grant type should be caught by validate() before kind() is called"
168 )
169 }
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
178 use serde_json::{Value, json};
179
180 fn assert_oauth_error(
181 result: Result<TokenParams, ControllerError>,
182 expected_error: OAuthErrorCode,
183 expected_description: &str,
184 ) {
185 match result {
186 Err(err) => match err.error_type() {
187 ControllerErrorType::OAuthError(data) => {
188 assert_eq!(data.error, expected_error.as_str());
189 assert_eq!(data.error_description, expected_description);
190 }
191 other => panic!("Expected OAuthError, got {:?}", other),
192 },
193 Ok(_) => panic!("Expected Err, got Ok(())"),
194 }
195 }
196
197 #[test]
198 fn token_missing_client_id() {
199 let q = TokenQuery {
200 client_id: None,
201 client_secret: None,
202 grant: None,
203 _extra: Default::default(),
204 };
205 let res = q.validate();
206 assert_oauth_error(res, OAuthErrorCode::InvalidClient, "client_id is required");
207 }
208
209 #[test]
210 fn token_public_client_without_secret_is_ok() {
211 let q = TokenQuery {
212 client_id: Some("cid".into()),
213 client_secret: None,
214 grant: Some(TokenGrant::RefreshToken {
215 refresh_token: "rt".into(),
216 scope: None,
217 }),
218 _extra: Default::default(),
219 };
220 assert!(q.validate().is_ok());
221 }
222
223 #[test]
224 fn token_missing_grant_type() {
225 let q = TokenQuery {
226 client_id: Some("cid".into()),
227 client_secret: Some("sec".into()),
228 grant: None,
229 _extra: Default::default(),
230 };
231 let res = q.validate();
232 assert_oauth_error(
233 res,
234 OAuthErrorCode::InvalidRequest,
235 "grant_type is required",
236 );
237 }
238
239 #[test]
240 fn token_auth_code_missing_code() {
241 let q = TokenQuery {
242 client_id: Some("cid".into()),
243 client_secret: Some("sec".into()),
244 grant: Some(TokenGrant::AuthorizationCode {
245 code: "".into(),
246 redirect_uri: Some("http://localhost".into()),
247 code_verifier: None,
248 }),
249 _extra: Default::default(),
250 };
251 let res = q.validate();
252 assert_oauth_error(
253 res,
254 OAuthErrorCode::InvalidRequest,
255 "code is required for authorization_code grant",
256 );
257 }
258
259 #[test]
260 fn token_auth_code_empty_redirect_uri_is_invalid() {
261 let q = TokenQuery {
262 client_id: Some("cid".into()),
263 client_secret: Some("sec".into()),
264 grant: Some(TokenGrant::AuthorizationCode {
265 code: "C".into(),
266 redirect_uri: Some("".into()),
267 code_verifier: None,
268 }),
269 _extra: Default::default(),
270 };
271 let res = q.validate();
272 assert_oauth_error(
273 res,
274 OAuthErrorCode::InvalidRequest,
275 "redirect_uri must not be empty when provided",
276 );
277 }
278
279 #[test]
280 fn token_auth_code_minimal_ok_without_redirect_uri_or_pkce() {
281 let q = TokenQuery {
283 client_id: Some("cid".into()),
284 client_secret: Some("sec".into()),
285 grant: Some(TokenGrant::AuthorizationCode {
286 code: "C".into(),
287 redirect_uri: None,
288 code_verifier: None,
289 }),
290 _extra: Default::default(),
291 };
292 assert!(q.validate().is_ok());
293 }
294
295 #[test]
296 fn token_auth_code_with_pkce_ok() {
297 let q = TokenQuery {
298 client_id: Some("cid".into()),
299 client_secret: Some("sec".into()),
300 grant: Some(TokenGrant::AuthorizationCode {
301 code: "C".into(),
302 redirect_uri: Some("http://localhost".into()),
303 code_verifier: Some("verifier".into()),
304 }),
305 _extra: Default::default(),
306 };
307 assert!(q.validate().is_ok());
308 }
309
310 #[test]
311 fn token_refresh_missing_field() {
312 let q = TokenQuery {
313 client_id: Some("cid".into()),
314 client_secret: Some("sec".into()),
315 grant: Some(TokenGrant::RefreshToken {
316 refresh_token: "".into(),
317 scope: None,
318 }),
319 _extra: Default::default(),
320 };
321 let res = q.validate();
322 assert_oauth_error(
323 res,
324 OAuthErrorCode::InvalidRequest,
325 "refresh_token is required",
326 );
327 }
328
329 #[test]
330 fn token_valid_auth_code() {
331 let q = TokenQuery {
332 client_id: Some("cid".into()),
333 client_secret: Some("sec".into()),
334 grant: Some(TokenGrant::AuthorizationCode {
335 code: "abc".into(),
336 redirect_uri: Some("http://localhost".into()),
337 code_verifier: None,
338 }),
339 _extra: Default::default(),
340 };
341 assert!(q.validate().is_ok());
342 }
343
344 #[test]
345 fn token_valid_refresh_token() {
346 let q = TokenQuery {
347 client_id: Some("cid".into()),
348 client_secret: Some("sec".into()),
349 grant: Some(TokenGrant::RefreshToken {
350 refresh_token: "r1".into(),
351 scope: None,
352 }),
353 _extra: Default::default(),
354 };
355 assert!(q.validate().is_ok());
356 }
357
358 #[test]
359 fn token_unknown_params_are_captured_in_extra() {
360 let v: Value = json!({
361 "client_id": "cid",
362 "client_secret": "sec",
363 "grant_type": "refresh_token",
364 "refresh_token": "rt",
365 "extra_param": "zzz"
366 });
367 let q: TokenQuery = serde_json::from_value(v).unwrap();
368 assert_eq!(q._extra.get("extra_param").map(String::as_str), Some("zzz"));
369 assert!(q.validate().is_ok());
370 }
371
372 #[test]
373 fn token_grant_tagging_deserializes_properly() {
374 let ac: TokenQuery = serde_json::from_value(json!({
376 "client_id": "cid",
377 "client_secret": "sec",
378 "grant_type": "authorization_code",
379 "code": "C",
380 "redirect_uri": "http://localhost",
381 "code_verifier": "ver"
382 }))
383 .unwrap();
384 match ac.grant {
385 Some(TokenGrant::AuthorizationCode {
386 code,
387 redirect_uri,
388 code_verifier,
389 }) => {
390 assert_eq!(code, "C");
391 assert_eq!(redirect_uri.as_deref(), Some("http://localhost"));
392 assert_eq!(code_verifier.as_deref(), Some("ver"));
393 }
394 _ => panic!("expected AuthorizationCode"),
395 }
396
397 let rt: TokenQuery = serde_json::from_value(json!({
399 "client_id": "cid",
400 "client_secret": "sec",
401 "grant_type": "refresh_token",
402 "refresh_token": "R",
403 "scope": "read write"
404 }))
405 .unwrap();
406 match rt.grant {
407 Some(TokenGrant::RefreshToken {
408 refresh_token,
409 scope,
410 }) => {
411 assert_eq!(refresh_token, "R");
412 assert_eq!(scope.as_deref(), Some("read write"));
413 }
414 _ => panic!("expected RefreshToken"),
415 }
416 }
417}