headless_lms_server/domain/oauth/
token_service.rs

1use chrono::{DateTime, Utc};
2use sqlx::{Connection, PgConnection};
3use uuid::Uuid;
4
5use crate::domain::oauth::errors::TokenGrantError;
6use crate::domain::oauth::pkce::verify_token_pkce;
7use headless_lms_models::library::oauth::Digest;
8use headless_lms_models::library::oauth::tokens::token_digest_sha256;
9use headless_lms_models::oauth_access_token::TokenType;
10use headless_lms_models::oauth_auth_code::OAuthAuthCode;
11use headless_lms_models::oauth_client::OAuthClient;
12use headless_lms_models::oauth_refresh_tokens::{
13    IssueTokensFromAuthCodeParams, OAuthRefreshTokens, RotateRefreshTokenParams,
14};
15
16use super::token_query::TokenGrant;
17
18/// A pair of access and refresh tokens with their digests.
19pub struct TokenPair {
20    pub access_token: String,
21    pub refresh_token: String,
22    pub access_digest: Digest,
23    pub refresh_digest: Digest,
24}
25
26/// Generate a new token pair (access token and refresh token) with their digests.
27pub fn generate_token_pair(key: &str) -> TokenPair {
28    let access_token = headless_lms_models::library::oauth::tokens::generate_access_token();
29    let refresh_token = headless_lms_models::library::oauth::tokens::generate_access_token();
30    TokenPair {
31        access_token: access_token.clone(),
32        refresh_token: refresh_token.clone(),
33        access_digest: token_digest_sha256(&access_token, key),
34        refresh_digest: token_digest_sha256(&refresh_token, key),
35    }
36}
37
38pub struct TokenGrantRequest<'a> {
39    pub grant: &'a TokenGrant,
40    pub client: &'a OAuthClient,
41    pub token_pair: TokenPair,
42    pub access_expires_at: DateTime<Utc>,
43    pub refresh_expires_at: DateTime<Utc>,
44    pub issued_token_type: TokenType,
45    pub dpop_jkt: Option<&'a str>,
46    pub token_hmac_key: &'a str,
47}
48
49pub struct TokenGrantResult {
50    pub user_id: Uuid,
51    pub scopes: Vec<String>,
52    pub nonce: Option<String>,
53    pub access_expires_at: DateTime<Utc>,
54    pub issue_id_token: bool,
55}
56
57pub async fn process_token_grant(
58    conn: &mut PgConnection,
59    request: TokenGrantRequest<'_>,
60) -> Result<TokenGrantResult, TokenGrantError> {
61    let mut tx = conn
62        .begin()
63        .await
64        .map_err(|e| TokenGrantError::ServerError(format!("Failed to start transaction: {}", e)))?;
65
66    let result = match request.grant {
67        TokenGrant::AuthorizationCode {
68            code,
69            redirect_uri,
70            code_verifier,
71        } => {
72            let code_digest = token_digest_sha256(code, request.token_hmac_key);
73            // Consume with client_id check in WHERE clause to prevent DoS attacks
74            let code_row = if let Some(ref_uri) = redirect_uri {
75                OAuthAuthCode::consume_with_redirect_in_transaction(
76                    &mut tx,
77                    code_digest,
78                    request.client.id,
79                    ref_uri,
80                )
81                .await
82                .map_err(|e| {
83                    tracing::warn!(
84                        err = %e,
85                        "OAuth token: auth code consume failed (redirect_uri check); possible causes: code already used, wrong redirect_uri, expired, or wrong client"
86                    );
87                    TokenGrantError::InvalidGrant("Given grant is invalid".to_string())
88                })?
89            } else {
90                OAuthAuthCode::consume_in_transaction(&mut tx, code_digest, request.client.id)
91                    .await
92                    .map_err(|e| {
93                        tracing::warn!(
94                            err = %e,
95                            "OAuth token: auth code consume failed; possible causes: code already used, expired, or wrong client"
96                        );
97                        TokenGrantError::InvalidGrant("Given grant is invalid".to_string())
98                    })?
99            };
100
101            // PKCE verification happens after client_id check (enforced in SQL)
102            verify_token_pkce(
103                request.client,
104                code_row.code_challenge.as_deref(),
105                code_row.code_challenge_method,
106                code_verifier.as_deref(),
107            )
108            .map_err(|_| TokenGrantError::PkceVerificationFailed)?;
109
110            OAuthRefreshTokens::issue_tokens_from_auth_code_in_transaction(
111                &mut tx,
112                IssueTokensFromAuthCodeParams {
113                    user_id: code_row.user_id,
114                    client_id: code_row.client_id,
115                    scopes: &code_row.scopes,
116                    access_token_digest: &request.token_pair.access_digest,
117                    refresh_token_digest: &request.token_pair.refresh_digest,
118                    access_token_expires_at: request.access_expires_at,
119                    refresh_token_expires_at: request.refresh_expires_at,
120                    access_token_type: request.issued_token_type,
121                    access_token_dpop_jkt: request.dpop_jkt,
122                    refresh_token_dpop_jkt: request.dpop_jkt,
123                },
124            )
125            .await
126            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
127
128            // Determine if ID token should be issued based on presence of "openid" scope
129            let has_openid = code_row.scopes.iter().any(|s| s == "openid");
130
131            Ok(TokenGrantResult {
132                user_id: code_row.user_id,
133                scopes: code_row.scopes,
134                nonce: code_row.nonce.clone(),
135                access_expires_at: request.access_expires_at,
136                issue_id_token: has_openid,
137            })
138        }
139        TokenGrant::RefreshToken { refresh_token, .. } => {
140            let presented = token_digest_sha256(refresh_token, request.token_hmac_key);
141            // Consume with client_id check in WHERE clause to prevent DoS attacks
142            let old =
143                OAuthRefreshTokens::consume_in_transaction(&mut tx, presented, request.client.id)
144                    .await
145                    .map_err(|e| TokenGrantError::InvalidGrant(format!("{}", e)))?;
146
147            if let Some(expected_jkt) = old.dpop_jkt.as_deref() {
148                let presented_jkt = request.dpop_jkt.ok_or_else(|| {
149                    TokenGrantError::InvalidClient(
150                        "missing DPoP header for sender-constrained refresh".into(),
151                    )
152                })?;
153                if presented_jkt != expected_jkt {
154                    return Err(TokenGrantError::DpopMismatch);
155                }
156            }
157
158            let refresh_issue_type = if old.dpop_jkt.is_some() {
159                TokenType::DPoP
160            } else {
161                request.issued_token_type
162            };
163            let at_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
164            let refresh_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
165
166            OAuthRefreshTokens::complete_refresh_token_rotation_in_transaction(
167                &mut tx,
168                &old,
169                RotateRefreshTokenParams {
170                    new_refresh_token_digest: &request.token_pair.refresh_digest,
171                    new_access_token_digest: &request.token_pair.access_digest,
172                    access_token_expires_at: request.access_expires_at,
173                    refresh_token_expires_at: request.refresh_expires_at,
174                    access_token_type: refresh_issue_type,
175                    access_token_dpop_jkt: at_jkt,
176                    refresh_token_dpop_jkt: refresh_jkt,
177                },
178            )
179            .await
180            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
181
182            Ok(TokenGrantResult {
183                user_id: old.user_id,
184                scopes: old.scopes.clone(),
185                nonce: None,
186                access_expires_at: request.access_expires_at,
187                issue_id_token: false,
188            })
189        }
190        TokenGrant::Unknown => Err(TokenGrantError::UnsupportedGrantType),
191    };
192
193    match result {
194        Ok(res) => {
195            tx.commit().await.map_err(|e| {
196                TokenGrantError::ServerError(format!("Failed to commit transaction: {}", e))
197            })?;
198            Ok(res)
199        }
200        Err(e) => {
201            // Transaction will be rolled back on drop
202            Err(e)
203        }
204    }
205}