headless_lms_server/domain/oauth/
token_service.rs

1use chrono::{DateTime, Utc};
2use sqlx::{Connection, PgConnection};
3use uuid::Uuid;
4
5use crate::domain::oauth::errors::TokenGrantError;
6use crate::domain::oauth::pkce::verify_token_pkce;
7use headless_lms_models::library::oauth::Digest;
8use headless_lms_models::library::oauth::tokens::token_digest_sha256;
9use headless_lms_models::oauth_access_token::TokenType;
10use headless_lms_models::oauth_auth_code::OAuthAuthCode;
11use headless_lms_models::oauth_client::OAuthClient;
12use headless_lms_models::oauth_refresh_tokens::{
13    IssueTokensFromAuthCodeParams, OAuthRefreshTokens, RotateRefreshTokenParams,
14};
15
16use super::token_query::TokenGrant;
17
18/// A pair of access and refresh tokens with their digests.
19pub struct TokenPair {
20    pub access_token: String,
21    pub refresh_token: String,
22    pub access_digest: Digest,
23    pub refresh_digest: Digest,
24}
25
26/// Generate a new token pair (access token and refresh token) with their digests.
27pub fn generate_token_pair(key: &str) -> TokenPair {
28    let access_token = headless_lms_models::library::oauth::tokens::generate_access_token();
29    let refresh_token = headless_lms_models::library::oauth::tokens::generate_access_token();
30    TokenPair {
31        access_token: access_token.clone(),
32        refresh_token: refresh_token.clone(),
33        access_digest: token_digest_sha256(&access_token, key),
34        refresh_digest: token_digest_sha256(&refresh_token, key),
35    }
36}
37
38pub struct TokenGrantRequest<'a> {
39    pub grant: &'a TokenGrant,
40    pub client: &'a OAuthClient,
41    pub token_pair: TokenPair,
42    pub access_expires_at: DateTime<Utc>,
43    pub refresh_expires_at: DateTime<Utc>,
44    pub issued_token_type: TokenType,
45    pub dpop_jkt: Option<&'a str>,
46    pub token_hmac_key: &'a str,
47}
48
49pub struct TokenGrantResult {
50    pub user_id: Uuid,
51    pub scopes: Vec<String>,
52    pub nonce: Option<String>,
53    pub access_expires_at: DateTime<Utc>,
54    pub issue_id_token: bool,
55}
56
57pub async fn process_token_grant(
58    conn: &mut PgConnection,
59    request: TokenGrantRequest<'_>,
60) -> Result<TokenGrantResult, TokenGrantError> {
61    let mut tx = conn
62        .begin()
63        .await
64        .map_err(|e| TokenGrantError::ServerError(format!("Failed to start transaction: {}", e)))?;
65
66    let result = match request.grant {
67        TokenGrant::AuthorizationCode {
68            code,
69            redirect_uri,
70            code_verifier,
71        } => {
72            let code_digest = token_digest_sha256(code, request.token_hmac_key);
73            // Consume with client_id check in WHERE clause to prevent DoS attacks
74            let code_row = if let Some(ref_uri) = redirect_uri {
75                OAuthAuthCode::consume_with_redirect_in_transaction(
76                    &mut tx,
77                    code_digest,
78                    request.client.id,
79                    ref_uri,
80                )
81                .await
82                .map_err(|_| TokenGrantError::InvalidGrant("Given grant is invalid".to_string()))?
83            } else {
84                OAuthAuthCode::consume_in_transaction(&mut tx, code_digest, request.client.id)
85                    .await
86                    .map_err(|_| {
87                        TokenGrantError::InvalidGrant("Given grant is invalid".to_string())
88                    })?
89            };
90
91            // PKCE verification happens after client_id check (enforced in SQL)
92            verify_token_pkce(
93                request.client,
94                code_row.code_challenge.as_deref(),
95                code_row.code_challenge_method,
96                code_verifier.as_deref(),
97            )
98            .map_err(|_| TokenGrantError::PkceVerificationFailed)?;
99
100            OAuthRefreshTokens::issue_tokens_from_auth_code_in_transaction(
101                &mut tx,
102                IssueTokensFromAuthCodeParams {
103                    user_id: code_row.user_id,
104                    client_id: code_row.client_id,
105                    scopes: &code_row.scopes,
106                    access_token_digest: &request.token_pair.access_digest,
107                    refresh_token_digest: &request.token_pair.refresh_digest,
108                    access_token_expires_at: request.access_expires_at,
109                    refresh_token_expires_at: request.refresh_expires_at,
110                    access_token_type: request.issued_token_type,
111                    access_token_dpop_jkt: request.dpop_jkt,
112                    refresh_token_dpop_jkt: request.dpop_jkt,
113                },
114            )
115            .await
116            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
117
118            // Determine if ID token should be issued based on presence of "openid" scope
119            let has_openid = code_row.scopes.iter().any(|s| s == "openid");
120
121            Ok(TokenGrantResult {
122                user_id: code_row.user_id,
123                scopes: code_row.scopes,
124                nonce: code_row.nonce.clone(),
125                access_expires_at: request.access_expires_at,
126                issue_id_token: has_openid,
127            })
128        }
129        TokenGrant::RefreshToken { refresh_token, .. } => {
130            let presented = token_digest_sha256(refresh_token, request.token_hmac_key);
131            // Consume with client_id check in WHERE clause to prevent DoS attacks
132            let old =
133                OAuthRefreshTokens::consume_in_transaction(&mut tx, presented, request.client.id)
134                    .await
135                    .map_err(|e| TokenGrantError::InvalidGrant(format!("{}", e)))?;
136
137            if let Some(expected_jkt) = old.dpop_jkt.as_deref() {
138                let presented_jkt = request.dpop_jkt.ok_or_else(|| {
139                    TokenGrantError::InvalidClient(
140                        "missing DPoP header for sender-constrained refresh".into(),
141                    )
142                })?;
143                if presented_jkt != expected_jkt {
144                    return Err(TokenGrantError::DpopMismatch);
145                }
146            }
147
148            let refresh_issue_type = if old.dpop_jkt.is_some() {
149                TokenType::DPoP
150            } else {
151                request.issued_token_type
152            };
153            let at_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
154            let refresh_jkt = old.dpop_jkt.as_deref().or(request.dpop_jkt);
155
156            OAuthRefreshTokens::complete_refresh_token_rotation_in_transaction(
157                &mut tx,
158                &old,
159                RotateRefreshTokenParams {
160                    new_refresh_token_digest: &request.token_pair.refresh_digest,
161                    new_access_token_digest: &request.token_pair.access_digest,
162                    access_token_expires_at: request.access_expires_at,
163                    refresh_token_expires_at: request.refresh_expires_at,
164                    access_token_type: refresh_issue_type,
165                    access_token_dpop_jkt: at_jkt,
166                    refresh_token_dpop_jkt: refresh_jkt,
167                },
168            )
169            .await
170            .map_err(|e| TokenGrantError::ServerError(format!("{}", e)))?;
171
172            Ok(TokenGrantResult {
173                user_id: old.user_id,
174                scopes: old.scopes.clone(),
175                nonce: None,
176                access_expires_at: request.access_expires_at,
177                issue_id_token: false,
178            })
179        }
180        TokenGrant::Unknown => Err(TokenGrantError::UnsupportedGrantType),
181    };
182
183    match result {
184        Ok(res) => {
185            tx.commit().await.map_err(|e| {
186                TokenGrantError::ServerError(format!("Failed to commit transaction: {}", e))
187            })?;
188            Ok(res)
189        }
190        Err(e) => {
191            // Transaction will be rolled back on drop
192            Err(e)
193        }
194    }
195}