headless_lms_server/domain/oauth/
token_service.rs1use chrono::{DateTime, Utc};
2use sqlx::{Connection, PgConnection};
3use uuid::Uuid;
4
5use crate::domain::oauth::errors::TokenGrantError;
6use crate::domain::oauth::pkce::verify_token_pkce;
7use headless_lms_models::library::oauth::Digest;
8use headless_lms_models::library::oauth::tokens::token_digest_sha256;
9use headless_lms_models::oauth_access_token::TokenType;
10use headless_lms_models::oauth_auth_code::OAuthAuthCode;
11use headless_lms_models::oauth_client::OAuthClient;
12use headless_lms_models::oauth_refresh_tokens::{
13 IssueTokensFromAuthCodeParams, OAuthRefreshTokens, RotateRefreshTokenParams,
14};
15
16use super::token_query::TokenGrant;
17
18pub struct TokenPair {
20 pub access_token: String,
21 pub refresh_token: String,
22 pub access_digest: Digest,
23 pub refresh_digest: Digest,
24}
25
26pub fn generate_token_pair(key: &str) -> TokenPair {
28 let access_token = headless_lms_models::library::oauth::tokens::generate_access_token();
29 let refresh_token = headless_lms_models::library::oauth::tokens::generate_access_token();
30 TokenPair {
31 access_token: access_token.clone(),
32 refresh_token: refresh_token.clone(),
33 access_digest: token_digest_sha256(&access_token, key),
34 refresh_digest: token_digest_sha256(&refresh_token, key),
35 }
36}
37
38pub struct TokenGrantRequest<'a> {
39 pub grant: &'a TokenGrant,
40 pub client: &'a OAuthClient,
41 pub token_pair: TokenPair,
42 pub access_expires_at: DateTime<Utc>,
43 pub refresh_expires_at: DateTime<Utc>,
44 pub issued_token_type: TokenType,
45 pub dpop_jkt: Option<&'a str>,
46 pub token_hmac_key: &'a str,
47}
48
49pub struct TokenGrantResult {
50 pub user_id: Uuid,
51 pub scopes: Vec<String>,
52 pub nonce: Option<String>,
53 pub access_expires_at: DateTime<Utc>,
54 pub issue_id_token: bool,
55}
56
57pub async fn process_token_grant(
58 conn: &mut PgConnection,
59 request: TokenGrantRequest<'_>,
60) -> Result<TokenGrantResult, TokenGrantError> {
61 let mut tx = conn
62 .begin()
63 .await
64 .map_err(|e| TokenGrantError::ServerError(format!("Failed to start transaction: {}", e)))?;
65
66 let result = match request.grant {
67 TokenGrant::AuthorizationCode {
68 code,
69 redirect_uri,
70 code_verifier,
71 } => {
72 let code_digest = token_digest_sha256(code, request.token_hmac_key);
73 let code_row = if let Some(ref_uri) = redirect_uri {
75 OAuthAuthCode::consume_with_redirect_in_transaction(
76 &mut tx,
77 code_digest,
78 request.client.id,
79 ref_uri,
80 )
81 .await
82 .map_err(|e| {
83 tracing::warn!(
84 err = %e,
85 "OAuth token: auth code consume failed (redirect_uri check); possible causes: code already used, wrong redirect_uri, expired, or wrong client"
86 );
87 TokenGrantError::InvalidGrant("Given grant is invalid".to_string())
88 })?
89 } else {
90 OAuthAuthCode::consume_in_transaction(&mut tx, code_digest, request.client.id)
91 .await
92 .map_err(|e| {
93 tracing::warn!(
94 err = %e,
95 "OAuth token: auth code consume failed; possible causes: code already used, expired, or wrong client"
96 );
97 TokenGrantError::InvalidGrant("Given grant is invalid".to_string())
98 })?
99 };
100
101 verify_token_pkce(
103 request.client,
104 code_row.code_challenge.as_deref(),
105 code_row.code_challenge_method,
106 code_verifier.as_deref(),
107 )
108 .map_err(|_| TokenGrantError::PkceVerificationFailed)?;
109
110 OAuthRefreshTokens::issue_tokens_from_auth_code_in_transaction(
111 &mut tx,
112 IssueTokensFromAuthCodeParams {
113 user_id: code_row.user_id,
114 client_id: code_row.client_id,
115 scopes: &code_row.scopes,
116 access_token_digest: &request.token_pair.access_digest,
117 refresh_token_digest: &request.token_pair.refresh_digest,
118 access_token_expires_at: request.access_expires_at,
119 refresh_token_expires_at: request.refresh_expires_at,
120 access_token_type: request.issued_token_type,
121 access_token_dpop_jkt: request.dpop_jkt,
122 refresh_token_dpop_jkt: request.dpop_jkt,
123 },
124 )
125 .await
126 .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
127
128 let has_openid = code_row.scopes.iter().any(|s| s == "openid");
130
131 Ok(TokenGrantResult {
132 user_id: code_row.user_id,
133 scopes: code_row.scopes,
134 nonce: code_row.nonce.clone(),
135 access_expires_at: request.access_expires_at,
136 issue_id_token: has_openid,
137 })
138 }
139 TokenGrant::RefreshToken { refresh_token, .. } => {
140 let presented = token_digest_sha256(refresh_token, request.token_hmac_key);
141 let old =
143 OAuthRefreshTokens::consume_in_transaction(&mut tx, presented, request.client.id)
144 .await
145 .map_err(|e| TokenGrantError::InvalidGrant(format!("{}", e)))?;
146
147 if let Some(expected_jkt) = old.dpop_jkt.as_deref() {
148 let presented_jkt = request.dpop_jkt.ok_or_else(|| {
149 TokenGrantError::InvalidClient(
150 "missing DPoP header for sender-constrained refresh".into(),
151 )
152 })?;
153 if presented_jkt != expected_jkt {
154 return Err(TokenGrantError::DpopMismatch);
155 }
156 }
157
158 let refresh_issue_type = if old.dpop_jkt.is_some() {
159 TokenType::DPoP
160 } else {
161 request.issued_token_type
162 };
163 let at_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
164 let refresh_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
165
166 OAuthRefreshTokens::complete_refresh_token_rotation_in_transaction(
167 &mut tx,
168 &old,
169 RotateRefreshTokenParams {
170 new_refresh_token_digest: &request.token_pair.refresh_digest,
171 new_access_token_digest: &request.token_pair.access_digest,
172 access_token_expires_at: request.access_expires_at,
173 refresh_token_expires_at: request.refresh_expires_at,
174 access_token_type: refresh_issue_type,
175 access_token_dpop_jkt: at_jkt,
176 refresh_token_dpop_jkt: refresh_jkt,
177 },
178 )
179 .await
180 .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
181
182 Ok(TokenGrantResult {
183 user_id: old.user_id,
184 scopes: old.scopes.clone(),
185 nonce: None,
186 access_expires_at: request.access_expires_at,
187 issue_id_token: false,
188 })
189 }
190 TokenGrant::Unknown => Err(TokenGrantError::UnsupportedGrantType),
191 };
192
193 match result {
194 Ok(res) => {
195 tx.commit().await.map_err(|e| {
196 TokenGrantError::ServerError(format!("Failed to commit transaction: {}", e))
197 })?;
198 Ok(res)
199 }
200 Err(e) => {
201 Err(e)
203 }
204 }
205}