Skip to main content

headless_lms_models/
oauth_client.rs

1//! OAuth 2.1 / OIDC Client model
2//!
3//! Mirrors the `oauth_clients` table and PostgreSQL enums.
4//! Includes small policy helpers (public/confidential, PKCE, grants).
5
6use crate::{
7    library::oauth::{Digest, GrantTypeName, pkce::PkceMethod},
8    prelude::*,
9};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use sqlx::{FromRow, PgConnection, Type};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
16#[sqlx(type_name = "token_endpoint_auth_method", rename_all = "snake_case")]
17#[serde(rename_all = "snake_case")]
18pub enum TokenEndpointAuthMethod {
19    None,
20    ClientSecretPost,
21}
22
23impl TokenEndpointAuthMethod {
24    pub fn is_public(self) -> bool {
25        matches!(self, Self::None)
26    }
27
28    pub fn is_confidential(self) -> bool {
29        !self.is_public()
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Type)]
34#[sqlx(type_name = "application_type", rename_all = "snake_case")]
35#[serde(rename_all = "snake_case")]
36pub enum ApplicationType {
37    Web,
38    Native,
39    Spa,
40    Service,
41}
42
43#[derive(Debug, Serialize, Deserialize, FromRow)]
44pub struct OAuthClient {
45    pub id: Uuid,
46    pub client_id: String,
47    pub client_name: String,
48    pub application_type: ApplicationType,
49
50    pub token_endpoint_auth_method: TokenEndpointAuthMethod,
51
52    /// Hashed/HMACed secret for confidential clients (`BYTEA`); `None` for public clients.
53    pub client_secret: Option<Digest>,
54    pub client_secret_expires_at: Option<DateTime<Utc>>,
55
56    pub redirect_uris: Vec<String>,
57    pub post_logout_redirect_uris: Option<Vec<String>>,
58
59    pub allowed_grant_types: Vec<GrantTypeName>,
60    pub scopes: Vec<String>,
61
62    pub require_pkce: bool,
63    pub pkce_methods_allowed: Vec<PkceMethod>,
64
65    /// Optional list of allowed origins (same validation as redirect URIs). If None or empty, origin check is not enforced.
66    pub allowed_origins: Option<Vec<String>>,
67    pub bearer_allowed: bool,
68
69    pub created_at: DateTime<Utc>,
70    pub updated_at: DateTime<Utc>,
71    pub deleted_at: Option<DateTime<Utc>>,
72}
73
74impl OAuthClient {
75    pub fn is_public(&self) -> bool {
76        self.token_endpoint_auth_method.is_public()
77    }
78
79    pub fn is_confidential(&self) -> bool {
80        self.token_endpoint_auth_method.is_confidential()
81    }
82
83    pub fn allows_bearer(&self) -> bool {
84        self.bearer_allowed
85    }
86
87    pub fn requires_pkce(&self) -> bool {
88        self.require_pkce || self.is_public()
89    }
90
91    pub fn allows_pkce_method(&self, m: PkceMethod) -> bool {
92        self.pkce_methods_allowed.contains(&m)
93    }
94
95    pub fn allows_grant(&self, g: GrantTypeName) -> bool {
96        self.allowed_grant_types.contains(&g)
97    }
98}
99
100#[derive(Debug, Clone)]
101pub struct NewClientParams<'a> {
102    pub client_id: &'a str,
103    pub client_name: &'a str,
104    pub application_type: ApplicationType,
105    pub token_endpoint_auth_method: TokenEndpointAuthMethod,
106
107    pub client_secret: Option<&'a Digest>,
108    pub client_secret_expires_at: Option<DateTime<Utc>>,
109
110    pub redirect_uris: &'a [String],
111    pub post_logout_redirect_uris: Option<&'a [String]>,
112
113    pub allowed_grant_types: &'a [GrantTypeName],
114    pub scopes: &'a [String],
115
116    pub require_pkce: bool,
117    pub pkce_methods_allowed: &'a [PkceMethod],
118
119    pub allowed_origins: Option<&'a [String]>,
120    pub bearer_allowed: bool,
121}
122
123impl<'a> NewClientParams<'a> {
124    /// Lightweight pre-DB checks (DB constraints still enforce policy).
125    pub fn validate(&self) -> ModelResult<()> {
126        if self.client_id.trim().is_empty() {
127            return Err(ModelError::new(
128                ModelErrorType::InvalidRequest,
129                "client_id cannot be empty",
130                None::<anyhow::Error>,
131            ));
132        }
133
134        if self.client_name.trim().is_empty() {
135            return Err(ModelError::new(
136                ModelErrorType::InvalidRequest,
137                "client_name cannot be empty",
138                None::<anyhow::Error>,
139            ));
140        }
141
142        if self.redirect_uris.is_empty() {
143            return Err(ModelError::new(
144                ModelErrorType::InvalidRequest,
145                "redirect_uris must not be empty",
146                None::<anyhow::Error>,
147            ));
148        }
149
150        if self.token_endpoint_auth_method.is_public() && self.client_secret.is_some() {
151            return Err(ModelError::new(
152                ModelErrorType::PreconditionFailed,
153                "public clients must not include client_secret",
154                None::<anyhow::Error>,
155            ));
156        }
157
158        if self.token_endpoint_auth_method.is_confidential() && self.client_secret.is_none() {
159            return Err(ModelError::new(
160                ModelErrorType::PreconditionFailed,
161                "confidential clients must include client_secret",
162                None::<anyhow::Error>,
163            ));
164        }
165
166        if !self.require_pkce && self.token_endpoint_auth_method.is_public() {
167            return Err(ModelError::new(
168                ModelErrorType::PreconditionFailed,
169                "public clients must require PKCE",
170                None::<anyhow::Error>,
171            ));
172        }
173
174        Ok(())
175    }
176}
177
178impl OAuthClient {
179    /// Find an **active** (non-soft-deleted) client by DB `id` (UUID).
180    pub async fn find_by_id(conn: &mut PgConnection, id: Uuid) -> ModelResult<Self> {
181        let client = sqlx::query_as!(
182            OAuthClient,
183            r#"
184        SELECT *
185        FROM oauth_clients
186        WHERE id = $1
187          AND deleted_at IS NULL
188        "#,
189            id
190        )
191        .fetch_one(conn)
192        .await?;
193
194        Ok(client)
195    }
196
197    /// Same as `find_by_id`, but returns `Ok(None)` when not found.
198    pub async fn find_by_id_optional(
199        conn: &mut PgConnection,
200        id: Uuid,
201    ) -> Result<Option<Self>, ModelError> {
202        Self::find_by_id(conn, id).await.optional()
203    }
204
205    /// Find an **active** (non-soft-deleted) client by `client_id`.
206    pub async fn find_by_client_id(conn: &mut PgConnection, client_id: &str) -> ModelResult<Self> {
207        let client = sqlx::query_as!(
208            OAuthClient,
209            r#"
210    SELECT *
211    FROM oauth_clients
212    WHERE client_id = $1
213      AND deleted_at IS NULL
214    "#,
215            client_id
216        )
217        .fetch_one(conn)
218        .await?;
219
220        Ok(client)
221    }
222
223    /// Same as `find_by_client_id`, but returns `Ok(None)` when not found.
224    pub async fn find_by_client_id_optional(
225        conn: &mut PgConnection,
226        client_id: &str,
227    ) -> Result<Option<Self>, ModelError> {
228        Self::find_by_client_id(conn, client_id).await.optional()
229    }
230
231    /// Insert a new client and return the full hydrated row.
232    pub async fn insert(conn: &mut PgConnection, p: NewClientParams<'_>) -> ModelResult<Self> {
233        p.validate()?;
234        let row = sqlx::query_as!(
235            OAuthClient,
236            r#"
237    INSERT INTO oauth_clients (
238        client_id,
239        client_name,
240        application_type,
241        token_endpoint_auth_method,
242        client_secret,
243        client_secret_expires_at,
244        redirect_uris,
245        post_logout_redirect_uris,
246        allowed_grant_types,
247        scopes,
248        require_pkce,
249        pkce_methods_allowed,
250        allowed_origins,
251        bearer_allowed
252    )
253    VALUES (
254        $1, $2, $3, $4,
255        $5, $6,
256        $7, COALESCE($8, '{}'::text[]),     -- << cast needed for text[]
257        $9, $10,
258        $11, $12,
259        $13, $14
260    )
261    RETURNING
262      *
263    "#,
264            p.client_id,
265            p.client_name,
266            p.application_type as ApplicationType,
267            p.token_endpoint_auth_method as TokenEndpointAuthMethod,
268            p.client_secret.map(|d| d.as_bytes() as &[u8]),
269            p.client_secret_expires_at,
270            p.redirect_uris,
271            p.post_logout_redirect_uris,
272            p.allowed_grant_types as &[GrantTypeName],
273            p.scopes,
274            p.require_pkce,
275            p.pkce_methods_allowed as &[PkceMethod],
276            p.allowed_origins,
277            p.bearer_allowed
278        )
279        .fetch_one(conn)
280        .await?;
281
282        Ok(row)
283    }
284}