1use std::borrow::Cow;
2
3use serde::de::DeserializeOwned;
4
5use crate::algorithms::AlgorithmFamily;
6use crate::crypto::verify;
7use crate::errors::{new_error, ErrorKind, Result};
8use crate::header::Header;
9use crate::pem::decoder::PemEncodedKey;
10use crate::serialization::from_jwt_part_claims;
11use crate::validation::{validate, Validation};
12
13#[derive(Debug)]
15pub struct TokenData<T> {
16 pub header: Header,
18 pub claims: T,
20}
21
22macro_rules! expect_two {
25 ($iter:expr) => {{
26 let mut i = $iter;
27 match (i.next(), i.next(), i.next()) {
28 (Some(first), Some(second), None) => (first, second),
29 _ => return Err(new_error(ErrorKind::InvalidToken)),
30 }
31 }};
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub(crate) enum DecodingKeyKind<'a> {
36 SecretOrDer(Cow<'a, [u8]>),
37 RsaModulusExponent { n: Cow<'a, str>, e: Cow<'a, str> },
38}
39
40#[derive(Debug, Clone, PartialEq)]
43pub struct DecodingKey<'a> {
44 pub(crate) family: AlgorithmFamily,
45 pub(crate) kind: DecodingKeyKind<'a>,
46}
47
48impl<'a> DecodingKey<'a> {
49 pub fn from_secret(secret: &'a [u8]) -> Self {
51 DecodingKey {
52 family: AlgorithmFamily::Hmac,
53 kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(secret)),
54 }
55 }
56
57 pub fn from_base64_secret(secret: &str) -> Result<Self> {
59 let out = base64::decode(&secret)?;
60 Ok(DecodingKey {
61 family: AlgorithmFamily::Hmac,
62 kind: DecodingKeyKind::SecretOrDer(Cow::Owned(out)),
63 })
64 }
65
66 pub fn from_rsa_pem(key: &'a [u8]) -> Result<Self> {
68 let pem_key = PemEncodedKey::new(key)?;
69 let content = pem_key.as_rsa_key()?;
70 Ok(DecodingKey {
71 family: AlgorithmFamily::Rsa,
72 kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
73 })
74 }
75
76 pub fn from_rsa_components(modulus: &'a str, exponent: &'a str) -> Self {
78 DecodingKey {
79 family: AlgorithmFamily::Rsa,
80 kind: DecodingKeyKind::RsaModulusExponent {
81 n: Cow::Borrowed(modulus),
82 e: Cow::Borrowed(exponent),
83 },
84 }
85 }
86
87 pub fn from_ec_pem(key: &'a [u8]) -> Result<Self> {
89 let pem_key = PemEncodedKey::new(key)?;
90 let content = pem_key.as_ec_public_key()?;
91 Ok(DecodingKey {
92 family: AlgorithmFamily::Ec,
93 kind: DecodingKeyKind::SecretOrDer(Cow::Owned(content.to_vec())),
94 })
95 }
96
97 pub fn from_rsa_der(der: &'a [u8]) -> Self {
99 DecodingKey {
100 family: AlgorithmFamily::Rsa,
101 kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
102 }
103 }
104
105 pub fn from_ec_der(der: &'a [u8]) -> Self {
107 DecodingKey {
108 family: AlgorithmFamily::Ec,
109 kind: DecodingKeyKind::SecretOrDer(Cow::Borrowed(der)),
110 }
111 }
112
113 pub fn into_static(self) -> DecodingKey<'static> {
115 use DecodingKeyKind::*;
116 let DecodingKey { family, kind } = self;
117 let static_kind = match kind {
118 SecretOrDer(key) => SecretOrDer(Cow::Owned(key.into_owned())),
119 RsaModulusExponent { n, e } => {
120 RsaModulusExponent { n: Cow::Owned(n.into_owned()), e: Cow::Owned(e.into_owned()) }
121 }
122 };
123 DecodingKey { family, kind: static_kind }
124 }
125
126 pub(crate) fn as_bytes(&self) -> &[u8] {
127 match &self.kind {
128 DecodingKeyKind::SecretOrDer(b) => &b,
129 DecodingKeyKind::RsaModulusExponent { .. } => unreachable!(),
130 }
131 }
132}
133
134pub fn decode<T: DeserializeOwned>(
153 token: &str,
154 key: &DecodingKey,
155 validation: &Validation,
156) -> Result<TokenData<T>> {
157 for alg in &validation.algorithms {
158 if key.family != alg.family() {
159 return Err(new_error(ErrorKind::InvalidAlgorithm));
160 }
161 }
162
163 let (signature, message) = expect_two!(token.rsplitn(2, '.'));
164 let (claims, header) = expect_two!(message.rsplitn(2, '.'));
165 let header = Header::from_encoded(header)?;
166
167 if !validation.algorithms.contains(&header.alg) {
168 return Err(new_error(ErrorKind::InvalidAlgorithm));
169 }
170
171 if !verify(signature, message, key, header.alg)? {
172 return Err(new_error(ErrorKind::InvalidSignature));
173 }
174
175 let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?;
176 validate(&claims_map, validation)?;
177
178 Ok(TokenData { header, claims: decoded_claims })
179}
180
181pub fn dangerous_insecure_decode<T: DeserializeOwned>(token: &str) -> Result<TokenData<T>> {
200 let (_, message) = expect_two!(token.rsplitn(2, '.'));
201 let (claims, header) = expect_two!(message.rsplitn(2, '.'));
202 let header = Header::from_encoded(header)?;
203
204 let (decoded_claims, _): (T, _) = from_jwt_part_claims(claims)?;
205
206 Ok(TokenData { header, claims: decoded_claims })
207}
208
209pub fn dangerous_insecure_decode_with_validation<T: DeserializeOwned>(
230 token: &str,
231 validation: &Validation,
232) -> Result<TokenData<T>> {
233 let (_, message) = expect_two!(token.rsplitn(2, '.'));
234 let (claims, header) = expect_two!(message.rsplitn(2, '.'));
235 let header = Header::from_encoded(header)?;
236
237 if !validation.algorithms.contains(&header.alg) {
238 return Err(new_error(ErrorKind::InvalidAlgorithm));
239 }
240
241 let (decoded_claims, claims_map): (T, _) = from_jwt_part_claims(claims)?;
242 validate(&claims_map, validation)?;
243
244 Ok(TokenData { header, claims: decoded_claims })
245}
246
247#[deprecated(
249 note = "This function has been renamed to `dangerous_insecure_decode` and will be removed in a later version."
250)]
251pub fn dangerous_unsafe_decode<T: DeserializeOwned>(token: &str) -> Result<TokenData<T>> {
252 dangerous_insecure_decode(token)
253
254}
255
256pub fn decode_header(token: &str) -> Result<Header> {
267 let (_, message) = expect_two!(token.rsplitn(2, '.'));
268 let (_, header) = expect_two!(message.rsplitn(2, '.'));
269 Header::from_encoded(header)
270}