headless_lms_server/programs/seed/
mod.rs

1#![allow(clippy::unwrap_used)]
2
3pub mod builder;
4pub mod seed_application_task_llms;
5pub mod seed_certificate_fonts;
6pub mod seed_courses;
7pub mod seed_exercise_services;
8pub mod seed_file_storage;
9pub mod seed_generic_emails;
10pub mod seed_helpers;
11pub mod seed_oauth_clients;
12pub mod seed_organizations;
13pub mod seed_playground_examples;
14pub mod seed_roles;
15mod seed_user_research_consents;
16pub mod seed_users;
17
18use std::{env, process::Command, sync::Arc, time::Duration};
19
20use crate::{
21    domain::models_requests::JwtKey,
22    programs::seed::{
23        seed_application_task_llms::seed_application_task_llms,
24        seed_oauth_clients::seed_oauth_clients,
25    },
26    setup_tracing,
27};
28
29use anyhow::Context;
30use futures::try_join;
31
32use headless_lms_utils::futures::run_parallelly;
33use sqlx::{Pool, Postgres, migrate::MigrateDatabase, postgres::PgPoolOptions};
34use tracing::info;
35
36pub async fn main() -> anyhow::Result<()> {
37    let base_url = std::env::var("BASE_URL").context("BASE_URL must be defined")?;
38    let db_pool = setup_seed_environment().await?;
39    let jwt_key = Arc::new(JwtKey::try_from_env().expect("Failed to create JwtKey"));
40
41    // Initialize the global spec fetcher before any seeding
42    seed_helpers::init_seed_spec_fetcher(base_url.clone(), Arc::clone(&jwt_key))
43        .expect("Failed to initialize seed spec fetcher");
44
45    // Run parallelly to improve performance.
46    let (_, seed_users_result, _, seed_llms_result) = try_join!(
47        run_parallelly(seed_exercise_services::seed_exercise_services(
48            db_pool.clone()
49        )),
50        run_parallelly(seed_users::seed_users(db_pool.clone())),
51        run_parallelly(seed_playground_examples::seed_playground_examples(
52            db_pool.clone()
53        )),
54        run_parallelly(seed_application_task_llms(db_pool.clone()))
55    )?;
56
57    // Not run parallely because waits another future that is not send.
58    let seed_file_storage_result = seed_file_storage::seed_file_storage().await?;
59
60    let (uh_cs_organization_result, _uh_mathstat_organization_id, _no_users_organization_id) = try_join!(
61        run_parallelly(seed_organizations::uh_cs::seed_organization_uh_cs(
62            db_pool.clone(),
63            seed_users_result,
64            base_url.clone(),
65            Arc::clone(&jwt_key),
66            seed_file_storage_result.clone()
67        )),
68        run_parallelly(
69            seed_organizations::uh_mathstat::seed_organization_uh_mathstat(
70                db_pool.clone(),
71                seed_users_result,
72                seed_llms_result,
73                base_url.clone(),
74                Arc::clone(&jwt_key),
75                seed_file_storage_result.clone()
76            )
77        ),
78        run_parallelly(seed_organizations::no_users::seed_organization_no_users(
79            db_pool.clone()
80        ))
81    )?;
82
83    try_join!(
84        run_parallelly(seed_roles::seed_roles(
85            db_pool.clone(),
86            seed_users_result,
87            uh_cs_organization_result
88        )),
89        run_parallelly(seed_user_research_consents::seed_user_research_consents(
90            db_pool.clone(),
91            seed_users_result
92        )),
93        run_parallelly(seed_certificate_fonts::seed_certificate_fonts(
94            db_pool.clone()
95        )),
96        run_parallelly(seed_generic_emails::seed_generic_emails(
97            db_pool.clone(),
98            seed_users_result
99        )),
100        run_parallelly(seed_oauth_clients(db_pool.clone()))
101    )?;
102
103    Ok(())
104}
105
106async fn setup_seed_environment() -> anyhow::Result<Pool<Postgres>> {
107    // TODO: Audit that the environment access only happens in single-threaded code.
108    unsafe { env::set_var("RUST_LOG", "info,sqlx=warn,headless_lms_models=info") };
109
110    dotenv::dotenv().ok();
111    setup_tracing()?;
112
113    let clean = env::args().any(|a| a == "clean");
114
115    let db_url = env::var("DATABASE_URL")?;
116    let cpu_count = std::thread::available_parallelism()
117        .map(|n| n.get())
118        .unwrap_or(2);
119
120    let max_conns: u32 = std::cmp::max(2, cpu_count as u32);
121
122    let min_conns: u32 = std::cmp::max(1, (cpu_count / 2) as u32);
123
124    let db_pool = PgPoolOptions::new()
125        .max_connections(max_conns)
126        .min_connections(min_conns)
127        // Since this is the seed, it should be fine to wait for a long time for connections
128        .acquire_timeout(Duration::from_secs(10 * 60))
129        .connect(&db_url)
130        .await?;
131
132    if clean {
133        info!("cleaning");
134        // hardcoded for now
135        let status = Command::new("dropdb")
136            .args(["-U", "headless-lms"])
137            .args(["-h", "localhost"])
138            .args(["-p", "54328"])
139            .arg("--force")
140            .arg("-e")
141            .arg("headless_lms_dev")
142            .status()?;
143        assert!(status.success());
144        let db_url = env::var("DATABASE_URL")?;
145        Postgres::create_database(&db_url).await?;
146    }
147
148    if clean {
149        let mut conn = db_pool.acquire().await?;
150        info!("running migrations");
151        sqlx::migrate!("../migrations").run(&mut conn).await?;
152    }
153    Ok(db_pool)
154}