1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5#[derive(Clone, Debug)]
7pub struct Context {
8    type_map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
9}
10
11impl Default for Context {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl Context {
18    pub fn new() -> Self {
20        Self {
21            type_map: HashMap::new(),
22        }
23    }
24
25    pub fn insert_or_replace<E>(&mut self, entity: E) -> Option<Arc<E>>
28    where
29        E: Send + Sync + 'static,
30    {
31        self.type_map
34            .insert(TypeId::of::<E>(), Arc::new(entity))
35            .map(|displaced| displaced.downcast().expect("failed to unwrap downcast"))
36    }
37
38    pub fn insert<E>(&mut self, entity: E) -> &mut Self
42    where
43        E: Send + Sync + 'static,
44    {
45        self.type_map.insert(TypeId::of::<E>(), Arc::new(entity));
46
47        self
48    }
49
50    pub fn remove<E>(&mut self) -> Option<Arc<E>>
52    where
53        E: Send + Sync + 'static,
54    {
55        self.type_map
56            .remove(&TypeId::of::<E>())
57            .map(|removed| removed.downcast().expect("failed to unwrap downcast"))
58    }
59
60    pub fn get<E>(&self) -> Option<&E>
64    where
65        E: Send + Sync + 'static,
66    {
67        self.type_map
68            .get(&TypeId::of::<E>())
69            .and_then(|item| item.downcast_ref())
70    }
71
72    pub fn len(&self) -> usize {
74        self.type_map.len()
75    }
76
77    pub fn is_empty(&self) -> bool {
79        self.type_map.is_empty()
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use std::sync::Mutex;
87
88    #[test]
89    fn insert_get_string() {
90        let mut context = Context::new();
91        context.insert_or_replace("pollo".to_string());
92        assert_eq!(Some(&"pollo".to_string()), context.get());
93    }
94
95    #[test]
96    fn insert_get_custom_structs() {
97        #[derive(Debug, PartialEq, Eq)]
98        struct S1 {}
99        #[derive(Debug, PartialEq, Eq)]
100        struct S2 {}
101
102        let mut context = Context::new();
103        context.insert_or_replace(S1 {});
104        context.insert_or_replace(S2 {});
105
106        assert_eq!(Some(Arc::new(S1 {})), context.insert_or_replace(S1 {}));
107        assert_eq!(Some(Arc::new(S2 {})), context.insert_or_replace(S2 {}));
108
109        assert_eq!(Some(&S1 {}), context.get());
110        assert_eq!(Some(&S2 {}), context.get());
111    }
112
113    #[test]
114    fn insert_fluent_syntax() {
115        #[derive(Debug, PartialEq, Eq, Default)]
116        struct S1 {}
117        #[derive(Debug, PartialEq, Eq, Default)]
118        struct S2 {}
119
120        let mut context = Context::new();
121
122        context
123            .insert("static str")
124            .insert("a String".to_string())
125            .insert(S1::default())
126            .insert(S1::default()) .insert(S2::default());
128
129        assert_eq!(4, context.len());
130        assert_eq!(Some(&"static str"), context.get());
131    }
132
133    fn require_send_sync<T: Send + Sync>(_: &T) {}
134
135    #[test]
136    fn test_require_send_sync() {
137        require_send_sync(&Context::new());
139    }
140
141    #[test]
142    fn mutability() {
143        #[derive(Debug, PartialEq, Eq, Default)]
144        struct S1 {
145            num: u8,
146        }
147        let mut context = Context::new();
148        context.insert_or_replace(Mutex::new(S1::default()));
149
150        assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
152
153        context.get::<Mutex<S1>>().unwrap().lock().unwrap().num = 42;
155
156        assert_eq!(42, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
158
159        let displaced = context
161            .insert_or_replace(Mutex::new(S1::default()))
162            .unwrap();
163
164        assert_eq!(42, displaced.lock().unwrap().num);
166
167        assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
169
170        context.insert_or_replace(Mutex::new(33u32));
171        *context.get::<Mutex<u32>>().unwrap().lock().unwrap() = 42;
172        assert_eq!(42, *context.get::<Mutex<u32>>().unwrap().lock().unwrap());
173    }
174}