1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::Arc;
4
5#[derive(Clone, Debug)]
7pub struct Context {
8 type_map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
9}
10
11impl Default for Context {
12 fn default() -> Self {
13 Self::new()
14 }
15}
16
17impl Context {
18 pub fn new() -> Self {
20 Self {
21 type_map: HashMap::new(),
22 }
23 }
24
25 pub fn insert_or_replace<E>(&mut self, entity: E) -> Option<Arc<E>>
28 where
29 E: Send + Sync + 'static,
30 {
31 self.type_map
34 .insert(TypeId::of::<E>(), Arc::new(entity))
35 .map(|displaced| displaced.downcast().expect("failed to unwrap downcast"))
36 }
37
38 pub fn insert<E>(&mut self, entity: E) -> &mut Self
42 where
43 E: Send + Sync + 'static,
44 {
45 self.type_map.insert(TypeId::of::<E>(), Arc::new(entity));
46
47 self
48 }
49
50 pub fn remove<E>(&mut self) -> Option<Arc<E>>
52 where
53 E: Send + Sync + 'static,
54 {
55 self.type_map
56 .remove(&TypeId::of::<E>())
57 .map(|removed| removed.downcast().expect("failed to unwrap downcast"))
58 }
59
60 pub fn get<E>(&self) -> Option<&E>
64 where
65 E: Send + Sync + 'static,
66 {
67 self.type_map
68 .get(&TypeId::of::<E>())
69 .and_then(|item| item.downcast_ref())
70 }
71
72 pub fn len(&self) -> usize {
74 self.type_map.len()
75 }
76
77 pub fn is_empty(&self) -> bool {
79 self.type_map.is_empty()
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use std::sync::Mutex;
87
88 #[test]
89 fn insert_get_string() {
90 let mut context = Context::new();
91 context.insert_or_replace("pollo".to_string());
92 assert_eq!(Some(&"pollo".to_string()), context.get());
93 }
94
95 #[test]
96 fn insert_get_custom_structs() {
97 #[derive(Debug, PartialEq, Eq)]
98 struct S1 {}
99 #[derive(Debug, PartialEq, Eq)]
100 struct S2 {}
101
102 let mut context = Context::new();
103 context.insert_or_replace(S1 {});
104 context.insert_or_replace(S2 {});
105
106 assert_eq!(Some(Arc::new(S1 {})), context.insert_or_replace(S1 {}));
107 assert_eq!(Some(Arc::new(S2 {})), context.insert_or_replace(S2 {}));
108
109 assert_eq!(Some(&S1 {}), context.get());
110 assert_eq!(Some(&S2 {}), context.get());
111 }
112
113 #[test]
114 fn insert_fluent_syntax() {
115 #[derive(Debug, PartialEq, Eq, Default)]
116 struct S1 {}
117 #[derive(Debug, PartialEq, Eq, Default)]
118 struct S2 {}
119
120 let mut context = Context::new();
121
122 context
123 .insert("static str")
124 .insert("a String".to_string())
125 .insert(S1::default())
126 .insert(S1::default()) .insert(S2::default());
128
129 assert_eq!(4, context.len());
130 assert_eq!(Some(&"static str"), context.get());
131 }
132
133 fn require_send_sync<T: Send + Sync>(_: &T) {}
134
135 #[test]
136 fn test_require_send_sync() {
137 require_send_sync(&Context::new());
139 }
140
141 #[test]
142 fn mutability() {
143 #[derive(Debug, PartialEq, Eq, Default)]
144 struct S1 {
145 num: u8,
146 }
147 let mut context = Context::new();
148 context.insert_or_replace(Mutex::new(S1::default()));
149
150 assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
152
153 context.get::<Mutex<S1>>().unwrap().lock().unwrap().num = 42;
155
156 assert_eq!(42, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
158
159 let displaced = context
161 .insert_or_replace(Mutex::new(S1::default()))
162 .unwrap();
163
164 assert_eq!(42, displaced.lock().unwrap().num);
166
167 assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
169
170 context.insert_or_replace(Mutex::new(33u32));
171 *context.get::<Mutex<u32>>().unwrap().lock().unwrap() = 42;
172 assert_eq!(42, *context.get::<Mutex<u32>>().unwrap().lock().unwrap());
173 }
174}