headless_lms_server/controllers/main_frontend/oauth/
authorize.rs1use crate::domain::oauth::authorize_query::AuthorizeQuery;
2use crate::domain::oauth::helpers::{oauth_error, oauth_invalid_request};
3use crate::domain::oauth::oauth_validated::OAuthValidated;
4use crate::domain::oauth::pkce::parse_authorize_pkce;
5use crate::domain::oauth::redirects::{
6 build_authorize_qs, build_consent_redirect, build_login_redirect, redirect_with_code,
7};
8use crate::domain::rate_limit_middleware_builder::{RateLimit, RateLimitConfig};
9use crate::prelude::*;
10use actix_web::web;
11use chrono::{Duration, Utc};
12use itertools::Itertools;
13use models::{
14 library::oauth::{generate_access_token, token_digest_sha256},
15 oauth_auth_code::{NewAuthCodeParams, OAuthAuthCode},
16 oauth_client::OAuthClient,
17 oauth_user_client_scopes::OAuthUserClientScopes,
18};
19use sqlx::PgPool;
20use std::collections::HashSet;
21
22#[derive(Debug, Clone, Copy, Default)]
23struct PromptFlags {
24 none: bool,
25 consent: bool,
26 login: bool,
27 select_account: bool,
28}
29
30fn parse_prompt(prompt: Option<&str>) -> Result<PromptFlags, &'static str> {
31 let mut f = PromptFlags::default();
32 let Some(p) = prompt else { return Ok(f) };
33
34 for v in p.split_whitespace() {
35 match v {
36 "none" => f.none = true,
37 "consent" => f.consent = true,
38 "login" => f.login = true,
39 "select_account" => f.select_account = true,
40 _ => return Err("unsupported prompt value"),
41 }
42 }
43
44 if f.none && (f.consent || f.login || f.select_account) {
45 return Err("prompt=none cannot be combined with other values");
46 }
47
48 Ok(f)
49}
50
51pub async fn authorize(
88 pool: web::Data<PgPool>,
89 OAuthValidated(query): OAuthValidated<AuthorizeQuery>,
90 user: Option<AuthUser>,
91 app_conf: web::Data<headless_lms_utils::ApplicationConfiguration>,
92) -> ControllerResult<HttpResponse> {
93 let mut conn = pool.acquire().await?;
94 let server_token = skip_authorize();
95
96 let client = OAuthClient::find_by_client_id(&mut conn, &query.client_id)
97 .await
98 .map_err(|e| {
99 tracing::error!(err = %e, "OAuth authorize: client lookup failed");
100 oauth_invalid_request(
101 "invalid client_id",
102 None, query.state.as_deref(),
104 )
105 })?;
106
107 tracing::Span::current().record("client_id", &query.client_id);
109 tracing::Span::current().record("response_type", &query.response_type);
110
111 if !client.redirect_uris.contains(&query.redirect_uri) {
112 return Err(oauth_invalid_request(
113 "redirect_uri does not match client",
114 None, query.state.as_deref(),
116 ));
117 }
118
119 if query.request.is_some() {
120 return Err(oauth_error(
121 "request_not_supported",
122 "request object is not supported",
123 Some(&query.redirect_uri),
124 query.state.as_deref(),
125 ));
126 }
127
128 let prompt = parse_prompt(query.prompt.as_deref()).map_err(|msg| {
129 oauth_invalid_request(msg, Some(&query.redirect_uri), query.state.as_deref())
130 })?;
131
132 if prompt.login {
133 return Err(oauth_error(
134 "inalid_request",
135 "prompt=login is not supported",
136 Some(&query.redirect_uri),
137 query.state.as_deref(),
138 ));
139 }
140
141 if prompt.select_account {
142 return Err(oauth_error(
143 "inalid_request",
144 "prompt=select_account is not supported",
145 Some(&query.redirect_uri),
146 query.state.as_deref(),
147 ));
148 }
149
150 let parsed_pkce_method = parse_authorize_pkce(
151 &client,
152 query.code_challenge.as_deref(),
153 query.code_challenge_method.as_deref(),
154 &query.redirect_uri,
155 query.state.as_deref(),
156 )?;
157
158 let redirect_url = match user {
159 Some(user) => {
160 let granted_scopes: Vec<String> =
161 OAuthUserClientScopes::find_scopes(&mut conn, user.id, client.id).await?;
162
163 let requested: HashSet<&str> = query.scope.split_whitespace().collect();
164 let granted: HashSet<&str> = granted_scopes.iter().map(|s| s.as_str()).collect();
165 let missing: Vec<&str> = requested.difference(&granted).copied().collect();
166 if prompt.none && !missing.is_empty() {
167 return Err(oauth_error(
168 "consent_required",
169 "end-user consent is required",
170 Some(&query.redirect_uri),
171 query.state.as_deref(),
172 ));
173 }
174
175 if prompt.consent || !missing.is_empty() {
176 let return_to = format!(
177 "/api/v0/main-frontend/oauth/authorize?{}",
178 build_authorize_qs(&query)
179 );
180 build_consent_redirect(&query, &return_to)
181 } else {
182 let code = generate_access_token();
183 let expires_at = Utc::now() + Duration::minutes(10);
184 let token_hmac_key = &app_conf.oauth_server_configuration.oauth_token_hmac_key;
185 let code_digest = token_digest_sha256(&code, token_hmac_key);
186
187 let new_auth_code_params = NewAuthCodeParams {
188 digest: &code_digest,
189 user_id: user.id,
190 client_id: client.id,
191 redirect_uri: &query.redirect_uri,
192 scopes: &query
193 .scope
194 .split_whitespace()
195 .map(|s| s.to_string())
196 .collect_vec(),
197 nonce: query.nonce.as_deref(),
198 code_challenge: query.code_challenge.as_deref(),
199 code_challenge_method: parsed_pkce_method,
200 dpop_jkt: None, expires_at,
202 metadata: serde_json::Map::new(),
203 };
204
205 OAuthAuthCode::insert(&mut conn, new_auth_code_params).await?;
206 redirect_with_code(&query.redirect_uri, &code, query.state.as_deref())
207 }
208 }
209 None => {
210 if prompt.none {
211 return Err(oauth_error(
212 "login_required",
213 "end-user is not logged in",
214 Some(&query.redirect_uri),
215 query.state.as_deref(),
216 ));
217 }
218 build_login_redirect(&query)
219 }
220 };
221
222 server_token.authorized_ok(
223 HttpResponse::Found()
224 .append_header(("Location", redirect_url))
225 .finish(),
226 )
227}
228
229pub fn _add_routes(cfg: &mut web::ServiceConfig) {
230 cfg.service(
231 web::resource("/authorize")
232 .wrap(RateLimit::new(RateLimitConfig {
233 per_minute: Some(100),
234 per_hour: Some(500),
235 per_day: Some(2000),
236 per_month: None,
237 }))
238 .route(web::get().to(authorize))
239 .route(web::post().to(authorize)),
240 );
241}