Skip to main content

headless_lms_server/domain/oauth/
oidc.rs

1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use chrono::{DateTime, Utc};
3use jsonwebtoken::{EncodingKey, Header, encode};
4use rsa::RsaPublicKey;
5use rsa::pkcs1::DecodeRsaPublicKey;
6use rsa::pkcs8::{DecodePublicKey, EncodePublicKey};
7use rsa::traits::PublicKeyParts;
8use secrecy::ExposeSecret;
9use sha2::{Digest as ShaDigest, Sha256};
10
11use crate::domain::error::{ControllerError, ControllerErrorType, OAuthErrorCode, OAuthErrorData};
12use crate::domain::oauth::claims::Claims;
13use crate::prelude::{ApplicationConfiguration, BackendError};
14
15pub fn rsa_n_e_and_kid_from_pem(public_pem: &str) -> anyhow::Result<(String, String, String)> {
16    let pubkey = match RsaPublicKey::from_pkcs1_pem(public_pem) {
17        Ok(k) => k,
18        Err(_) => RsaPublicKey::from_public_key_pem(public_pem)?,
19    };
20
21    let n_b64 = URL_SAFE_NO_PAD.encode(pubkey.n().to_bytes_be());
22    let e_b64 = URL_SAFE_NO_PAD.encode(pubkey.e().to_bytes_be());
23
24    let spki_der = pubkey.to_public_key_der()?;
25    let kid = URL_SAFE_NO_PAD.encode(Sha256::digest(spki_der.as_bytes()));
26
27    Ok((n_b64, e_b64, kid))
28}
29
30/// Generate an ID token. `nonce` should be `Some` only when the authorization request
31/// included a nonce; when absent or empty, the nonce claim is omitted from the id_token.
32pub fn generate_id_token(
33    user_id: uuid::Uuid,
34    client_id: &str,
35    nonce: Option<&str>,
36    expires_at: DateTime<Utc>,
37    issuer: &str,
38    cfg: &ApplicationConfiguration,
39) -> Result<String, ControllerError> {
40    let now = Utc::now().timestamp();
41    let exp = expires_at.timestamp();
42
43    let (_, _, kid) = rsa_n_e_and_kid_from_pem(&cfg.oauth_server_configuration.rsa_public_key)
44        .map_err(|e| {
45            ControllerError::new(
46                ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
47                    error: OAuthErrorCode::ServerError.as_str().into(),
48                    error_description: "Failed to derive key id (kid) from public key".into(),
49                    redirect_uri: None,
50                    state: None,
51                    nonce: None,
52                })),
53                "Failed to derive kid from public key",
54                Some(e),
55            )
56        })?;
57
58    let nonce_claim = nonce.and_then(|s| {
59        if s.is_empty() {
60            None
61        } else {
62            Some(s.to_string())
63        }
64    });
65
66    let claims = Claims {
67        sub: user_id.to_string(),
68        aud: client_id.to_string(),
69        iss: issuer.to_string(),
70        iat: now,
71        exp,
72        nonce: nonce_claim,
73    };
74
75    let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
76    header.kid = Some(kid);
77
78    let enc_key = EncodingKey::from_rsa_pem(
79        cfg.oauth_server_configuration
80            .rsa_private_key
81            .expose_secret()
82            .as_bytes(),
83    )
84    .map_err(|e| {
85        ControllerError::new(
86            ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
87                error: OAuthErrorCode::ServerError.as_str().into(),
88                error_description: "Failed to generate ID token".into(),
89                redirect_uri: None,
90                state: None,
91                nonce: None,
92            })),
93            "Failed to generate ID token (invalid private key)",
94            Some(e.into()),
95        )
96    })?;
97
98    encode(&header, &claims, &enc_key).map_err(|e| {
99        ControllerError::new(
100            ControllerErrorType::OAuthError(Box::new(OAuthErrorData {
101                error: OAuthErrorCode::ServerError.as_str().into(),
102                error_description: "Failed to generate ID token".into(),
103                redirect_uri: None,
104                state: None,
105                nonce: None,
106            })),
107            "Failed to generate ID token",
108            Some(e.into()),
109        )
110    })
111}