headless_lms_server/domain/oauth/
authorize_query.rs1use super::oauth_validate::OAuthValidate;
2use crate::prelude::*;
3use domain::error::{OAuthErrorCode, OAuthErrorData};
4use std::collections::HashMap;
5
6#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
7pub struct AuthorizeQuery {
8 pub response_type: Option<String>,
9 pub client_id: Option<String>,
10 pub redirect_uri: Option<String>,
11 pub scope: Option<String>,
12 pub state: Option<String>,
13 pub nonce: Option<String>,
14 pub code_challenge: Option<String>,
15 pub code_challenge_method: Option<String>,
16
17 #[serde(flatten)]
20 pub _extra: HashMap<String, String>,
21}
22
23#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default)]
24pub struct AuthorizeParams {
25 pub response_type: String,
26 pub client_id: String,
27 pub redirect_uri: String,
28 pub scope: String,
29 pub state: Option<String>,
30 pub nonce: Option<String>,
31 pub code_challenge: Option<String>,
32 pub code_challenge_method: Option<String>,
33}
34
35impl OAuthValidate for AuthorizeQuery {
38 type Output = AuthorizeParams;
39
40 fn validate(&self) -> Result<Self::Output, ControllerError> {
41 let rt_opt = self.response_type.as_deref();
42 let rt = rt_opt.unwrap_or_default();
43
44 let client_id = self.client_id.as_deref().unwrap_or_default();
45 let redirect_uri = self.redirect_uri.as_deref().unwrap_or_default();
46 let scope = self.scope.as_deref().unwrap_or_default();
47
48 let state_opt = self.state.clone();
50
51 if client_id.is_empty() || redirect_uri.is_empty() || scope.is_empty() {
53 return Err(ControllerError::new(
54 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
55 error: OAuthErrorCode::InvalidRequest.as_str().into(),
56 error_description: "client_id, redirect_uri, and scope are required".into(),
57 redirect_uri: None,
58 state: state_opt.clone(),
59 nonce: None,
60 })),
61 "Missing required OAuth parameters",
62 None,
63 ));
64 }
65
66 if rt.is_empty() {
68 return Err(ControllerError::new(
69 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
70 error: OAuthErrorCode::InvalidRequest.as_str().into(),
71 error_description: "response_type is required".into(),
72 redirect_uri: None,
73 state: state_opt.clone(),
74 nonce: None,
75 })),
76 "Missing response_type",
77 None::<anyhow::Error>,
78 ));
79 }
80
81 if rt != "code" {
83 return Err(ControllerError::new(
84 ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
85 error: OAuthErrorCode::UnsupportedResponseType.as_str().into(),
86 error_description: "unsupported response_type".into(),
87 redirect_uri: None, state: state_opt,
89 nonce: None,
90 })),
91 "Unsupported response_type",
92 None::<anyhow::Error>,
93 ));
94 }
95
96 Ok(AuthorizeParams {
97 response_type: rt.to_string(),
98 client_id: client_id.to_string(),
99 redirect_uri: redirect_uri.to_string(),
100 scope: scope.to_string(),
101 state: self.state.clone(),
102 nonce: self.nonce.clone(),
103 code_challenge: self.code_challenge.clone(),
104 code_challenge_method: self.code_challenge_method.clone(),
105 })
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode};
113 use serde_json::{Value, json};
114
115 fn assert_oauth_error(
116 result: Result<AuthorizeParams, ControllerError>,
117 expected_error: OAuthErrorCode,
118 expected_description: &str,
119 ) {
120 match result {
121 Err(err) => match err.error_type() {
122 ControllerErrorType::OAuthError(data) => {
123 assert_eq!(data.error, expected_error.as_str());
124 assert_eq!(data.error_description, expected_description);
125 }
126 other => panic!("Expected OAuthError, got {:?}", other),
127 },
128 Ok(_) => panic!("Expected Err, got Ok(_)"),
129 }
130 }
131
132 #[test]
133 fn authorize_missing_fields() {
134 let q = AuthorizeQuery {
135 response_type: Some("code".into()),
136 client_id: None,
137 redirect_uri: None,
138 scope: None,
139 state: Some("xyz".into()),
140 nonce: None,
141 code_challenge: None,
142 code_challenge_method: None,
143 _extra: Default::default(),
144 };
145 let res = q.validate();
146 assert_oauth_error(
147 res,
148 OAuthErrorCode::InvalidRequest,
149 "client_id, redirect_uri, and scope are required",
150 );
151 }
152
153 #[test]
154 fn authorize_unsupported_response_type() {
155 let q = AuthorizeQuery {
156 response_type: Some("token".into()),
157 client_id: Some("cid".into()),
158 redirect_uri: Some("http://localhost".into()),
159 scope: Some("openid".into()),
160 state: None,
161 nonce: None,
162 code_challenge: None,
163 code_challenge_method: None,
164 _extra: Default::default(),
165 };
166 let res = q.validate();
167 assert_oauth_error(
168 res,
169 OAuthErrorCode::UnsupportedResponseType,
170 "unsupported response_type",
171 );
172 }
173
174 #[test]
175 fn authorize_valid_code_flow_openid_without_nonce_is_ok() {
176 let q = AuthorizeQuery {
178 response_type: Some("code".into()),
179 client_id: Some("cid".into()),
180 redirect_uri: Some("http://localhost".into()),
181 scope: Some("openid profile".into()),
182 state: Some("s".into()),
183 nonce: None,
184 code_challenge: None,
185 code_challenge_method: None,
186 _extra: Default::default(),
187 };
188 assert!(q.validate().is_ok());
189 }
190
191 #[test]
192 fn authorize_unknown_params_are_captured_in_extra() {
193 let v: Value = json!({
194 "response_type": "code",
195 "client_id": "cid",
196 "redirect_uri": "http://localhost",
197 "scope": "openid",
198 "state": "s",
199 "foo": "bar",
200 "x": "y"
201 });
202 let q: AuthorizeQuery = serde_json::from_value(v).unwrap();
203 assert_eq!(q._extra.get("foo").map(String::as_str), Some("bar"));
204 assert_eq!(q._extra.get("x").map(String::as_str), Some("y"));
205 assert!(q.validate().is_ok());
206 }
207
208 #[test]
209 fn authorize_missing_response_type_is_invalid_request() {
210 let q = AuthorizeQuery {
211 response_type: None,
212 client_id: Some("cid".into()),
213 redirect_uri: Some("http://localhost".into()),
214 scope: Some("openid".into()),
215 state: None,
216 nonce: None,
217 code_challenge: None,
218 code_challenge_method: None,
219 _extra: Default::default(),
220 };
221 let res = q.validate();
222 assert_oauth_error(
223 res,
224 OAuthErrorCode::InvalidRequest,
225 "response_type is required",
226 );
227 }
228
229 #[test]
230 fn authorize_pkce_fields_passthrough() {
231 let q = AuthorizeQuery {
232 response_type: Some("code".into()),
233 client_id: Some("cid".into()),
234 redirect_uri: Some("http://localhost".into()),
235 scope: Some("openid profile".into()),
236 state: Some("s".into()),
237 nonce: Some("n".into()),
238 code_challenge: Some("abcDEF123-_".into()),
239 code_challenge_method: Some("S256".into()),
240 _extra: Default::default(),
241 };
242 let p = q.validate().expect("validate should pass");
243 assert_eq!(p.code_challenge.as_deref(), Some("abcDEF123-_"));
244 assert_eq!(p.code_challenge_method.as_deref(), Some("S256"));
245 }
246}