1use crate::{
7 library::oauth::{Digest, GrantTypeName, pkce::PkceMethod},
8 prelude::*,
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use sqlx::{FromRow, PgConnection, Type};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
16#[sqlx(type_name = "token_endpoint_auth_method", rename_all = "snake_case")]
17#[serde(rename_all = "snake_case")]
18pub enum TokenEndpointAuthMethod {
19 None,
20 ClientSecretPost,
21}
22
23impl TokenEndpointAuthMethod {
24 pub fn is_public(self) -> bool {
25 matches!(self, Self::None)
26 }
27
28 pub fn is_confidential(self) -> bool {
29 !self.is_public()
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
34#[sqlx(type_name = "application_type", rename_all = "snake_case")]
35#[serde(rename_all = "snake_case")]
36pub enum ApplicationType {
37 Web,
38 Native,
39 Spa,
40 Service,
41}
42
43#[derive(Debug, Serialize, Deserialize, FromRow)]
44pub struct OAuthClient {
45 pub id: Uuid,
46 pub client_id: String,
47 pub client_name: String,
48 pub application_type: ApplicationType,
49
50 pub token_endpoint_auth_method: TokenEndpointAuthMethod,
51
52 pub client_secret: Option<Digest>,
54 pub client_secret_expires_at: Option<DateTime<Utc>>,
55
56 pub redirect_uris: Vec<String>,
57 pub post_logout_redirect_uris: Option<Vec<String>>,
58
59 pub allowed_grant_types: Vec<GrantTypeName>,
60 pub scopes: Vec<String>,
61
62 pub require_pkce: bool,
63 pub pkce_methods_allowed: Vec<PkceMethod>,
64
65 pub origin: String,
66 pub bearer_allowed: bool,
67
68 pub created_at: DateTime<Utc>,
69 pub updated_at: DateTime<Utc>,
70 pub deleted_at: Option<DateTime<Utc>>,
71}
72
73impl OAuthClient {
74 pub fn is_public(&self) -> bool {
75 self.token_endpoint_auth_method.is_public()
76 }
77
78 pub fn is_confidential(&self) -> bool {
79 self.token_endpoint_auth_method.is_confidential()
80 }
81
82 pub fn allows_bearer(&self) -> bool {
83 self.bearer_allowed
84 }
85
86 pub fn requires_pkce(&self) -> bool {
87 self.require_pkce || self.is_public()
88 }
89
90 pub fn allows_pkce_method(&self, m: PkceMethod) -> bool {
91 self.pkce_methods_allowed.contains(&m)
92 }
93
94 pub fn allows_grant(&self, g: GrantTypeName) -> bool {
95 self.allowed_grant_types.contains(&g)
96 }
97}
98
99#[derive(Debug, Clone)]
100pub struct NewClientParams<'a> {
101 pub client_id: &'a str,
102 pub client_name: &'a str,
103 pub application_type: ApplicationType,
104 pub token_endpoint_auth_method: TokenEndpointAuthMethod,
105
106 pub client_secret: Option<&'a Digest>,
107 pub client_secret_expires_at: Option<DateTime<Utc>>,
108
109 pub redirect_uris: &'a [String],
110 pub post_logout_redirect_uris: Option<&'a [String]>,
111
112 pub allowed_grant_types: &'a [GrantTypeName],
113 pub scopes: &'a [String],
114
115 pub require_pkce: bool,
116 pub pkce_methods_allowed: &'a [PkceMethod],
117
118 pub origin: &'a str,
119 pub bearer_allowed: bool,
120}
121
122impl<'a> NewClientParams<'a> {
123 pub fn validate(&self) -> ModelResult<()> {
125 if self.client_id.trim().is_empty() {
126 return Err(ModelError::new(
127 ModelErrorType::InvalidRequest,
128 "client_id cannot be empty",
129 None::<anyhow::Error>,
130 ));
131 }
132
133 if self.client_name.trim().is_empty() {
134 return Err(ModelError::new(
135 ModelErrorType::InvalidRequest,
136 "client_name cannot be empty",
137 None::<anyhow::Error>,
138 ));
139 }
140
141 if self.redirect_uris.is_empty() {
142 return Err(ModelError::new(
143 ModelErrorType::InvalidRequest,
144 "redirect_uris must not be empty",
145 None::<anyhow::Error>,
146 ));
147 }
148
149 if self.token_endpoint_auth_method.is_public() && self.client_secret.is_some() {
150 return Err(ModelError::new(
151 ModelErrorType::PreconditionFailed,
152 "public clients must not include client_secret",
153 None::<anyhow::Error>,
154 ));
155 }
156
157 if self.token_endpoint_auth_method.is_confidential() && self.client_secret.is_none() {
158 return Err(ModelError::new(
159 ModelErrorType::PreconditionFailed,
160 "confidential clients must include client_secret",
161 None::<anyhow::Error>,
162 ));
163 }
164
165 if !self.require_pkce && self.token_endpoint_auth_method.is_public() {
166 return Err(ModelError::new(
167 ModelErrorType::PreconditionFailed,
168 "public clients must require PKCE",
169 None::<anyhow::Error>,
170 ));
171 }
172
173 Ok(())
174 }
175}
176
177impl OAuthClient {
178 pub async fn find_by_id(conn: &mut PgConnection, id: Uuid) -> ModelResult<Self> {
180 let client = sqlx::query_as!(
181 OAuthClient,
182 r#"
183 SELECT
184 id,
185 client_id,
186 client_name,
187 application_type AS "application_type: _",
188 token_endpoint_auth_method AS "token_endpoint_auth_method: _",
189 client_secret AS "client_secret: _",
190 client_secret_expires_at,
191 redirect_uris,
192 post_logout_redirect_uris,
193 allowed_grant_types AS "allowed_grant_types: _",
194 scopes,
195 require_pkce,
196 pkce_methods_allowed AS "pkce_methods_allowed: _",
197 origin,
198 bearer_allowed,
199 created_at,
200 updated_at,
201 deleted_at
202 FROM oauth_clients
203 WHERE id = $1
204 AND deleted_at IS NULL
205 "#,
206 id
207 )
208 .fetch_one(conn)
209 .await?;
210
211 Ok(client)
212 }
213
214 pub async fn find_by_id_optional(
216 conn: &mut PgConnection,
217 id: Uuid,
218 ) -> Result<Option<Self>, ModelError> {
219 Self::find_by_id(conn, id).await.optional()
220 }
221
222 pub async fn find_by_client_id(conn: &mut PgConnection, client_id: &str) -> ModelResult<Self> {
224 let client = sqlx::query_as!(
225 OAuthClient,
226 r#"
227 SELECT
228 id,
229 client_id,
230 client_name,
231 application_type AS "application_type: _",
232 token_endpoint_auth_method AS "token_endpoint_auth_method: _",
233 client_secret AS "client_secret: _",
234 client_secret_expires_at,
235 redirect_uris,
236 post_logout_redirect_uris,
237 allowed_grant_types AS "allowed_grant_types: _", -- or "grant_type[]"
238 scopes,
239 require_pkce,
240 pkce_methods_allowed AS "pkce_methods_allowed: _", -- or "pkce_method[]"
241 origin,
242 bearer_allowed,
243 created_at,
244 updated_at,
245 deleted_at
246 FROM oauth_clients
247 WHERE client_id = $1
248 AND deleted_at IS NULL
249 "#,
250 client_id
251 )
252 .fetch_one(conn)
253 .await?;
254
255 Ok(client)
256 }
257
258 pub async fn find_by_client_id_optional(
260 conn: &mut PgConnection,
261 client_id: &str,
262 ) -> Result<Option<Self>, ModelError> {
263 Self::find_by_client_id(conn, client_id).await.optional()
264 }
265
266 pub async fn insert(conn: &mut PgConnection, p: NewClientParams<'_>) -> ModelResult<Self> {
268 p.validate()?;
269 let row = sqlx::query_as!(
270 OAuthClient,
271 r#"
272 INSERT INTO oauth_clients (
273 client_id,
274 client_name,
275 application_type,
276 token_endpoint_auth_method,
277 client_secret,
278 client_secret_expires_at,
279 redirect_uris,
280 post_logout_redirect_uris,
281 allowed_grant_types,
282 scopes,
283 require_pkce,
284 pkce_methods_allowed,
285 origin,
286 bearer_allowed
287 )
288 VALUES (
289 $1, $2, $3, $4,
290 $5, $6,
291 $7, COALESCE($8, '{}'::text[]), -- << cast needed for text[]
292 $9, $10,
293 $11, $12,
294 $13, $14
295 )
296 RETURNING
297 id,
298 client_id,
299 client_name,
300 application_type AS "application_type: _",
301 token_endpoint_auth_method AS "token_endpoint_auth_method: _",
302 client_secret AS "client_secret: _",
303 client_secret_expires_at,
304 redirect_uris,
305 post_logout_redirect_uris,
306 allowed_grant_types AS "allowed_grant_types: _",
307 scopes,
308 require_pkce,
309 pkce_methods_allowed AS "pkce_methods_allowed: _",
310 origin,
311 bearer_allowed,
312 created_at,
313 updated_at,
314 deleted_at
315 "#,
316 p.client_id,
317 p.client_name,
318 p.application_type as ApplicationType,
319 p.token_endpoint_auth_method as TokenEndpointAuthMethod,
320 p.client_secret.map(|d| d.as_bytes() as &[u8]),
321 p.client_secret_expires_at,
322 p.redirect_uris,
323 p.post_logout_redirect_uris,
324 p.allowed_grant_types as &[GrantTypeName],
325 p.scopes,
326 p.require_pkce,
327 p.pkce_methods_allowed as &[PkceMethod],
328 p.origin,
329 p.bearer_allowed
330 )
331 .fetch_one(conn)
332 .await?;
333
334 Ok(row)
335 }
336}