headless_lms_server/domain/oauth/
pkce.rs1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use sha2::{Digest, Sha256};
3
4use crate::domain::error::{ControllerError, PkceFlowError};
5use crate::domain::oauth::helpers::oauth_invalid_request;
6use headless_lms_models::oauth_client::OAuthClient;
7
8pub use headless_lms_models::library::oauth::pkce::PkceMethod;
10
11pub const VERIFIER_MIN_LEN: usize = 43;
13pub const VERIFIER_MAX_LEN: usize = 128;
14
15#[derive(Debug, thiserror::Error)]
17pub enum PkceError {
18 #[error("code_verifier length out of bounds")]
19 BadLength,
20 #[error("code_verifier contains invalid characters")]
21 BadCharset,
22 #[error("Disallowed PKCE method")]
23 BadMethod,
24}
25
26#[derive(Debug, Clone)]
28pub struct CodeVerifier(String);
29
30impl CodeVerifier {
31 pub fn new(s: &str) -> Result<Self, PkceError> {
33 validate_verifier(s)?;
34 Ok(Self(s.to_owned()))
35 }
36
37 pub fn try_from_string(s: String) -> Result<Self, PkceError> {
39 validate_verifier(&s)?;
40 Ok(Self(s))
41 }
42
43 pub fn as_str(&self) -> &str {
45 &self.0
46 }
47
48 pub fn to_challenge(&self, method: PkceMethod) -> CodeChallenge {
50 match method {
51 PkceMethod::Plain => CodeChallenge(self.0.clone()),
52 PkceMethod::S256 => {
53 let digest = Sha256::digest(self.0.as_bytes());
54 CodeChallenge(URL_SAFE_NO_PAD.encode(digest))
55 }
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct CodeChallenge(String);
63
64impl CodeChallenge {
65 pub fn from_stored<S: Into<String>>(s: S) -> Self {
67 Self(s.into())
68 }
69
70 pub fn as_str(&self) -> &str {
71 &self.0
72 }
73
74 pub fn verify(&self, verifier: &CodeVerifier, method: PkceMethod) -> bool {
76 let computed = verifier.to_challenge(method);
77 constant_time_eq(self.as_str(), computed.as_str())
78 }
79}
80
81fn validate_verifier(v: &str) -> Result<(), PkceError> {
85 let len = v.len();
86 if !(VERIFIER_MIN_LEN..=VERIFIER_MAX_LEN).contains(&len) {
87 return Err(PkceError::BadLength);
88 }
89 if !v.bytes().all(|b| {
90 matches!(
91 b,
92 b'A'..=b'Z' |
93 b'a'..=b'z' |
94 b'0'..=b'9' |
95 b'-' | b'.' | b'_' | b'~'
96 )
97 }) {
98 return Err(PkceError::BadCharset);
99 }
100 Ok(())
101}
102
103fn constant_time_eq(a: &str, b: &str) -> bool {
105 if a.len() != b.len() {
106 return false;
107 }
108 let mut diff = 0u8;
109 for (x, y) in a.bytes().zip(b.bytes()) {
110 diff |= x ^ y;
111 }
112 diff == 0
113}
114
115pub fn parse_authorize_pkce(
117 client: &OAuthClient,
118 code_challenge: Option<&str>,
119 code_challenge_method: Option<&str>,
120 redirect_uri: &str,
121 state: Option<&str>,
122) -> Result<Option<PkceMethod>, ControllerError> {
123 let pkce_required = client.requires_pkce();
124 let parsed = match (code_challenge, code_challenge_method) {
125 (Some(ch), Some(method_str)) => {
126 let method = PkceMethod::parse(method_str).ok_or_else(|| {
127 oauth_invalid_request(
128 "unsupported code_challenge_method",
129 Some(redirect_uri),
130 state,
131 )
132 })?;
133
134 if !client.allows_pkce_method(method) {
135 return Err(oauth_invalid_request(
136 "code_challenge_method not allowed for this client",
137 Some(redirect_uri),
138 state,
139 ));
140 }
141
142 match method {
143 PkceMethod::S256 => {
144 let bytes = URL_SAFE_NO_PAD.decode(ch).map_err(|_| {
145 oauth_invalid_request(
146 "invalid code_challenge for S256 (not base64url/no-pad)",
147 Some(redirect_uri),
148 state,
149 )
150 })?;
151 if bytes.len() != 32 {
152 return Err(oauth_invalid_request(
153 "invalid code_challenge for S256 (must decode to 32 bytes)",
154 Some(redirect_uri),
155 state,
156 ));
157 }
158 }
159 PkceMethod::Plain => {
160 CodeVerifier::new(ch).map_err(|_| {
161 oauth_invalid_request(
162 "invalid code_challenge for plain",
163 Some(redirect_uri),
164 state,
165 )
166 })?;
167 }
168 }
169
170 Some(method)
171 }
172 (None, None) => None,
173 _ => {
174 return Err(oauth_invalid_request(
175 "code_challenge and code_challenge_method must be used together",
176 Some(redirect_uri),
177 state,
178 ));
179 }
180 };
181
182 if pkce_required && parsed.is_none() {
183 return Err(oauth_invalid_request(
184 "PKCE required for this client",
185 Some(redirect_uri),
186 state,
187 ));
188 }
189
190 Ok(parsed)
191}
192
193pub fn verify_token_pkce(
195 client: &OAuthClient,
196 stored_challenge: Option<&str>,
197 stored_method: Option<PkceMethod>,
198 provided_verifier: Option<&str>,
199) -> Result<(), ControllerError> {
200 match (stored_challenge, stored_method) {
201 (Some(stored_chal), Some(method)) => {
202 if !client.allows_pkce_method(method) {
203 return Err(PkceFlowError::InvalidRequest(
204 "pkce method not allowed for this client",
205 )
206 .into());
207 }
208
209 let verifier_str =
210 provided_verifier.ok_or(PkceFlowError::InvalidRequest("code_verifier required"))?;
211 let verifier = CodeVerifier::new(verifier_str)
212 .map_err(|_| PkceFlowError::InvalidRequest("invalid code_verifier"))?;
213
214 let challenge = CodeChallenge::from_stored(stored_chal);
215 if !challenge.verify(&verifier, method) {
216 return Err(PkceFlowError::InvalidGrant("PKCE verification failed").into());
217 }
218 }
219 (None, None) => {
220 if client.requires_pkce() {
221 return Err(PkceFlowError::InvalidRequest("PKCE required for this client").into());
222 }
223 }
224 _ => {
225 return Err(PkceFlowError::ServerError("inconsistent PKCE state").into());
226 }
227 }
228 Ok(())
229}