headless_lms_models/library/oauth/
digest.rs

1use core::fmt;
2use std::borrow::Cow;
3use std::str::FromStr;
4
5use secrecy::{ExposeSecret, SecretBox};
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use subtle::ConstantTimeEq;
8
9use sqlx::encode::IsNull;
10use sqlx::postgres::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef};
11use sqlx::{Decode, Encode, Postgres, Type, error::BoxDynError};
12
13#[derive(Debug)]
14pub enum DigestError {
15    WrongLength(usize),
16    Hex(hex::FromHexError),
17    Base64(base64::DecodeError),
18}
19
20impl fmt::Display for DigestError {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        match self {
23            Self::WrongLength(n) => write!(f, "expected 32 bytes, got {}", n),
24            Self::Hex(e) => write!(f, "hex decode error: {e}"),
25            Self::Base64(e) => write!(f, "base64 decode error: {e}"),
26        }
27    }
28}
29impl std::error::Error for DigestError {}
30
31/// Secure Digest (zeroizes on drop via `secrecy::SecretBox`)
32pub struct Digest(SecretBox<[u8; Self::LEN]>);
33
34impl Digest {
35    pub const LEN: usize = 32;
36
37    pub fn new(bytes: [u8; Self::LEN]) -> Self {
38        Self(SecretBox::new(Box::new(bytes)))
39    }
40
41    pub fn from_slice(slice: &[u8]) -> Result<Self, DigestError> {
42        if slice.len() != Self::LEN {
43            return Err(DigestError::WrongLength(slice.len()));
44        }
45        let mut arr = [0u8; Self::LEN];
46        arr.copy_from_slice(slice);
47        Ok(Self::new(arr))
48    }
49
50    pub fn as_bytes(&self) -> &[u8; Self::LEN] {
51        self.0.expose_secret()
52    }
53
54    pub fn as_slice(&self) -> &[u8] {
55        &self.0.expose_secret()[..]
56    }
57
58    /// Constant-time equality helper that returns bool.
59    pub fn constant_eq(&self, other: &Self) -> bool {
60        self.as_slice().ct_eq(other.as_slice()).unwrap_u8() == 1
61    }
62}
63
64impl FromStr for Digest {
65    type Err = DigestError;
66    fn from_str(s: &str) -> Result<Self, Self::Err> {
67        let bytes = hex::decode(s).map_err(DigestError::Hex)?;
68        Self::from_slice(&bytes)
69    }
70}
71
72impl fmt::Debug for Digest {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        // Never expose contents
75        write!(f, "Digest(…redacted…)")
76    }
77}
78
79// Intentionally no Clone, no AsRef, no Display.
80
81impl From<[u8; Digest::LEN]> for Digest {
82    fn from(v: [u8; Digest::LEN]) -> Self {
83        Self::new(v)
84    }
85}
86
87impl core::convert::TryFrom<Vec<u8>> for Digest {
88    type Error = DigestError;
89    fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
90        Self::from_slice(&v)
91    }
92}
93
94impl From<Digest> for Vec<u8> {
95    fn from(d: Digest) -> Self {
96        d.as_slice().to_vec()
97    }
98}
99
100impl Serialize for Digest {
101    fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
102        ser.serialize_bytes(self.as_slice())
103    }
104}
105impl<'de> Deserialize<'de> for Digest {
106    fn deserialize<D: Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
107        let bytes: Cow<'de, [u8]> = serde_bytes::deserialize(de)?;
108        Digest::from_slice(&bytes).map_err(serde::de::Error::custom)
109    }
110}
111
112impl<'r> Decode<'r, Postgres> for Digest {
113    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
114        let bytes: Vec<u8> = <Vec<u8> as Decode<Postgres>>::decode(value)?;
115        Ok(Digest::from_slice(&bytes)?)
116    }
117}
118
119impl<'q> Encode<'q, Postgres> for Digest {
120    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
121        <&[u8] as Encode<Postgres>>::encode_by_ref(&self.as_slice(), buf)
122    }
123}
124
125impl Type<Postgres> for Digest {
126    fn type_info() -> PgTypeInfo {
127        <Vec<u8> as Type<Postgres>>::type_info()
128    }
129}
130impl PgHasArrayType for Digest {
131    fn array_type_info() -> PgTypeInfo {
132        <Vec<u8> as PgHasArrayType>::array_type_info()
133    }
134}