headless_lms_models/
oauth_client.rs

1//! OAuth 2.1 / OIDC Client model
2//!
3//! Mirrors the `oauth_clients` table and PostgreSQL enums.
4//! Includes small policy helpers (public/confidential, PKCE, grants).
5
6use crate::{
7    library::oauth::{Digest, GrantTypeName, pkce::PkceMethod},
8    prelude::*,
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use sqlx::{FromRow, PgConnection, Type};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
16#[sqlx(type_name = "token_endpoint_auth_method", rename_all = "snake_case")]
17#[serde(rename_all = "snake_case")]
18pub enum TokenEndpointAuthMethod {
19    None,
20    ClientSecretPost,
21}
22
23impl TokenEndpointAuthMethod {
24    pub fn is_public(self) -> bool {
25        matches!(self, Self::None)
26    }
27
28    pub fn is_confidential(self) -> bool {
29        !self.is_public()
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
34#[sqlx(type_name = "application_type", rename_all = "snake_case")]
35#[serde(rename_all = "snake_case")]
36pub enum ApplicationType {
37    Web,
38    Native,
39    Spa,
40    Service,
41}
42
43#[derive(Debug, Serialize, Deserialize, FromRow)]
44pub struct OAuthClient {
45    pub id: Uuid,
46    pub client_id: String,
47    pub client_name: String,
48    pub application_type: ApplicationType,
49
50    pub token_endpoint_auth_method: TokenEndpointAuthMethod,
51
52    /// Hashed/HMACed secret for confidential clients (`BYTEA`); `None` for public clients.
53    pub client_secret: Option<Digest>,
54    pub client_secret_expires_at: Option<DateTime<Utc>>,
55
56    pub redirect_uris: Vec<String>,
57    pub post_logout_redirect_uris: Option<Vec<String>>,
58
59    pub allowed_grant_types: Vec<GrantTypeName>,
60    pub scopes: Vec<String>,
61
62    pub require_pkce: bool,
63    pub pkce_methods_allowed: Vec<PkceMethod>,
64
65    pub origin: String,
66    pub bearer_allowed: bool,
67
68    pub created_at: DateTime<Utc>,
69    pub updated_at: DateTime<Utc>,
70    pub deleted_at: Option<DateTime<Utc>>,
71}
72
73impl OAuthClient {
74    pub fn is_public(&self) -> bool {
75        self.token_endpoint_auth_method.is_public()
76    }
77
78    pub fn is_confidential(&self) -> bool {
79        self.token_endpoint_auth_method.is_confidential()
80    }
81
82    pub fn allows_bearer(&self) -> bool {
83        self.bearer_allowed
84    }
85
86    pub fn requires_pkce(&self) -> bool {
87        self.require_pkce || self.is_public()
88    }
89
90    pub fn allows_pkce_method(&self, m: PkceMethod) -> bool {
91        self.pkce_methods_allowed.contains(&m)
92    }
93
94    pub fn allows_grant(&self, g: GrantTypeName) -> bool {
95        self.allowed_grant_types.contains(&g)
96    }
97}
98
99#[derive(Debug, Clone)]
100pub struct NewClientParams<'a> {
101    pub client_id: &'a str,
102    pub client_name: &'a str,
103    pub application_type: ApplicationType,
104    pub token_endpoint_auth_method: TokenEndpointAuthMethod,
105
106    pub client_secret: Option<&'a Digest>,
107    pub client_secret_expires_at: Option<DateTime<Utc>>,
108
109    pub redirect_uris: &'a [String],
110    pub post_logout_redirect_uris: Option<&'a [String]>,
111
112    pub allowed_grant_types: &'a [GrantTypeName],
113    pub scopes: &'a [String],
114
115    pub require_pkce: bool,
116    pub pkce_methods_allowed: &'a [PkceMethod],
117
118    pub origin: &'a str,
119    pub bearer_allowed: bool,
120}
121
122impl<'a> NewClientParams<'a> {
123    /// Lightweight pre-DB checks (DB constraints still enforce policy).
124    pub fn validate(&self) -> ModelResult<()> {
125        if self.client_id.trim().is_empty() {
126            return Err(ModelError::new(
127                ModelErrorType::InvalidRequest,
128                "client_id cannot be empty",
129                None::<anyhow::Error>,
130            ));
131        }
132
133        if self.client_name.trim().is_empty() {
134            return Err(ModelError::new(
135                ModelErrorType::InvalidRequest,
136                "client_name cannot be empty",
137                None::<anyhow::Error>,
138            ));
139        }
140
141        if self.redirect_uris.is_empty() {
142            return Err(ModelError::new(
143                ModelErrorType::InvalidRequest,
144                "redirect_uris must not be empty",
145                None::<anyhow::Error>,
146            ));
147        }
148
149        if self.token_endpoint_auth_method.is_public() && self.client_secret.is_some() {
150            return Err(ModelError::new(
151                ModelErrorType::PreconditionFailed,
152                "public clients must not include client_secret",
153                None::<anyhow::Error>,
154            ));
155        }
156
157        if self.token_endpoint_auth_method.is_confidential() && self.client_secret.is_none() {
158            return Err(ModelError::new(
159                ModelErrorType::PreconditionFailed,
160                "confidential clients must include client_secret",
161                None::<anyhow::Error>,
162            ));
163        }
164
165        if !self.require_pkce && self.token_endpoint_auth_method.is_public() {
166            return Err(ModelError::new(
167                ModelErrorType::PreconditionFailed,
168                "public clients must require PKCE",
169                None::<anyhow::Error>,
170            ));
171        }
172
173        Ok(())
174    }
175}
176
177impl OAuthClient {
178    /// Find an **active** (non-soft-deleted) client by DB `id` (UUID).
179    pub async fn find_by_id(conn: &mut PgConnection, id: Uuid) -> ModelResult<Self> {
180        let client = sqlx::query_as!(
181            OAuthClient,
182            r#"
183        SELECT
184          id,
185          client_id,
186          client_name,
187          application_type                AS "application_type: _",
188          token_endpoint_auth_method      AS "token_endpoint_auth_method: _",
189          client_secret                   AS "client_secret: _",
190          client_secret_expires_at,
191          redirect_uris,
192          post_logout_redirect_uris,
193          allowed_grant_types             AS "allowed_grant_types: _",
194          scopes,
195          require_pkce,
196          pkce_methods_allowed            AS "pkce_methods_allowed: _",
197          origin,
198          bearer_allowed,
199          created_at,
200          updated_at,
201          deleted_at
202        FROM oauth_clients
203        WHERE id = $1
204          AND deleted_at IS NULL
205        "#,
206            id
207        )
208        .fetch_one(conn)
209        .await?;
210
211        Ok(client)
212    }
213
214    /// Same as `find_by_id`, but returns `Ok(None)` when not found.
215    pub async fn find_by_id_optional(
216        conn: &mut PgConnection,
217        id: Uuid,
218    ) -> Result<Option<Self>, ModelError> {
219        Self::find_by_id(conn, id).await.optional()
220    }
221
222    /// Find an **active** (non-soft-deleted) client by `client_id`.
223    pub async fn find_by_client_id(conn: &mut PgConnection, client_id: &str) -> ModelResult<Self> {
224        let client = sqlx::query_as!(
225            OAuthClient,
226            r#"
227    SELECT
228      id,
229      client_id,
230      client_name,
231      application_type                AS "application_type: _",
232      token_endpoint_auth_method      AS "token_endpoint_auth_method: _",
233      client_secret                   AS "client_secret: _",
234      client_secret_expires_at,
235      redirect_uris,
236      post_logout_redirect_uris,
237      allowed_grant_types             AS "allowed_grant_types: _",   -- or "grant_type[]"
238      scopes,
239      require_pkce,
240      pkce_methods_allowed            AS "pkce_methods_allowed: _",  -- or "pkce_method[]"
241      origin,
242      bearer_allowed,
243      created_at,
244      updated_at,
245      deleted_at
246    FROM oauth_clients
247    WHERE client_id = $1
248      AND deleted_at IS NULL
249    "#,
250            client_id
251        )
252        .fetch_one(conn)
253        .await?;
254
255        Ok(client)
256    }
257
258    /// Same as `find_by_client_id`, but returns `Ok(None)` when not found.
259    pub async fn find_by_client_id_optional(
260        conn: &mut PgConnection,
261        client_id: &str,
262    ) -> Result<Option<Self>, ModelError> {
263        Self::find_by_client_id(conn, client_id).await.optional()
264    }
265
266    /// Insert a new client and return the full hydrated row.
267    pub async fn insert(conn: &mut PgConnection, p: NewClientParams<'_>) -> ModelResult<Self> {
268        p.validate()?;
269        let row = sqlx::query_as!(
270            OAuthClient,
271            r#"
272    INSERT INTO oauth_clients (
273        client_id,
274        client_name,
275        application_type,
276        token_endpoint_auth_method,
277        client_secret,
278        client_secret_expires_at,
279        redirect_uris,
280        post_logout_redirect_uris,
281        allowed_grant_types,
282        scopes,
283        require_pkce,
284        pkce_methods_allowed,
285        origin,
286        bearer_allowed
287    )
288    VALUES (
289        $1, $2, $3, $4,
290        $5, $6,
291        $7, COALESCE($8, '{}'::text[]),     -- << cast needed for text[]
292        $9, $10,
293        $11, $12,
294        $13, $14
295    )
296    RETURNING
297      id,
298      client_id,
299      client_name,
300      application_type                AS "application_type: _",
301      token_endpoint_auth_method      AS "token_endpoint_auth_method: _",
302      client_secret                   AS "client_secret: _",
303      client_secret_expires_at,
304      redirect_uris,
305      post_logout_redirect_uris,
306      allowed_grant_types             AS "allowed_grant_types: _",
307      scopes,
308      require_pkce,
309      pkce_methods_allowed            AS "pkce_methods_allowed: _",
310      origin,
311      bearer_allowed,
312      created_at,
313      updated_at,
314      deleted_at
315    "#,
316            p.client_id,
317            p.client_name,
318            p.application_type as ApplicationType,
319            p.token_endpoint_auth_method as TokenEndpointAuthMethod,
320            p.client_secret.map(|d| d.as_bytes() as &[u8]),
321            p.client_secret_expires_at,
322            p.redirect_uris,
323            p.post_logout_redirect_uris,
324            p.allowed_grant_types as &[GrantTypeName],
325            p.scopes,
326            p.require_pkce,
327            p.pkce_methods_allowed as &[PkceMethod],
328            p.origin,
329            p.bearer_allowed
330        )
331        .fetch_one(conn)
332        .await?;
333
334        Ok(row)
335    }
336}