Skip to main content

headless_lms_server/domain/oauth/
token_service.rs

1use chrono::{DateTime, Utc};
2use secrecy::{ExposeSecret, SecretString};
3use sqlx::{Connection, PgConnection};
4use uuid::Uuid;
5
6use crate::domain::oauth::errors::TokenGrantError;
7use crate::domain::oauth::pkce::verify_token_pkce;
8use headless_lms_models::library::oauth::Digest;
9use headless_lms_models::library::oauth::tokens::token_digest_sha256;
10use headless_lms_models::oauth_access_token::TokenType;
11use headless_lms_models::oauth_auth_code::OAuthAuthCode;
12use headless_lms_models::oauth_client::OAuthClient;
13use headless_lms_models::oauth_refresh_tokens::{
14    IssueTokensFromAuthCodeParams, OAuthRefreshTokens, RotateRefreshTokenParams,
15};
16
17use super::token_query::TokenGrant;
18
19/// A pair of access and refresh tokens with their digests.
20pub struct TokenPair {
21    pub access_token: String,
22    pub refresh_token: String,
23    pub access_digest: Digest,
24    pub refresh_digest: Digest,
25}
26
27/// Generate a new token pair (access token and refresh token) with their digests.
28pub fn generate_token_pair(key: &SecretString) -> TokenPair {
29    let access_token = headless_lms_models::library::oauth::tokens::generate_access_token();
30    let refresh_token = headless_lms_models::library::oauth::tokens::generate_access_token();
31    TokenPair {
32        access_token: access_token.clone(),
33        refresh_token: refresh_token.clone(),
34        access_digest: token_digest_sha256(&access_token, key),
35        refresh_digest: token_digest_sha256(&refresh_token, key),
36    }
37}
38
39pub struct TokenGrantRequest<'a> {
40    pub grant: &'a TokenGrant,
41    pub client: &'a OAuthClient,
42    pub token_pair: TokenPair,
43    pub access_expires_at: DateTime<Utc>,
44    pub refresh_expires_at: DateTime<Utc>,
45    pub issued_token_type: TokenType,
46    pub dpop_jkt: Option<&'a str>,
47    pub token_hmac_key: &'a SecretString,
48}
49
50pub struct TokenGrantResult {
51    pub user_id: Uuid,
52    pub scopes: Vec<String>,
53    pub nonce: Option<String>,
54    pub access_expires_at: DateTime<Utc>,
55    pub issue_id_token: bool,
56}
57
58pub async fn process_token_grant(
59    conn: &mut PgConnection,
60    request: TokenGrantRequest<'_>,
61) -> Result<TokenGrantResult, TokenGrantError> {
62    let mut tx = conn
63        .begin()
64        .await
65        .map_err(|e| TokenGrantError::ServerError(format!("Failed to start transaction: {}", e)))?;
66
67    let result = match request.grant {
68        TokenGrant::AuthorizationCode {
69            code,
70            redirect_uri,
71            code_verifier,
72        } => {
73            let code_digest = token_digest_sha256(code.expose_secret(), request.token_hmac_key);
74            // Consume with client_id check in WHERE clause to prevent DoS attacks
75            let code_row = if let Some(ref_uri) = redirect_uri {
76                OAuthAuthCode::consume_with_redirect_in_transaction(
77                    &mut tx,
78                    code_digest,
79                    request.client.id,
80                    ref_uri,
81                )
82                .await
83                .map_err(|e| {
84                    tracing::warn!(
85                        err = %e,
86                        "OAuth token: auth code consume failed (redirect_uri check); possible causes: code already used, wrong redirect_uri, expired, or wrong client"
87                    );
88                    TokenGrantError::InvalidGrant("Given grant is invalid".to_string())
89                })?
90            } else {
91                OAuthAuthCode::consume_in_transaction(&mut tx, code_digest, request.client.id)
92                    .await
93                    .map_err(|e| {
94                        tracing::warn!(
95                            err = %e,
96                            "OAuth token: auth code consume failed; possible causes: code already used, expired, or wrong client"
97                        );
98                        TokenGrantError::InvalidGrant("Given grant is invalid".to_string())
99                    })?
100            };
101
102            // PKCE verification happens after client_id check (enforced in SQL)
103            verify_token_pkce(
104                request.client,
105                code_row.code_challenge.as_deref(),
106                code_row.code_challenge_method,
107                code_verifier.as_ref().map(|v| v.expose_secret()),
108            )
109            .map_err(|_| TokenGrantError::PkceVerificationFailed)?;
110
111            OAuthRefreshTokens::issue_tokens_from_auth_code_in_transaction(
112                &mut tx,
113                IssueTokensFromAuthCodeParams {
114                    user_id: code_row.user_id,
115                    client_id: code_row.client_id,
116                    scopes: &code_row.scopes,
117                    access_token_digest: &request.token_pair.access_digest,
118                    refresh_token_digest: &request.token_pair.refresh_digest,
119                    access_token_expires_at: request.access_expires_at,
120                    refresh_token_expires_at: request.refresh_expires_at,
121                    access_token_type: request.issued_token_type,
122                    access_token_dpop_jkt: request.dpop_jkt,
123                    refresh_token_dpop_jkt: request.dpop_jkt,
124                },
125            )
126            .await
127            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
128
129            // Determine if ID token should be issued based on presence of "openid" scope
130            let has_openid = code_row.scopes.iter().any(|s| s == "openid");
131
132            Ok(TokenGrantResult {
133                user_id: code_row.user_id,
134                scopes: code_row.scopes,
135                nonce: code_row.nonce.clone(),
136                access_expires_at: request.access_expires_at,
137                issue_id_token: has_openid,
138            })
139        }
140        TokenGrant::RefreshToken { refresh_token, .. } => {
141            let presented =
142                token_digest_sha256(refresh_token.expose_secret(), request.token_hmac_key);
143            // Consume with client_id check in WHERE clause to prevent DoS attacks
144            let old =
145                OAuthRefreshTokens::consume_in_transaction(&mut tx, presented, request.client.id)
146                    .await
147                    .map_err(|e| TokenGrantError::InvalidGrant(format!("{}", e)))?;
148
149            if let Some(expected_jkt) = old.dpop_jkt.as_deref() {
150                let presented_jkt = request.dpop_jkt.ok_or_else(|| {
151                    TokenGrantError::InvalidClient(
152                        "missing DPoP header for sender-constrained refresh".into(),
153                    )
154                })?;
155                if presented_jkt != expected_jkt {
156                    return Err(TokenGrantError::DpopMismatch);
157                }
158            }
159
160            let refresh_issue_type = if old.dpop_jkt.is_some() {
161                TokenType::DPoP
162            } else {
163                request.issued_token_type
164            };
165            let at_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
166            let refresh_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
167
168            OAuthRefreshTokens::complete_refresh_token_rotation_in_transaction(
169                &mut tx,
170                &old,
171                RotateRefreshTokenParams {
172                    new_refresh_token_digest: &request.token_pair.refresh_digest,
173                    new_access_token_digest: &request.token_pair.access_digest,
174                    access_token_expires_at: request.access_expires_at,
175                    refresh_token_expires_at: request.refresh_expires_at,
176                    access_token_type: refresh_issue_type,
177                    access_token_dpop_jkt: at_jkt,
178                    refresh_token_dpop_jkt: refresh_jkt,
179                },
180            )
181            .await
182            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
183
184            Ok(TokenGrantResult {
185                user_id: old.user_id,
186                scopes: old.scopes.clone(),
187                nonce: None,
188                access_expires_at: request.access_expires_at,
189                issue_id_token: false,
190            })
191        }
192        TokenGrant::Unknown => Err(TokenGrantError::UnsupportedGrantType),
193    };
194
195    match result {
196        Ok(res) => {
197            tx.commit().await.map_err(|e| {
198                TokenGrantError::ServerError(format!("Failed to commit transaction: {}", e))
199            })?;
200            Ok(res)
201        }
202        Err(e) => {
203            // Transaction will be rolled back on drop
204            Err(e)
205        }
206    }
207}