headless_lms_server/domain/oauth/
token_service.rs

1use chrono::{DateTime, Utc};
2use sqlx::{Connection, PgConnection};
3use uuid::Uuid;
4
5use crate::domain::oauth::errors::TokenGrantError;
6use crate::domain::oauth::pkce::verify_token_pkce;
7use headless_lms_models::library::oauth::Digest;
8use headless_lms_models::library::oauth::tokens::token_digest_sha256;
9use headless_lms_models::oauth_access_token::TokenType;
10use headless_lms_models::oauth_auth_code::OAuthAuthCode;
11use headless_lms_models::oauth_client::OAuthClient;
12use headless_lms_models::oauth_refresh_tokens::{
13    IssueTokensFromAuthCodeParams, OAuthRefreshTokens, RotateRefreshTokenParams,
14};
15
16use super::token_query::TokenGrant;
17
18/// A pair of access and refresh tokens with their digests.
19pub struct TokenPair {
20    pub access_token: String,
21    pub refresh_token: String,
22    pub access_digest: Digest,
23    pub refresh_digest: Digest,
24}
25
26/// Generate a new token pair (access token and refresh token) with their digests.
27pub fn generate_token_pair(key: &str) -> TokenPair {
28    let access_token = headless_lms_models::library::oauth::tokens::generate_access_token();
29    let refresh_token = headless_lms_models::library::oauth::tokens::generate_access_token();
30    TokenPair {
31        access_token: access_token.clone(),
32        refresh_token: refresh_token.clone(),
33        access_digest: token_digest_sha256(&access_token, key),
34        refresh_digest: token_digest_sha256(&refresh_token, key),
35    }
36}
37
38pub struct TokenGrantRequest<'a> {
39    pub grant: &'a TokenGrant,
40    pub client: &'a OAuthClient,
41    pub token_pair: TokenPair,
42    pub access_expires_at: DateTime<Utc>,
43    pub refresh_expires_at: DateTime<Utc>,
44    pub issued_token_type: TokenType,
45    pub dpop_jkt: Option<&'a str>,
46    pub token_hmac_key: &'a str,
47}
48
49pub struct TokenGrantResult {
50    pub user_id: Uuid,
51    pub scopes: Vec<String>,
52    pub nonce: Option<String>,
53    pub access_expires_at: DateTime<Utc>,
54    pub issue_id_token: bool,
55}
56
57pub async fn process_token_grant(
58    conn: &mut PgConnection,
59    request: TokenGrantRequest<'_>,
60) -> Result<TokenGrantResult, TokenGrantError> {
61    let mut tx = conn
62        .begin()
63        .await
64        .map_err(|e| TokenGrantError::ServerError(format!("Failed to start transaction: {}", e)))?;
65
66    let result = match request.grant {
67        TokenGrant::AuthorizationCode {
68            code,
69            redirect_uri,
70            code_verifier,
71        } => {
72            let code_digest = token_digest_sha256(code, request.token_hmac_key);
73            // Consume with client_id check in WHERE clause to prevent DoS attacks
74            let code_row = if let Some(ref_uri) = redirect_uri {
75                OAuthAuthCode::consume_with_redirect_in_transaction(
76                    &mut tx,
77                    code_digest,
78                    request.client.id,
79                    ref_uri,
80                )
81                .await
82                .map_err(|e| TokenGrantError::InvalidGrant(format!("{}", e)))?
83            } else {
84                OAuthAuthCode::consume_in_transaction(&mut tx, code_digest, request.client.id)
85                    .await
86                    .map_err(|e| TokenGrantError::InvalidGrant(format!("{}", e)))?
87            };
88
89            // PKCE verification happens after client_id check (enforced in SQL)
90            verify_token_pkce(
91                request.client,
92                code_row.code_challenge.as_deref(),
93                code_row.code_challenge_method,
94                code_verifier.as_deref(),
95            )
96            .map_err(|_| TokenGrantError::PkceVerificationFailed)?;
97
98            OAuthRefreshTokens::issue_tokens_from_auth_code_in_transaction(
99                &mut tx,
100                IssueTokensFromAuthCodeParams {
101                    user_id: code_row.user_id,
102                    client_id: code_row.client_id,
103                    scopes: &code_row.scopes,
104                    access_token_digest: &request.token_pair.access_digest,
105                    refresh_token_digest: &request.token_pair.refresh_digest,
106                    access_token_expires_at: request.access_expires_at,
107                    refresh_token_expires_at: request.refresh_expires_at,
108                    access_token_type: request.issued_token_type,
109                    access_token_dpop_jkt: request.dpop_jkt,
110                    refresh_token_dpop_jkt: request.dpop_jkt,
111                },
112            )
113            .await
114            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
115
116            // Determine if ID token should be issued based on presence of "openid" scope
117            let has_openid = code_row.scopes.iter().any(|s| s == "openid");
118
119            Ok(TokenGrantResult {
120                user_id: code_row.user_id,
121                scopes: code_row.scopes,
122                nonce: code_row.nonce.clone(),
123                access_expires_at: request.access_expires_at,
124                issue_id_token: has_openid,
125            })
126        }
127        TokenGrant::RefreshToken { refresh_token, .. } => {
128            let presented = token_digest_sha256(refresh_token, request.token_hmac_key);
129            // Consume with client_id check in WHERE clause to prevent DoS attacks
130            let old =
131                OAuthRefreshTokens::consume_in_transaction(&mut tx, presented, request.client.id)
132                    .await
133                    .map_err(|e| TokenGrantError::InvalidGrant(format!("{}", e)))?;
134
135            if let Some(expected_jkt) = old.dpop_jkt.as_deref() {
136                let presented_jkt = request.dpop_jkt.ok_or_else(|| {
137                    TokenGrantError::InvalidClient(
138                        "missing DPoP header for sender-constrained refresh".into(),
139                    )
140                })?;
141                if presented_jkt != expected_jkt {
142                    return Err(TokenGrantError::DpopMismatch);
143                }
144            }
145
146            let refresh_issue_type = if old.dpop_jkt.is_some() {
147                TokenType::DPoP
148            } else {
149                request.issued_token_type
150            };
151            let at_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
152            let refresh_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
153
154            OAuthRefreshTokens::complete_refresh_token_rotation_in_transaction(
155                &mut tx,
156                &old,
157                RotateRefreshTokenParams {
158                    new_refresh_token_digest: &request.token_pair.refresh_digest,
159                    new_access_token_digest: &request.token_pair.access_digest,
160                    access_token_expires_at: request.access_expires_at,
161                    refresh_token_expires_at: request.refresh_expires_at,
162                    access_token_type: refresh_issue_type,
163                    access_token_dpop_jkt: at_jkt,
164                    refresh_token_dpop_jkt: refresh_jkt,
165                },
166            )
167            .await
168            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
169
170            Ok(TokenGrantResult {
171                user_id: old.user_id,
172                scopes: old.scopes.clone(),
173                nonce: None,
174                access_expires_at: request.access_expires_at,
175                issue_id_token: false,
176            })
177        }
178        TokenGrant::Unknown => Err(TokenGrantError::UnsupportedGrantType),
179    };
180
181    match result {
182        Ok(res) => {
183            tx.commit().await.map_err(|e| {
184                TokenGrantError::ServerError(format!("Failed to commit transaction: {}", e))
185            })?;
186            Ok(res)
187        }
188        Err(e) => {
189            // Transaction will be rolled back on drop
190            Err(e)
191        }
192    }
193}