1use crate::{
7 library::oauth::{Digest, GrantTypeName, pkce::PkceMethod},
8 prelude::*,
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use sqlx::{FromRow, PgConnection, Type};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
16#[sqlx(type_name = "token_endpoint_auth_method", rename_all = "snake_case")]
17#[serde(rename_all = "snake_case")]
18pub enum TokenEndpointAuthMethod {
19 None,
20 ClientSecretPost,
21}
22
23impl TokenEndpointAuthMethod {
24 pub fn is_public(self) -> bool {
25 matches!(self, Self::None)
26 }
27
28 pub fn is_confidential(self) -> bool {
29 !self.is_public()
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
34#[sqlx(type_name = "application_type", rename_all = "snake_case")]
35#[serde(rename_all = "snake_case")]
36pub enum ApplicationType {
37 Web,
38 Native,
39 Spa,
40 Service,
41}
42
43#[derive(Debug, Serialize, Deserialize, FromRow)]
44pub struct OAuthClient {
45 pub id: Uuid,
46 pub client_id: String,
47 pub client_name: String,
48 pub application_type: ApplicationType,
49
50 pub token_endpoint_auth_method: TokenEndpointAuthMethod,
51
52 pub client_secret: Option<Digest>,
54 pub client_secret_expires_at: Option<DateTime<Utc>>,
55
56 pub redirect_uris: Vec<String>,
57 pub post_logout_redirect_uris: Option<Vec<String>>,
58
59 pub allowed_grant_types: Vec<GrantTypeName>,
60 pub scopes: Vec<String>,
61
62 pub require_pkce: bool,
63 pub pkce_methods_allowed: Vec<PkceMethod>,
64
65 pub allowed_origins: Option<Vec<String>>,
67 pub bearer_allowed: bool,
68
69 pub created_at: DateTime<Utc>,
70 pub updated_at: DateTime<Utc>,
71 pub deleted_at: Option<DateTime<Utc>>,
72}
73
74impl OAuthClient {
75 pub fn is_public(&self) -> bool {
76 self.token_endpoint_auth_method.is_public()
77 }
78
79 pub fn is_confidential(&self) -> bool {
80 self.token_endpoint_auth_method.is_confidential()
81 }
82
83 pub fn allows_bearer(&self) -> bool {
84 self.bearer_allowed
85 }
86
87 pub fn requires_pkce(&self) -> bool {
88 self.require_pkce || self.is_public()
89 }
90
91 pub fn allows_pkce_method(&self, m: PkceMethod) -> bool {
92 self.pkce_methods_allowed.contains(&m)
93 }
94
95 pub fn allows_grant(&self, g: GrantTypeName) -> bool {
96 self.allowed_grant_types.contains(&g)
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct NewClientParams<'a> {
102 pub client_id: &'a str,
103 pub client_name: &'a str,
104 pub application_type: ApplicationType,
105 pub token_endpoint_auth_method: TokenEndpointAuthMethod,
106
107 pub client_secret: Option<&'a Digest>,
108 pub client_secret_expires_at: Option<DateTime<Utc>>,
109
110 pub redirect_uris: &'a [String],
111 pub post_logout_redirect_uris: Option<&'a [String]>,
112
113 pub allowed_grant_types: &'a [GrantTypeName],
114 pub scopes: &'a [String],
115
116 pub require_pkce: bool,
117 pub pkce_methods_allowed: &'a [PkceMethod],
118
119 pub allowed_origins: Option<&'a [String]>,
120 pub bearer_allowed: bool,
121}
122
123impl<'a> NewClientParams<'a> {
124 pub fn validate(&self) -> ModelResult<()> {
126 if self.client_id.trim().is_empty() {
127 return Err(ModelError::new(
128 ModelErrorType::InvalidRequest,
129 "client_id cannot be empty",
130 None::<anyhow::Error>,
131 ));
132 }
133
134 if self.client_name.trim().is_empty() {
135 return Err(ModelError::new(
136 ModelErrorType::InvalidRequest,
137 "client_name cannot be empty",
138 None::<anyhow::Error>,
139 ));
140 }
141
142 if self.redirect_uris.is_empty() {
143 return Err(ModelError::new(
144 ModelErrorType::InvalidRequest,
145 "redirect_uris must not be empty",
146 None::<anyhow::Error>,
147 ));
148 }
149
150 if self.token_endpoint_auth_method.is_public() && self.client_secret.is_some() {
151 return Err(ModelError::new(
152 ModelErrorType::PreconditionFailed,
153 "public clients must not include client_secret",
154 None::<anyhow::Error>,
155 ));
156 }
157
158 if self.token_endpoint_auth_method.is_confidential() && self.client_secret.is_none() {
159 return Err(ModelError::new(
160 ModelErrorType::PreconditionFailed,
161 "confidential clients must include client_secret",
162 None::<anyhow::Error>,
163 ));
164 }
165
166 if !self.require_pkce && self.token_endpoint_auth_method.is_public() {
167 return Err(ModelError::new(
168 ModelErrorType::PreconditionFailed,
169 "public clients must require PKCE",
170 None::<anyhow::Error>,
171 ));
172 }
173
174 Ok(())
175 }
176}
177
178impl OAuthClient {
179 pub async fn find_by_id(conn: &mut PgConnection, id: Uuid) -> ModelResult<Self> {
181 let client = sqlx::query_as!(
182 OAuthClient,
183 r#"
184 SELECT *
185 FROM oauth_clients
186 WHERE id = $1
187 AND deleted_at IS NULL
188 "#,
189 id
190 )
191 .fetch_one(conn)
192 .await?;
193
194 Ok(client)
195 }
196
197 pub async fn find_by_id_optional(
199 conn: &mut PgConnection,
200 id: Uuid,
201 ) -> Result<Option<Self>, ModelError> {
202 Self::find_by_id(conn, id).await.optional()
203 }
204
205 pub async fn find_by_client_id(conn: &mut PgConnection, client_id: &str) -> ModelResult<Self> {
207 let client = sqlx::query_as!(
208 OAuthClient,
209 r#"
210 SELECT *
211 FROM oauth_clients
212 WHERE client_id = $1
213 AND deleted_at IS NULL
214 "#,
215 client_id
216 )
217 .fetch_one(conn)
218 .await?;
219
220 Ok(client)
221 }
222
223 pub async fn find_by_client_id_optional(
225 conn: &mut PgConnection,
226 client_id: &str,
227 ) -> Result<Option<Self>, ModelError> {
228 Self::find_by_client_id(conn, client_id).await.optional()
229 }
230
231 pub async fn insert(conn: &mut PgConnection, p: NewClientParams<'_>) -> ModelResult<Self> {
233 p.validate()?;
234 let row = sqlx::query_as!(
235 OAuthClient,
236 r#"
237 INSERT INTO oauth_clients (
238 client_id,
239 client_name,
240 application_type,
241 token_endpoint_auth_method,
242 client_secret,
243 client_secret_expires_at,
244 redirect_uris,
245 post_logout_redirect_uris,
246 allowed_grant_types,
247 scopes,
248 require_pkce,
249 pkce_methods_allowed,
250 allowed_origins,
251 bearer_allowed
252 )
253 VALUES (
254 $1, $2, $3, $4,
255 $5, $6,
256 $7, COALESCE($8, '{}'::text[]), -- << cast needed for text[]
257 $9, $10,
258 $11, $12,
259 $13, $14
260 )
261 RETURNING
262 *
263 "#,
264 p.client_id,
265 p.client_name,
266 p.application_type as ApplicationType,
267 p.token_endpoint_auth_method as TokenEndpointAuthMethod,
268 p.client_secret.map(|d| d.as_bytes() as &[u8]),
269 p.client_secret_expires_at,
270 p.redirect_uris,
271 p.post_logout_redirect_uris,
272 p.allowed_grant_types as &[GrantTypeName],
273 p.scopes,
274 p.require_pkce,
275 p.pkce_methods_allowed as &[PkceMethod],
276 p.allowed_origins,
277 p.bearer_allowed
278 )
279 .fetch_one(conn)
280 .await?;
281
282 Ok(row)
283 }
284}