headless_lms_server/domain/
models_requests.rs

1//! Contains helper functions that are passed to headless-lms-models where it needs to make requests to exercise services.
2
3use crate::prelude::*;
4use actix_http::Payload;
5use actix_web::{FromRequest, HttpRequest};
6use chrono::{DateTime, Duration, Utc};
7use futures::{
8    FutureExt,
9    future::{BoxFuture, Ready, ready},
10};
11use headless_lms_models::{
12    HttpErrorType, ModelError, ModelErrorType, ModelResult,
13    exercise_service_info::ExerciseServiceInfoApi,
14    exercise_task_gradings::{ExerciseTaskGradingRequest, ExerciseTaskGradingResult},
15    exercise_task_submissions::ExerciseTaskSubmission,
16    exercise_tasks::ExerciseTask,
17};
18
19use headless_lms_utils::error::backend_error::BackendError;
20use hmac::{Hmac, Mac};
21use jwt::{SignWithKey, VerifyWithKey};
22use models::SpecFetcher;
23use sha2::Sha256;
24use std::collections::HashMap;
25use std::sync::{Arc, Mutex};
26use std::{borrow::Cow, fmt::Debug};
27use url::Url;
28
29use super::error::{ControllerError, ControllerErrorType};
30
31// keep in sync with the shared-module constants
32const EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER: &str = "exercise-service-grading-update-claim";
33const EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER: &str = "exercise-service-upload-claim";
34
35/// A type for caching the spec fetching (only for the seed)
36type SpecCache = HashMap<(String, String, Option<String>), serde_json::Value>;
37
38#[derive(Clone, Debug)]
39pub struct JwtKey(Hmac<Sha256>);
40
41impl JwtKey {
42    pub fn try_from_env() -> anyhow::Result<Self> {
43        let jwt_password = std::env::var("JWT_PASSWORD").context("JWT_PASSWORD must be defined")?;
44        let jwt_key = Self::new(&jwt_password)?;
45        Ok(jwt_key)
46    }
47
48    pub fn new(key: &str) -> Result<Self, sha2::digest::InvalidLength> {
49        let key: Hmac<Sha256> = Hmac::new_from_slice(key.as_bytes())?;
50        Ok(Self(key))
51    }
52
53    #[cfg(test)]
54    pub fn test_key() -> Self {
55        let test_jwt_key = "sMG87WlKnNZoITzvL2+jczriTR7JRsCtGu/bSKaSIvw=asdfjklasd***FSDfsdASDFDS";
56        Self::new(test_jwt_key).unwrap()
57    }
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61pub struct UploadClaim<'a> {
62    exercise_service_slug: Cow<'a, str>,
63    expiration_time: DateTime<Utc>,
64}
65
66impl<'a> UploadClaim<'a> {
67    pub fn exercise_service_slug(&self) -> &str {
68        self.exercise_service_slug.as_ref()
69    }
70
71    pub fn expiration_time(&self) -> &DateTime<Utc> {
72        &self.expiration_time
73    }
74
75    pub fn expiring_in_1_day(exercise_service_slug: Cow<'a, str>) -> Self {
76        Self {
77            exercise_service_slug,
78            expiration_time: Utc::now() + Duration::days(1),
79        }
80    }
81
82    pub fn sign(self, key: &JwtKey) -> String {
83        self.sign_with_key(&key.0)
84            .expect("JWT signing failed - this should never happen with a valid key")
85    }
86
87    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
88        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
89            ControllerError::new(
90                ControllerErrorType::BadRequest,
91                format!("Invalid jwt key: {}", err),
92                Some(err.into()),
93            )
94        })?;
95        if claim.expiration_time < Utc::now() {
96            return Err(ControllerError::new(
97                ControllerErrorType::BadRequest,
98                "Upload claim has expired".to_string(),
99                None,
100            ));
101        }
102        Ok(claim)
103    }
104}
105
106impl FromRequest for UploadClaim<'_> {
107    type Error = ControllerError;
108    type Future = Ready<Result<Self, Self::Error>>;
109
110    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
111        let try_from_request = move || {
112            let jwt_key = req.app_data::<web::Data<JwtKey>>().ok_or_else(|| {
113                ControllerError::new(
114                    ControllerErrorType::InternalServerError,
115                    "Missing JwtKey in app data - server configuration error".to_string(),
116                    None,
117                )
118            })?;
119            let header = req
120                .headers()
121                .get(EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER)
122                .ok_or_else(|| {
123                    ControllerError::new(
124                        ControllerErrorType::BadRequest,
125                        format!("Missing header {EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER}",),
126                        None,
127                    )
128                })?;
129            let header = std::str::from_utf8(header.as_bytes()).map_err(|err| {
130                ControllerError::new(
131                    ControllerErrorType::BadRequest,
132                    format!(
133                        "Invalid header {EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER} = {}",
134                        String::from_utf8_lossy(header.as_bytes())
135                    ),
136                    Some(err.into()),
137                )
138            })?;
139            let claim = UploadClaim::validate(header, jwt_key)?;
140            Result::<_, Self::Error>::Ok(claim)
141        };
142        ready(try_from_request())
143    }
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147pub struct GradingUpdateClaim {
148    submission_id: Uuid,
149    expiration_time: DateTime<Utc>,
150}
151
152impl GradingUpdateClaim {
153    pub fn submission_id(&self) -> Uuid {
154        self.submission_id
155    }
156
157    pub fn expiration_time(&self) -> &DateTime<Utc> {
158        &self.expiration_time
159    }
160
161    pub fn expiring_in_1_day(submission_id: Uuid) -> Self {
162        Self {
163            submission_id,
164            expiration_time: Utc::now() + Duration::days(1),
165        }
166    }
167
168    pub fn sign(self, key: &JwtKey) -> String {
169        self.sign_with_key(&key.0)
170            .expect("JWT signing failed - this should never happen with a valid key")
171    }
172
173    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
174        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
175            ControllerError::new(
176                ControllerErrorType::BadRequest,
177                format!("Invalid jwt key: {}", err),
178                Some(err.into()),
179            )
180        })?;
181        if claim.expiration_time < Utc::now() {
182            return Err(ControllerError::new(
183                ControllerErrorType::BadRequest,
184                "Grading update claim has expired".to_string(),
185                None,
186            ));
187        }
188        Ok(claim)
189    }
190}
191
192impl FromRequest for GradingUpdateClaim {
193    type Error = ControllerError;
194    type Future = Ready<Result<Self, Self::Error>>;
195
196    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
197        let try_from_request = move || {
198            let jwt_key = req.app_data::<web::Data<JwtKey>>().ok_or_else(|| {
199                ControllerError::new(
200                    ControllerErrorType::InternalServerError,
201                    "Missing JwtKey in app data - server configuration error".to_string(),
202                    None,
203                )
204            })?;
205            let header = req
206                .headers()
207                .get(EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER)
208                .ok_or_else(|| {
209                    ControllerError::new(
210                        ControllerErrorType::BadRequest,
211                        format!("Missing header {EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER}",),
212                        None,
213                    )
214                })?;
215            let header = std::str::from_utf8(header.as_bytes()).map_err(|err| {
216                ControllerError::new(
217                    ControllerErrorType::BadRequest,
218                    format!(
219                        "Invalid header {EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER} = {}",
220                        String::from_utf8_lossy(header.as_bytes())
221                    ),
222                    Some(err.into()),
223                )
224            })?;
225            let claim = GradingUpdateClaim::validate(header, jwt_key)?;
226            Result::<_, Self::Error>::Ok(claim)
227        };
228        ready(try_from_request())
229    }
230}
231
232/// Accepted by the public-spec and model-solution endpoints of exercise services.
233#[derive(Debug, Serialize)]
234#[cfg_attr(feature = "ts_rs", derive(TS))]
235pub struct SpecRequest<'a> {
236    request_id: Uuid,
237    private_spec: Option<&'a serde_json::Value>,
238    upload_url: Option<String>,
239}
240
241/// Fetches a public/model spec based on the private spec from the given url.
242/// The slug and jwt key are used for an upload claim that allows the service
243/// to upload files as part of the spec.
244pub fn make_spec_fetcher(
245    base_url: String,
246    request_id: Uuid,
247    jwt_key: Arc<JwtKey>,
248) -> impl SpecFetcher {
249    move |url, exercise_service_slug, private_spec| {
250        let client = reqwest::Client::new();
251        let upload_claim = UploadClaim::expiring_in_1_day(exercise_service_slug.into());
252        let upload_url = Some(format!("{base_url}/api/v0/files/{exercise_service_slug}"));
253        let req = client
254            .post(url.clone())
255            .header(
256                EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER,
257                upload_claim.sign(&jwt_key),
258            )
259            .timeout(std::time::Duration::from_secs(120))
260            .json(&SpecRequest {
261                request_id,
262                private_spec,
263                upload_url,
264            })
265            .send();
266        async move {
267            let res = req.await.map_err(ModelError::from)?;
268            let status_code = res.status();
269            if !status_code.is_success() {
270                let error_text = res.text().await;
271                let error = error_text.as_deref().unwrap_or("(No text in response)");
272                error!(
273                    ?url,
274                    ?exercise_service_slug,
275                    ?private_spec,
276                    ?status_code,
277                    "Exercise service returned an error while generating a spec: {}",
278                    error
279                );
280                return Err(ModelError::new(
281                    ModelErrorType::HttpRequest {
282                        status_code: status_code.as_u16(),
283                        response_body: error.to_string(),
284                    },
285                    format!(
286                        "Failed to generate spec for exercise for {exercise_service_slug}: {error}."
287                    ),
288                    None,
289                ));
290            }
291            let json = parse_response_json(res).await?;
292            Ok(json)
293        }
294        .boxed()
295    }
296}
297
298// see `fetch_service_info_fast` while handling HTTP requests
299pub fn fetch_service_info(url: Url) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
300    fetch_service_info_with_timeout(url, 1000 * 120)
301}
302
303// use this while handling HTTP requests, see `fetch_service_info`
304pub fn fetch_service_info_fast(
305    url: Url,
306) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
307    fetch_service_info_with_timeout(url, 1000 * 5)
308}
309
310fn fetch_service_info_with_timeout(
311    url: Url,
312    timeout_ms: u64,
313) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
314    async move {
315        let client = reqwest::Client::new();
316        let res = client
317            .get(url) // e.g. http://example-exercise.default.svc.cluster.local:3002/example-exercise/api/service-info
318            .timeout(std::time::Duration::from_millis(timeout_ms))
319            .send()
320            .await
321            .map_err(ModelError::from)?;
322        let status = res.status();
323        if !status.is_success() {
324            let response_url = res.url().to_string();
325            let body = res.text().await.map_err(ModelError::from)?;
326            warn!(url=?response_url, status=?status, body=?body, "Could not fetch service info.");
327            return Err(ModelError::new(
328                ModelErrorType::HttpRequest {
329                    status_code: status.as_u16(),
330                    response_body: body,
331                },
332                "Could not fetch service info.".to_string(),
333                None,
334            ));
335        }
336        let res = parse_response_json(res).await?;
337        Ok(res)
338    }
339    .boxed()
340}
341
342pub fn make_grading_request_sender(
343    jwt_key: Arc<JwtKey>,
344) -> impl Fn(
345    Url,
346    &ExerciseTask,
347    &ExerciseTaskSubmission,
348) -> BoxFuture<'static, ModelResult<ExerciseTaskGradingResult>> {
349    move |grade_url, exercise_task, submission| {
350        let client = reqwest::Client::new();
351        // TODO: use real url
352        let grading_update_url = format!(
353            "http://project-331.local/api/v0/exercise-services/grading/grading-update/{}",
354            submission.id
355        );
356        let grading_update_claim = GradingUpdateClaim::expiring_in_1_day(submission.id);
357        let req = client
358            .post(grade_url)
359            .header(
360                EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER,
361                grading_update_claim.sign(&jwt_key),
362            )
363            .timeout(std::time::Duration::from_secs(120))
364            .json(&ExerciseTaskGradingRequest {
365                grading_update_url: &grading_update_url,
366                exercise_spec: &exercise_task.private_spec,
367                submission_data: &submission.data_json,
368            });
369        async move {
370            let res = req.send().await.map_err(ModelError::from)?;
371            let status = res.status();
372            if !status.is_success() {
373                let status_code = status.as_u16();
374                let response_body = res.text().await.unwrap_or_default();
375                error!(
376                    ?response_body,
377                    status_code = %status_code,
378                    "Grading request returned an unsuccesful status code"
379                );
380
381                return Err(ModelError::new(
382                    ModelErrorType::HttpRequest {
383                        status_code,
384                        response_body: response_body.clone(),
385                    },
386                    format!(
387                        "Grading failed with status: {} response: {}",
388                        status_code, response_body
389                    ),
390                    None,
391                ));
392            }
393            let obj = parse_response_json(res).await?;
394            info!("Received a grading result: {:#?}", &obj);
395            Ok(obj)
396        }
397        .boxed()
398    }
399}
400
401#[derive(Debug, Serialize, Deserialize)]
402pub struct GivePeerReviewClaim {
403    pub exercise_slide_submission_id: Uuid,
404    pub peer_or_self_review_config_id: Uuid,
405    expiration_time: DateTime<Utc>,
406}
407
408impl GivePeerReviewClaim {
409    pub fn expiring_in_1_day(
410        exercise_slide_submission_id: Uuid,
411        peer_or_self_review_config_id: Uuid,
412    ) -> Self {
413        Self {
414            exercise_slide_submission_id,
415            peer_or_self_review_config_id,
416            expiration_time: Utc::now() + Duration::days(1),
417        }
418    }
419
420    pub fn sign(self, key: &JwtKey) -> String {
421        self.sign_with_key(&key.0)
422            .expect("JWT signing failed - this should never happen with a valid key")
423    }
424
425    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
426        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
427            ControllerError::new(
428                ControllerErrorType::BadRequest,
429                format!("Invalid claim: {}", err),
430                Some(err.into()),
431            )
432        })?;
433        if claim.expiration_time < Utc::now() {
434            return Err(ControllerError::new(
435                ControllerErrorType::BadRequest,
436                "The review has expired.".to_string(),
437                None,
438            ));
439        }
440        Ok(claim)
441    }
442}
443
444/// A caching spec fetcher ONLY FOR THE SEED that returns a cached spec if the same
445/// (url, exercise_service_slug, private_spec) is requested. Since this is only used during seeding,
446/// there is no cache eviction.
447pub fn make_seed_spec_fetcher_with_cache(
448    base_url: String,
449    request_id: Uuid,
450    jwt_key: Arc<JwtKey>,
451) -> impl SpecFetcher {
452    // Cache key: (url, exercise_service_slug, private_spec serialized)
453    let cache: Arc<Mutex<SpecCache>> = Arc::new(Mutex::new(HashMap::new()));
454
455    // Create the base non-caching spec fetcher and wrap it in Arc to make it clonable
456    let base_fetcher = Arc::new(make_spec_fetcher(base_url, request_id, jwt_key));
457
458    move |url, exercise_service_slug, private_spec| {
459        let url_str = url.to_string();
460        let service_slug = exercise_service_slug.to_string();
461        // Convert private_spec to string for cache key if present
462        let private_spec_str =
463            private_spec.map(|spec| serde_json::to_string(&spec).unwrap_or_default());
464        let key = (url_str.clone(), service_slug.clone(), private_spec_str);
465        let cache = Arc::clone(&cache);
466        let base_fetcher = Arc::clone(&base_fetcher);
467
468        async move {
469            // Try to get from cache first
470            if let Some(cached_spec) = cache
471                .lock()
472                .expect("Seed spec fetcher cache lock poisoned")
473                .get(&key)
474            {
475                return Ok(cached_spec.clone());
476            }
477
478            // Not in cache - fetch using base fetcher
479            let fetched_spec = base_fetcher(url, exercise_service_slug, private_spec).await?;
480
481            // Store in cache
482            cache
483                .lock()
484                .expect("Seed spec fetcher cache lock poisoned")
485                .insert(key, fetched_spec.clone());
486
487            Ok(fetched_spec)
488        }
489        .boxed()
490    }
491}
492
493/// Safely parses a response body as JSON, capturing the actual response body in error cases
494async fn parse_response_json<T>(response: reqwest::Response) -> ModelResult<T>
495where
496    T: serde::de::DeserializeOwned,
497{
498    let status = response.status();
499    let response_text = response.text().await.map_err(ModelError::from)?;
500
501    serde_json::from_str(&response_text).map_err(|err| {
502        ModelError::new(
503            ModelErrorType::HttpError {
504                error_type: HttpErrorType::ResponseDecodeFailed,
505                reason: err.to_string(),
506                status_code: Some(status.as_u16()),
507                response_body: Some(response_text),
508            },
509            format!("Failed to decode JSON response: {}", err),
510            None,
511        )
512    })
513}