1use crate::{library::oauth::Digest, prelude::*};
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use sqlx::{FromRow, PgConnection};
5use uuid::Uuid;
6
7use crate::library::oauth::pkce::PkceMethod;
8
9#[derive(Debug, Serialize, Deserialize, FromRow)]
22pub struct OAuthAuthCode {
23 pub digest: Digest,
24 pub user_id: Uuid,
25 pub client_id: Uuid,
26 pub redirect_uri: String,
27 pub scopes: Vec<String>,
28 pub jti: Uuid,
29 pub nonce: Option<String>,
30
31 pub code_challenge: Option<String>,
32 pub code_challenge_method: Option<PkceMethod>,
33
34 pub dpop_jkt: Option<String>,
35
36 pub used: bool,
37 pub expires_at: DateTime<Utc>,
38 pub created_at: DateTime<Utc>,
39 pub updated_at: DateTime<Utc>,
40 pub metadata: serde_json::Value,
41}
42
43#[derive(Debug, Clone)]
44pub struct NewAuthCodeParams<'a> {
45 pub digest: &'a Digest,
46 pub user_id: Uuid,
47 pub client_id: Uuid,
48 pub redirect_uri: &'a str,
49 pub scopes: &'a [String],
50 pub nonce: Option<&'a str>,
51
52 pub code_challenge: Option<&'a str>,
53 pub code_challenge_method: Option<PkceMethod>,
54
55 pub dpop_jkt: Option<&'a str>,
56
57 pub expires_at: DateTime<Utc>,
58 pub metadata: serde_json::Map<String, serde_json::Value>,
59}
60
61impl<'a> NewAuthCodeParams<'a> {
62 pub fn validate(&self) -> ModelResult<()> {
63 match (self.code_challenge, self.code_challenge_method) {
65 (Some(_), Some(_)) | (None, None) => {}
66 _ => {
67 return Err(ModelError::new(
68 ModelErrorType::InvalidRequest,
69 "PKCE: code_challenge and code_challenge_method must be provided together",
70 None::<anyhow::Error>,
71 ));
72 }
73 }
74 Ok(())
75 }
76}
77
78impl OAuthAuthCode {
79 pub async fn insert(conn: &mut PgConnection, params: NewAuthCodeParams<'_>) -> ModelResult<()> {
80 params.validate()?;
81
82 sqlx::query!(
83 r#"
84 INSERT INTO oauth_auth_codes (
85 digest,
86 user_id,
87 client_id,
88 redirect_uri,
89 scopes,
90 nonce,
91 code_challenge,
92 code_challenge_method,
93 dpop_jkt,
94 expires_at,
95 metadata
96 )
97 VALUES (
98 $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11
99 )
100 "#,
101 params.digest.as_bytes(),
102 params.user_id,
103 params.client_id,
104 params.redirect_uri,
105 params.scopes,
106 params.nonce,
107 params.code_challenge,
108 params.code_challenge_method as Option<PkceMethod>,
110 params.dpop_jkt,
111 params.expires_at,
112 serde_json::Value::Object(params.metadata)
113 )
114 .execute(conn)
115 .await?;
116
117 Ok(())
118 }
119
120 pub async fn consume_in_transaction(
128 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
129 digest: Digest,
130 client_id: Uuid,
131 ) -> ModelResult<OAuthAuthCode> {
132 let auth_code = sqlx::query_as!(
133 OAuthAuthCode,
134 r#"
135 UPDATE oauth_auth_codes
136 SET used = true
137 WHERE digest = $1
138 AND client_id = $2
139 AND used = false
140 AND expires_at > now()
141 RETURNING *
142 "#,
143 digest.as_bytes(),
144 client_id
145 )
146 .fetch_one(&mut **tx)
147 .await?;
148
149 Ok(auth_code)
150 }
151
152 pub async fn consume_with_redirect_in_transaction(
160 tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
161 digest: Digest,
162 client_id: Uuid,
163 redirect_uri: &str,
164 ) -> ModelResult<OAuthAuthCode> {
165 let auth_code = sqlx::query_as!(
166 OAuthAuthCode,
167 r#"
168 UPDATE oauth_auth_codes
169 SET used = true
170 WHERE digest = $1
171 AND client_id = $2
172 AND redirect_uri = $3
173 AND used = false
174 AND expires_at > now()
175 RETURNING *
176 "#,
177 digest.as_bytes(),
178 client_id,
179 redirect_uri
180 )
181 .fetch_one(&mut **tx)
182 .await?;
183
184 Ok(auth_code)
185 }
186}