headless_lms_server/controllers/main_frontend/oauth/
authorize.rs1use crate::domain::oauth::authorize_query::AuthorizeQuery;
2use crate::domain::oauth::helpers::oauth_invalid_request;
3use crate::domain::oauth::oauth_validated::OAuthValidated;
4use crate::domain::oauth::pkce::parse_authorize_pkce;
5use crate::domain::oauth::redirects::{
6 build_authorize_qs, build_consent_redirect, build_login_redirect, redirect_with_code,
7};
8use crate::domain::rate_limit_middleware_builder::build_rate_limiting_middleware;
9use crate::prelude::*;
10use actix_web::web;
11use chrono::{Duration, Utc};
12use itertools::Itertools;
13use models::{
14 library::oauth::{generate_access_token, token_digest_sha256},
15 oauth_auth_code::{NewAuthCodeParams, OAuthAuthCode},
16 oauth_client::OAuthClient,
17 oauth_user_client_scopes::OAuthUserClientScopes,
18};
19use sqlx::PgPool;
20use std::collections::HashSet;
21use std::time::Duration as StdDuration;
22
23pub async fn authorize(
60 pool: web::Data<PgPool>,
61 OAuthValidated(query): OAuthValidated<AuthorizeQuery>,
62 user: Option<AuthUser>,
63 app_conf: web::Data<headless_lms_utils::ApplicationConfiguration>,
64) -> ControllerResult<HttpResponse> {
65 let mut conn = pool.acquire().await?;
66 let server_token = skip_authorize();
67
68 let client = OAuthClient::find_by_client_id(&mut conn, &query.client_id)
69 .await
70 .map_err(|_| {
71 oauth_invalid_request(
72 "invalid client_id",
73 None, query.state.as_deref(),
75 )
76 })?;
77
78 tracing::Span::current().record("client_id", &query.client_id);
80 tracing::Span::current().record("response_type", &query.response_type);
81
82 if !client.redirect_uris.contains(&query.redirect_uri) {
83 return Err(oauth_invalid_request(
84 "redirect_uri does not match client",
85 None, query.state.as_deref(),
87 ));
88 }
89
90 let parsed_pkce_method = parse_authorize_pkce(
91 &client,
92 query.code_challenge.as_deref(),
93 query.code_challenge_method.as_deref(),
94 &query.redirect_uri,
95 query.state.as_deref(),
96 )?;
97
98 let redirect_url = match user {
99 Some(user) => {
100 let granted_scopes: Vec<String> =
101 OAuthUserClientScopes::find_scopes(&mut conn, user.id, client.id).await?;
102
103 let requested: HashSet<&str> = query.scope.split_whitespace().collect();
104 let granted: HashSet<&str> = granted_scopes.iter().map(|s| s.as_str()).collect();
105 let missing: Vec<&str> = requested.difference(&granted).copied().collect();
106
107 if !missing.is_empty() {
108 let return_to = format!(
109 "/api/v0/main-frontend/oauth/authorize?{}",
110 build_authorize_qs(&query)
111 );
112 build_consent_redirect(&query, &return_to)
113 } else {
114 let code = generate_access_token();
115 let expires_at = Utc::now() + Duration::minutes(10);
116 let token_hmac_key = &app_conf.oauth_server_configuration.oauth_token_hmac_key;
117 let code_digest = token_digest_sha256(&code, token_hmac_key);
118
119 let new_auth_code_params = NewAuthCodeParams {
120 digest: &code_digest,
121 user_id: user.id,
122 client_id: client.id,
123 redirect_uri: &query.redirect_uri,
124 scopes: &query
125 .scope
126 .split_whitespace()
127 .map(|s| s.to_string())
128 .collect_vec(),
129 nonce: query.nonce.as_deref(),
130 code_challenge: query.code_challenge.as_deref(),
131 code_challenge_method: parsed_pkce_method,
132 dpop_jkt: None, expires_at,
134 metadata: serde_json::Map::new(),
135 };
136
137 OAuthAuthCode::insert(&mut conn, new_auth_code_params).await?;
138 redirect_with_code(&query.redirect_uri, &code, query.state.as_deref())
139 }
140 }
141 None => build_login_redirect(&query),
142 };
143
144 server_token.authorized_ok(
145 HttpResponse::Found()
146 .append_header(("Location", redirect_url))
147 .finish(),
148 )
149}
150
151pub fn _add_routes(cfg: &mut web::ServiceConfig) {
152 cfg.service(
153 web::resource("/authorize")
154 .wrap(build_rate_limiting_middleware(
155 StdDuration::from_secs(60),
156 100,
157 ))
158 .wrap(build_rate_limiting_middleware(
159 StdDuration::from_secs(60 * 60),
160 500,
161 ))
162 .wrap(build_rate_limiting_middleware(
163 StdDuration::from_secs(60 * 60 * 24),
164 2000,
165 ))
166 .route(web::get().to(authorize))
167 .route(web::post().to(authorize)),
168 );
169}