actix_session/
session.rs

1use std::{
2    cell::{Ref, RefCell},
3    collections::HashMap,
4    error::Error as StdError,
5    mem,
6    rc::Rc,
7};
8
9use actix_utils::future::{ready, Ready};
10use actix_web::{
11    body::BoxBody,
12    dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
13    error::Error,
14    FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
15};
16use anyhow::Context;
17use derive_more::derive::{Display, From};
18use serde::{de::DeserializeOwned, Serialize};
19
20/// The primary interface to access and modify session state.
21///
22/// [`Session`] is an [extractor](#impl-FromRequest)—you can specify it as an input type for your
23/// request handlers and it will be automatically extracted from the incoming request.
24///
25/// ```
26/// use actix_session::Session;
27///
28/// async fn index(session: Session) -> actix_web::Result<&'static str> {
29///     // access session data
30///     if let Some(count) = session.get::<i32>("counter")? {
31///         session.insert("counter", count + 1)?;
32///     } else {
33///         session.insert("counter", 1)?;
34///     }
35///
36///     Ok("Welcome!")
37/// }
38/// # actix_web::web::to(index);
39/// ```
40///
41/// You can also retrieve a [`Session`] object from an `HttpRequest` or a `ServiceRequest` using
42/// [`SessionExt`].
43///
44/// [`SessionExt`]: crate::SessionExt
45#[derive(Clone)]
46pub struct Session(Rc<RefCell<SessionInner>>);
47
48/// Status of a [`Session`].
49#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub enum SessionStatus {
51    /// Session state has been updated - the changes will have to be persisted to the backend.
52    Changed,
53
54    /// The session has been flagged for deletion - the session cookie will be removed from
55    /// the client and the session state will be deleted from the session store.
56    ///
57    /// Most operations on the session after it has been marked for deletion will have no effect.
58    Purged,
59
60    /// The session has been flagged for renewal.
61    ///
62    /// The session key will be regenerated and the time-to-live of the session state will be
63    /// extended.
64    Renewed,
65
66    /// The session state has not been modified since its creation/retrieval.
67    #[default]
68    Unchanged,
69}
70
71#[derive(Default)]
72struct SessionInner {
73    state: HashMap<String, String>,
74    status: SessionStatus,
75}
76
77impl Session {
78    /// Get a `value` from the session.
79    ///
80    /// It returns an error if it fails to deserialize as `T` the JSON value associated with `key`.
81    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
82        if let Some(val_str) = self.0.borrow().state.get(key) {
83            Ok(Some(
84                serde_json::from_str(val_str)
85                    .with_context(|| {
86                        format!(
87                            "Failed to deserialize the JSON-encoded session data attached to key \
88                            `{}` as a `{}` type",
89                            key,
90                            std::any::type_name::<T>()
91                        )
92                    })
93                    .map_err(SessionGetError)?,
94            ))
95        } else {
96            Ok(None)
97        }
98    }
99
100    /// Get all raw key-value data from the session.
101    ///
102    /// Note that values are JSON encoded.
103    pub fn entries(&self) -> Ref<'_, HashMap<String, String>> {
104        Ref::map(self.0.borrow(), |inner| &inner.state)
105    }
106
107    /// Returns session status.
108    pub fn status(&self) -> SessionStatus {
109        Ref::map(self.0.borrow(), |inner| &inner.status).clone()
110    }
111
112    /// Inserts a key-value pair into the session.
113    ///
114    /// Any serializable value can be used and will be encoded as JSON in session data, hence why
115    /// only a reference to the value is taken.
116    ///
117    /// It returns an error if it fails to serialize `value` to JSON.
118    pub fn insert<T: Serialize>(
119        &self,
120        key: impl Into<String>,
121        value: T,
122    ) -> Result<(), SessionInsertError> {
123        let mut inner = self.0.borrow_mut();
124
125        if inner.status != SessionStatus::Purged {
126            if inner.status != SessionStatus::Renewed {
127                inner.status = SessionStatus::Changed;
128            }
129
130            let key = key.into();
131            let val = serde_json::to_string(&value)
132                .with_context(|| {
133                    format!(
134                        "Failed to serialize the provided `{}` type instance as JSON in order to \
135                        attach as session data to the `{}` key",
136                        std::any::type_name::<T>(),
137                        &key
138                    )
139                })
140                .map_err(SessionInsertError)?;
141
142            inner.state.insert(key, val);
143        }
144
145        Ok(())
146    }
147
148    /// Remove value from the session.
149    ///
150    /// If present, the JSON encoded value is returned.
151    pub fn remove(&self, key: &str) -> Option<String> {
152        let mut inner = self.0.borrow_mut();
153
154        if inner.status != SessionStatus::Purged {
155            if inner.status != SessionStatus::Renewed {
156                inner.status = SessionStatus::Changed;
157            }
158            return inner.state.remove(key);
159        }
160
161        None
162    }
163
164    /// Remove value from the session and deserialize.
165    ///
166    /// Returns `None` if key was not present in session. Returns `T` if deserialization succeeds,
167    /// otherwise returns un-deserialized JSON string.
168    pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, String>> {
169        self.remove(key)
170            .map(|val_str| match serde_json::from_str(&val_str) {
171                Ok(val) => Ok(val),
172                Err(_err) => {
173                    tracing::debug!(
174                        "Removed value (key: {}) could not be deserialized as {}",
175                        key,
176                        std::any::type_name::<T>()
177                    );
178
179                    Err(val_str)
180                }
181            })
182    }
183
184    /// Clear the session.
185    pub fn clear(&self) {
186        let mut inner = self.0.borrow_mut();
187
188        if inner.status != SessionStatus::Purged {
189            if inner.status != SessionStatus::Renewed {
190                inner.status = SessionStatus::Changed;
191            }
192            inner.state.clear()
193        }
194    }
195
196    /// Removes session both client and server side.
197    pub fn purge(&self) {
198        let mut inner = self.0.borrow_mut();
199        inner.status = SessionStatus::Purged;
200        inner.state.clear();
201    }
202
203    /// Renews the session key, assigning existing session state to new key.
204    pub fn renew(&self) {
205        let mut inner = self.0.borrow_mut();
206
207        if inner.status != SessionStatus::Purged {
208            inner.status = SessionStatus::Renewed;
209        }
210    }
211
212    /// Adds the given key-value pairs to the session on the request.
213    ///
214    /// Values that match keys already existing on the session will be overwritten. Values should
215    /// already be JSON serialized.
216    #[allow(clippy::needless_pass_by_ref_mut)]
217    pub(crate) fn set_session(
218        req: &mut ServiceRequest,
219        data: impl IntoIterator<Item = (String, String)>,
220    ) {
221        let session = Session::get_session(&mut req.extensions_mut());
222        let mut inner = session.0.borrow_mut();
223        inner.state.extend(data);
224    }
225
226    /// Returns session status and iterator of key-value pairs of changes.
227    ///
228    /// This is a destructive operation - the session state is removed from the request extensions
229    /// typemap, leaving behind a new empty map. It should only be used when the session is being
230    /// finalised (i.e. in `SessionMiddleware`).
231    #[allow(clippy::needless_pass_by_ref_mut)]
232    pub(crate) fn get_changes<B>(
233        res: &mut ServiceResponse<B>,
234    ) -> (SessionStatus, HashMap<String, String>) {
235        if let Some(s_impl) = res
236            .request()
237            .extensions()
238            .get::<Rc<RefCell<SessionInner>>>()
239        {
240            let state = mem::take(&mut s_impl.borrow_mut().state);
241            (s_impl.borrow().status.clone(), state)
242        } else {
243            (SessionStatus::Unchanged, HashMap::new())
244        }
245    }
246
247    pub(crate) fn get_session(extensions: &mut Extensions) -> Session {
248        if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
249            return Session(Rc::clone(s_impl));
250        }
251
252        let inner = Rc::new(RefCell::new(SessionInner::default()));
253        extensions.insert(inner.clone());
254
255        Session(inner)
256    }
257}
258
259/// Extractor implementation for [`Session`]s.
260///
261/// # Examples
262/// ```
263/// # use actix_web::*;
264/// use actix_session::Session;
265///
266/// #[get("/")]
267/// async fn index(session: Session) -> Result<impl Responder> {
268///     // access session data
269///     if let Some(count) = session.get::<i32>("counter")? {
270///         session.insert("counter", count + 1)?;
271///     } else {
272///         session.insert("counter", 1)?;
273///     }
274///
275///     let count = session.get::<i32>("counter")?.unwrap();
276///     Ok(format!("Counter: {}", count))
277/// }
278/// ```
279impl FromRequest for Session {
280    type Error = Error;
281    type Future = Ready<Result<Session, Error>>;
282
283    #[inline]
284    fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
285        ready(Ok(Session::get_session(&mut req.extensions_mut())))
286    }
287}
288
289/// Error returned by [`Session::get`].
290#[derive(Debug, Display, From)]
291#[display("{_0}")]
292pub struct SessionGetError(anyhow::Error);
293
294impl StdError for SessionGetError {
295    fn source(&self) -> Option<&(dyn StdError + 'static)> {
296        Some(self.0.as_ref())
297    }
298}
299
300impl ResponseError for SessionGetError {
301    fn error_response(&self) -> HttpResponse<BoxBody> {
302        HttpResponse::new(self.status_code())
303    }
304}
305
306/// Error returned by [`Session::insert`].
307#[derive(Debug, Display, From)]
308#[display("{_0}")]
309pub struct SessionInsertError(anyhow::Error);
310
311impl StdError for SessionInsertError {
312    fn source(&self) -> Option<&(dyn StdError + 'static)> {
313        Some(self.0.as_ref())
314    }
315}
316
317impl ResponseError for SessionInsertError {
318    fn error_response(&self) -> HttpResponse<BoxBody> {
319        HttpResponse::new(self.status_code())
320    }
321}