Skip to main content

headless_lms_models/
oauth_access_token.rs

1use crate::library::oauth::Digest;
2use crate::prelude::*;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use sqlx::{self, FromRow, PgConnection, Type};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
9#[sqlx(type_name = "token_type")]
10pub enum TokenType {
11    Bearer,
12    DPoP,
13}
14
15/// **INTERNAL/DATABASE-ONLY MODEL - DO NOT EXPOSE TO CLIENTS**
16///
17/// This struct is a database model that contains a `Digest` field, which contains raw bytes
18/// and uses custom (de)serialization. This model must **never** be serialized into external
19/// API payloads or returned directly to clients.
20///
21/// For external-facing responses, use DTOs such as `TokenResponse`, `UserInfoResponse`, or
22/// an explicit redacting wrapper that strips or converts `Digest` fields to safe types (e.g., strings).
23///
24/// **Rationale**: The `Digest` type contains sensitive raw bytes and uses custom serialization
25/// that is not suitable for external APIs. Exposing this model directly could leak internal
26/// implementation details or cause serialization issues.
27#[derive(Debug, Serialize, Deserialize, FromRow)]
28pub struct OAuthAccessToken {
29    pub digest: Digest,
30    pub user_id: Option<Uuid>,
31    pub client_id: Uuid,
32    pub scopes: Vec<String>,
33    pub audience: Option<Vec<String>>,
34    pub jti: Uuid,
35
36    /// Sender constraint: present only when `token_type = DPoP`
37    pub dpop_jkt: Option<String>,
38
39    pub token_type: TokenType,
40
41    pub metadata: serde_json::Value,
42    pub expires_at: DateTime<Utc>,
43    pub created_at: DateTime<Utc>,
44    pub updated_at: DateTime<Utc>,
45}
46
47#[derive(Debug, Clone)]
48pub struct NewAccessTokenParams<'a> {
49    pub digest: &'a Digest,
50    pub user_id: Option<Uuid>,
51    pub client_id: Uuid,
52    pub scopes: &'a [String],
53    pub audience: Option<&'a [String]>,
54
55    /// Set to `TokenType::Bearer` **and** leave `dpop_jkt` = None for Bearer tokens.
56    /// Set to `TokenType::DPoP` **and** provide `dpop_jkt = Some(...)` for DPoP tokens.
57    pub token_type: TokenType,
58    pub dpop_jkt: Option<&'a str>,
59
60    pub metadata: serde_json::Map<String, serde_json::Value>,
61    pub expires_at: DateTime<Utc>,
62}
63
64impl OAuthAccessToken {
65    /// Insert a new access token (with jti).
66    ///
67    /// DB constraint requires:
68    ///  - Bearer  => dpop_jkt = NULL
69    ///  - DPoP    => dpop_jkt IS NOT NULL
70    pub async fn insert(
71        conn: &mut PgConnection,
72        params: NewAccessTokenParams<'_>,
73    ) -> ModelResult<()> {
74        match (params.token_type, params.dpop_jkt) {
75            (TokenType::Bearer, None) => {}
76            (TokenType::Bearer, Some(_)) => {
77                return Err(ModelError::new(
78                    ModelErrorType::InvalidRequest,
79                    "Bearer tokens must not include dpop_jkt",
80                    None::<anyhow::Error>,
81                ));
82            }
83            (TokenType::DPoP, Some(_)) => {}
84            (TokenType::DPoP, None) => {
85                return Err(ModelError::new(
86                    ModelErrorType::InvalidRequest,
87                    "DPoP tokens must include dpop_jkt",
88                    None::<anyhow::Error>,
89                ));
90            }
91        }
92
93        sqlx::query!(
94            r#"
95            INSERT INTO oauth_access_tokens
96              (digest, user_id, client_id, scopes, audience, token_type, dpop_jkt, metadata, expires_at)
97            VALUES
98              ($1,    $2,     $3,       $4,     $5,       $6,         $7,       $8,       $9)
99            "#,
100            params.digest.as_bytes(),
101            params.user_id,
102            params.client_id,
103            params.scopes,
104            params.audience,
105            params.token_type as TokenType,
106            params.dpop_jkt,
107            serde_json::Value::Object(params.metadata),
108            params.expires_at
109        )
110        .execute(conn)
111        .await?;
112        Ok(())
113    }
114
115    /// Find a still-valid token by digest (no sender enforcement).
116    pub async fn find_valid(
117        conn: &mut PgConnection,
118        digest: Digest,
119    ) -> ModelResult<OAuthAccessToken> {
120        let token = sqlx::query_as!(
121            OAuthAccessToken,
122            r#"
123            SELECT *
124            FROM oauth_access_tokens
125            WHERE digest = $1 AND expires_at > now()
126            "#,
127            digest.as_bytes()
128        )
129        .fetch_one(conn)
130        .await?;
131        Ok(token)
132    }
133
134    /// Find a still-valid token by digest and enforce sender:
135    ///  - DPoP => `dpop_jkt` must match `sender_jkt`
136    ///  - Bearer => sender is ignored
137    pub async fn find_valid_for_sender(
138        conn: &mut PgConnection,
139        digest: Digest,
140        sender_jkt: Option<&str>,
141    ) -> ModelResult<OAuthAccessToken> {
142        let t = Self::find_valid(conn, digest).await?;
143
144        match t.token_type {
145            TokenType::Bearer => Ok(t),
146            TokenType::DPoP => {
147                let Some(expected) = t.dpop_jkt.as_deref() else {
148                    return Err(ModelError::new(
149                        ModelErrorType::PreconditionFailed,
150                        "token missing dpop_jkt",
151                        None::<anyhow::Error>,
152                    ));
153                };
154                let Some(presented) = sender_jkt else {
155                    return Err(ModelError::new(
156                        ModelErrorType::PreconditionFailed,
157                        "DPoP proof missing JKT",
158                        None::<anyhow::Error>,
159                    ));
160                };
161                if expected != presented {
162                    return Err(ModelError::new(
163                        ModelErrorType::PreconditionFailed,
164                        "DPoP JKT mismatch",
165                        None::<anyhow::Error>,
166                    ));
167                }
168                Ok(t)
169            }
170        }
171    }
172
173    pub async fn delete_all_by_user_client(
174        conn: &mut PgConnection,
175        user_id: Uuid,
176        client_id: Uuid,
177    ) -> ModelResult<()> {
178        let mut tx = conn.begin().await?;
179        sqlx::query!(
180            r#"
181            DELETE FROM oauth_access_tokens
182            WHERE user_id = $1 AND client_id = $2
183            "#,
184            user_id,
185            client_id
186        )
187        .execute(&mut *tx)
188        .await?;
189        tx.commit().await?;
190        Ok(())
191    }
192
193    /// Revoke (delete) an access token by its digest.
194    ///
195    /// This method is used for the OAuth 2.0 token revocation endpoint (RFC 7009).
196    /// Access tokens are deleted rather than marked as revoked since they are short-lived.
197    pub async fn revoke_by_digest(conn: &mut PgConnection, digest: Digest) -> ModelResult<()> {
198        sqlx::query!(
199            r#"
200            DELETE FROM oauth_access_tokens
201            WHERE digest = $1
202            "#,
203            digest.as_bytes()
204        )
205        .execute(conn)
206        .await?;
207        Ok(())
208    }
209}