headless_lms_models/
oauth_client.rs

1//! OAuth 2.1 / OIDC Client model
2//!
3//! Mirrors the `oauth_clients` table and PostgreSQL enums.
4//! Includes small policy helpers (public/confidential, PKCE, grants).
5
6use crate::{
7    library::oauth::{Digest, GrantTypeName, pkce::PkceMethod},
8    prelude::*,
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use sqlx::{FromRow, PgConnection, Type};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
16#[sqlx(type_name = "token_endpoint_auth_method", rename_all = "snake_case")]
17#[serde(rename_all = "snake_case")]
18pub enum TokenEndpointAuthMethod {
19    None,
20    ClientSecretPost,
21}
22
23impl TokenEndpointAuthMethod {
24    pub fn is_public(self) -> bool {
25        matches!(self, Self::None)
26    }
27
28    pub fn is_confidential(self) -> bool {
29        !self.is_public()
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
34#[sqlx(type_name = "application_type", rename_all = "snake_case")]
35#[serde(rename_all = "snake_case")]
36pub enum ApplicationType {
37    Web,
38    Native,
39    Spa,
40    Service,
41}
42
43#[derive(Debug, Serialize, Deserialize, FromRow)]
44pub struct OAuthClient {
45    pub id: Uuid,
46    pub client_id: String,
47    pub client_name: String,
48    pub application_type: ApplicationType,
49
50    pub token_endpoint_auth_method: TokenEndpointAuthMethod,
51
52    /// Hashed/HMACed secret for confidential clients (`BYTEA`); `None` for public clients.
53    pub client_secret: Option<Digest>,
54    pub client_secret_expires_at: Option<DateTime<Utc>>,
55
56    pub redirect_uris: Vec<String>,
57    pub post_logout_redirect_uris: Option<Vec<String>>,
58
59    pub allowed_grant_types: Vec<GrantTypeName>,
60    pub scopes: Vec<String>,
61
62    pub require_pkce: bool,
63    pub pkce_methods_allowed: Vec<PkceMethod>,
64
65    /// Optional list of allowed origins (same validation as redirect URIs). If None or empty, origin check is not enforced.
66    pub allowed_origins: Option<Vec<String>>,
67    pub bearer_allowed: bool,
68
69    pub created_at: DateTime<Utc>,
70    pub updated_at: DateTime<Utc>,
71    pub deleted_at: Option<DateTime<Utc>>,
72}
73
74impl OAuthClient {
75    pub fn is_public(&self) -> bool {
76        self.token_endpoint_auth_method.is_public()
77    }
78
79    pub fn is_confidential(&self) -> bool {
80        self.token_endpoint_auth_method.is_confidential()
81    }
82
83    pub fn allows_bearer(&self) -> bool {
84        self.bearer_allowed
85    }
86
87    pub fn requires_pkce(&self) -> bool {
88        self.require_pkce || self.is_public()
89    }
90
91    pub fn allows_pkce_method(&self, m: PkceMethod) -> bool {
92        self.pkce_methods_allowed.contains(&m)
93    }
94
95    pub fn allows_grant(&self, g: GrantTypeName) -> bool {
96        self.allowed_grant_types.contains(&g)
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct NewClientParams<'a> {
102    pub client_id: &'a str,
103    pub client_name: &'a str,
104    pub application_type: ApplicationType,
105    pub token_endpoint_auth_method: TokenEndpointAuthMethod,
106
107    pub client_secret: Option<&'a Digest>,
108    pub client_secret_expires_at: Option<DateTime<Utc>>,
109
110    pub redirect_uris: &'a [String],
111    pub post_logout_redirect_uris: Option<&'a [String]>,
112
113    pub allowed_grant_types: &'a [GrantTypeName],
114    pub scopes: &'a [String],
115
116    pub require_pkce: bool,
117    pub pkce_methods_allowed: &'a [PkceMethod],
118
119    pub allowed_origins: Option<&'a [String]>,
120    pub bearer_allowed: bool,
121}
122
123impl<'a> NewClientParams<'a> {
124    /// Lightweight pre-DB checks (DB constraints still enforce policy).
125    pub fn validate(&self) -> ModelResult<()> {
126        if self.client_id.trim().is_empty() {
127            return Err(ModelError::new(
128                ModelErrorType::InvalidRequest,
129                "client_id cannot be empty",
130                None::<anyhow::Error>,
131            ));
132        }
133
134        if self.client_name.trim().is_empty() {
135            return Err(ModelError::new(
136                ModelErrorType::InvalidRequest,
137                "client_name cannot be empty",
138                None::<anyhow::Error>,
139            ));
140        }
141
142        if self.redirect_uris.is_empty() {
143            return Err(ModelError::new(
144                ModelErrorType::InvalidRequest,
145                "redirect_uris must not be empty",
146                None::<anyhow::Error>,
147            ));
148        }
149
150        if self.token_endpoint_auth_method.is_public() && self.client_secret.is_some() {
151            return Err(ModelError::new(
152                ModelErrorType::PreconditionFailed,
153                "public clients must not include client_secret",
154                None::<anyhow::Error>,
155            ));
156        }
157
158        if self.token_endpoint_auth_method.is_confidential() && self.client_secret.is_none() {
159            return Err(ModelError::new(
160                ModelErrorType::PreconditionFailed,
161                "confidential clients must include client_secret",
162                None::<anyhow::Error>,
163            ));
164        }
165
166        if !self.require_pkce && self.token_endpoint_auth_method.is_public() {
167            return Err(ModelError::new(
168                ModelErrorType::PreconditionFailed,
169                "public clients must require PKCE",
170                None::<anyhow::Error>,
171            ));
172        }
173
174        Ok(())
175    }
176}
177
178impl OAuthClient {
179    /// Find an **active** (non-soft-deleted) client by DB `id` (UUID).
180    pub async fn find_by_id(conn: &mut PgConnection, id: Uuid) -> ModelResult<Self> {
181        let client = sqlx::query_as!(
182            OAuthClient,
183            r#"
184        SELECT
185          id,
186          client_id,
187          client_name,
188          application_type                AS "application_type: _",
189          token_endpoint_auth_method      AS "token_endpoint_auth_method: _",
190          client_secret                   AS "client_secret: _",
191          client_secret_expires_at,
192          redirect_uris,
193          post_logout_redirect_uris,
194          allowed_grant_types             AS "allowed_grant_types: _",
195          scopes,
196          require_pkce,
197          pkce_methods_allowed            AS "pkce_methods_allowed: _",
198          allowed_origins,
199          bearer_allowed,
200          created_at,
201          updated_at,
202          deleted_at
203        FROM oauth_clients
204        WHERE id = $1
205          AND deleted_at IS NULL
206        "#,
207            id
208        )
209        .fetch_one(conn)
210        .await?;
211
212        Ok(client)
213    }
214
215    /// Same as `find_by_id`, but returns `Ok(None)` when not found.
216    pub async fn find_by_id_optional(
217        conn: &mut PgConnection,
218        id: Uuid,
219    ) -> Result<Option<Self>, ModelError> {
220        Self::find_by_id(conn, id).await.optional()
221    }
222
223    /// Find an **active** (non-soft-deleted) client by `client_id`.
224    pub async fn find_by_client_id(conn: &mut PgConnection, client_id: &str) -> ModelResult<Self> {
225        let client = sqlx::query_as!(
226            OAuthClient,
227            r#"
228    SELECT
229      id,
230      client_id,
231      client_name,
232      application_type                AS "application_type: _",
233      token_endpoint_auth_method      AS "token_endpoint_auth_method: _",
234      client_secret                   AS "client_secret: _",
235      client_secret_expires_at,
236      redirect_uris,
237      post_logout_redirect_uris,
238      allowed_grant_types             AS "allowed_grant_types: _",   -- or "grant_type[]"
239      scopes,
240      require_pkce,
241      pkce_methods_allowed            AS "pkce_methods_allowed: _",  -- or "pkce_method[]"
242      allowed_origins,
243      bearer_allowed,
244      created_at,
245      updated_at,
246      deleted_at
247    FROM oauth_clients
248    WHERE client_id = $1
249      AND deleted_at IS NULL
250    "#,
251            client_id
252        )
253        .fetch_one(conn)
254        .await?;
255
256        Ok(client)
257    }
258
259    /// Same as `find_by_client_id`, but returns `Ok(None)` when not found.
260    pub async fn find_by_client_id_optional(
261        conn: &mut PgConnection,
262        client_id: &str,
263    ) -> Result<Option<Self>, ModelError> {
264        Self::find_by_client_id(conn, client_id).await.optional()
265    }
266
267    /// Insert a new client and return the full hydrated row.
268    pub async fn insert(conn: &mut PgConnection, p: NewClientParams<'_>) -> ModelResult<Self> {
269        p.validate()?;
270        let row = sqlx::query_as!(
271            OAuthClient,
272            r#"
273    INSERT INTO oauth_clients (
274        client_id,
275        client_name,
276        application_type,
277        token_endpoint_auth_method,
278        client_secret,
279        client_secret_expires_at,
280        redirect_uris,
281        post_logout_redirect_uris,
282        allowed_grant_types,
283        scopes,
284        require_pkce,
285        pkce_methods_allowed,
286        allowed_origins,
287        bearer_allowed
288    )
289    VALUES (
290        $1, $2, $3, $4,
291        $5, $6,
292        $7, COALESCE($8, '{}'::text[]),     -- << cast needed for text[]
293        $9, $10,
294        $11, $12,
295        $13, $14
296    )
297    RETURNING
298      id,
299      client_id,
300      client_name,
301      application_type                AS "application_type: _",
302      token_endpoint_auth_method      AS "token_endpoint_auth_method: _",
303      client_secret                   AS "client_secret: _",
304      client_secret_expires_at,
305      redirect_uris,
306      post_logout_redirect_uris,
307      allowed_grant_types             AS "allowed_grant_types: _",
308      scopes,
309      require_pkce,
310      pkce_methods_allowed            AS "pkce_methods_allowed: _",
311      allowed_origins,
312      bearer_allowed,
313      created_at,
314      updated_at,
315      deleted_at
316    "#,
317            p.client_id,
318            p.client_name,
319            p.application_type as ApplicationType,
320            p.token_endpoint_auth_method as TokenEndpointAuthMethod,
321            p.client_secret.map(|d| d.as_bytes() as &[u8]),
322            p.client_secret_expires_at,
323            p.redirect_uris,
324            p.post_logout_redirect_uris,
325            p.allowed_grant_types as &[GrantTypeName],
326            p.scopes,
327            p.require_pkce,
328            p.pkce_methods_allowed as &[PkceMethod],
329            p.allowed_origins,
330            p.bearer_allowed
331        )
332        .fetch_one(conn)
333        .await?;
334
335        Ok(row)
336    }
337}