headless_lms_server/domain/
models_requests.rs

1//! Contains helper functions that are passed to headless-lms-models where it needs to make requests to exercise services.
2
3use crate::prelude::*;
4use actix_http::Payload;
5use actix_web::{FromRequest, HttpRequest};
6use chrono::{DateTime, Duration, Utc};
7use futures::{
8    FutureExt,
9    future::{BoxFuture, Ready, ready},
10};
11use headless_lms_models::{
12    ModelError, ModelErrorType, ModelResult,
13    exercise_service_info::ExerciseServiceInfoApi,
14    exercise_task_gradings::{ExerciseTaskGradingRequest, ExerciseTaskGradingResult},
15    exercise_task_submissions::ExerciseTaskSubmission,
16    exercise_tasks::ExerciseTask,
17};
18use headless_lms_utils::error::backend_error::BackendError;
19use hmac::{Hmac, Mac};
20use jwt::{SignWithKey, VerifyWithKey};
21use models::SpecFetcher;
22use sha2::Sha256;
23use std::collections::HashMap;
24use std::sync::{Arc, Mutex};
25use std::{borrow::Cow, fmt::Debug};
26use url::Url;
27
28use super::error::{ControllerError, ControllerErrorType};
29
30// keep in sync with the shared-module constants
31const EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER: &str = "exercise-service-grading-update-claim";
32const EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER: &str = "exercise-service-upload-claim";
33
34/// A type for caching the spec fetching (only for the seed)
35type SpecCache = HashMap<(String, String, Option<String>), serde_json::Value>;
36
37#[derive(Clone, Debug)]
38pub struct JwtKey(Hmac<Sha256>);
39
40impl JwtKey {
41    pub fn try_from_env() -> anyhow::Result<Self> {
42        let jwt_password = std::env::var("JWT_PASSWORD").context("JWT_PASSWORD must be defined")?;
43        let jwt_key = Self::new(&jwt_password)?;
44        Ok(jwt_key)
45    }
46
47    pub fn new(key: &str) -> Result<Self, sha2::digest::InvalidLength> {
48        let key: Hmac<Sha256> = Hmac::new_from_slice(key.as_bytes())?;
49        Ok(Self(key))
50    }
51
52    #[cfg(test)]
53    pub fn test_key() -> Self {
54        let test_jwt_key = "sMG87WlKnNZoITzvL2+jczriTR7JRsCtGu/bSKaSIvw=asdfjklasd***FSDfsdASDFDS";
55        Self::new(test_jwt_key).unwrap()
56    }
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60pub struct UploadClaim<'a> {
61    exercise_service_slug: Cow<'a, str>,
62    expiration_time: DateTime<Utc>,
63}
64
65impl<'a> UploadClaim<'a> {
66    pub fn exercise_service_slug(&self) -> &str {
67        self.exercise_service_slug.as_ref()
68    }
69
70    pub fn expiration_time(&self) -> &DateTime<Utc> {
71        &self.expiration_time
72    }
73
74    pub fn expiring_in_1_day(exercise_service_slug: Cow<'a, str>) -> Self {
75        Self {
76            exercise_service_slug,
77            expiration_time: Utc::now() + Duration::days(1),
78        }
79    }
80
81    pub fn sign(self, key: &JwtKey) -> String {
82        self.sign_with_key(&key.0).expect("should never fail")
83    }
84
85    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
86        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
87            ControllerError::new(
88                ControllerErrorType::BadRequest,
89                format!("Invalid jwt key: {}", err),
90                Some(err.into()),
91            )
92        })?;
93        if claim.expiration_time < Utc::now() {
94            return Err(ControllerError::new(
95                ControllerErrorType::BadRequest,
96                "Upload claim has expired".to_string(),
97                None,
98            ));
99        }
100        Ok(claim)
101    }
102}
103
104impl FromRequest for UploadClaim<'_> {
105    type Error = ControllerError;
106    type Future = Ready<Result<Self, Self::Error>>;
107
108    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
109        let try_from_request = move || {
110            let jwt_key = req
111                .app_data::<web::Data<JwtKey>>()
112                .expect("Missing JwtKey in app data");
113            let header = req
114                .headers()
115                .get(EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER)
116                .ok_or_else(|| {
117                    ControllerError::new(
118                        ControllerErrorType::BadRequest,
119                        format!("Missing header {EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER}",),
120                        None,
121                    )
122                })?;
123            let header = std::str::from_utf8(header.as_bytes()).map_err(|err| {
124                ControllerError::new(
125                    ControllerErrorType::BadRequest,
126                    format!(
127                        "Invalid header {EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER} = {}",
128                        String::from_utf8_lossy(header.as_bytes())
129                    ),
130                    Some(err.into()),
131                )
132            })?;
133            let claim = UploadClaim::validate(header, jwt_key)?;
134            Result::<_, Self::Error>::Ok(claim)
135        };
136        ready(try_from_request())
137    }
138}
139
140#[derive(Debug, Serialize, Deserialize)]
141pub struct GradingUpdateClaim {
142    submission_id: Uuid,
143    expiration_time: DateTime<Utc>,
144}
145
146impl GradingUpdateClaim {
147    pub fn submission_id(&self) -> Uuid {
148        self.submission_id
149    }
150
151    pub fn expiration_time(&self) -> &DateTime<Utc> {
152        &self.expiration_time
153    }
154
155    pub fn expiring_in_1_day(submission_id: Uuid) -> Self {
156        Self {
157            submission_id,
158            expiration_time: Utc::now() + Duration::days(1),
159        }
160    }
161
162    pub fn sign(self, key: &JwtKey) -> String {
163        self.sign_with_key(&key.0).expect("should never fail")
164    }
165
166    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
167        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
168            ControllerError::new(
169                ControllerErrorType::BadRequest,
170                format!("Invalid jwt key: {}", err),
171                Some(err.into()),
172            )
173        })?;
174        if claim.expiration_time < Utc::now() {
175            return Err(ControllerError::new(
176                ControllerErrorType::BadRequest,
177                "Grading update claim has expired".to_string(),
178                None,
179            ));
180        }
181        Ok(claim)
182    }
183}
184
185impl FromRequest for GradingUpdateClaim {
186    type Error = ControllerError;
187    type Future = Ready<Result<Self, Self::Error>>;
188
189    fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
190        let try_from_request = move || {
191            let jwt_key = req
192                .app_data::<web::Data<JwtKey>>()
193                .expect("Missing JwtKey in app data");
194            let header = req
195                .headers()
196                .get(EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER)
197                .ok_or_else(|| {
198                    ControllerError::new(
199                        ControllerErrorType::BadRequest,
200                        format!("Missing header {EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER}",),
201                        None,
202                    )
203                })?;
204            let header = std::str::from_utf8(header.as_bytes()).map_err(|err| {
205                ControllerError::new(
206                    ControllerErrorType::BadRequest,
207                    format!(
208                        "Invalid header {EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER} = {}",
209                        String::from_utf8_lossy(header.as_bytes())
210                    ),
211                    Some(err.into()),
212                )
213            })?;
214            let claim = GradingUpdateClaim::validate(header, jwt_key)?;
215            Result::<_, Self::Error>::Ok(claim)
216        };
217        ready(try_from_request())
218    }
219}
220
221fn reqwest_err(err: reqwest::Error) -> ModelError {
222    ModelError::new(
223        ModelErrorType::Generic,
224        format!("Error during request: {err}"),
225        None,
226    )
227}
228
229/// Accepted by the public-spec and model-solution endpoints of exercise services.
230#[derive(Debug, Serialize)]
231#[cfg_attr(feature = "ts_rs", derive(TS))]
232pub struct SpecRequest<'a> {
233    request_id: Uuid,
234    private_spec: Option<&'a serde_json::Value>,
235    upload_url: Option<String>,
236}
237
238/// Fetches a public/model spec based on the private spec from the given url.
239/// The slug and jwt key are used for an upload claim that allows the service
240/// to upload files as part of the spec.
241pub fn make_spec_fetcher(
242    base_url: String,
243    request_id: Uuid,
244    jwt_key: Arc<JwtKey>,
245) -> impl SpecFetcher {
246    move |url, exercise_service_slug, private_spec| {
247        let client = reqwest::Client::new();
248        let upload_claim = UploadClaim::expiring_in_1_day(exercise_service_slug.into());
249        let upload_url = Some(format!("{base_url}/api/v0/files/{exercise_service_slug}"));
250        let req = client
251            .post(url.clone())
252            .header(
253                EXERCISE_SERVICE_UPLOAD_CLAIM_HEADER,
254                upload_claim.sign(&jwt_key),
255            )
256            .timeout(std::time::Duration::from_secs(120))
257            .json(&SpecRequest {
258                request_id,
259                private_spec,
260                upload_url,
261            })
262            .send();
263        async move {
264            let res = req.await.map_err(reqwest_err)?;
265            let status_code = res.status();
266            if !status_code.is_success() {
267                let error_text = res.text().await;
268                let error = error_text.as_deref().unwrap_or("(No text in response)");
269                error!(
270                    ?url,
271                    ?exercise_service_slug,
272                    ?private_spec,
273                    ?status_code,
274                    "Exercise service returned an error while generating a spec: {}",
275                    error
276                );
277                return Err(ModelError::new(
278                    ModelErrorType::Generic,
279                    format!(
280                        "Failed to generate spec for exercise for {exercise_service_slug}: {error}."
281                    ),
282                    None,
283                ));
284            }
285            let json = res.json().await.map_err(reqwest_err)?;
286            Ok(json)
287        }
288        .boxed()
289    }
290}
291
292// see `fetch_service_info_fast` while handling HTTP requests
293pub fn fetch_service_info(url: Url) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
294    fetch_service_info_with_timeout(url, 1000 * 120)
295}
296
297// use this while handling HTTP requests, see `fetch_service_info`
298pub fn fetch_service_info_fast(
299    url: Url,
300) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
301    fetch_service_info_with_timeout(url, 1000 * 5)
302}
303
304fn fetch_service_info_with_timeout(
305    url: Url,
306    timeout_ms: u64,
307) -> BoxFuture<'static, ModelResult<ExerciseServiceInfoApi>> {
308    async move {
309        let client = reqwest::Client::new();
310        let res = client
311            .get(url) // e.g. http://example-exercise.default.svc.cluster.local:3002/example-exercise/api/service-info
312            .timeout(std::time::Duration::from_millis(timeout_ms))
313            .send()
314            .await
315            .map_err(reqwest_err)?;
316        let status = res.status();
317        if !status.is_success() {
318            let response_url = res.url().to_string();
319            let body = res.text().await.map_err(reqwest_err)?;
320            warn!(url=?response_url, status=?status, body=?body, "Could not fetch service info.");
321            return Err(ModelError::new(
322                ModelErrorType::Generic,
323                "Could not fetch service info.".to_string(),
324                None,
325            ));
326        }
327        let res = res
328            .json::<ExerciseServiceInfoApi>()
329            .await
330            .map_err(reqwest_err)?;
331        Ok(res)
332    }
333    .boxed()
334}
335
336pub fn make_grading_request_sender(
337    jwt_key: Arc<JwtKey>,
338) -> impl Fn(
339    Url,
340    &ExerciseTask,
341    &ExerciseTaskSubmission,
342) -> BoxFuture<'static, ModelResult<ExerciseTaskGradingResult>> {
343    move |grade_url, exercise_task, submission| {
344        let client = reqwest::Client::new();
345        // TODO: use real url
346        let grading_update_url = format!(
347            "http://project-331.local/api/v0/exercise-services/grading/grading-update/{}",
348            submission.id
349        );
350        let grading_update_claim = GradingUpdateClaim::expiring_in_1_day(submission.id);
351        let req = client
352            .post(grade_url)
353            .header(
354                EXERCISE_SERVICE_GRADING_UPDATE_CLAIM_HEADER,
355                grading_update_claim.sign(&jwt_key),
356            )
357            .timeout(std::time::Duration::from_secs(120))
358            .json(&ExerciseTaskGradingRequest {
359                grading_update_url: &grading_update_url,
360                exercise_spec: &exercise_task.private_spec,
361                submission_data: &submission.data_json,
362            });
363        async move {
364            let res = req.send().await.map_err(reqwest_err)?;
365            let status = res.status();
366            if !status.is_success() {
367                let response_body = res.text().await;
368                error!(
369                    ?response_body,
370                    "Grading request returned an unsuccesful status code"
371                );
372                let source_error = ModelError::new(
373                    ModelErrorType::Generic,
374                    format!("{:?}", response_body),
375                    None,
376                );
377                return Err(ModelError::new(
378                    ModelErrorType::Generic,
379                    "Grading failed".to_string(),
380                    Some(source_error.into()),
381                ));
382            }
383            let obj = res
384                .json::<ExerciseTaskGradingResult>()
385                .await
386                .map_err(reqwest_err)?;
387            info!("Received a grading result: {:#?}", &obj);
388            Ok(obj)
389        }
390        .boxed()
391    }
392}
393
394#[derive(Debug, Serialize, Deserialize)]
395pub struct GivePeerReviewClaim {
396    pub exercise_slide_submission_id: Uuid,
397    pub peer_or_self_review_config_id: Uuid,
398    expiration_time: DateTime<Utc>,
399}
400
401impl GivePeerReviewClaim {
402    pub fn expiring_in_1_day(
403        exercise_slide_submission_id: Uuid,
404        peer_or_self_review_config_id: Uuid,
405    ) -> Self {
406        Self {
407            exercise_slide_submission_id,
408            peer_or_self_review_config_id,
409            expiration_time: Utc::now() + Duration::days(1),
410        }
411    }
412
413    pub fn sign(self, key: &JwtKey) -> String {
414        self.sign_with_key(&key.0).expect("should never fail")
415    }
416
417    pub fn validate(token: &str, key: &JwtKey) -> Result<Self, ControllerError> {
418        let claim: Self = token.verify_with_key(&key.0).map_err(|err| {
419            ControllerError::new(
420                ControllerErrorType::BadRequest,
421                format!("Invalid claim: {}", err),
422                Some(err.into()),
423            )
424        })?;
425        if claim.expiration_time < Utc::now() {
426            return Err(ControllerError::new(
427                ControllerErrorType::BadRequest,
428                "The review has expired.".to_string(),
429                None,
430            ));
431        }
432        Ok(claim)
433    }
434}
435
436/// A caching spec fetcher ONLY FOR THE SEED that returns a cached spec if the same
437/// (url, exercise_service_slug, private_spec) is requested. Since this is only used during seeding,
438/// there is no cache eviction.
439pub fn make_seed_spec_fetcher_with_cache(
440    base_url: String,
441    request_id: Uuid,
442    jwt_key: Arc<JwtKey>,
443) -> impl SpecFetcher {
444    // Cache key: (url, exercise_service_slug, private_spec serialized)
445    let cache: Arc<Mutex<SpecCache>> = Arc::new(Mutex::new(HashMap::new()));
446
447    // Create the base non-caching spec fetcher and wrap it in Arc to make it clonable
448    let base_fetcher = Arc::new(make_spec_fetcher(base_url, request_id, jwt_key));
449
450    move |url, exercise_service_slug, private_spec| {
451        let url_str = url.to_string();
452        let service_slug = exercise_service_slug.to_string();
453        // Convert private_spec to string for cache key if present
454        let private_spec_str =
455            private_spec.map(|spec| serde_json::to_string(&spec).unwrap_or_default());
456        let key = (url_str.clone(), service_slug.clone(), private_spec_str);
457        let cache = Arc::clone(&cache);
458        let base_fetcher = Arc::clone(&base_fetcher);
459
460        async move {
461            // Try to get from cache first
462            if let Some(cached_spec) = cache
463                .lock()
464                .expect("Seed spec fetcher cache lock poisoned")
465                .get(&key)
466            {
467                return Ok(cached_spec.clone());
468            }
469
470            // Not in cache - fetch using base fetcher
471            let fetched_spec = base_fetcher(url, exercise_service_slug, private_spec).await?;
472
473            // Store in cache
474            cache
475                .lock()
476                .expect("Seed spec fetcher cache lock poisoned")
477                .insert(key, fetched_spec.clone());
478
479            Ok(fetched_spec)
480        }
481        .boxed()
482    }
483}