1use crate::{
7 library::oauth::{Digest, GrantTypeName, pkce::PkceMethod},
8 prelude::*,
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use sqlx::{FromRow, PgConnection, Type};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
16#[sqlx(type_name = "token_endpoint_auth_method", rename_all = "snake_case")]
17#[serde(rename_all = "snake_case")]
18pub enum TokenEndpointAuthMethod {
19 None,
20 ClientSecretPost,
21}
22
23impl TokenEndpointAuthMethod {
24 pub fn is_public(self) -> bool {
25 matches!(self, Self::None)
26 }
27
28 pub fn is_confidential(self) -> bool {
29 !self.is_public()
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
34#[sqlx(type_name = "application_type", rename_all = "snake_case")]
35#[serde(rename_all = "snake_case")]
36pub enum ApplicationType {
37 Web,
38 Native,
39 Spa,
40 Service,
41}
42
43#[derive(Debug, Serialize, Deserialize, FromRow)]
44pub struct OAuthClient {
45 pub id: Uuid,
46 pub client_id: String,
47 pub client_name: String,
48 pub application_type: ApplicationType,
49
50 pub token_endpoint_auth_method: TokenEndpointAuthMethod,
51
52 pub client_secret: Option<Digest>,
54 pub client_secret_expires_at: Option<DateTime<Utc>>,
55
56 pub redirect_uris: Vec<String>,
57 pub post_logout_redirect_uris: Option<Vec<String>>,
58
59 pub allowed_grant_types: Vec<GrantTypeName>,
60 pub scopes: Vec<String>,
61
62 pub require_pkce: bool,
63 pub pkce_methods_allowed: Vec<PkceMethod>,
64
65 pub allowed_origins: Option<Vec<String>>,
67 pub bearer_allowed: bool,
68
69 pub created_at: DateTime<Utc>,
70 pub updated_at: DateTime<Utc>,
71 pub deleted_at: Option<DateTime<Utc>>,
72}
73
74impl OAuthClient {
75 pub fn is_public(&self) -> bool {
76 self.token_endpoint_auth_method.is_public()
77 }
78
79 pub fn is_confidential(&self) -> bool {
80 self.token_endpoint_auth_method.is_confidential()
81 }
82
83 pub fn allows_bearer(&self) -> bool {
84 self.bearer_allowed
85 }
86
87 pub fn requires_pkce(&self) -> bool {
88 self.require_pkce || self.is_public()
89 }
90
91 pub fn allows_pkce_method(&self, m: PkceMethod) -> bool {
92 self.pkce_methods_allowed.contains(&m)
93 }
94
95 pub fn allows_grant(&self, g: GrantTypeName) -> bool {
96 self.allowed_grant_types.contains(&g)
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct NewClientParams<'a> {
102 pub client_id: &'a str,
103 pub client_name: &'a str,
104 pub application_type: ApplicationType,
105 pub token_endpoint_auth_method: TokenEndpointAuthMethod,
106
107 pub client_secret: Option<&'a Digest>,
108 pub client_secret_expires_at: Option<DateTime<Utc>>,
109
110 pub redirect_uris: &'a [String],
111 pub post_logout_redirect_uris: Option<&'a [String]>,
112
113 pub allowed_grant_types: &'a [GrantTypeName],
114 pub scopes: &'a [String],
115
116 pub require_pkce: bool,
117 pub pkce_methods_allowed: &'a [PkceMethod],
118
119 pub allowed_origins: Option<&'a [String]>,
120 pub bearer_allowed: bool,
121}
122
123impl<'a> NewClientParams<'a> {
124 pub fn validate(&self) -> ModelResult<()> {
126 if self.client_id.trim().is_empty() {
127 return Err(ModelError::new(
128 ModelErrorType::InvalidRequest,
129 "client_id cannot be empty",
130 None::<anyhow::Error>,
131 ));
132 }
133
134 if self.client_name.trim().is_empty() {
135 return Err(ModelError::new(
136 ModelErrorType::InvalidRequest,
137 "client_name cannot be empty",
138 None::<anyhow::Error>,
139 ));
140 }
141
142 if self.redirect_uris.is_empty() {
143 return Err(ModelError::new(
144 ModelErrorType::InvalidRequest,
145 "redirect_uris must not be empty",
146 None::<anyhow::Error>,
147 ));
148 }
149
150 if self.token_endpoint_auth_method.is_public() && self.client_secret.is_some() {
151 return Err(ModelError::new(
152 ModelErrorType::PreconditionFailed,
153 "public clients must not include client_secret",
154 None::<anyhow::Error>,
155 ));
156 }
157
158 if self.token_endpoint_auth_method.is_confidential() && self.client_secret.is_none() {
159 return Err(ModelError::new(
160 ModelErrorType::PreconditionFailed,
161 "confidential clients must include client_secret",
162 None::<anyhow::Error>,
163 ));
164 }
165
166 if !self.require_pkce && self.token_endpoint_auth_method.is_public() {
167 return Err(ModelError::new(
168 ModelErrorType::PreconditionFailed,
169 "public clients must require PKCE",
170 None::<anyhow::Error>,
171 ));
172 }
173
174 Ok(())
175 }
176}
177
178impl OAuthClient {
179 pub async fn find_by_id(conn: &mut PgConnection, id: Uuid) -> ModelResult<Self> {
181 let client = sqlx::query_as!(
182 OAuthClient,
183 r#"
184 SELECT
185 id,
186 client_id,
187 client_name,
188 application_type AS "application_type: _",
189 token_endpoint_auth_method AS "token_endpoint_auth_method: _",
190 client_secret AS "client_secret: _",
191 client_secret_expires_at,
192 redirect_uris,
193 post_logout_redirect_uris,
194 allowed_grant_types AS "allowed_grant_types: _",
195 scopes,
196 require_pkce,
197 pkce_methods_allowed AS "pkce_methods_allowed: _",
198 allowed_origins,
199 bearer_allowed,
200 created_at,
201 updated_at,
202 deleted_at
203 FROM oauth_clients
204 WHERE id = $1
205 AND deleted_at IS NULL
206 "#,
207 id
208 )
209 .fetch_one(conn)
210 .await?;
211
212 Ok(client)
213 }
214
215 pub async fn find_by_id_optional(
217 conn: &mut PgConnection,
218 id: Uuid,
219 ) -> Result<Option<Self>, ModelError> {
220 Self::find_by_id(conn, id).await.optional()
221 }
222
223 pub async fn find_by_client_id(conn: &mut PgConnection, client_id: &str) -> ModelResult<Self> {
225 let client = sqlx::query_as!(
226 OAuthClient,
227 r#"
228 SELECT
229 id,
230 client_id,
231 client_name,
232 application_type AS "application_type: _",
233 token_endpoint_auth_method AS "token_endpoint_auth_method: _",
234 client_secret AS "client_secret: _",
235 client_secret_expires_at,
236 redirect_uris,
237 post_logout_redirect_uris,
238 allowed_grant_types AS "allowed_grant_types: _", -- or "grant_type[]"
239 scopes,
240 require_pkce,
241 pkce_methods_allowed AS "pkce_methods_allowed: _", -- or "pkce_method[]"
242 allowed_origins,
243 bearer_allowed,
244 created_at,
245 updated_at,
246 deleted_at
247 FROM oauth_clients
248 WHERE client_id = $1
249 AND deleted_at IS NULL
250 "#,
251 client_id
252 )
253 .fetch_one(conn)
254 .await?;
255
256 Ok(client)
257 }
258
259 pub async fn find_by_client_id_optional(
261 conn: &mut PgConnection,
262 client_id: &str,
263 ) -> Result<Option<Self>, ModelError> {
264 Self::find_by_client_id(conn, client_id).await.optional()
265 }
266
267 pub async fn insert(conn: &mut PgConnection, p: NewClientParams<'_>) -> ModelResult<Self> {
269 p.validate()?;
270 let row = sqlx::query_as!(
271 OAuthClient,
272 r#"
273 INSERT INTO oauth_clients (
274 client_id,
275 client_name,
276 application_type,
277 token_endpoint_auth_method,
278 client_secret,
279 client_secret_expires_at,
280 redirect_uris,
281 post_logout_redirect_uris,
282 allowed_grant_types,
283 scopes,
284 require_pkce,
285 pkce_methods_allowed,
286 allowed_origins,
287 bearer_allowed
288 )
289 VALUES (
290 $1, $2, $3, $4,
291 $5, $6,
292 $7, COALESCE($8, '{}'::text[]), -- << cast needed for text[]
293 $9, $10,
294 $11, $12,
295 $13, $14
296 )
297 RETURNING
298 id,
299 client_id,
300 client_name,
301 application_type AS "application_type: _",
302 token_endpoint_auth_method AS "token_endpoint_auth_method: _",
303 client_secret AS "client_secret: _",
304 client_secret_expires_at,
305 redirect_uris,
306 post_logout_redirect_uris,
307 allowed_grant_types AS "allowed_grant_types: _",
308 scopes,
309 require_pkce,
310 pkce_methods_allowed AS "pkce_methods_allowed: _",
311 allowed_origins,
312 bearer_allowed,
313 created_at,
314 updated_at,
315 deleted_at
316 "#,
317 p.client_id,
318 p.client_name,
319 p.application_type as ApplicationType,
320 p.token_endpoint_auth_method as TokenEndpointAuthMethod,
321 p.client_secret.map(|d| d.as_bytes() as &[u8]),
322 p.client_secret_expires_at,
323 p.redirect_uris,
324 p.post_logout_redirect_uris,
325 p.allowed_grant_types as &[GrantTypeName],
326 p.scopes,
327 p.require_pkce,
328 p.pkce_methods_allowed as &[PkceMethod],
329 p.allowed_origins,
330 p.bearer_allowed
331 )
332 .fetch_one(conn)
333 .await?;
334
335 Ok(row)
336 }
337}