actix_session/
session.rs

1use std::{
2    cell::{Ref, RefCell},
3    collections::HashMap,
4    error::Error as StdError,
5    mem,
6    rc::Rc,
7};
8
9use actix_utils::future::{ready, Ready};
10use actix_web::{
11    body::BoxBody,
12    dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
13    error::Error,
14    FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
15};
16use anyhow::Context;
17use derive_more::derive::{Display, From};
18use serde::{de::DeserializeOwned, Serialize};
19
20/// The primary interface to access and modify session state.
21///
22/// [`Session`] is an [extractor](#impl-FromRequest)—you can specify it as an input type for your
23/// request handlers and it will be automatically extracted from the incoming request.
24///
25/// ```
26/// use actix_session::Session;
27///
28/// async fn index(session: Session) -> actix_web::Result<&'static str> {
29///     // access session data
30///     if let Some(count) = session.get::<i32>("counter")? {
31///         session.insert("counter", count + 1)?;
32///     } else {
33///         session.insert("counter", 1)?;
34///     }
35///
36///     // or use the shorthand
37///     session.update_or("counter", 1, |count: i32| count + 1);
38///
39///     Ok("Welcome!")
40/// }
41/// # actix_web::web::to(index);
42/// ```
43///
44/// You can also retrieve a [`Session`] object from an `HttpRequest` or a `ServiceRequest` using
45/// [`SessionExt`].
46///
47/// [`SessionExt`]: crate::SessionExt
48#[derive(Clone)]
49pub struct Session(Rc<RefCell<SessionInner>>);
50
51/// Status of a [`Session`].
52#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub enum SessionStatus {
54    /// Session state has been updated - the changes will have to be persisted to the backend.
55    Changed,
56
57    /// The session has been flagged for deletion - the session cookie will be removed from
58    /// the client and the session state will be deleted from the session store.
59    ///
60    /// Most operations on the session after it has been marked for deletion will have no effect.
61    Purged,
62
63    /// The session has been flagged for renewal.
64    ///
65    /// The session key will be regenerated and the time-to-live of the session state will be
66    /// extended.
67    Renewed,
68
69    /// The session state has not been modified since its creation/retrieval.
70    #[default]
71    Unchanged,
72}
73
74#[derive(Default)]
75struct SessionInner {
76    state: HashMap<String, String>,
77    status: SessionStatus,
78}
79
80impl Session {
81    /// Get a `value` from the session.
82    ///
83    /// It returns an error if it fails to deserialize as `T` the JSON value associated with `key`.
84    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
85        if let Some(val_str) = self.0.borrow().state.get(key) {
86            Ok(Some(
87                serde_json::from_str(val_str)
88                    .with_context(|| {
89                        format!(
90                            "Failed to deserialize the JSON-encoded session data attached to key \
91                            `{}` as a `{}` type",
92                            key,
93                            std::any::type_name::<T>()
94                        )
95                    })
96                    .map_err(SessionGetError)?,
97            ))
98        } else {
99            Ok(None)
100        }
101    }
102
103    /// Returns `true` if the session contains a value for the specified `key`.
104    pub fn contains_key(&self, key: &str) -> bool {
105        self.0.borrow().state.contains_key(key)
106    }
107
108    /// Get all raw key-value data from the session.
109    ///
110    /// Note that values are JSON encoded.
111    pub fn entries(&self) -> Ref<'_, HashMap<String, String>> {
112        Ref::map(self.0.borrow(), |inner| &inner.state)
113    }
114
115    /// Returns session status.
116    pub fn status(&self) -> SessionStatus {
117        Ref::map(self.0.borrow(), |inner| &inner.status).clone()
118    }
119
120    /// Inserts a key-value pair into the session.
121    ///
122    /// Any serializable value can be used and will be encoded as JSON in session data, hence why
123    /// only a reference to the value is taken.
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if JSON serialization of `value` fails.
128    pub fn insert<T: Serialize>(
129        &self,
130        key: impl Into<String>,
131        value: T,
132    ) -> Result<(), SessionInsertError> {
133        let mut inner = self.0.borrow_mut();
134
135        if inner.status != SessionStatus::Purged {
136            if inner.status != SessionStatus::Renewed {
137                inner.status = SessionStatus::Changed;
138            }
139
140            let key = key.into();
141            let val = serde_json::to_string(&value)
142                .with_context(|| {
143                    format!(
144                        "Failed to serialize the provided `{}` type instance as JSON in order to \
145                        attach as session data to the `{key}` key",
146                        std::any::type_name::<T>(),
147                    )
148                })
149                .map_err(SessionInsertError)?;
150
151            inner.state.insert(key, val);
152        }
153
154        Ok(())
155    }
156
157    /// Updates a key-value pair into the session.
158    ///
159    /// If the key exists then update it to the new value and place it back in. If the key does not
160    /// exist it will not be updated.
161    ///
162    /// Any serializable value can be used and will be encoded as JSON in the session data, hence
163    /// why only a reference to the value is taken.
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if JSON serialization of the value fails.
168    pub fn update<T: Serialize + DeserializeOwned, F>(
169        &self,
170        key: impl Into<String>,
171        updater: F,
172    ) -> Result<(), SessionUpdateError>
173    where
174        F: FnOnce(T) -> T,
175    {
176        let mut inner = self.0.borrow_mut();
177        let key_str = key.into();
178
179        if let Some(val_str) = inner.state.get(&key_str) {
180            let value = serde_json::from_str(val_str)
181                .with_context(|| {
182                    format!(
183                        "Failed to deserialize the JSON-encoded session data attached to key \
184                        `{key_str}` as a `{}` type",
185                        std::any::type_name::<T>()
186                    )
187                })
188                .map_err(SessionUpdateError)?;
189
190            let val = serde_json::to_string(&updater(value))
191                .with_context(|| {
192                    format!(
193                        "Failed to serialize the provided `{}` type instance as JSON in order to \
194                        attach as session data to the `{key_str}` key",
195                        std::any::type_name::<T>(),
196                    )
197                })
198                .map_err(SessionUpdateError)?;
199
200            inner.state.insert(key_str, val);
201        }
202
203        Ok(())
204    }
205
206    /// Updates a key-value pair into the session, or inserts a default value.
207    ///
208    /// If the key exists then update it to the new value and place it back in. If the key does not
209    /// exist the default value will be inserted instead.
210    ///
211    /// Any serializable value can be used and will be encoded as JSON in session data, hence why
212    /// only a reference to the value is taken.
213    ///
214    /// # Errors
215    ///
216    /// Returns error if JSON serialization of a value fails.
217    pub fn update_or<T: Serialize + DeserializeOwned, F>(
218        &self,
219        key: &str,
220        default_value: T,
221        updater: F,
222    ) -> Result<(), SessionUpdateError>
223    where
224        F: FnOnce(T) -> T,
225    {
226        if self.contains_key(key) {
227            self.update(key, updater)
228        } else {
229            self.insert(key, default_value)
230                .map_err(|err| SessionUpdateError(err.into()))
231        }
232    }
233
234    /// Remove value from the session.
235    ///
236    /// If present, the JSON encoded value is returned.
237    pub fn remove(&self, key: &str) -> Option<String> {
238        let mut inner = self.0.borrow_mut();
239
240        if inner.status != SessionStatus::Purged {
241            if inner.status != SessionStatus::Renewed {
242                inner.status = SessionStatus::Changed;
243            }
244            return inner.state.remove(key);
245        }
246
247        None
248    }
249
250    /// Remove value from the session and deserialize.
251    ///
252    /// Returns `None` if key was not present in session. Returns `T` if deserialization succeeds,
253    /// otherwise returns un-deserialized JSON string.
254    pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, String>> {
255        self.remove(key)
256            .map(|val_str| match serde_json::from_str(&val_str) {
257                Ok(val) => Ok(val),
258                Err(_err) => {
259                    tracing::debug!(
260                        "Removed value (key: {}) could not be deserialized as {}",
261                        key,
262                        std::any::type_name::<T>()
263                    );
264
265                    Err(val_str)
266                }
267            })
268    }
269
270    /// Clear the session.
271    pub fn clear(&self) {
272        let mut inner = self.0.borrow_mut();
273
274        if inner.status != SessionStatus::Purged {
275            if inner.status != SessionStatus::Renewed {
276                inner.status = SessionStatus::Changed;
277            }
278            inner.state.clear()
279        }
280    }
281
282    /// Removes session both client and server side.
283    pub fn purge(&self) {
284        let mut inner = self.0.borrow_mut();
285        inner.status = SessionStatus::Purged;
286        inner.state.clear();
287    }
288
289    /// Renews the session key, assigning existing session state to new key.
290    pub fn renew(&self) {
291        let mut inner = self.0.borrow_mut();
292
293        if inner.status != SessionStatus::Purged {
294            inner.status = SessionStatus::Renewed;
295        }
296    }
297
298    /// Adds the given key-value pairs to the session on the request.
299    ///
300    /// Values that match keys already existing on the session will be overwritten. Values should
301    /// already be JSON serialized.
302    #[allow(clippy::needless_pass_by_ref_mut)]
303    pub(crate) fn set_session(
304        req: &mut ServiceRequest,
305        data: impl IntoIterator<Item = (String, String)>,
306    ) {
307        let session = Session::get_session(&mut req.extensions_mut());
308        let mut inner = session.0.borrow_mut();
309        inner.state.extend(data);
310    }
311
312    /// Returns session status and iterator of key-value pairs of changes.
313    ///
314    /// This is a destructive operation - the session state is removed from the request extensions
315    /// typemap, leaving behind a new empty map. It should only be used when the session is being
316    /// finalised (i.e. in `SessionMiddleware`).
317    #[allow(clippy::needless_pass_by_ref_mut)]
318    pub(crate) fn get_changes<B>(
319        res: &mut ServiceResponse<B>,
320    ) -> (SessionStatus, HashMap<String, String>) {
321        if let Some(s_impl) = res
322            .request()
323            .extensions()
324            .get::<Rc<RefCell<SessionInner>>>()
325        {
326            let state = mem::take(&mut s_impl.borrow_mut().state);
327            (s_impl.borrow().status.clone(), state)
328        } else {
329            (SessionStatus::Unchanged, HashMap::new())
330        }
331    }
332
333    pub(crate) fn get_session(extensions: &mut Extensions) -> Session {
334        if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
335            return Session(Rc::clone(s_impl));
336        }
337
338        let inner = Rc::new(RefCell::new(SessionInner::default()));
339        extensions.insert(inner.clone());
340
341        Session(inner)
342    }
343}
344
345/// Extractor implementation for [`Session`]s.
346///
347/// # Examples
348/// ```
349/// # use actix_web::*;
350/// use actix_session::Session;
351///
352/// #[get("/")]
353/// async fn index(session: Session) -> Result<impl Responder> {
354///     // access session data
355///     if let Some(count) = session.get::<i32>("counter")? {
356///         session.insert("counter", count + 1)?;
357///     } else {
358///         session.insert("counter", 1)?;
359///     }
360///
361///     let count = session.get::<i32>("counter")?.unwrap();
362///     Ok(format!("Counter: {}", count))
363/// }
364/// ```
365impl FromRequest for Session {
366    type Error = Error;
367    type Future = Ready<Result<Session, Error>>;
368
369    #[inline]
370    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
371        ready(Ok(Session::get_session(&mut req.extensions_mut())))
372    }
373}
374
375/// Error returned by [`Session::get`].
376#[derive(Debug, Display, From)]
377#[display("{_0}")]
378pub struct SessionGetError(anyhow::Error);
379
380impl StdError for SessionGetError {
381    fn source(&self) -> Option<&(dyn StdError + 'static)> {
382        Some(self.0.as_ref())
383    }
384}
385
386impl ResponseError for SessionGetError {
387    fn error_response(&self) -> HttpResponse<BoxBody> {
388        HttpResponse::new(self.status_code())
389    }
390}
391
392/// Error returned by [`Session::insert`].
393#[derive(Debug, Display, From)]
394#[display("{_0}")]
395pub struct SessionInsertError(anyhow::Error);
396
397impl StdError for SessionInsertError {
398    fn source(&self) -> Option<&(dyn StdError + 'static)> {
399        Some(self.0.as_ref())
400    }
401}
402
403impl ResponseError for SessionInsertError {
404    fn error_response(&self) -> HttpResponse<BoxBody> {
405        HttpResponse::new(self.status_code())
406    }
407}
408
409/// Error returned by [`Session::update`].
410#[derive(Debug, Display, From)]
411#[display("{_0}")]
412pub struct SessionUpdateError(anyhow::Error);
413
414impl StdError for SessionUpdateError {
415    fn source(&self) -> Option<&(dyn StdError + 'static)> {
416        Some(self.0.as_ref())
417    }
418}
419
420impl ResponseError for SessionUpdateError {
421    fn error_response(&self) -> HttpResponse<BoxBody> {
422        HttpResponse::new(self.status_code())
423    }
424}