1use std::borrow::Cow;
2
3use serde::de::DeserializeOwned;
4
5use crate::algorithms::AlgorithmFamily;
6use crate::crypto::verify;
7use crate::errors::{new_error, ErrorKind, Result};
8use crate::header::Header;
9use crate::pem::decoder::PemEncodedKey;
10use crate::serialization::from_jwt_part_claims;
11use crate::validation::{validate, Validation};
12
13#[derive(Debug)]
15pub struct TokenData<T> {
16    pub header: Header,
18    pub claims: T,
20}
21
22macro_rules! expect_two {
25    ($iter:expr) => {{
26        let mut i = $iter;
27        match (i.next(), i.next(), i.next()) {
28            (Some(first), Some(second), None) => (first, second),
29            _ => return Err(new_error(ErrorKind::InvalidToken)),
30        }
31    }};
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub(crate) enum DecodingKeyKind<'a> {
36    SecretOrDer(Cow<'a, [u8]>),
37    RsaModulusExponent { n: Cow<'a, str>, e: Cow<'a, str> },
38}
39
40#[derive(Debug, Clone, PartialEq)]
43pub struct DecodingKey<'a> {
44    pub(crate) family: AlgorithmFamily,
45    pub(crate) kind: DecodingKeyKind<'a>,
46}
47
48impl<'a> DecodingKey<'a> {
49    pub fn from_secret(secret: &'a [u8]) -> Self {
51        DecodingKey {
52            family: AlgorithmFamily::Hmac,
53            kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)),
54        }
55    }
56
57    pub fn from_base64_secret(secret: &str) -> Result<Self> {
59        let out = base64::decode(&secret)?;
60        Ok(DecodingKey {
61            family: AlgorithmFamily::Hmac,
62            kind: DecodingKeyKind::SecretOrDer(Cow::Owned(out)),
63        })
64    }
65
66    pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> {
68        let pem_key = PemEncodedKey::new(key)?;
69        let content = pem_key.as_rsa_key()?;
70        Ok(DecodingKey {
71            family: AlgorithmFamily::Rsa,
72            kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
73        })
74    }
75
76    pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self {
78        DecodingKey {
79            family: AlgorithmFamily::Rsa,
80            kind: DecodingKeyKind::RsaModulusExponent {
81                n: Cow::Borrowed(modulus),
82                e: Cow::Borrowed(exponent),
83            },
84        }
85    }
86
87    pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> {
89        let pem_key = PemEncodedKey::new(key)?;
90        let content = pem_key.as_ec_public_key()?;
91        Ok(DecodingKey {
92            family: AlgorithmFamily::Ec,
93            kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
94        })
95    }
96
97    pub fn from_rsa_der(der: &'a [u8]) -> Self {
99        DecodingKey {
100            family: AlgorithmFamily::Rsa,
101            kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
102        }
103    }
104
105    pub fn from_ec_der(der: &'a [u8]) -> Self {
107        DecodingKey {
108            family: AlgorithmFamily::Ec,
109            kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
110        }
111    }
112
113    pub fn into_static(self) -> DecodingKey<'static> {
115        use DecodingKeyKind::*;
116        let DecodingKey { family, kind } = self;
117        let static_kind = match kind {
118            SecretOrDer(key) => SecretOrDer(Cow::Owned(key.into_owned())),
119            RsaModulusExponent { n, e } => {
120                RsaModulusExponent { n: Cow::Owned(n.into_owned()), e: Cow::Owned(e.into_owned()) }
121            }
122        };
123        DecodingKey { family, kind: static_kind }
124    }
125
126    pub(crate) fn as_bytes(&self) -> &[u8] {
127        match &self.kind {
128            DecodingKeyKind::SecretOrDer(b) => &b,
129            DecodingKeyKind::RsaModulusExponent { .. } => unreachable!(),
130        }
131    }
132}
133
134pub fn decode<T: DeserializeOwned>(
153    token: &str,
154    key: &DecodingKey,
155    validation: &Validation,
156) -> Result<TokenData<T>> {
157    for alg in &validation.algorithms {
158        if key.family != alg.family() {
159            return Err(new_error(ErrorKind::InvalidAlgorithm));
160        }
161    }
162
163    let (signature, message) = expect_two!(token.rsplitn(2, '.'));
164    let (claims, header) = expect_two!(message.rsplitn(2, '.'));
165    let header = Header::from_encoded(header)?;
166
167    if !validation.algorithms.contains(&header.alg) {
168        return Err(new_error(ErrorKind::InvalidAlgorithm));
169    }
170
171    if !verify(signature, message, key, header.alg)? {
172        return Err(new_error(ErrorKind::InvalidSignature));
173    }
174
175    let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?;
176    validate(&claims_map, validation)?;
177
178    Ok(TokenData { header, claims: decoded_claims })
179}
180
181pub fn dangerous_insecure_decode<T: DeserializeOwned>(token: &str) -> Result<TokenData<T>> {
200    let (_, message) = expect_two!(token.rsplitn(2, '.'));
201    let (claims, header) = expect_two!(message.rsplitn(2, '.'));
202    let header = Header::from_encoded(header)?;
203
204    let (decoded_claims, _): (T, _) = from_jwt_part_claims(claims)?;
205
206    Ok(TokenData { header, claims: decoded_claims })
207}
208
209pub fn dangerous_insecure_decode_with_validation<T: DeserializeOwned>(
230    token: &str,
231    validation: &Validation,
232) -> Result<TokenData<T>> {
233    let (_, message) = expect_two!(token.rsplitn(2, '.'));
234    let (claims, header) = expect_two!(message.rsplitn(2, '.'));
235    let header = Header::from_encoded(header)?;
236
237    if !validation.algorithms.contains(&header.alg) {
238        return Err(new_error(ErrorKind::InvalidAlgorithm));
239    }
240
241    let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?;
242    validate(&claims_map, validation)?;
243
244    Ok(TokenData { header, claims: decoded_claims })
245}
246
247#[deprecated(
249    note = "This function has been renamed to `dangerous_insecure_decode` and will be removed in a later version."
250)]
251pub fn dangerous_unsafe_decode<T: DeserializeOwned>(token: &str) -> Result<TokenData<T>> {
252    dangerous_insecure_decode(token)
253
254}
255
256pub fn decode_header(token: &str) -> Result<Header> {
267    let (_, message) = expect_two!(token.rsplitn(2, '.'));
268    let (_, header) = expect_two!(message.rsplitn(2, '.'));
269    Header::from_encoded(header)
270}