1use std::fmt::Display;
14use std::future::Future;
15use std::time::Duration;
16use tokio::time;
17
18use crate::util::rng::{HasherRng, Rng};
19
20pub trait MakeBackoff {
22    type Backoff: Backoff;
24
25    fn make_backoff(&mut self) -> Self::Backoff;
27}
28
29pub trait Backoff {
33    type Future: Future<Output = ()>;
36
37    fn next_backoff(&mut self) -> Self::Future;
39}
40
41#[derive(Debug, Clone)]
43pub struct ExponentialBackoffMaker<R = HasherRng> {
44    min: time::Duration,
46    max: time::Duration,
48    jitter: f64,
52    rng: R,
53}
54
55#[derive(Debug, Clone)]
64pub struct ExponentialBackoff<R = HasherRng> {
65    min: time::Duration,
66    max: time::Duration,
67    jitter: f64,
68    rng: R,
69    iterations: u32,
70}
71
72impl<R> ExponentialBackoffMaker<R>
73where
74    R: Rng,
75{
76    pub fn new(
87        min: time::Duration,
88        max: time::Duration,
89        jitter: f64,
90        rng: R,
91    ) -> Result<Self, InvalidBackoff> {
92        if min > max {
93            return Err(InvalidBackoff("maximum must not be less than minimum"));
94        }
95        if max == time::Duration::from_millis(0) {
96            return Err(InvalidBackoff("maximum must be non-zero"));
97        }
98        if jitter < 0.0 {
99            return Err(InvalidBackoff("jitter must not be negative"));
100        }
101        if jitter > 100.0 {
102            return Err(InvalidBackoff("jitter must not be greater than 100"));
103        }
104        if !jitter.is_finite() {
105            return Err(InvalidBackoff("jitter must be finite"));
106        }
107
108        Ok(ExponentialBackoffMaker {
109            min,
110            max,
111            jitter,
112            rng,
113        })
114    }
115}
116
117impl<R> MakeBackoff for ExponentialBackoffMaker<R>
118where
119    R: Rng + Clone,
120{
121    type Backoff = ExponentialBackoff<R>;
122
123    fn make_backoff(&mut self) -> Self::Backoff {
124        ExponentialBackoff {
125            max: self.max,
126            min: self.min,
127            jitter: self.jitter,
128            rng: self.rng.clone(),
129            iterations: 0,
130        }
131    }
132}
133
134impl<R: Rng> ExponentialBackoff<R> {
135    fn base(&self) -> time::Duration {
136        debug_assert!(
137            self.min <= self.max,
138            "maximum backoff must not be less than minimum backoff"
139        );
140        debug_assert!(
141            self.max > time::Duration::from_millis(0),
142            "Maximum backoff must be non-zero"
143        );
144        self.min
145            .checked_mul(2_u32.saturating_pow(self.iterations))
146            .unwrap_or(self.max)
147            .min(self.max)
148    }
149
150    fn jitter(&mut self, base: time::Duration) -> time::Duration {
153        if self.jitter == 0.0 {
154            time::Duration::default()
155        } else {
156            let jitter_factor = self.rng.next_f64();
157            debug_assert!(
158                jitter_factor > 0.0,
159                "rng returns values between 0.0 and 1.0"
160            );
161            let rand_jitter = jitter_factor * self.jitter;
162            let secs = (base.as_secs() as f64) * rand_jitter;
163            let nanos = (base.subsec_nanos() as f64) * rand_jitter;
164            let remaining = self.max - base;
165            time::Duration::new(secs as u64, nanos as u32).min(remaining)
166        }
167    }
168}
169
170impl<R> Backoff for ExponentialBackoff<R>
171where
172    R: Rng,
173{
174    type Future = tokio::time::Sleep;
175
176    fn next_backoff(&mut self) -> Self::Future {
177        let base = self.base();
178        let next = base + self.jitter(base);
179
180        self.iterations += 1;
181
182        tokio::time::sleep(next)
183    }
184}
185
186impl Default for ExponentialBackoffMaker {
187    fn default() -> Self {
188        ExponentialBackoffMaker::new(
189            Duration::from_millis(50),
190            Duration::from_millis(u64::MAX),
191            0.99,
192            HasherRng::default(),
193        )
194        .expect("Unable to create ExponentialBackoff")
195    }
196}
197
198#[derive(Debug)]
200pub struct InvalidBackoff(&'static str);
201
202impl Display for InvalidBackoff {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        write!(f, "invalid backoff: {}", self.0)
205    }
206}
207
208impl std::error::Error for InvalidBackoff {}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use quickcheck::*;
214
215    quickcheck! {
216        fn backoff_base_first(min_ms: u64, max_ms: u64) -> TestResult {
217            let min = time::Duration::from_millis(min_ms);
218            let max = time::Duration::from_millis(max_ms);
219            let rng = HasherRng::default();
220            let mut backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
221                Err(_) => return TestResult::discard(),
222                Ok(backoff) => backoff,
223            };
224            let backoff = backoff.make_backoff();
225
226            let delay = backoff.base();
227            TestResult::from_bool(min == delay)
228        }
229
230        fn backoff_base(min_ms: u64, max_ms: u64, iterations: u32) -> TestResult {
231            let min = time::Duration::from_millis(min_ms);
232            let max = time::Duration::from_millis(max_ms);
233            let rng = HasherRng::default();
234            let mut backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
235                Err(_) => return TestResult::discard(),
236                Ok(backoff) => backoff,
237            };
238            let mut backoff = backoff.make_backoff();
239
240            backoff.iterations = iterations;
241            let delay = backoff.base();
242            TestResult::from_bool(min <= delay && delay <= max)
243        }
244
245        fn backoff_jitter(base_ms: u64, max_ms: u64, jitter: f64) -> TestResult {
246            let base = time::Duration::from_millis(base_ms);
247            let max = time::Duration::from_millis(max_ms);
248            let rng = HasherRng::default();
249            let mut backoff = match ExponentialBackoffMaker::new(base, max, jitter, rng) {
250                Err(_) => return TestResult::discard(),
251                Ok(backoff) => backoff,
252            };
253            let mut backoff = backoff.make_backoff();
254
255            let j = backoff.jitter(base);
256            if jitter == 0.0 || base_ms == 0 || max_ms == base_ms {
257                TestResult::from_bool(j == time::Duration::default())
258            } else {
259                TestResult::from_bool(j > time::Duration::default())
260            }
261        }
262    }
263}