headless_lms_utils/
pagination.rs

1use std::{fmt, num::ParseIntError};
2
3use anyhow::bail;
4use serde::{
5    Deserialize, Deserializer,
6    de::{self, MapAccess, Visitor},
7};
8
9/// Represents the URL query parameters `page` and `limit`, used for paginating database queries.
10#[derive(Debug, Clone, Copy)]
11
12pub struct Pagination {
13    // the deserialize implementation contains a default value for page
14    page: u32,
15    // the deserialize implementation contains a default value for limit
16    limit: u32,
17}
18
19impl Pagination {
20    /// Errors on non-positive page or limit values.
21    pub fn new(page: u32, limit: u32) -> anyhow::Result<Self> {
22        if page == 0 {
23            bail!("Page must be a positive value.");
24        }
25        if limit == 0 {
26            bail!("Limit must be a positive value.");
27        }
28        if limit > 10_000 {
29            bail!("Limit can be at most 10000.")
30        }
31        Ok(Pagination { page, limit })
32    }
33
34    /// Guaranteed to be positive.
35    pub fn page(&self) -> i64 {
36        self.page.into()
37    }
38
39    /// Guaranteed to be positive.
40    pub fn limit(&self) -> i64 {
41        self.limit.into()
42    }
43
44    /// Guaranteed to be nonnegative.
45    pub fn offset(&self) -> i64 {
46        (self.limit * (self.page - 1)).into()
47    }
48
49    /// Guaranteed to be positive.
50    pub fn total_pages(&self, total_count: u32) -> u32 {
51        let remainder = total_count % self.limit;
52        if remainder == 0 {
53            total_count / self.limit
54        } else {
55            total_count / self.limit + 1
56        }
57    }
58
59    /// Helper to paginate an existing Vec efficiently.
60    pub fn paginate<T>(&self, v: &mut Vec<T>) {
61        let limit = self.limit as usize;
62        let start = limit * (self.page as usize - 1);
63        v.truncate(start + limit);
64        v.drain(..start);
65    }
66
67    pub fn next_page(&mut self) {
68        self.page += 1;
69    }
70}
71
72impl Default for Pagination {
73    fn default() -> Self {
74        Self {
75            page: 1,
76            limit: 100,
77        }
78    }
79}
80
81impl<'de> Deserialize<'de> for Pagination {
82    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
83    where
84        D: Deserializer<'de>,
85    {
86        struct PaginationVisitor;
87
88        impl<'de> Visitor<'de> for PaginationVisitor {
89            type Value = Pagination;
90
91            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
92                formatter.write_str("query parameters `page` and `limit`")
93            }
94
95            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
96            where
97                A: MapAccess<'de>,
98            {
99                let mut page = None;
100                let mut limit = None;
101                while let Some(key) = map.next_key().map_err(|e| {
102                    de::Error::custom(format!("Failed to deserialize map key: {}", e))
103                })? {
104                    match key {
105                        "page" => {
106                            if page.is_some() {
107                                return Err(de::Error::duplicate_field("page"));
108                            }
109                            let value: StrOrInt = map.next_value().map_err(|e| {
110                                de::Error::custom(format!(
111                                    "Failed to deserialize page value: {}",
112                                    e
113                                ))
114                            })?;
115                            let value = value.into_int().map_err(|e| {
116                                de::Error::custom(format!(
117                                    "Failed to deserialize page value: {}",
118                                    e
119                                ))
120                            })?;
121                            if value < 1 {
122                                return Err(de::Error::custom(
123                                    "query parameter `page` must be a positive integer",
124                                ));
125                            }
126                            page = Some(value);
127                        }
128                        "limit" => {
129                            if limit.is_some() {
130                                return Err(de::Error::duplicate_field("limit"));
131                            }
132                            let value: StrOrInt = map.next_value().map_err(|e| {
133                                de::Error::custom(format!(
134                                    "Failed to deserialize limit value: {}",
135                                    e
136                                ))
137                            })?;
138                            let value = value.into_int().map_err(|e| {
139                                de::Error::custom(format!(
140                                    "Failed to deserialize limit value: {}",
141                                    e
142                                ))
143                            })?;
144                            if !(1..=10000).contains(&value) {
145                                return Err(de::Error::custom(
146                                    "query parameter `limit` must be an integer between 1 and 10000",
147                                ));
148                            }
149                            limit = Some(value);
150                        }
151                        field => {
152                            return Err(de::Error::custom(format!(
153                                "unexpected parameter `{}`",
154                                field
155                            )));
156                        }
157                    }
158                }
159                Ok(Pagination {
160                    page: page.unwrap_or(Pagination::default().page),
161                    limit: limit.unwrap_or(Pagination::default().limit),
162                })
163            }
164        }
165
166        deserializer.deserialize_struct("Pagination", &["page", "limit"], PaginationVisitor)
167    }
168}
169
170// for some reason, it seems like when there are only numeric query parameters, actix gives them to serde as numbers,
171// but if there's a string mixed in, they are all given as strings. this helper is used to handle both cases
172#[derive(Debug, Deserialize)]
173#[serde(untagged)]
174enum StrOrInt<'a> {
175    Str(&'a str),
176    Int(u32),
177}
178
179impl StrOrInt<'_> {
180    fn into_int(self) -> Result<u32, ParseIntError> {
181        match self {
182            Self::Str(s) => s.parse(),
183            Self::Int(i) => Ok(i),
184        }
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use super::*;
191
192    #[test]
193    fn paginates() {
194        let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8];
195        let pagination = Pagination::new(2, 3).unwrap();
196        pagination.paginate(&mut v);
197        assert_eq!(v, &[4, 5, 6]);
198    }
199
200    #[test]
201    fn paginates_non_existent_page() {
202        let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8];
203        let pagination = Pagination::new(3, 4).unwrap();
204        pagination.paginate(&mut v);
205        assert_eq!(v, &[] as &[i32]);
206    }
207
208    #[test]
209    fn paginates_incomplete_page() {
210        let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8];
211        let pagination = Pagination::new(2, 5).unwrap();
212        pagination.paginate(&mut v);
213        assert_eq!(v, &[6, 7, 8]);
214    }
215}