headless_lms_server/domain/oauth/
pkce.rs

1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use sha2::{Digest, Sha256};
3
4use crate::domain::error::{ControllerError, PkceFlowError};
5use crate::domain::oauth::helpers::oauth_invalid_request;
6use headless_lms_models::oauth_client::OAuthClient;
7
8/// Re-export PkceMethod from models (it's used in SQL queries, so must stay in models crate)
9pub use headless_lms_models::library::oauth::pkce::PkceMethod;
10
11/// RFC 7636: code_verifier length MUST be between 43 and 128 characters.
12pub const VERIFIER_MIN_LEN: usize = 43;
13pub const VERIFIER_MAX_LEN: usize = 128;
14
15/// Errors constructing/validating PKCE values.
16#[derive(Debug, thiserror::Error)]
17pub enum PkceError {
18    #[error("code_verifier length out of bounds")]
19    BadLength,
20    #[error("code_verifier contains invalid characters")]
21    BadCharset,
22    #[error("Disallowed PKCE method")]
23    BadMethod,
24}
25
26/// Validated PKCE code_verifier (RFC 7636).
27#[derive(Debug, Clone)]
28pub struct CodeVerifier(String);
29
30impl CodeVerifier {
31    /// Construct after validating length and allowed charset.
32    pub fn new(s: &str) -> Result<Self, PkceError> {
33        validate_verifier(s)?;
34        Ok(Self(s.to_owned()))
35    }
36
37    /// Construct without allocation if you already own a String; still validates.
38    pub fn try_from_string(s: String) -> Result<Self, PkceError> {
39        validate_verifier(&s)?;
40        Ok(Self(s))
41    }
42
43    /// Borrow the inner str.
44    pub fn as_str(&self) -> &str {
45        &self.0
46    }
47
48    /// Compute the PKCE `code_challenge` for this verifier using `method`.
49    pub fn to_challenge(&self, method: PkceMethod) -> CodeChallenge {
50        match method {
51            PkceMethod::Plain => CodeChallenge(self.0.clone()),
52            PkceMethod::S256 => {
53                let digest = Sha256::digest(self.0.as_bytes());
54                CodeChallenge(URL_SAFE_NO_PAD.encode(digest))
55            }
56        }
57    }
58}
59
60/// Stored PKCE code_challenge.
61#[derive(Debug, Clone)]
62pub struct CodeChallenge(String);
63
64impl CodeChallenge {
65    /// Wrap a stored challenge (e.g., from DB). No extra validation needed at this layer.
66    pub fn from_stored<S: Into<String>>(s: S) -> Self {
67        Self(s.into())
68    }
69
70    pub fn as_str(&self) -> &str {
71        &self.0
72    }
73
74    /// Verify that `verifier` corresponds to this challenge under `method`.
75    pub fn verify(&self, verifier: &CodeVerifier, method: PkceMethod) -> bool {
76        let computed = verifier.to_challenge(method);
77        constant_time_eq(self.as_str(), computed.as_str())
78    }
79}
80
81/// Strict RFC 7636 validator: length 43–128 and only unreserved characters.
82///
83/// Unreserved: ALPHA / DIGIT / "-" / "." / "_" / "~"
84fn validate_verifier(v: &str) -> Result<(), PkceError> {
85    let len = v.len();
86    if !(VERIFIER_MIN_LEN..=VERIFIER_MAX_LEN).contains(&len) {
87        return Err(PkceError::BadLength);
88    }
89    if !v.bytes().all(|b| {
90        matches!(
91            b,
92            b'A'..=b'Z' |
93            b'a'..=b'z' |
94            b'0'..=b'9' |
95            b'-' | b'.' | b'_' | b'~'
96        )
97    }) {
98        return Err(PkceError::BadCharset);
99    }
100    Ok(())
101}
102
103/// Constant-time equality on ASCII strings (safe for our base64url/plain outputs).
104fn constant_time_eq(a: &str, b: &str) -> bool {
105    if a.len() != b.len() {
106        return false;
107    }
108    let mut diff = 0u8;
109    for (x, y) in a.bytes().zip(b.bytes()) {
110        diff |= x ^ y;
111    }
112    diff == 0
113}
114
115/// Validate PKCE parameters during `/authorize`.
116pub fn parse_authorize_pkce(
117    client: &OAuthClient,
118    code_challenge: Option<&str>,
119    code_challenge_method: Option<&str>,
120    redirect_uri: &str,
121    state: Option<&str>,
122) -> Result<Option<PkceMethod>, ControllerError> {
123    let pkce_required = client.requires_pkce();
124    let parsed = match (code_challenge, code_challenge_method) {
125        (Some(ch), Some(method_str)) => {
126            let method = PkceMethod::parse(method_str).ok_or_else(|| {
127                oauth_invalid_request(
128                    "unsupported code_challenge_method",
129                    Some(redirect_uri),
130                    state,
131                )
132            })?;
133
134            if !client.allows_pkce_method(method) {
135                return Err(oauth_invalid_request(
136                    "code_challenge_method not allowed for this client",
137                    Some(redirect_uri),
138                    state,
139                ));
140            }
141
142            match method {
143                PkceMethod::S256 => {
144                    let bytes = URL_SAFE_NO_PAD.decode(ch).map_err(|_| {
145                        oauth_invalid_request(
146                            "invalid code_challenge for S256 (not base64url/no-pad)",
147                            Some(redirect_uri),
148                            state,
149                        )
150                    })?;
151                    if bytes.len() != 32 {
152                        return Err(oauth_invalid_request(
153                            "invalid code_challenge for S256 (must decode to 32 bytes)",
154                            Some(redirect_uri),
155                            state,
156                        ));
157                    }
158                }
159                PkceMethod::Plain => {
160                    CodeVerifier::new(ch).map_err(|_| {
161                        oauth_invalid_request(
162                            "invalid code_challenge for plain",
163                            Some(redirect_uri),
164                            state,
165                        )
166                    })?;
167                }
168            }
169
170            Some(method)
171        }
172        (None, None) => None,
173        _ => {
174            return Err(oauth_invalid_request(
175                "code_challenge and code_challenge_method must be used together",
176                Some(redirect_uri),
177                state,
178            ));
179        }
180    };
181
182    if pkce_required && parsed.is_none() {
183        return Err(oauth_invalid_request(
184            "PKCE required for this client",
185            Some(redirect_uri),
186            state,
187        ));
188    }
189
190    Ok(parsed)
191}
192
193/// Verify PKCE bindings during `/token`.
194pub fn verify_token_pkce(
195    client: &OAuthClient,
196    stored_challenge: Option<&str>,
197    stored_method: Option<PkceMethod>,
198    provided_verifier: Option<&str>,
199) -> Result<(), ControllerError> {
200    match (stored_challenge, stored_method) {
201        (Some(stored_chal), Some(method)) => {
202            if !client.allows_pkce_method(method) {
203                return Err(PkceFlowError::InvalidRequest(
204                    "pkce method not allowed for this client",
205                )
206                .into());
207            }
208
209            let verifier_str =
210                provided_verifier.ok_or(PkceFlowError::InvalidRequest("code_verifier required"))?;
211            let verifier = CodeVerifier::new(verifier_str)
212                .map_err(|_| PkceFlowError::InvalidRequest("invalid code_verifier"))?;
213
214            let challenge = CodeChallenge::from_stored(stored_chal);
215            if !challenge.verify(&verifier, method) {
216                return Err(PkceFlowError::InvalidGrant("PKCE verification failed").into());
217            }
218        }
219        (None, None) => {
220            if client.requires_pkce() {
221                return Err(PkceFlowError::InvalidRequest("PKCE required for this client").into());
222            }
223        }
224        _ => {
225            return Err(PkceFlowError::ServerError("inconsistent PKCE state").into());
226        }
227    }
228    Ok(())
229}