1use once_cell::sync::OnceCell;
4use serde::{Deserialize, Serialize};
5use std::{ops::DerefMut, sync::RwLock, time::Instant};
6use type_map::concurrent::TypeMap;
7
8#[derive(Debug, Serialize, Deserialize)]
10#[serde(rename_all = "kebab-case")]
11#[cfg_attr(feature = "ts-rs", derive(ts_rs::TS))]
12pub struct StatusUpdate<T> {
13 pub finished: bool,
14 pub message: String,
15 pub percent_done: f64,
16 pub time: u32,
17 pub data: Option<T>,
18}
19
20type UpdateClosure<T> = dyn 'static + Sync + Send + Fn(StatusUpdate<T>);
22
23struct ProgressReporter<T> {
25 progress_report: Box<UpdateClosure<T>>,
26}
27
28struct ProgressReporterContainer {
30 reporters: TypeMap,
31 current_progress: f64,
32 total_steps_left: u32,
33 start_time: Instant,
34 stage_steps: Vec<u32>, }
36
37impl ProgressReporterContainer {
38 pub fn elapsed_millis(&self) -> u32 {
39 self.start_time.elapsed().as_millis() as u32
42 }
43}
44
45static PROGRESS_REPORTERS: OnceCell<RwLock<ProgressReporterContainer>> = OnceCell::new();
46
47pub fn subscribe<T, F>(progress_report: F)
49where
50 T: 'static + Send + Sync,
51 F: 'static + Sync + Send + Fn(StatusUpdate<T>),
52{
53 let lock = PROGRESS_REPORTERS.get_or_init(|| {
54 RwLock::new(ProgressReporterContainer {
55 reporters: TypeMap::new(),
56 current_progress: 0.0,
57 total_steps_left: 0,
58 start_time: Instant::now(),
59 stage_steps: Vec::new(),
60 })
61 });
62 let mut guard = lock
63 .write()
64 .expect("only fails if the lock is poisoned; we should never panic while holding the lock");
65 let reporter = ProgressReporter {
66 progress_report: Box::new(progress_report),
67 };
68 guard.reporters.insert(reporter);
69}
70
71pub fn start_stage<T: 'static + Send + Sync>(total_steps: u32, message: String, data: Option<T>) {
73 if let Some(lock) = PROGRESS_REPORTERS.get() {
75 let mut reporter = lock.write().expect(
76 "only fails if the lock is poisoned; we should never panic while holding the lock",
77 );
78 let reporter = reporter.deref_mut();
79 reporter.total_steps_left += total_steps;
80 reporter.stage_steps.push(total_steps);
81
82 if let Some(progress_reporter) = reporter.reporters.get::<ProgressReporter<T>>() {
84 let status_update = StatusUpdate {
86 finished: false,
87 message,
88 percent_done: reporter.current_progress,
89 time: reporter.elapsed_millis(),
90 data,
91 };
92 progress_reporter.progress_report.as_ref()(status_update);
93 }
94 }
95}
96
97pub fn progress_stage<T: 'static + Send + Sync>(message: String, data: Option<T>) {
99 if let Some(lock) = PROGRESS_REPORTERS.get() {
101 let mut reporter = lock.write().expect(
102 "only fails if the lock is poisoned; we should never panic while holding the lock",
103 );
104 let reporter = reporter.deref_mut();
105
106 if let Some(stage_steps_left) = reporter.stage_steps.last_mut() {
108 if *stage_steps_left > 0 {
110 let step_progress =
111 (1.0 - reporter.current_progress) / reporter.total_steps_left as f64;
112 *stage_steps_left -= 1;
113 reporter.total_steps_left -= 1;
114 reporter.current_progress =
115 f64::min(reporter.current_progress + step_progress, 1.0);
116 }
118
119 let time = reporter.elapsed_millis();
121 if let Some(progress_reporter) = reporter.reporters.get_mut::<ProgressReporter<T>>() {
122 let status_update = StatusUpdate {
123 finished: false,
124 message,
125 percent_done: reporter.current_progress,
126 time,
127 data,
128 };
129 progress_reporter.progress_report.as_ref()(status_update);
130 }
131 }
132 }
133}
134
135pub fn finish_stage<T: 'static + Send + Sync>(message: String, data: Option<T>) {
137 if let Some(lock) = PROGRESS_REPORTERS.get() {
139 let mut reporter = lock.write().expect(
140 "only fails if the lock is poisoned; we should never panic while holding the lock",
141 );
142 let reporter = reporter.deref_mut();
143
144 if let Some(stage_steps_left) = reporter.stage_steps.pop() {
146 let step_progress =
147 (1.0 - reporter.current_progress) / reporter.total_steps_left as f64;
148 reporter.total_steps_left -= stage_steps_left;
149 reporter.current_progress = f64::min(
150 reporter.current_progress + stage_steps_left as f64 * step_progress,
151 1.0,
152 ); if let Some(progress_reporter) = reporter.reporters.get::<ProgressReporter<T>>() {
156 let status_update = StatusUpdate {
157 finished: true,
158 message,
159 percent_done: reporter.current_progress,
160 time: reporter.elapsed_millis(),
161 data,
162 };
163 progress_reporter.progress_report.as_ref()(status_update);
164 }
165 }
166
167 if reporter.total_steps_left == 0 && (reporter.current_progress - 1.0_f64).abs() < 0.001 {
169 reporter.current_progress = 0.0;
170 }
171 }
172}
173
174#[cfg(test)]
175#[allow(clippy::unwrap_used)]
176mod test {
177 use super::*;
178 use std::sync::{Arc, Mutex, MutexGuard};
179
180 static PROGRESS_MUTEX: OnceCell<Mutex<()>> = OnceCell::new();
181
182 fn init() -> MutexGuard<'static, ()> {
183 use log::*;
184 use simple_logger::*;
185 let _ = SimpleLogger::new().with_level(LevelFilter::Debug).init();
186
187 let mutex = PROGRESS_MUTEX.get_or_init(|| Mutex::new(()));
189 let guard = mutex.lock().unwrap();
190 if let Some(reporters) = PROGRESS_REPORTERS.get() {
191 let mut reporters = reporters.write().unwrap();
192 *reporters = ProgressReporterContainer {
193 reporters: TypeMap::new(),
194 current_progress: 0.0,
195 total_steps_left: 0,
196 start_time: Instant::now(),
197 stage_steps: Vec::new(),
198 };
199 }
200 guard
201 }
202
203 #[test]
204 fn single_stage_progress() {
205 let _lock = init();
206
207 let su = Arc::new(Mutex::new(None));
208 let suc = Arc::clone(&su);
209 subscribe::<u32, _>(move |s| {
210 log::debug!("got {s:#?}");
211 *suc.lock().unwrap() = Some(s);
212 });
213
214 start_stage::<u32>(2, "starting".to_string(), None);
215
216 progress_stage::<u32>("hello".to_string(), None);
217 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.5000).abs() < 0.01);
218 progress_stage::<u32>("hello!".to_string(), Some(2));
219 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
220 }
221
222 #[test]
223 fn multi_stage_progress() {
224 let _lock = init();
225
226 let su = Arc::new(Mutex::new(None));
227 let suc = Arc::clone(&su);
228 subscribe::<u32, _>(move |s| {
229 log::debug!("got {s:#?}");
230 *suc.lock().unwrap() = Some(s);
231 });
232
233 start_stage::<u32>(2, "starting".to_string(), None);
234 progress_stage::<u32>("msg".to_string(), None);
235 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.5000).abs() < 0.01);
236
237 start_stage::<u32>(2, "starting".to_string(), None);
238 progress_stage::<u32>("msg".to_string(), None);
239 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.6666).abs() < 0.01);
240
241 start_stage::<u32>(2, "starting".to_string(), None);
242 progress_stage::<u32>("msg".to_string(), None);
243 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.7499).abs() < 0.01);
244 progress_stage::<u32>("msg".to_string(), None);
245 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.8333).abs() < 0.01);
246 finish_stage::<u32>("msg".to_string(), None);
247 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.8333).abs() < 0.01);
248
249 finish_stage::<u32>("msg".to_string(), None);
250 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.9166).abs() < 0.01);
251
252 finish_stage::<u32>("msg".to_string(), None);
253 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
254 }
255
256 #[test]
257 fn consecutive_progress() {
258 let _lock = init();
259
260 let su = Arc::new(Mutex::new(None));
261 let suc = Arc::clone(&su);
262 subscribe::<u32, _>(move |s| {
263 log::debug!("got {s:#?}");
264 *suc.lock().unwrap() = Some(s);
265 });
266
267 start_stage::<u32>(3, "starting".to_string(), None);
268 progress_stage::<u32>("hello".to_string(), None);
269 assert!(
270 (su.lock().unwrap().as_ref().unwrap().percent_done - (1.0000 / 3.0000)).abs() < 0.01
271 );
272 progress_stage::<u32>("hello!".to_string(), Some(2));
273 assert!(
274 (su.lock().unwrap().as_ref().unwrap().percent_done - (2.0000 / 3.0000)).abs() < 0.01
275 );
276 finish_stage::<u32>("finished".to_string(), None);
277 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
278
279 start_stage::<u32>(2, "starting".to_string(), None);
280 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.0000).abs() < 0.01);
281 progress_stage::<u32>("hello".to_string(), None);
282 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 0.5000).abs() < 0.01);
283 progress_stage::<u32>("hello!".to_string(), Some(2));
284 assert!((su.lock().unwrap().as_ref().unwrap().percent_done - 1.0000).abs() < 0.01);
285 }
286}