1use std::{
2 cell::{Ref, RefCell},
3 collections::HashMap,
4 error::Error as StdError,
5 mem,
6 rc::Rc,
7};
8
9use actix_utils::future::{ready, Ready};
10use actix_web::{
11 body::BoxBody,
12 dev::{Extensions, Payload, ServiceRequest, ServiceResponse},
13 error::Error,
14 FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError,
15};
16use anyhow::Context;
17use derive_more::derive::{Display, From};
18use serde::{de::DeserializeOwned, Serialize};
19
20#[derive(Clone)]
49pub struct Session(Rc<RefCell<SessionInner>>);
50
51#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub enum SessionStatus {
54 Changed,
56
57 Purged,
62
63 Renewed,
68
69 #[default]
71 Unchanged,
72}
73
74#[derive(Default)]
75struct SessionInner {
76 state: HashMap<String, String>,
77 status: SessionStatus,
78}
79
80impl Session {
81 pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, SessionGetError> {
85 if let Some(val_str) = self.0.borrow().state.get(key) {
86 Ok(Some(
87 serde_json::from_str(val_str)
88 .with_context(|| {
89 format!(
90 "Failed to deserialize the JSON-encoded session data attached to key \
91 `{}` as a `{}` type",
92 key,
93 std::any::type_name::<T>()
94 )
95 })
96 .map_err(SessionGetError)?,
97 ))
98 } else {
99 Ok(None)
100 }
101 }
102
103 pub fn contains_key(&self, key: &str) -> bool {
105 self.0.borrow().state.contains_key(key)
106 }
107
108 pub fn entries(&self) -> Ref<'_, HashMap<String, String>> {
112 Ref::map(self.0.borrow(), |inner| &inner.state)
113 }
114
115 pub fn status(&self) -> SessionStatus {
117 Ref::map(self.0.borrow(), |inner| &inner.status).clone()
118 }
119
120 pub fn insert<T: Serialize>(
129 &self,
130 key: impl Into<String>,
131 value: T,
132 ) -> Result<(), SessionInsertError> {
133 let mut inner = self.0.borrow_mut();
134
135 if inner.status != SessionStatus::Purged {
136 if inner.status != SessionStatus::Renewed {
137 inner.status = SessionStatus::Changed;
138 }
139
140 let key = key.into();
141 let val = serde_json::to_string(&value)
142 .with_context(|| {
143 format!(
144 "Failed to serialize the provided `{}` type instance as JSON in order to \
145 attach as session data to the `{key}` key",
146 std::any::type_name::<T>(),
147 )
148 })
149 .map_err(SessionInsertError)?;
150
151 inner.state.insert(key, val);
152 }
153
154 Ok(())
155 }
156
157 pub fn update<T: Serialize + DeserializeOwned, F>(
169 &self,
170 key: impl Into<String>,
171 updater: F,
172 ) -> Result<(), SessionUpdateError>
173 where
174 F: FnOnce(T) -> T,
175 {
176 let mut inner = self.0.borrow_mut();
177 let key_str = key.into();
178
179 if let Some(val_str) = inner.state.get(&key_str) {
180 let value = serde_json::from_str(val_str)
181 .with_context(|| {
182 format!(
183 "Failed to deserialize the JSON-encoded session data attached to key \
184 `{key_str}` as a `{}` type",
185 std::any::type_name::<T>()
186 )
187 })
188 .map_err(SessionUpdateError)?;
189
190 let val = serde_json::to_string(&updater(value))
191 .with_context(|| {
192 format!(
193 "Failed to serialize the provided `{}` type instance as JSON in order to \
194 attach as session data to the `{key_str}` key",
195 std::any::type_name::<T>(),
196 )
197 })
198 .map_err(SessionUpdateError)?;
199
200 inner.state.insert(key_str, val);
201 }
202
203 Ok(())
204 }
205
206 pub fn update_or<T: Serialize + DeserializeOwned, F>(
218 &self,
219 key: &str,
220 default_value: T,
221 updater: F,
222 ) -> Result<(), SessionUpdateError>
223 where
224 F: FnOnce(T) -> T,
225 {
226 if self.contains_key(key) {
227 self.update(key, updater)
228 } else {
229 self.insert(key, default_value)
230 .map_err(|err| SessionUpdateError(err.into()))
231 }
232 }
233
234 pub fn remove(&self, key: &str) -> Option<String> {
238 let mut inner = self.0.borrow_mut();
239
240 if inner.status != SessionStatus::Purged {
241 if inner.status != SessionStatus::Renewed {
242 inner.status = SessionStatus::Changed;
243 }
244 return inner.state.remove(key);
245 }
246
247 None
248 }
249
250 pub fn remove_as<T: DeserializeOwned>(&self, key: &str) -> Option<Result<T, String>> {
255 self.remove(key)
256 .map(|val_str| match serde_json::from_str(&val_str) {
257 Ok(val) => Ok(val),
258 Err(_err) => {
259 tracing::debug!(
260 "Removed value (key: {}) could not be deserialized as {}",
261 key,
262 std::any::type_name::<T>()
263 );
264
265 Err(val_str)
266 }
267 })
268 }
269
270 pub fn clear(&self) {
272 let mut inner = self.0.borrow_mut();
273
274 if inner.status != SessionStatus::Purged {
275 if inner.status != SessionStatus::Renewed {
276 inner.status = SessionStatus::Changed;
277 }
278 inner.state.clear()
279 }
280 }
281
282 pub fn purge(&self) {
284 let mut inner = self.0.borrow_mut();
285 inner.status = SessionStatus::Purged;
286 inner.state.clear();
287 }
288
289 pub fn renew(&self) {
291 let mut inner = self.0.borrow_mut();
292
293 if inner.status != SessionStatus::Purged {
294 inner.status = SessionStatus::Renewed;
295 }
296 }
297
298 #[allow(clippy::needless_pass_by_ref_mut)]
303 pub(crate) fn set_session(
304 req: &mut ServiceRequest,
305 data: impl IntoIterator<Item = (String, String)>,
306 ) {
307 let session = Session::get_session(&mut req.extensions_mut());
308 let mut inner = session.0.borrow_mut();
309 inner.state.extend(data);
310 }
311
312 #[allow(clippy::needless_pass_by_ref_mut)]
318 pub(crate) fn get_changes<B>(
319 res: &mut ServiceResponse<B>,
320 ) -> (SessionStatus, HashMap<String, String>) {
321 if let Some(s_impl) = res
322 .request()
323 .extensions()
324 .get::<Rc<RefCell<SessionInner>>>()
325 {
326 let state = mem::take(&mut s_impl.borrow_mut().state);
327 (s_impl.borrow().status.clone(), state)
328 } else {
329 (SessionStatus::Unchanged, HashMap::new())
330 }
331 }
332
333 pub(crate) fn get_session(extensions: &mut Extensions) -> Session {
334 if let Some(s_impl) = extensions.get::<Rc<RefCell<SessionInner>>>() {
335 return Session(Rc::clone(s_impl));
336 }
337
338 let inner = Rc::new(RefCell::new(SessionInner::default()));
339 extensions.insert(inner.clone());
340
341 Session(inner)
342 }
343}
344
345impl FromRequest for Session {
366 type Error = Error;
367 type Future = Ready<Result<Session, Error>>;
368
369 #[inline]
370 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
371 ready(Ok(Session::get_session(&mut req.extensions_mut())))
372 }
373}
374
375#[derive(Debug, Display, From)]
377#[display("{_0}")]
378pub struct SessionGetError(anyhow::Error);
379
380impl StdError for SessionGetError {
381 fn source(&self) -> Option<&(dyn StdError + 'static)> {
382 Some(self.0.as_ref())
383 }
384}
385
386impl ResponseError for SessionGetError {
387 fn error_response(&self) -> HttpResponse<BoxBody> {
388 HttpResponse::new(self.status_code())
389 }
390}
391
392#[derive(Debug, Display, From)]
394#[display("{_0}")]
395pub struct SessionInsertError(anyhow::Error);
396
397impl StdError for SessionInsertError {
398 fn source(&self) -> Option<&(dyn StdError + 'static)> {
399 Some(self.0.as_ref())
400 }
401}
402
403impl ResponseError for SessionInsertError {
404 fn error_response(&self) -> HttpResponse<BoxBody> {
405 HttpResponse::new(self.status_code())
406 }
407}
408
409#[derive(Debug, Display, From)]
411#[display("{_0}")]
412pub struct SessionUpdateError(anyhow::Error);
413
414impl StdError for SessionUpdateError {
415 fn source(&self) -> Option<&(dyn StdError + 'static)> {
416 Some(self.0.as_ref())
417 }
418}
419
420impl ResponseError for SessionUpdateError {
421 fn error_response(&self) -> HttpResponse<BoxBody> {
422 HttpResponse::new(self.status_code())
423 }
424}