headless_lms_server/domain/
models_requests.rs

1//! Contains helper functions that are passed to headless-lms-models where it needs to make requests to exercise services.
2
3use crate::prelude::*;
4use actix_http::Payload;
5use actix_web::{FromRequest, HttpRequest};
6use chrono::{DateTime, Duration, Utc};
7use futures::{
8    FutureExt,
9    future::{BoxFuture, Ready, ready},
10};
11use headless_lms_models::{
12    HttpErrorType, ModelError, ModelErrorType, ModelResult,
13    exercise_service_info::ExerciseServiceInfoApi,
14    exercise_task_gradings::{ExerciseTaskGradingRequest, ExerciseTaskGradingResult},
15    exercise_task_submissions::ExerciseTaskSubmission,
16    exercise_tasks::ExerciseTask,
17};
18
19use headless_lms_utils::error::backend_error::BackendError;
20use hmac::{Hmac, Mac};
21use jwt::{SignWithKey, VerifyWithKey};
22use models::SpecFetcher;
23use sha2::Sha256;
24use std::collections::HashMap;
25use std::sync::{Arc, Mutex};
26use std::{borrow::Cow, fmt::Debug};
27use url::Url;
28
29use super::error::{ControllerError, ControllerErrorType};
30
31// keep in sync with the shared-module constants
32const EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER: &str = "exercise-service-grading-update-claim";
33const EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER: &str = "exercise-service-upload-claim";
34
35/// A type for caching the spec fetching (only for the seed)
36type SpecCache = HashMap<(String, String, Option<String>), serde_json::Value>;
37
38#[derive(Clone, Debug)]
39pub struct JwtKey(Hmac<Sha256>);
40
41impl JwtKey {
42    pub fn try_from_env() -> anyhow::Result<Self> {
43        let jwt_password = std::env::var("JWT_PASSWORD").context("JWT_PASSWORD must be defined")?;
44        let jwt_key = Self::new(&jwt_password)?;
45        Ok(jwt_key)
46    }
47
48    pub fn new(key: &str) -> Result<Self, sha2::digest::InvalidLength> {
49        let key: Hmac<Sha256> = Hmac::new_from_slice(key.as_bytes())?;
50        Ok(Self(key))
51    }
52
53    #[cfg(test)]
54    pub fn test_key() -> Self {
55        let test_jwt_key = "sMG87WlKnNZoITzvL2+jczriTR7JRsCtGu/bSKaSIvw=asdfjklasd***FSDfsdASDFDS";
56        Self::new(test_jwt_key).unwrap()
57    }
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61pub struct UploadClaim<'a> {
62    exercise_service_slug: Cow<'a, str>,
63    expiration_time: DateTime<Utc>,
64}
65
66impl<'a> UploadClaim<'a> {
67    pub fn exercise_service_slug(&self) -> &str {
68        self.exercise_service_slug.as_ref()
69    }
70
71    pub fn expiration_time(&self) -> &DateTime<Utc> {
72        &self.expiration_time
73    }
74
75    pub fn expiring_in_1_day(exercise_service_slug: Cow<'a, str>) -> Self {
76        Self {
77            exercise_service_slug,
78            expiration_time: Utc::now() + Duration::days(1),
79        }
80    }
81
82    pub fn sign(self, key: &JwtKey) -> String {
83        self.sign_with_key(&key.0).expect("should never fail")
84    }
85
86    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
87        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
88            ControllerError::new(
89                ControllerErrorType::BadRequest,
90                format!("Invalid jwt key: {}", err),
91                Some(err.into()),
92            )
93        })?;
94        if claim.expiration_time < Utc::now() {
95            return Err(ControllerError::new(
96                ControllerErrorType::BadRequest,
97                "Upload claim has expired".to_string(),
98                None,
99            ));
100        }
101        Ok(claim)
102    }
103}
104
105impl FromRequest for UploadClaim<'_> {
106    type Error = ControllerError;
107    type Future = Ready<Result<Self, Self::Error>>;
108
109    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
110        let try_from_request = move || {
111            let jwt_key = req
112                .app_data::<web::Data<JwtKey>>()
113                .expect("Missing JwtKey in app data");
114            let header = req
115                .headers()
116                .get(EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER)
117                .ok_or_else(|| {
118                    ControllerError::new(
119                        ControllerErrorType::BadRequest,
120                        format!("Missing header {EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER}",),
121                        None,
122                    )
123                })?;
124            let header = std::str::from_utf8(header.as_bytes()).map_err(|err| {
125                ControllerError::new(
126                    ControllerErrorType::BadRequest,
127                    format!(
128                        "Invalid header {EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER} = {}",
129                        String::from_utf8_lossy(header.as_bytes())
130                    ),
131                    Some(err.into()),
132                )
133            })?;
134            let claim = UploadClaim::validate(header, jwt_key)?;
135            Result::<_, Self::Error>::Ok(claim)
136        };
137        ready(try_from_request())
138    }
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142pub struct GradingUpdateClaim {
143    submission_id: Uuid,
144    expiration_time: DateTime<Utc>,
145}
146
147impl GradingUpdateClaim {
148    pub fn submission_id(&self) -> Uuid {
149        self.submission_id
150    }
151
152    pub fn expiration_time(&self) -> &DateTime<Utc> {
153        &self.expiration_time
154    }
155
156    pub fn expiring_in_1_day(submission_id: Uuid) -> Self {
157        Self {
158            submission_id,
159            expiration_time: Utc::now() + Duration::days(1),
160        }
161    }
162
163    pub fn sign(self, key: &JwtKey) -> String {
164        self.sign_with_key(&key.0).expect("should never fail")
165    }
166
167    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
168        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
169            ControllerError::new(
170                ControllerErrorType::BadRequest,
171                format!("Invalid jwt key: {}", err),
172                Some(err.into()),
173            )
174        })?;
175        if claim.expiration_time < Utc::now() {
176            return Err(ControllerError::new(
177                ControllerErrorType::BadRequest,
178                "Grading update claim has expired".to_string(),
179                None,
180            ));
181        }
182        Ok(claim)
183    }
184}
185
186impl FromRequest for GradingUpdateClaim {
187    type Error = ControllerError;
188    type Future = Ready<Result<Self, Self::Error>>;
189
190    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
191        let try_from_request = move || {
192            let jwt_key = req
193                .app_data::<web::Data<JwtKey>>()
194                .expect("Missing JwtKey in app data");
195            let header = req
196                .headers()
197                .get(EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER)
198                .ok_or_else(|| {
199                    ControllerError::new(
200                        ControllerErrorType::BadRequest,
201                        format!("Missing header {EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER}",),
202                        None,
203                    )
204                })?;
205            let header = std::str::from_utf8(header.as_bytes()).map_err(|err| {
206                ControllerError::new(
207                    ControllerErrorType::BadRequest,
208                    format!(
209                        "Invalid header {EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER} = {}",
210                        String::from_utf8_lossy(header.as_bytes())
211                    ),
212                    Some(err.into()),
213                )
214            })?;
215            let claim = GradingUpdateClaim::validate(header, jwt_key)?;
216            Result::<_, Self::Error>::Ok(claim)
217        };
218        ready(try_from_request())
219    }
220}
221
222/// Accepted by the public-spec and model-solution endpoints of exercise services.
223#[derive(Debug, Serialize)]
224#[cfg_attr(feature = "ts_rs", derive(TS))]
225pub struct SpecRequest<'a> {
226    request_id: Uuid,
227    private_spec: Option<&'a serde_json::Value>,
228    upload_url: Option<String>,
229}
230
231/// Fetches a public/model spec based on the private spec from the given url.
232/// The slug and jwt key are used for an upload claim that allows the service
233/// to upload files as part of the spec.
234pub fn make_spec_fetcher(
235    base_url: String,
236    request_id: Uuid,
237    jwt_key: Arc<JwtKey>,
238) -> impl SpecFetcher {
239    move |url, exercise_service_slug, private_spec| {
240        let client = reqwest::Client::new();
241        let upload_claim = UploadClaim::expiring_in_1_day(exercise_service_slug.into());
242        let upload_url = Some(format!("{base_url}/api/v0/files/{exercise_service_slug}"));
243        let req = client
244            .post(url.clone())
245            .header(
246                EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER,
247                upload_claim.sign(&jwt_key),
248            )
249            .timeout(std::time::Duration::from_secs(120))
250            .json(&SpecRequest {
251                request_id,
252                private_spec,
253                upload_url,
254            })
255            .send();
256        async move {
257            let res = req.await.map_err(ModelError::from)?;
258            let status_code = res.status();
259            if !status_code.is_success() {
260                let error_text = res.text().await;
261                let error = error_text.as_deref().unwrap_or("(No text in response)");
262                error!(
263                    ?url,
264                    ?exercise_service_slug,
265                    ?private_spec,
266                    ?status_code,
267                    "Exercise service returned an error while generating a spec: {}",
268                    error
269                );
270                return Err(ModelError::new(
271                    ModelErrorType::HttpRequest {
272                        status_code: status_code.as_u16(),
273                        response_body: error.to_string(),
274                    },
275                    format!(
276                        "Failed to generate spec for exercise for {exercise_service_slug}: {error}."
277                    ),
278                    None,
279                ));
280            }
281            let json = parse_response_json(res).await?;
282            Ok(json)
283        }
284        .boxed()
285    }
286}
287
288// see `fetch_service_info_fast` while handling HTTP requests
289pub fn fetch_service_info(url: Url) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
290    fetch_service_info_with_timeout(url, 1000 * 120)
291}
292
293// use this while handling HTTP requests, see `fetch_service_info`
294pub fn fetch_service_info_fast(
295    url: Url,
296) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
297    fetch_service_info_with_timeout(url, 1000 * 5)
298}
299
300fn fetch_service_info_with_timeout(
301    url: Url,
302    timeout_ms: u64,
303) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
304    async move {
305        let client = reqwest::Client::new();
306        let res = client
307            .get(url) // e.g. http://example-exercise.default.svc.cluster.local:3002/example-exercise/api/service-info
308            .timeout(std::time::Duration::from_millis(timeout_ms))
309            .send()
310            .await
311            .map_err(ModelError::from)?;
312        let status = res.status();
313        if !status.is_success() {
314            let response_url = res.url().to_string();
315            let body = res.text().await.map_err(ModelError::from)?;
316            warn!(url=?response_url, status=?status, body=?body, "Could not fetch service info.");
317            return Err(ModelError::new(
318                ModelErrorType::HttpRequest {
319                    status_code: status.as_u16(),
320                    response_body: body,
321                },
322                "Could not fetch service info.".to_string(),
323                None,
324            ));
325        }
326        let res = parse_response_json(res).await?;
327        Ok(res)
328    }
329    .boxed()
330}
331
332pub fn make_grading_request_sender(
333    jwt_key: Arc<JwtKey>,
334) -> impl Fn(
335    Url,
336    &ExerciseTask,
337    &ExerciseTaskSubmission,
338) -> BoxFuture<'static, ModelResult<ExerciseTaskGradingResult>> {
339    move |grade_url, exercise_task, submission| {
340        let client = reqwest::Client::new();
341        // TODO: use real url
342        let grading_update_url = format!(
343            "http://project-331.local/api/v0/exercise-services/grading/grading-update/{}",
344            submission.id
345        );
346        let grading_update_claim = GradingUpdateClaim::expiring_in_1_day(submission.id);
347        let req = client
348            .post(grade_url)
349            .header(
350                EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER,
351                grading_update_claim.sign(&jwt_key),
352            )
353            .timeout(std::time::Duration::from_secs(120))
354            .json(&ExerciseTaskGradingRequest {
355                grading_update_url: &grading_update_url,
356                exercise_spec: &exercise_task.private_spec,
357                submission_data: &submission.data_json,
358            });
359        async move {
360            let res = req.send().await.map_err(ModelError::from)?;
361            let status = res.status();
362            if !status.is_success() {
363                let status_code = status.as_u16();
364                let response_body = res.text().await.unwrap_or_default();
365                error!(
366                    ?response_body,
367                    status_code = %status_code,
368                    "Grading request returned an unsuccesful status code"
369                );
370
371                return Err(ModelError::new(
372                    ModelErrorType::HttpRequest {
373                        status_code,
374                        response_body: response_body.clone(),
375                    },
376                    format!(
377                        "Grading failed with status: {} response: {}",
378                        status_code, response_body
379                    ),
380                    None,
381                ));
382            }
383            let obj = parse_response_json(res).await?;
384            info!("Received a grading result: {:#?}", &obj);
385            Ok(obj)
386        }
387        .boxed()
388    }
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct GivePeerReviewClaim {
393    pub exercise_slide_submission_id: Uuid,
394    pub peer_or_self_review_config_id: Uuid,
395    expiration_time: DateTime<Utc>,
396}
397
398impl GivePeerReviewClaim {
399    pub fn expiring_in_1_day(
400        exercise_slide_submission_id: Uuid,
401        peer_or_self_review_config_id: Uuid,
402    ) -> Self {
403        Self {
404            exercise_slide_submission_id,
405            peer_or_self_review_config_id,
406            expiration_time: Utc::now() + Duration::days(1),
407        }
408    }
409
410    pub fn sign(self, key: &JwtKey) -> String {
411        self.sign_with_key(&key.0).expect("should never fail")
412    }
413
414    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
415        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
416            ControllerError::new(
417                ControllerErrorType::BadRequest,
418                format!("Invalid claim: {}", err),
419                Some(err.into()),
420            )
421        })?;
422        if claim.expiration_time < Utc::now() {
423            return Err(ControllerError::new(
424                ControllerErrorType::BadRequest,
425                "The review has expired.".to_string(),
426                None,
427            ));
428        }
429        Ok(claim)
430    }
431}
432
433/// A caching spec fetcher ONLY FOR THE SEED that returns a cached spec if the same
434/// (url, exercise_service_slug, private_spec) is requested. Since this is only used during seeding,
435/// there is no cache eviction.
436pub fn make_seed_spec_fetcher_with_cache(
437    base_url: String,
438    request_id: Uuid,
439    jwt_key: Arc<JwtKey>,
440) -> impl SpecFetcher {
441    // Cache key: (url, exercise_service_slug, private_spec serialized)
442    let cache: Arc<Mutex<SpecCache>> = Arc::new(Mutex::new(HashMap::new()));
443
444    // Create the base non-caching spec fetcher and wrap it in Arc to make it clonable
445    let base_fetcher = Arc::new(make_spec_fetcher(base_url, request_id, jwt_key));
446
447    move |url, exercise_service_slug, private_spec| {
448        let url_str = url.to_string();
449        let service_slug = exercise_service_slug.to_string();
450        // Convert private_spec to string for cache key if present
451        let private_spec_str =
452            private_spec.map(|spec| serde_json::to_string(&spec).unwrap_or_default());
453        let key = (url_str.clone(), service_slug.clone(), private_spec_str);
454        let cache = Arc::clone(&cache);
455        let base_fetcher = Arc::clone(&base_fetcher);
456
457        async move {
458            // Try to get from cache first
459            if let Some(cached_spec) = cache
460                .lock()
461                .expect("Seed spec fetcher cache lock poisoned")
462                .get(&key)
463            {
464                return Ok(cached_spec.clone());
465            }
466
467            // Not in cache - fetch using base fetcher
468            let fetched_spec = base_fetcher(url, exercise_service_slug, private_spec).await?;
469
470            // Store in cache
471            cache
472                .lock()
473                .expect("Seed spec fetcher cache lock poisoned")
474                .insert(key, fetched_spec.clone());
475
476            Ok(fetched_spec)
477        }
478        .boxed()
479    }
480}
481
482/// Safely parses a response body as JSON, capturing the actual response body in error cases
483async fn parse_response_json<T>(response: reqwest::Response) -> ModelResult<T>
484where
485    T: serde::de::DeserializeOwned,
486{
487    let status = response.status();
488    let response_text = response.text().await.map_err(ModelError::from)?;
489
490    serde_json::from_str(&response_text).map_err(|err| {
491        ModelError::new(
492            ModelErrorType::HttpError {
493                error_type: HttpErrorType::ResponseDecodeFailed,
494                reason: err.to_string(),
495                status_code: Some(status.as_u16()),
496                response_body: Some(response_text),
497            },
498            format!("Failed to decode JSON response: {}", err),
499            None,
500        )
501    })
502}