1use std::{
2 cell::{Ref, RefCell},
3 collections::HashMap,
4 error::Error as StdError,
5 mem,
6 rc::Rc,
7};
8
9use actix_utils::future::{ready, Ready};
10use actix_web::{
11 body::BoxBody,
12 dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
13 error::Error,
14 FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
15};
16use anyhow::Context;
17use derive_more::derive::{Display, From};
18use serde::{de::DeserializeOwned, Serialize};
19
20#[derive(Clone)]
46pub struct Session(Rc<RefCell<SessionInner>>);
47
48#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub enum SessionStatus {
51 Changed,
53
54 Purged,
59
60 Renewed,
65
66 #[default]
68 Unchanged,
69}
70
71#[derive(Default)]
72struct SessionInner {
73 state: HashMap<String, String>,
74 status: SessionStatus,
75}
76
77impl Session {
78 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
82 if let Some(val_str) = self.0.borrow().state.get(key) {
83 Ok(Some(
84 serde_json::from_str(val_str)
85 .with_context(|| {
86 format!(
87 "Failed to deserialize the JSON-encoded session data attached to key \
88 `{}` as a `{}` type",
89 key,
90 std::any::type_name::<T>()
91 )
92 })
93 .map_err(SessionGetError)?,
94 ))
95 } else {
96 Ok(None)
97 }
98 }
99
100 pub fn entries(&self) -> Ref<'_, HashMap<String, String>> {
104 Ref::map(self.0.borrow(), |inner| &inner.state)
105 }
106
107 pub fn status(&self) -> SessionStatus {
109 Ref::map(self.0.borrow(), |inner| &inner.status).clone()
110 }
111
112 pub fn insert<T: Serialize>(
119 &self,
120 key: impl Into<String>,
121 value: T,
122 ) -> Result<(), SessionInsertError> {
123 let mut inner = self.0.borrow_mut();
124
125 if inner.status != SessionStatus::Purged {
126 if inner.status != SessionStatus::Renewed {
127 inner.status = SessionStatus::Changed;
128 }
129
130 let key = key.into();
131 let val = serde_json::to_string(&value)
132 .with_context(|| {
133 format!(
134 "Failed to serialize the provided `{}` type instance as JSON in order to \
135 attach as session data to the `{}` key",
136 std::any::type_name::<T>(),
137 &key
138 )
139 })
140 .map_err(SessionInsertError)?;
141
142 inner.state.insert(key, val);
143 }
144
145 Ok(())
146 }
147
148 pub fn remove(&self, key: &str) -> Option<String> {
152 let mut inner = self.0.borrow_mut();
153
154 if inner.status != SessionStatus::Purged {
155 if inner.status != SessionStatus::Renewed {
156 inner.status = SessionStatus::Changed;
157 }
158 return inner.state.remove(key);
159 }
160
161 None
162 }
163
164 pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, String>> {
169 self.remove(key)
170 .map(|val_str| match serde_json::from_str(&val_str) {
171 Ok(val) => Ok(val),
172 Err(_err) => {
173 tracing::debug!(
174 "Removed value (key: {}) could not be deserialized as {}",
175 key,
176 std::any::type_name::<T>()
177 );
178
179 Err(val_str)
180 }
181 })
182 }
183
184 pub fn clear(&self) {
186 let mut inner = self.0.borrow_mut();
187
188 if inner.status != SessionStatus::Purged {
189 if inner.status != SessionStatus::Renewed {
190 inner.status = SessionStatus::Changed;
191 }
192 inner.state.clear()
193 }
194 }
195
196 pub fn purge(&self) {
198 let mut inner = self.0.borrow_mut();
199 inner.status = SessionStatus::Purged;
200 inner.state.clear();
201 }
202
203 pub fn renew(&self) {
205 let mut inner = self.0.borrow_mut();
206
207 if inner.status != SessionStatus::Purged {
208 inner.status = SessionStatus::Renewed;
209 }
210 }
211
212 #[allow(clippy::needless_pass_by_ref_mut)]
217 pub(crate) fn set_session(
218 req: &mut ServiceRequest,
219 data: impl IntoIterator<Item = (String, String)>,
220 ) {
221 let session = Session::get_session(&mut req.extensions_mut());
222 let mut inner = session.0.borrow_mut();
223 inner.state.extend(data);
224 }
225
226 #[allow(clippy::needless_pass_by_ref_mut)]
232 pub(crate) fn get_changes<B>(
233 res: &mut ServiceResponse<B>,
234 ) -> (SessionStatus, HashMap<String, String>) {
235 if let Some(s_impl) = res
236 .request()
237 .extensions()
238 .get::<Rc<RefCell<SessionInner>>>()
239 {
240 let state = mem::take(&mut s_impl.borrow_mut().state);
241 (s_impl.borrow().status.clone(), state)
242 } else {
243 (SessionStatus::Unchanged, HashMap::new())
244 }
245 }
246
247 pub(crate) fn get_session(extensions: &mut Extensions) -> Session {
248 if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
249 return Session(Rc::clone(s_impl));
250 }
251
252 let inner = Rc::new(RefCell::new(SessionInner::default()));
253 extensions.insert(inner.clone());
254
255 Session(inner)
256 }
257}
258
259impl FromRequest for Session {
280 type Error = Error;
281 type Future = Ready<Result<Session, Error>>;
282
283 #[inline]
284 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
285 ready(Ok(Session::get_session(&mut req.extensions_mut())))
286 }
287}
288
289#[derive(Debug, Display, From)]
291#[display("{_0}")]
292pub struct SessionGetError(anyhow::Error);
293
294impl StdError for SessionGetError {
295 fn source(&self) -> Option<&(dyn StdError + 'static)> {
296 Some(self.0.as_ref())
297 }
298}
299
300impl ResponseError for SessionGetError {
301 fn error_response(&self) -> HttpResponse<BoxBody> {
302 HttpResponse::new(self.status_code())
303 }
304}
305
306#[derive(Debug, Display, From)]
308#[display("{_0}")]
309pub struct SessionInsertError(anyhow::Error);
310
311impl StdError for SessionInsertError {
312 fn source(&self) -> Option<&(dyn StdError + 'static)> {
313 Some(self.0.as_ref())
314 }
315}
316
317impl ResponseError for SessionInsertError {
318 fn error_response(&self) -> HttpResponse<BoxBody> {
319 HttpResponse::new(self.status_code())
320 }
321}