1use crate::library::oauth::Digest;
2use crate::prelude::*;
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use sqlx::{self, FromRow, PgConnection, Type};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
9#[sqlx(type_name = "token_type")]
10pub enum TokenType {
11 Bearer,
12 DPoP,
13}
14
15#[derive(Debug, Serialize, Deserialize, FromRow)]
28pub struct OAuthAccessToken {
29 pub digest: Digest,
30 pub user_id: Option<Uuid>,
31 pub client_id: Uuid,
32 pub scopes: Vec<String>,
33 pub audience: Option<Vec<String>>,
34 pub jti: Uuid,
35
36 pub dpop_jkt: Option<String>,
38
39 pub token_type: TokenType,
40
41 pub metadata: serde_json::Value,
42 pub expires_at: DateTime<Utc>,
43 pub created_at: DateTime<Utc>,
44 pub updated_at: DateTime<Utc>,
45}
46
47#[derive(Debug, Clone)]
48pub struct NewAccessTokenParams<'a> {
49 pub digest: &'a Digest,
50 pub user_id: Option<Uuid>,
51 pub client_id: Uuid,
52 pub scopes: &'a [String],
53 pub audience: Option<&'a [String]>,
54
55 pub token_type: TokenType,
58 pub dpop_jkt: Option<&'a str>,
59
60 pub metadata: serde_json::Map<String, serde_json::Value>,
61 pub expires_at: DateTime<Utc>,
62}
63
64impl OAuthAccessToken {
65 pub async fn insert(
71 conn: &mut PgConnection,
72 params: NewAccessTokenParams<'_>,
73 ) -> ModelResult<()> {
74 match (params.token_type, params.dpop_jkt) {
75 (TokenType::Bearer, None) => {}
76 (TokenType::Bearer, Some(_)) => {
77 return Err(ModelError::new(
78 ModelErrorType::InvalidRequest,
79 "Bearer tokens must not include dpop_jkt",
80 None::<anyhow::Error>,
81 ));
82 }
83 (TokenType::DPoP, Some(_)) => {}
84 (TokenType::DPoP, None) => {
85 return Err(ModelError::new(
86 ModelErrorType::InvalidRequest,
87 "DPoP tokens must include dpop_jkt",
88 None::<anyhow::Error>,
89 ));
90 }
91 }
92
93 sqlx::query!(
94 r#"
95 INSERT INTO oauth_access_tokens
96 (digest, user_id, client_id, scopes, audience, token_type, dpop_jkt, metadata, expires_at)
97 VALUES
98 ($1, $2, $3, $4, $5, $6, $7, $8, $9)
99 "#,
100 params.digest.as_bytes(),
101 params.user_id,
102 params.client_id,
103 params.scopes,
104 params.audience,
105 params.token_type as TokenType,
106 params.dpop_jkt,
107 serde_json::Value::Object(params.metadata),
108 params.expires_at
109 )
110 .execute(conn)
111 .await?;
112 Ok(())
113 }
114
115 pub async fn find_valid(
117 conn: &mut PgConnection,
118 digest: Digest,
119 ) -> ModelResult<OAuthAccessToken> {
120 let token = sqlx::query_as!(
121 OAuthAccessToken,
122 r#"
123 SELECT *
124 FROM oauth_access_tokens
125 WHERE digest = $1 AND expires_at > now()
126 "#,
127 digest.as_bytes()
128 )
129 .fetch_one(conn)
130 .await?;
131 Ok(token)
132 }
133
134 pub async fn find_valid_for_sender(
138 conn: &mut PgConnection,
139 digest: Digest,
140 sender_jkt: Option<&str>,
141 ) -> ModelResult<OAuthAccessToken> {
142 let t = Self::find_valid(conn, digest).await?;
143
144 match t.token_type {
145 TokenType::Bearer => Ok(t),
146 TokenType::DPoP => {
147 let Some(expected) = t.dpop_jkt.as_deref() else {
148 return Err(ModelError::new(
149 ModelErrorType::PreconditionFailed,
150 "token missing dpop_jkt",
151 None::<anyhow::Error>,
152 ));
153 };
154 let Some(presented) = sender_jkt else {
155 return Err(ModelError::new(
156 ModelErrorType::PreconditionFailed,
157 "DPoP proof missing JKT",
158 None::<anyhow::Error>,
159 ));
160 };
161 if expected != presented {
162 return Err(ModelError::new(
163 ModelErrorType::PreconditionFailed,
164 "DPoP JKT mismatch",
165 None::<anyhow::Error>,
166 ));
167 }
168 Ok(t)
169 }
170 }
171 }
172
173 pub async fn delete_all_by_user_client(
174 conn: &mut PgConnection,
175 user_id: Uuid,
176 client_id: Uuid,
177 ) -> ModelResult<()> {
178 let mut tx = conn.begin().await?;
179 sqlx::query!(
180 r#"
181 DELETE FROM oauth_access_tokens
182 WHERE user_id = $1 AND client_id = $2
183 "#,
184 user_id,
185 client_id
186 )
187 .execute(&mut *tx)
188 .await?;
189 tx.commit().await?;
190 Ok(())
191 }
192
193 pub async fn revoke_by_digest(conn: &mut PgConnection, digest: Digest) -> ModelResult<()> {
198 sqlx::query!(
199 r#"
200 DELETE FROM oauth_access_tokens
201 WHERE digest = $1
202 "#,
203 digest.as_bytes()
204 )
205 .execute(conn)
206 .await?;
207 Ok(())
208 }
209}