1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use models::library::oauth::GrantTypeName;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
9pub struct TokenQuery {
10 pub client_id: Option<String>,
11 pub client_secret: Option<String>, #[serde(flatten)]
13 pub grant: Option<TokenGrant>,
14 #[serde(flatten)]
16 pub _extra: HashMap<String, String>,
17}
18
19#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
20pub struct TokenParams {
21 pub client_id: String,
22 pub client_secret: Option<String>, pub grant: TokenGrant,
24}
25
26impl OAuthValidate for TokenQuery {
27 type Output = TokenParams;
28
29 fn validate(&self) -> Result<Self::Output, ControllerError> {
30 let client_id = self.client_id.as_deref().unwrap_or_default();
31
32 if client_id.is_empty() {
33 return Err(ControllerError::new(
34 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
35 error: OAuthErrorCode::InvalidClient.as_str().into(),
36 error_description: "client_id is required".into(),
37 redirect_uri: None,
38 state: None,
39 nonce: None,
40 })),
41 "Missing client_id",
42 None::<anyhow::Error>,
43 ));
44 }
45
46 match &self.grant {
48 Some(TokenGrant::AuthorizationCode {
49 code,
50 redirect_uri,
51 code_verifier: _,
52 }) => {
53 if code.is_empty() {
54 return Err(ControllerError::new(
55 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
56 error: OAuthErrorCode::InvalidRequest.as_str().into(),
57 error_description: "code is required for authorization_code grant"
58 .into(),
59 redirect_uri: None,
60 state: None,
61 nonce: None,
62 })),
63 "Missing authorization code",
64 None::<anyhow::Error>,
65 ));
66 }
67 if matches!(redirect_uri.as_deref(), Some("")) {
69 return Err(ControllerError::new(
70 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
71 error: OAuthErrorCode::InvalidRequest.as_str().into(),
72 error_description: "redirect_uri must not be empty when provided"
73 .into(),
74 redirect_uri: None,
75 state: None,
76 nonce: None,
77 })),
78 "Empty redirect_uri",
79 None::<anyhow::Error>,
80 ));
81 }
82 }
84 Some(TokenGrant::RefreshToken { refresh_token, .. }) => {
85 if refresh_token.is_empty() {
86 return Err(ControllerError::new(
87 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
88 error: OAuthErrorCode::InvalidRequest.as_str().into(),
89 error_description: "refresh_token is required".into(),
90 redirect_uri: None,
91 state: None,
92 nonce: None,
93 })),
94 "Missing refresh token",
95 None::<anyhow::Error>,
96 ));
97 }
98 }
99 Some(TokenGrant::Unknown) => {
100 return Err(ControllerError::new(
101 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
102 error: OAuthErrorCode::UnsupportedGrantType.as_str().into(),
103 error_description: "unsupported grant type".into(),
104 redirect_uri: None,
105 state: None,
106 nonce: None,
107 })),
108 "Unsupported grant type",
109 None::<anyhow::Error>,
110 ));
111 }
112 None => {
113 return Err(ControllerError::new(
114 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
115 error: OAuthErrorCode::InvalidRequest.as_str().into(),
116 error_description: "grant_type is required".into(),
117 redirect_uri: None,
118 state: None,
119 nonce: None,
120 })),
121 "Missing grant type",
122 None::<anyhow::Error>,
123 ));
124 }
125 }
126
127 Ok(TokenParams {
128 client_id: client_id.to_string(),
129 client_secret: self.client_secret.clone(), grant: self.grant.clone().unwrap(), })
132 }
133}
134
135#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
136#[serde(tag = "grant_type", rename_all = "snake_case")]
137pub enum TokenGrant {
138 AuthorizationCode {
139 code: String,
140 redirect_uri: Option<String>,
142 code_verifier: Option<String>,
144 },
145 RefreshToken {
146 refresh_token: String,
147 #[serde(default)]
149 scope: Option<String>,
150 },
151 #[serde(other)]
152 Unknown,
153}
154
155impl TokenGrant {
156 pub fn kind(&self) -> GrantTypeName {
157 match self {
158 TokenGrant::AuthorizationCode { .. } => GrantTypeName::AuthorizationCode,
159 TokenGrant::RefreshToken { .. } => GrantTypeName::RefreshToken,
160 TokenGrant::Unknown => {
161 unreachable!(
162 "Unknown grant type should be caught by validate() before kind() is called"
163 )
164 }
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
173 use serde_json::{Value, json};
174
175 fn assert_oauth_error(
176 result: Result<TokenParams, ControllerError>,
177 expected_error: OAuthErrorCode,
178 expected_description: &str,
179 ) {
180 match result {
181 Err(err) => match err.error_type() {
182 ControllerErrorType::OAuthError(data) => {
183 assert_eq!(data.error, expected_error.as_str());
184 assert_eq!(data.error_description, expected_description);
185 }
186 other => panic!("Expected OAuthError, got {:?}", other),
187 },
188 Ok(_) => panic!("Expected Err, got Ok(())"),
189 }
190 }
191
192 #[test]
193 fn token_missing_client_id() {
194 let q = TokenQuery {
195 client_id: None,
196 client_secret: None,
197 grant: None,
198 _extra: Default::default(),
199 };
200 let res = q.validate();
201 assert_oauth_error(res, OAuthErrorCode::InvalidClient, "client_id is required");
202 }
203
204 #[test]
205 fn token_public_client_without_secret_is_ok() {
206 let q = TokenQuery {
207 client_id: Some("cid".into()),
208 client_secret: None,
209 grant: Some(TokenGrant::RefreshToken {
210 refresh_token: "rt".into(),
211 scope: None,
212 }),
213 _extra: Default::default(),
214 };
215 assert!(q.validate().is_ok());
216 }
217
218 #[test]
219 fn token_missing_grant_type() {
220 let q = TokenQuery {
221 client_id: Some("cid".into()),
222 client_secret: Some("sec".into()),
223 grant: None,
224 _extra: Default::default(),
225 };
226 let res = q.validate();
227 assert_oauth_error(
228 res,
229 OAuthErrorCode::InvalidRequest,
230 "grant_type is required",
231 );
232 }
233
234 #[test]
235 fn token_auth_code_missing_code() {
236 let q = TokenQuery {
237 client_id: Some("cid".into()),
238 client_secret: Some("sec".into()),
239 grant: Some(TokenGrant::AuthorizationCode {
240 code: "".into(),
241 redirect_uri: Some("http://localhost".into()),
242 code_verifier: None,
243 }),
244 _extra: Default::default(),
245 };
246 let res = q.validate();
247 assert_oauth_error(
248 res,
249 OAuthErrorCode::InvalidRequest,
250 "code is required for authorization_code grant",
251 );
252 }
253
254 #[test]
255 fn token_auth_code_empty_redirect_uri_is_invalid() {
256 let q = TokenQuery {
257 client_id: Some("cid".into()),
258 client_secret: Some("sec".into()),
259 grant: Some(TokenGrant::AuthorizationCode {
260 code: "C".into(),
261 redirect_uri: Some("".into()),
262 code_verifier: None,
263 }),
264 _extra: Default::default(),
265 };
266 let res = q.validate();
267 assert_oauth_error(
268 res,
269 OAuthErrorCode::InvalidRequest,
270 "redirect_uri must not be empty when provided",
271 );
272 }
273
274 #[test]
275 fn token_auth_code_minimal_ok_without_redirect_uri_or_pkce() {
276 let q = TokenQuery {
278 client_id: Some("cid".into()),
279 client_secret: Some("sec".into()),
280 grant: Some(TokenGrant::AuthorizationCode {
281 code: "C".into(),
282 redirect_uri: None,
283 code_verifier: None,
284 }),
285 _extra: Default::default(),
286 };
287 assert!(q.validate().is_ok());
288 }
289
290 #[test]
291 fn token_auth_code_with_pkce_ok() {
292 let q = TokenQuery {
293 client_id: Some("cid".into()),
294 client_secret: Some("sec".into()),
295 grant: Some(TokenGrant::AuthorizationCode {
296 code: "C".into(),
297 redirect_uri: Some("http://localhost".into()),
298 code_verifier: Some("verifier".into()),
299 }),
300 _extra: Default::default(),
301 };
302 assert!(q.validate().is_ok());
303 }
304
305 #[test]
306 fn token_refresh_missing_field() {
307 let q = TokenQuery {
308 client_id: Some("cid".into()),
309 client_secret: Some("sec".into()),
310 grant: Some(TokenGrant::RefreshToken {
311 refresh_token: "".into(),
312 scope: None,
313 }),
314 _extra: Default::default(),
315 };
316 let res = q.validate();
317 assert_oauth_error(
318 res,
319 OAuthErrorCode::InvalidRequest,
320 "refresh_token is required",
321 );
322 }
323
324 #[test]
325 fn token_valid_auth_code() {
326 let q = TokenQuery {
327 client_id: Some("cid".into()),
328 client_secret: Some("sec".into()),
329 grant: Some(TokenGrant::AuthorizationCode {
330 code: "abc".into(),
331 redirect_uri: Some("http://localhost".into()),
332 code_verifier: None,
333 }),
334 _extra: Default::default(),
335 };
336 assert!(q.validate().is_ok());
337 }
338
339 #[test]
340 fn token_valid_refresh_token() {
341 let q = TokenQuery {
342 client_id: Some("cid".into()),
343 client_secret: Some("sec".into()),
344 grant: Some(TokenGrant::RefreshToken {
345 refresh_token: "r1".into(),
346 scope: None,
347 }),
348 _extra: Default::default(),
349 };
350 assert!(q.validate().is_ok());
351 }
352
353 #[test]
354 fn token_unknown_params_are_captured_in_extra() {
355 let v: Value = json!({
356 "client_id": "cid",
357 "client_secret": "sec",
358 "grant_type": "refresh_token",
359 "refresh_token": "rt",
360 "extra_param": "zzz"
361 });
362 let q: TokenQuery = serde_json::from_value(v).unwrap();
363 assert_eq!(q._extra.get("extra_param").map(String::as_str), Some("zzz"));
364 assert!(q.validate().is_ok());
365 }
366
367 #[test]
368 fn token_grant_tagging_deserializes_properly() {
369 let ac: TokenQuery = serde_json::from_value(json!({
371 "client_id": "cid",
372 "client_secret": "sec",
373 "grant_type": "authorization_code",
374 "code": "C",
375 "redirect_uri": "http://localhost",
376 "code_verifier": "ver"
377 }))
378 .unwrap();
379 match ac.grant {
380 Some(TokenGrant::AuthorizationCode {
381 code,
382 redirect_uri,
383 code_verifier,
384 }) => {
385 assert_eq!(code, "C");
386 assert_eq!(redirect_uri.as_deref(), Some("http://localhost"));
387 assert_eq!(code_verifier.as_deref(), Some("ver"));
388 }
389 _ => panic!("expected AuthorizationCode"),
390 }
391
392 let rt: TokenQuery = serde_json::from_value(json!({
394 "client_id": "cid",
395 "client_secret": "sec",
396 "grant_type": "refresh_token",
397 "refresh_token": "R",
398 "scope": "read write"
399 }))
400 .unwrap();
401 match rt.grant {
402 Some(TokenGrant::RefreshToken {
403 refresh_token,
404 scope,
405 }) => {
406 assert_eq!(refresh_token, "R");
407 assert_eq!(scope.as_deref(), Some("read write"));
408 }
409 _ => panic!("expected RefreshToken"),
410 }
411 }
412}