1use std::{
2    cell::{Ref, RefCell},
3    collections::HashMap,
4    error::Error as StdError,
5    mem,
6    rc::Rc,
7};
8
9use actix_utils::future::{ready, Ready};
10use actix_web::{
11    body::BoxBody,
12    dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
13    error::Error,
14    FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
15};
16use anyhow::Context;
17use derive_more::derive::{Display, From};
18use serde::{de::DeserializeOwned, Serialize};
19
20#[derive(Clone)]
49pub struct Session(Rc<RefCell<SessionInner>>);
50
51#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub enum SessionStatus {
54    Changed,
56
57    Purged,
62
63    Renewed,
68
69    #[default]
71    Unchanged,
72}
73
74#[derive(Default)]
75struct SessionInner {
76    state: HashMap<String, String>,
77    status: SessionStatus,
78}
79
80impl Session {
81    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
85        if let Some(val_str) = self.0.borrow().state.get(key) {
86            Ok(Some(
87                serde_json::from_str(val_str)
88                    .with_context(|| {
89                        format!(
90                            "Failed to deserialize the JSON-encoded session data attached to key \
91                            `{}` as a `{}` type",
92                            key,
93                            std::any::type_name::<T>()
94                        )
95                    })
96                    .map_err(SessionGetError)?,
97            ))
98        } else {
99            Ok(None)
100        }
101    }
102
103    pub fn contains_key(&self, key: &str) -> bool {
105        self.0.borrow().state.contains_key(key)
106    }
107
108    pub fn entries(&self) -> Ref<'_, HashMap<String, String>> {
112        Ref::map(self.0.borrow(), |inner| &inner.state)
113    }
114
115    pub fn status(&self) -> SessionStatus {
117        Ref::map(self.0.borrow(), |inner| &inner.status).clone()
118    }
119
120    pub fn insert<T: Serialize>(
129        &self,
130        key: impl Into<String>,
131        value: T,
132    ) -> Result<(), SessionInsertError> {
133        let mut inner = self.0.borrow_mut();
134
135        if inner.status != SessionStatus::Purged {
136            if inner.status != SessionStatus::Renewed {
137                inner.status = SessionStatus::Changed;
138            }
139
140            let key = key.into();
141            let val = serde_json::to_string(&value)
142                .with_context(|| {
143                    format!(
144                        "Failed to serialize the provided `{}` type instance as JSON in order to \
145                        attach as session data to the `{key}` key",
146                        std::any::type_name::<T>(),
147                    )
148                })
149                .map_err(SessionInsertError)?;
150
151            inner.state.insert(key, val);
152        }
153
154        Ok(())
155    }
156
157    pub fn update<T: Serialize + DeserializeOwned, F>(
169        &self,
170        key: impl Into<String>,
171        updater: F,
172    ) -> Result<(), SessionUpdateError>
173    where
174        F: FnOnce(T) -> T,
175    {
176        let mut inner = self.0.borrow_mut();
177        let key_str = key.into();
178
179        if let Some(val_str) = inner.state.get(&key_str) {
180            let value = serde_json::from_str(val_str)
181                .with_context(|| {
182                    format!(
183                        "Failed to deserialize the JSON-encoded session data attached to key \
184                        `{key_str}` as a `{}` type",
185                        std::any::type_name::<T>()
186                    )
187                })
188                .map_err(SessionUpdateError)?;
189
190            let val = serde_json::to_string(&updater(value))
191                .with_context(|| {
192                    format!(
193                        "Failed to serialize the provided `{}` type instance as JSON in order to \
194                        attach as session data to the `{key_str}` key",
195                        std::any::type_name::<T>(),
196                    )
197                })
198                .map_err(SessionUpdateError)?;
199
200            inner.state.insert(key_str, val);
201        }
202
203        Ok(())
204    }
205
206    pub fn update_or<T: Serialize + DeserializeOwned, F>(
218        &self,
219        key: &str,
220        default_value: T,
221        updater: F,
222    ) -> Result<(), SessionUpdateError>
223    where
224        F: FnOnce(T) -> T,
225    {
226        if self.contains_key(key) {
227            self.update(key, updater)
228        } else {
229            self.insert(key, default_value)
230                .map_err(|err| SessionUpdateError(err.into()))
231        }
232    }
233
234    pub fn remove(&self, key: &str) -> Option<String> {
238        let mut inner = self.0.borrow_mut();
239
240        if inner.status != SessionStatus::Purged {
241            if inner.status != SessionStatus::Renewed {
242                inner.status = SessionStatus::Changed;
243            }
244            return inner.state.remove(key);
245        }
246
247        None
248    }
249
250    pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, String>> {
255        self.remove(key)
256            .map(|val_str| match serde_json::from_str(&val_str) {
257                Ok(val) => Ok(val),
258                Err(_err) => {
259                    tracing::debug!(
260                        "Removed value (key: {}) could not be deserialized as {}",
261                        key,
262                        std::any::type_name::<T>()
263                    );
264
265                    Err(val_str)
266                }
267            })
268    }
269
270    pub fn clear(&self) {
272        let mut inner = self.0.borrow_mut();
273
274        if inner.status != SessionStatus::Purged {
275            if inner.status != SessionStatus::Renewed {
276                inner.status = SessionStatus::Changed;
277            }
278            inner.state.clear()
279        }
280    }
281
282    pub fn purge(&self) {
284        let mut inner = self.0.borrow_mut();
285        inner.status = SessionStatus::Purged;
286        inner.state.clear();
287    }
288
289    pub fn renew(&self) {
291        let mut inner = self.0.borrow_mut();
292
293        if inner.status != SessionStatus::Purged {
294            inner.status = SessionStatus::Renewed;
295        }
296    }
297
298    #[allow(clippy::needless_pass_by_ref_mut)]
303    pub(crate) fn set_session(
304        req: &mut ServiceRequest,
305        data: impl IntoIterator<Item = (String, String)>,
306    ) {
307        let session = Session::get_session(&mut req.extensions_mut());
308        let mut inner = session.0.borrow_mut();
309        inner.state.extend(data);
310    }
311
312    #[allow(clippy::needless_pass_by_ref_mut)]
318    pub(crate) fn get_changes<B>(
319        res: &mut ServiceResponse<B>,
320    ) -> (SessionStatus, HashMap<String, String>) {
321        if let Some(s_impl) = res
322            .request()
323            .extensions()
324            .get::<Rc<RefCell<SessionInner>>>()
325        {
326            let state = mem::take(&mut s_impl.borrow_mut().state);
327            (s_impl.borrow().status.clone(), state)
328        } else {
329            (SessionStatus::Unchanged, HashMap::new())
330        }
331    }
332
333    pub(crate) fn get_session(extensions: &mut Extensions) -> Session {
334        if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
335            return Session(Rc::clone(s_impl));
336        }
337
338        let inner = Rc::new(RefCell::new(SessionInner::default()));
339        extensions.insert(inner.clone());
340
341        Session(inner)
342    }
343}
344
345impl FromRequest for Session {
366    type Error = Error;
367    type Future = Ready<Result<Session, Error>>;
368
369    #[inline]
370    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
371        ready(Ok(Session::get_session(&mut req.extensions_mut())))
372    }
373}
374
375#[derive(Debug, Display, From)]
377#[display("{_0}")]
378pub struct SessionGetError(anyhow::Error);
379
380impl StdError for SessionGetError {
381    fn source(&self) -> Option<&(dyn StdError + 'static)> {
382        Some(self.0.as_ref())
383    }
384}
385
386impl ResponseError for SessionGetError {
387    fn error_response(&self) -> HttpResponse<BoxBody> {
388        HttpResponse::new(self.status_code())
389    }
390}
391
392#[derive(Debug, Display, From)]
394#[display("{_0}")]
395pub struct SessionInsertError(anyhow::Error);
396
397impl StdError for SessionInsertError {
398    fn source(&self) -> Option<&(dyn StdError + 'static)> {
399        Some(self.0.as_ref())
400    }
401}
402
403impl ResponseError for SessionInsertError {
404    fn error_response(&self) -> HttpResponse<BoxBody> {
405        HttpResponse::new(self.status_code())
406    }
407}
408
409#[derive(Debug, Display, From)]
411#[display("{_0}")]
412pub struct SessionUpdateError(anyhow::Error);
413
414impl StdError for SessionUpdateError {
415    fn source(&self) -> Option<&(dyn StdError + 'static)> {
416        Some(self.0.as_ref())
417    }
418}
419
420impl ResponseError for SessionUpdateError {
421    fn error_response(&self) -> HttpResponse<BoxBody> {
422        HttpResponse::new(self.status_code())
423    }
424}