headless_lms_server/controllers/main_frontend/oauth/
authorize.rs1use crate::domain::oauth::authorize_query::AuthorizeQuery;
2use crate::domain::oauth::helpers::{oauth_error, oauth_invalid_request};
3use crate::domain::oauth::oauth_validated::OAuthValidated;
4use crate::domain::oauth::pkce::parse_authorize_pkce;
5use crate::domain::oauth::redirects::{
6 build_authorize_qs, build_consent_redirect, build_login_redirect, redirect_with_code,
7};
8use crate::domain::rate_limit_middleware_builder::{RateLimit, RateLimitConfig};
9use crate::prelude::*;
10use actix_web::web;
11use chrono::{Duration, Utc};
12use itertools::Itertools;
13use models::{
14 library::oauth::{generate_access_token, token_digest_sha256},
15 oauth_auth_code::{NewAuthCodeParams, OAuthAuthCode},
16 oauth_client::OAuthClient,
17 oauth_user_client_scopes::OAuthUserClientScopes,
18};
19use sqlx::PgPool;
20use std::collections::HashSet;
21
22#[derive(Debug, Clone, Copy, Default)]
23struct PromptFlags {
24 none: bool,
25 consent: bool,
26 login: bool,
27 select_account: bool,
28}
29
30fn parse_prompt(prompt: Option<&str>) -> Result<PromptFlags, &'static str> {
31 let mut f = PromptFlags::default();
32 let Some(p) = prompt else { return Ok(f) };
33
34 for v in p.split_whitespace() {
35 match v {
36 "none" => f.none = true,
37 "consent" => f.consent = true,
38 "login" => f.login = true,
39 "select_account" => f.select_account = true,
40 _ => return Err("unsupported prompt value"),
41 }
42 }
43
44 if f.none && (f.consent || f.login || f.select_account) {
45 return Err("prompt=none cannot be combined with other values");
46 }
47
48 Ok(f)
49}
50
51pub async fn authorize(
88 pool: web::Data<PgPool>,
89 OAuthValidated(query): OAuthValidated<AuthorizeQuery>,
90 user: Option<AuthUser>,
91 app_conf: web::Data<headless_lms_utils::ApplicationConfiguration>,
92) -> ControllerResult<HttpResponse> {
93 let mut conn = pool.acquire().await?;
94 let server_token = skip_authorize();
95
96 let client = OAuthClient::find_by_client_id(&mut conn, &query.client_id)
97 .await
98 .map_err(|_| {
99 oauth_invalid_request(
100 "invalid client_id",
101 None, query.state.as_deref(),
103 )
104 })?;
105
106 tracing::Span::current().record("client_id", &query.client_id);
108 tracing::Span::current().record("response_type", &query.response_type);
109
110 if !client.redirect_uris.contains(&query.redirect_uri) {
111 return Err(oauth_invalid_request(
112 "redirect_uri does not match client",
113 None, query.state.as_deref(),
115 ));
116 }
117
118 if query.request.is_some() {
119 return Err(oauth_error(
120 "request_not_supported",
121 "request object is not supported",
122 Some(&query.redirect_uri),
123 query.state.as_deref(),
124 ));
125 }
126
127 let prompt = parse_prompt(query.prompt.as_deref()).map_err(|msg| {
128 oauth_invalid_request(msg, Some(&query.redirect_uri), query.state.as_deref())
129 })?;
130
131 if prompt.login {
132 return Err(oauth_error(
133 "inalid_request",
134 "prompt=login is not supported",
135 Some(&query.redirect_uri),
136 query.state.as_deref(),
137 ));
138 }
139
140 if prompt.select_account {
141 return Err(oauth_error(
142 "inalid_request",
143 "prompt=select_account is not supported",
144 Some(&query.redirect_uri),
145 query.state.as_deref(),
146 ));
147 }
148
149 let parsed_pkce_method = parse_authorize_pkce(
150 &client,
151 query.code_challenge.as_deref(),
152 query.code_challenge_method.as_deref(),
153 &query.redirect_uri,
154 query.state.as_deref(),
155 )?;
156
157 let redirect_url = match user {
158 Some(user) => {
159 let granted_scopes: Vec<String> =
160 OAuthUserClientScopes::find_scopes(&mut conn, user.id, client.id).await?;
161
162 let requested: HashSet<&str> = query.scope.split_whitespace().collect();
163 let granted: HashSet<&str> = granted_scopes.iter().map(|s| s.as_str()).collect();
164 let missing: Vec<&str> = requested.difference(&granted).copied().collect();
165 if prompt.none && !missing.is_empty() {
166 return Err(oauth_error(
167 "consent_required",
168 "end-user consent is required",
169 Some(&query.redirect_uri),
170 query.state.as_deref(),
171 ));
172 }
173
174 if prompt.consent || !missing.is_empty() {
175 let return_to = format!(
176 "/api/v0/main-frontend/oauth/authorize?{}",
177 build_authorize_qs(&query)
178 );
179 build_consent_redirect(&query, &return_to)
180 } else {
181 let code = generate_access_token();
182 let expires_at = Utc::now() + Duration::minutes(10);
183 let token_hmac_key = &app_conf.oauth_server_configuration.oauth_token_hmac_key;
184 let code_digest = token_digest_sha256(&code, token_hmac_key);
185
186 let new_auth_code_params = NewAuthCodeParams {
187 digest: &code_digest,
188 user_id: user.id,
189 client_id: client.id,
190 redirect_uri: &query.redirect_uri,
191 scopes: &query
192 .scope
193 .split_whitespace()
194 .map(|s| s.to_string())
195 .collect_vec(),
196 nonce: query.nonce.as_deref(),
197 code_challenge: query.code_challenge.as_deref(),
198 code_challenge_method: parsed_pkce_method,
199 dpop_jkt: None, expires_at,
201 metadata: serde_json::Map::new(),
202 };
203
204 OAuthAuthCode::insert(&mut conn, new_auth_code_params).await?;
205 redirect_with_code(&query.redirect_uri, &code, query.state.as_deref())
206 }
207 }
208 None => {
209 if prompt.none {
210 return Err(oauth_error(
211 "login_required",
212 "end-user is not logged in",
213 Some(&query.redirect_uri),
214 query.state.as_deref(),
215 ));
216 }
217 build_login_redirect(&query)
218 }
219 };
220
221 server_token.authorized_ok(
222 HttpResponse::Found()
223 .append_header(("Location", redirect_url))
224 .finish(),
225 )
226}
227
228pub fn _add_routes(cfg: &mut web::ServiceConfig) {
229 cfg.service(
230 web::resource("/authorize")
231 .wrap(RateLimit::new(RateLimitConfig {
232 per_minute: Some(100),
233 per_hour: Some(500),
234 per_day: Some(2000),
235 per_month: None,
236 }))
237 .route(web::get().to(authorize))
238 .route(web::post().to(authorize)),
239 );
240}