tmc_langs_util/
progress_reporter.rs

1//! Utility struct for printing progress reports.
2
3use once_cell::sync::OnceCell;
4use serde::{Deserialize, Serialize};
5use std::{ops::DerefMut, sync::RwLock, time::Instant};
6use type_map::concurrent::TypeMap;
7
8/// The format for all status updates. May contain some data.
9#[derive(Debug, Serialize, Deserialize)]
10#[serde(rename_all = "kebab-case")]
11#[cfg_attr(feature = "ts-rs", derive(ts_rs::TS))]
12pub struct StatusUpdate<T> {
13    pub finished: bool,
14    pub message: String,
15    pub percent_done: f64,
16    pub time: u32,
17    pub data: Option<T>,
18}
19
20// the closure called to report progress, could for example print the report as JSON
21type UpdateClosure<T> = dyn 'static + Sync + Send + Fn(StatusUpdate<T>);
22
23/// The struct that keeps track of progress for a given progress update type T and contains a closure for reporting whenever progress is made.
24struct ProgressReporter<T> {
25    progress_report: Box<UpdateClosure<T>>,
26}
27
28/// Contains all the different progress reporters and keeps track of the overall progress.
29struct ProgressReporterContainer {
30    reporters: TypeMap,
31    current_progress: f64,
32    total_steps_left: u32,
33    start_time: Instant,
34    stage_steps: Vec<u32>, // steps left
35}
36
37impl ProgressReporterContainer {
38    pub fn elapsed_millis(&self) -> u32 {
39        // for the time to not fit into a u32, the time elapsed would have to be over a month
40        // which isn't going to happen here
41        self.start_time.elapsed().as_millis() as u32
42    }
43}
44
45static PROGRESS_REPORTERS: OnceCell<RwLock<ProgressReporterContainer>> = OnceCell::new();
46
47/// Subscribes to progress reports of type T with callback of type F called every time progress is made with type T.
48pub fn subscribe<T, F>(progress_report: F)
49where
50    T: 'static + Send + Sync,
51    F: 'static + Sync + Send + Fn(StatusUpdate<T>),
52{
53    let lock = PROGRESS_REPORTERS.get_or_init(|| {
54        RwLock::new(ProgressReporterContainer {
55            reporters: TypeMap::new(),
56            current_progress: 0.0,
57            total_steps_left: 0,
58            start_time: Instant::now(),
59            stage_steps: Vec::new(),
60        })
61    });
62    let mut guard = lock
63        .write()
64        .expect("only fails if the lock is poisoned; we should never panic while holding the lock");
65    let reporter = ProgressReporter {
66        progress_report: Box::new(progress_report),
67    };
68    guard.reporters.insert(reporter);
69}
70
71/// Starts a new stage.
72pub fn start_stage<T: 'static + Send + Sync>(total_steps: u32, message: String, data: Option<T>) {
73    // check for init
74    if let Some(lock) = PROGRESS_REPORTERS.get() {
75        let mut reporter = lock.write().expect(
76            "only fails if the lock is poisoned; we should never panic while holding the lock",
77        );
78        let reporter = reporter.deref_mut();
79        reporter.total_steps_left += total_steps;
80        reporter.stage_steps.push(total_steps);
81
82        // check for subscriber
83        if let Some(progress_reporter) = reporter.reporters.get::<ProgressReporter<T>>() {
84            // report status
85            let status_update = StatusUpdate {
86                finished: false,
87                message,
88                percent_done: reporter.current_progress,
89                time: reporter.elapsed_millis(),
90                data,
91            };
92            progress_reporter.progress_report.as_ref()(status_update);
93        }
94    }
95}
96
97/// Progresses the current stage.
98pub fn progress_stage<T: 'static + Send + Sync>(message: String, data: Option<T>) {
99    // check for init
100    if let Some(lock) = PROGRESS_REPORTERS.get() {
101        let mut reporter = lock.write().expect(
102            "only fails if the lock is poisoned; we should never panic while holding the lock",
103        );
104        let reporter = reporter.deref_mut();
105
106        // check for stage
107        if let Some(stage_steps_left) = reporter.stage_steps.last_mut() {
108            // check if steps left in stage
109            if *stage_steps_left > 0 {
110                let step_progress =
111                    (1.0 - reporter.current_progress) / reporter.total_steps_left as f64;
112                *stage_steps_left -= 1;
113                reporter.total_steps_left -= 1;
114                reporter.current_progress =
115                    f64::min(reporter.current_progress + step_progress, 1.0);
116                // guard against going over 1.0
117            }
118
119            // check for subscriber
120            let time = reporter.elapsed_millis();
121            if let Some(progress_reporter) = reporter.reporters.get_mut::<ProgressReporter<T>>() {
122                let status_update = StatusUpdate {
123                    finished: false,
124                    message,
125                    percent_done: reporter.current_progress,
126                    time,
127                    data,
128                };
129                progress_reporter.progress_report.as_ref()(status_update);
130            }
131        }
132    }
133}
134
135/// Finishes the current stage.
136pub fn finish_stage<T: 'static + Send + Sync>(message: String, data: Option<T>) {
137    // check for init
138    if let Some(lock) = PROGRESS_REPORTERS.get() {
139        let mut reporter = lock.write().expect(
140            "only fails if the lock is poisoned; we should never panic while holding the lock",
141        );
142        let reporter = reporter.deref_mut();
143
144        // check for stage
145        if let Some(stage_steps_left) = reporter.stage_steps.pop() {
146            let step_progress =
147                (1.0 - reporter.current_progress) / reporter.total_steps_left as f64;
148            reporter.total_steps_left -= stage_steps_left;
149            reporter.current_progress = f64::min(
150                reporter.current_progress + stage_steps_left as f64 * step_progress,
151                1.0,
152            ); // guard against going over 1.0
153
154            // check for subscriber
155            if let Some(progress_reporter) = reporter.reporters.get::<ProgressReporter<T>>() {
156                let status_update = StatusUpdate {
157                    finished: true,
158                    message,
159                    percent_done: reporter.current_progress,
160                    time: reporter.elapsed_millis(),
161                    data,
162                };
163                progress_reporter.progress_report.as_ref()(status_update);
164            }
165        }
166
167        // All of the stages have been finished, resetting progress for future events.
168        if reporter.total_steps_left == 0 && (reporter.current_progress - 1.0_f64).abs() < 0.001 {
169            reporter.current_progress = 0.0;
170        }
171    }
172}
173
174#[cfg(test)]
175#[allow(clippy::unwrap_used)]
176mod test {
177    use super::*;
178    use std::sync::{Arc, Mutex, MutexGuard};
179
180    static PROGRESS_MUTEX: OnceCell<Mutex<()>> = OnceCell::new();
181
182    fn init() -> MutexGuard<'static, ()> {
183        use log::*;
184        use simple_logger::*;
185        let _ = SimpleLogger::new().with_level(LevelFilter::Debug).init();
186
187        // wait for lock and clear reporter map
188        let mutex = PROGRESS_MUTEX.get_or_init(|| Mutex::new(()));
189        let guard = mutex.lock().unwrap();
190        if let Some(reporters) = PROGRESS_REPORTERS.get() {
191            let mut reporters = reporters.write().unwrap();
192            *reporters = ProgressReporterContainer {
193                reporters: TypeMap::new(),
194                current_progress: 0.0,
195                total_steps_left: 0,
196                start_time: Instant::now(),
197                stage_steps: Vec::new(),
198            };
199        }
200        guard
201    }
202
203    #[test]
204    fn single_stage_progress() {
205        let _lock = init();
206
207        let su = Arc::new(Mutex::new(None));
208        let suc = Arc::clone(&su);
209        subscribe::<u32, _>(move |s| {
210            log::debug!("got {s:#?}");
211            *suc.lock().unwrap() = Some(s);
212        });
213
214        start_stage::<u32>(2, "starting".to_string(), None);
215
216        progress_stage::<u32>("hello".to_string(), None);
217        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.5000).abs() < 0.01);
218        progress_stage::<u32>("hello!".to_string(), Some(2));
219        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
220    }
221
222    #[test]
223    fn multi_stage_progress() {
224        let _lock = init();
225
226        let su = Arc::new(Mutex::new(None));
227        let suc = Arc::clone(&su);
228        subscribe::<u32, _>(move |s| {
229            log::debug!("got {s:#?}");
230            *suc.lock().unwrap() = Some(s);
231        });
232
233        start_stage::<u32>(2, "starting".to_string(), None);
234        progress_stage::<u32>("msg".to_string(), None);
235        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.5000).abs() < 0.01);
236
237        start_stage::<u32>(2, "starting".to_string(), None);
238        progress_stage::<u32>("msg".to_string(), None);
239        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.6666).abs() < 0.01);
240
241        start_stage::<u32>(2, "starting".to_string(), None);
242        progress_stage::<u32>("msg".to_string(), None);
243        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.7499).abs() < 0.01);
244        progress_stage::<u32>("msg".to_string(), None);
245        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.8333).abs() < 0.01);
246        finish_stage::<u32>("msg".to_string(), None);
247        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.8333).abs() < 0.01);
248
249        finish_stage::<u32>("msg".to_string(), None);
250        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.9166).abs() < 0.01);
251
252        finish_stage::<u32>("msg".to_string(), None);
253        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
254    }
255
256    #[test]
257    fn consecutive_progress() {
258        let _lock = init();
259
260        let su = Arc::new(Mutex::new(None));
261        let suc = Arc::clone(&su);
262        subscribe::<u32, _>(move |s| {
263            log::debug!("got {s:#?}");
264            *suc.lock().unwrap() = Some(s);
265        });
266
267        start_stage::<u32>(3, "starting".to_string(), None);
268        progress_stage::<u32>("hello".to_string(), None);
269        assert!(
270            (su.lock().unwrap().as_ref().unwrap().percent_done - (1.0000 / 3.0000)).abs() < 0.01
271        );
272        progress_stage::<u32>("hello!".to_string(), Some(2));
273        assert!(
274            (su.lock().unwrap().as_ref().unwrap().percent_done - (2.0000 / 3.0000)).abs() < 0.01
275        );
276        finish_stage::<u32>("finished".to_string(), None);
277        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
278
279        start_stage::<u32>(2, "starting".to_string(), None);
280        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.0000).abs() < 0.01);
281        progress_stage::<u32>("hello".to_string(), None);
282        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.5000).abs() < 0.01);
283        progress_stage::<u32>("hello!".to_string(), Some(2));
284        assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
285    }
286}