headless_lms_utils/
pagination.rs

1use std::{fmt, num::ParseIntError};
2
3use anyhow::bail;
4use serde::{
5    Deserialize, Deserializer,
6    de::{self, MapAccess, Visitor},
7};
8#[cfg(feature = "ts_rs")]
9use ts_rs::TS;
10
11/// Represents the URL query parameters `page` and `limit`, used for paginating database queries.
12#[derive(Debug, Clone, Copy)]
13#[cfg_attr(feature = "ts_rs", derive(TS))]
14pub struct Pagination {
15    // the deserialize implementation contains a default value for page
16    #[cfg_attr(feature = "ts_rs", ts(type = "number | undefined"))]
17    page: u32,
18    // the deserialize implementation contains a default value for limit
19    #[cfg_attr(feature = "ts_rs", ts(type = "number | undefined"))]
20    limit: u32,
21}
22
23impl Pagination {
24    /// Errors on non-positive page or limit values.
25    pub fn new(page: u32, limit: u32) -> anyhow::Result<Self> {
26        if page == 0 {
27            bail!("Page must be a positive value.");
28        }
29        if limit == 0 {
30            bail!("Limit must be a positive value.");
31        }
32        if limit > 10_000 {
33            bail!("Limit can be at most 10000.")
34        }
35        Ok(Pagination { page, limit })
36    }
37
38    /// Guaranteed to be positive.
39    pub fn page(&self) -> i64 {
40        self.page.into()
41    }
42
43    /// Guaranteed to be positive.
44    pub fn limit(&self) -> i64 {
45        self.limit.into()
46    }
47
48    /// Guaranteed to be nonnegative.
49    pub fn offset(&self) -> i64 {
50        (self.limit * (self.page - 1)).into()
51    }
52
53    /// Guaranteed to be positive.
54    pub fn total_pages(&self, total_count: u32) -> u32 {
55        let remainder = total_count % self.limit;
56        if remainder == 0 {
57            total_count / self.limit
58        } else {
59            total_count / self.limit + 1
60        }
61    }
62
63    /// Helper to paginate an existing Vec efficiently.
64    pub fn paginate<T>(&self, v: &mut Vec<T>) {
65        let limit = self.limit as usize;
66        let start = limit * (self.page as usize - 1);
67        v.truncate(start + limit);
68        v.drain(..start);
69    }
70
71    pub fn next_page(&mut self) {
72        self.page += 1;
73    }
74}
75
76impl Default for Pagination {
77    fn default() -> Self {
78        Self {
79            page: 1,
80            limit: 100,
81        }
82    }
83}
84
85impl<'de> Deserialize<'de> for Pagination {
86    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
87    where
88        D: Deserializer<'de>,
89    {
90        struct PaginationVisitor;
91
92        impl<'de> Visitor<'de> for PaginationVisitor {
93            type Value = Pagination;
94
95            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
96                formatter.write_str("query parameters `page` and `limit`")
97            }
98
99            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
100            where
101                A: MapAccess<'de>,
102            {
103                let mut page = None;
104                let mut limit = None;
105                while let Some(key) = map.next_key().map_err(|e| {
106                    de::Error::custom(format!("Failed to deserialize map key: {}", e))
107                })? {
108                    match key {
109                        "page" => {
110                            if page.is_some() {
111                                return Err(de::Error::duplicate_field("page"));
112                            }
113                            let value: StrOrInt = map.next_value().map_err(|e| {
114                                de::Error::custom(format!(
115                                    "Failed to deserialize page value: {}",
116                                    e
117                                ))
118                            })?;
119                            let value = value.into_int().map_err(|e| {
120                                de::Error::custom(format!(
121                                    "Failed to deserialize page value: {}",
122                                    e
123                                ))
124                            })?;
125                            if value < 1 {
126                                return Err(de::Error::custom(
127                                    "query parameter `page` must be a positive integer",
128                                ));
129                            }
130                            page = Some(value);
131                        }
132                        "limit" => {
133                            if limit.is_some() {
134                                return Err(de::Error::duplicate_field("limit"));
135                            }
136                            let value: StrOrInt = map.next_value().map_err(|e| {
137                                de::Error::custom(format!(
138                                    "Failed to deserialize limit value: {}",
139                                    e
140                                ))
141                            })?;
142                            let value = value.into_int().map_err(|e| {
143                                de::Error::custom(format!(
144                                    "Failed to deserialize limit value: {}",
145                                    e
146                                ))
147                            })?;
148                            if !(1..=10000).contains(&value) {
149                                return Err(de::Error::custom(
150                                    "query parameter `limit` must be an integer between 1 and 10000",
151                                ));
152                            }
153                            limit = Some(value);
154                        }
155                        field => {
156                            return Err(de::Error::custom(format!(
157                                "unexpected parameter `{}`",
158                                field
159                            )));
160                        }
161                    }
162                }
163                Ok(Pagination {
164                    page: page.unwrap_or(Pagination::default().page),
165                    limit: limit.unwrap_or(Pagination::default().limit),
166                })
167            }
168        }
169
170        deserializer.deserialize_struct("Pagination", &["page", "limit"], PaginationVisitor)
171    }
172}
173
174// for some reason, it seems like when there are only numeric query parameters, actix gives them to serde as numbers,
175// but if there's a string mixed in, they are all given as strings. this helper is used to handle both cases
176#[derive(Debug, Deserialize)]
177#[serde(untagged)]
178enum StrOrInt<'a> {
179    Str(&'a str),
180    Int(u32),
181}
182
183impl StrOrInt<'_> {
184    fn into_int(self) -> Result<u32, ParseIntError> {
185        match self {
186            Self::Str(s) => s.parse(),
187            Self::Int(i) => Ok(i),
188        }
189    }
190}
191
192#[cfg(test)]
193mod test {
194    use super::*;
195
196    #[test]
197    fn paginates() {
198        let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8];
199        let pagination = Pagination::new(2, 3).unwrap();
200        pagination.paginate(&mut v);
201        assert_eq!(v, &[4, 5, 6]);
202    }
203
204    #[test]
205    fn paginates_non_existent_page() {
206        let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8];
207        let pagination = Pagination::new(3, 4).unwrap();
208        pagination.paginate(&mut v);
209        assert_eq!(v, &[] as &[i32]);
210    }
211
212    #[test]
213    fn paginates_incomplete_page() {
214        let mut v = vec![1, 2, 3, 4, 5, 6, 7, 8];
215        let pagination = Pagination::new(2, 5).unwrap();
216        pagination.paginate(&mut v);
217        assert_eq!(v, &[6, 7, 8]);
218    }
219}