1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use std::collections::HashMap;
5
6#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
7pub struct AuthorizeQuery {
8 pub response_type: Option<String>,
9 pub client_id: Option<String>,
10 pub redirect_uri: Option<String>,
11 pub scope: Option<String>,
12 pub state: Option<String>,
13 pub nonce: Option<String>,
14 pub code_challenge: Option<String>,
15 pub code_challenge_method: Option<String>,
16
17 pub prompt: Option<String>,
18 pub request: Option<String>,
20
21 #[serde(flatten)]
24 pub _extra: HashMap<String, String>,
25}
26
27#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
28pub struct AuthorizeParams {
29 pub response_type: String,
30 pub client_id: String,
31 pub redirect_uri: String,
32 pub scope: String,
33 pub state: Option<String>,
34 pub nonce: Option<String>,
35 pub code_challenge: Option<String>,
36 pub code_challenge_method: Option<String>,
37 pub prompt: Option<String>,
38 pub request: Option<String>,
39}
40
41impl OAuthValidate for AuthorizeQuery {
44 type Output = AuthorizeParams;
45
46 fn validate(&self) -> Result<Self::Output, ControllerError> {
47 let rt_opt = self.response_type.as_deref();
48 let rt = rt_opt.unwrap_or_default();
49
50 let client_id = self.client_id.as_deref().unwrap_or_default();
51 let redirect_uri = self.redirect_uri.as_deref().unwrap_or_default();
52 let scope = self.scope.as_deref().unwrap_or_default();
53
54 let state_opt = self.state.clone();
56
57 if client_id.is_empty() || redirect_uri.is_empty() || scope.is_empty() {
59 return Err(ControllerError::new(
60 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
61 error: OAuthErrorCode::InvalidRequest.as_str().into(),
62 error_description: "client_id, redirect_uri, and scope are required".into(),
63 redirect_uri: None,
64 state: state_opt.clone(),
65 nonce: None,
66 })),
67 "Missing required OAuth parameters",
68 None,
69 ));
70 }
71
72 if rt.is_empty() {
74 return Err(ControllerError::new(
75 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
76 error: OAuthErrorCode::InvalidRequest.as_str().into(),
77 error_description: "response_type is required".into(),
78 redirect_uri: None,
79 state: state_opt.clone(),
80 nonce: None,
81 })),
82 "Missing response_type",
83 None::<anyhow::Error>,
84 ));
85 }
86
87 if rt != "code" {
89 return Err(ControllerError::new(
90 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
91 error: OAuthErrorCode::UnsupportedResponseType.as_str().into(),
92 error_description: "unsupported response_type".into(),
93 redirect_uri: None, state: state_opt,
95 nonce: None,
96 })),
97 "Unsupported response_type",
98 None::<anyhow::Error>,
99 ));
100 }
101
102 Ok(AuthorizeParams {
103 response_type: rt.to_string(),
104 client_id: client_id.to_string(),
105 redirect_uri: redirect_uri.to_string(),
106 scope: scope.to_string(),
107 state: self.state.clone(),
108 nonce: self.nonce.clone(),
109 code_challenge: self.code_challenge.clone(),
110 code_challenge_method: self.code_challenge_method.clone(),
111 prompt: self.prompt.clone(),
112 request: self.request.clone(),
113 })
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
121 use serde_json::{Value, json};
122
123 fn assert_oauth_error(
124 result: Result<AuthorizeParams, ControllerError>,
125 expected_error: OAuthErrorCode,
126 expected_description: &str,
127 ) {
128 match result {
129 Err(err) => match err.error_type() {
130 ControllerErrorType::OAuthError(data) => {
131 assert_eq!(data.error, expected_error.as_str());
132 assert_eq!(data.error_description, expected_description);
133 }
134 other => panic!("Expected OAuthError, got {:?}", other),
135 },
136 Ok(_) => panic!("Expected Err, got Ok(_)"),
137 }
138 }
139
140 #[test]
141 fn authorize_missing_fields() {
142 let q = AuthorizeQuery {
143 response_type: Some("code".into()),
144 client_id: None,
145 redirect_uri: None,
146 scope: None,
147 state: Some("xyz".into()),
148 nonce: None,
149 code_challenge: None,
150 code_challenge_method: None,
151 prompt: None,
152 request: None,
153 _extra: Default::default(),
154 };
155 let res = q.validate();
156 assert_oauth_error(
157 res,
158 OAuthErrorCode::InvalidRequest,
159 "client_id, redirect_uri, and scope are required",
160 );
161 }
162
163 #[test]
164 fn authorize_unsupported_response_type() {
165 let q = AuthorizeQuery {
166 response_type: Some("token".into()),
167 client_id: Some("cid".into()),
168 redirect_uri: Some("http://localhost".into()),
169 scope: Some("openid".into()),
170 state: None,
171 nonce: None,
172 code_challenge: None,
173 code_challenge_method: None,
174 prompt: None,
175 request: None,
176 _extra: Default::default(),
177 };
178 let res = q.validate();
179 assert_oauth_error(
180 res,
181 OAuthErrorCode::UnsupportedResponseType,
182 "unsupported response_type",
183 );
184 }
185
186 #[test]
187 fn authorize_valid_code_flow_openid_without_nonce_is_ok() {
188 let q = AuthorizeQuery {
190 response_type: Some("code".into()),
191 client_id: Some("cid".into()),
192 redirect_uri: Some("http://localhost".into()),
193 scope: Some("openid profile".into()),
194 state: Some("s".into()),
195 nonce: None,
196 code_challenge: None,
197 code_challenge_method: None,
198 prompt: None,
199 request: None,
200 _extra: Default::default(),
201 };
202 assert!(q.validate().is_ok());
203 }
204
205 #[test]
206 fn authorize_unknown_params_are_captured_in_extra() {
207 let v: Value = json!({
208 "response_type": "code",
209 "client_id": "cid",
210 "redirect_uri": "http://localhost",
211 "scope": "openid",
212 "state": "s",
213 "foo": "bar",
214 "x": "y"
215 });
216 let q: AuthorizeQuery = serde_json::from_value(v).unwrap();
217 assert_eq!(q._extra.get("foo").map(String::as_str), Some("bar"));
218 assert_eq!(q._extra.get("x").map(String::as_str), Some("y"));
219 assert!(q.validate().is_ok());
220 }
221
222 #[test]
223 fn authorize_missing_response_type_is_invalid_request() {
224 let q = AuthorizeQuery {
225 response_type: None,
226 client_id: Some("cid".into()),
227 redirect_uri: Some("http://localhost".into()),
228 scope: Some("openid".into()),
229 state: None,
230 nonce: None,
231 code_challenge: None,
232 code_challenge_method: None,
233 prompt: None,
234 request: None,
235 _extra: Default::default(),
236 };
237 let res = q.validate();
238 assert_oauth_error(
239 res,
240 OAuthErrorCode::InvalidRequest,
241 "response_type is required",
242 );
243 }
244
245 #[test]
246 fn authorize_pkce_fields_passthrough() {
247 let q = AuthorizeQuery {
248 response_type: Some("code".into()),
249 client_id: Some("cid".into()),
250 redirect_uri: Some("http://localhost".into()),
251 scope: Some("openid profile".into()),
252 state: Some("s".into()),
253 nonce: Some("n".into()),
254 code_challenge: Some("abcDEF123-_".into()),
255 code_challenge_method: Some("S256".into()),
256 prompt: None,
257 request: None,
258 _extra: Default::default(),
259 };
260 let p = q.validate().expect("validate should pass");
261 assert_eq!(p.code_challenge.as_deref(), Some("abcDEF123-_"));
262 assert_eq!(p.code_challenge_method.as_deref(), Some("S256"));
263 }
264}