1use crate::prelude::*;
4use redis::{AsyncCommands, Client, ToRedisArgs, aio::ConnectionManager};
5use serde::{Serialize, de::DeserializeOwned};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::OnceCell;
9
10pub struct Cache {
13 connection_manager: Arc<OnceCell<ConnectionManager>>,
14 initial_connection_successful: Arc<std::sync::atomic::AtomicBool>,
15}
16
17impl Clone for Cache {
18 fn clone(&self) -> Self {
19 Self {
20 connection_manager: self.connection_manager.clone(),
21 initial_connection_successful: self.initial_connection_successful.clone(),
22 }
23 }
24}
25
26impl std::fmt::Debug for Cache {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("Cache")
29 .field("client", &"<Client>")
30 .field("connection", &"<ConnectionManager>")
31 .field(
32 "initial_connection_successful",
33 &self
34 .initial_connection_successful
35 .load(std::sync::atomic::Ordering::SeqCst),
36 )
37 .finish()
38 }
39}
40
41impl Cache {
42 pub fn new(redis_url: &str) -> UtilResult<Self> {
48 let client = Client::open(redis_url).map_err(|e| {
49 UtilError::new(
50 UtilErrorType::Other,
51 format!("Failed to create Redis client: {e}"),
52 Some(e.into()),
53 )
54 })?;
55
56 let cache = Self {
57 connection_manager: Arc::new(OnceCell::new()),
58 initial_connection_successful: Arc::new(std::sync::atomic::AtomicBool::new(false)),
59 };
60
61 let client_clone = client.clone();
62 let cache_clone = cache.clone();
63
64 tokio::spawn(async move {
65 let mut backoff = Duration::from_secs(1);
66 const MAX_BACKOFF: Duration = Duration::from_secs(6000);
67 const MAX_ATTEMPTS: usize = 1000;
68 let mut attempt = 0;
69
70 while attempt < MAX_ATTEMPTS {
71 attempt += 1;
72 info!("Attempting to establish Redis connection... (attempt {attempt})");
73 let config = redis::aio::ConnectionManagerConfig::new()
74 .set_connection_timeout(Duration::from_secs(5))
75 .set_response_timeout(Duration::from_secs(2))
76 .set_number_of_retries(3);
77
78 match ConnectionManager::new_with_config(client_clone.clone(), config).await {
79 Ok(conn_manager) => {
80 info!("Successfully established Redis connection");
81 match cache_clone.connection_manager.set(conn_manager) {
82 Ok(_) => {
83 info!("Connection manager set successfully");
84 cache_clone
85 .initial_connection_successful
86 .store(true, std::sync::atomic::Ordering::SeqCst);
87 trace!("Connection flag set to true");
88 break;
89 }
90 Err(_) => {
91 warn!("Failed to set connection manager - will retry");
92 continue;
93 }
94 }
95 }
96 Err(e) => {
97 warn!(
98 "Failed to establish Redis connection: {e}. Will retry in {:?}.",
99 backoff
100 );
101 tokio::time::sleep(backoff).await;
102 backoff = std::cmp::min(backoff * 2, MAX_BACKOFF);
103 }
104 }
105 }
106
107 if attempt >= MAX_ATTEMPTS {
108 error!("Failed to establish Redis connection after {MAX_ATTEMPTS} attempts");
109 cache_clone
110 .initial_connection_successful
111 .store(false, std::sync::atomic::Ordering::SeqCst);
112 }
113 });
114
115 Ok(cache)
116 }
117
118 pub async fn get_or_set<V, F, Fut, K>(
124 &self,
125 key: K,
126 expires_in: Duration,
127 f: F,
128 ) -> UtilResult<V>
129 where
130 V: DeserializeOwned + Serialize,
131 F: FnOnce() -> Fut,
132 Fut: std::future::Future<Output = UtilResult<V>>,
133 K: ToRedisArgs + Send + Sync + Clone + std::fmt::Debug,
134 {
135 if let Some(cached) = self.get_json::<V, K>(key.clone()).await {
136 info!("Cache hit for key: {:?}", key);
137 return Ok(cached);
138 }
139
140 info!("Cache miss for key: {:?}", key);
141
142 let start = std::time::Instant::now();
144 let value = f().await?;
145 let duration = start.elapsed();
146 info!("Generated value for key {:?} in {:?}", key, duration);
147
148 self.cache_json(key.clone(), &value, expires_in).await;
149
150 Ok(value)
151 }
152
153 pub async fn cache_json<V, K>(&self, key: K, value: &V, expires_in: Duration) -> bool
157 where
158 V: Serialize,
159 K: ToRedisArgs + Send + Sync,
160 {
161 if !self.initial_connection_successful() {
162 warn!("Skipping cache_json because initial connection not successful");
163 return false;
164 }
165
166 let mut connection = match self.connection_manager.get() {
167 Some(conn) => conn.clone(),
168 None => {
169 warn!("Skipping cache_json because no connection manager");
170 return false;
171 }
172 };
173
174 let value = match serde_json::to_vec(value) {
175 Ok(v) => v,
176 Err(e) => {
177 error!("Failed to serialize value to JSON: {e}");
178 return false;
179 }
180 };
181
182 match connection
183 .set_ex::<_, _, ()>(key, value, expires_in.as_secs())
184 .await
185 {
186 Ok(_) => {
187 debug!("Successfully cached value");
188 true
189 }
190 Err(e) => {
191 error!("Failed to cache value: {e}");
192 false
193 }
194 }
195 }
196
197 pub async fn get_json<V, K>(&self, key: K) -> Option<V>
201 where
202 V: DeserializeOwned,
203 K: ToRedisArgs + Send + Sync,
204 {
205 if !self.initial_connection_successful() {
206 warn!("Skipping get_json because initial connection not successful");
207 return None;
208 }
209
210 let mut connection = match self.connection_manager.get() {
211 Some(conn) => conn.clone(),
212 None => {
213 warn!("Skipping get_json because no connection manager");
214 return None;
215 }
216 };
217
218 match connection.get::<_, Option<Vec<u8>>>(key).await {
219 Ok(Some(bytes)) => match serde_json::from_slice(&bytes) {
220 Ok(value) => {
221 debug!("Successfully retrieved and deserialized value from cache");
222 Some(value)
223 }
224 Err(e) => {
225 error!("Failed to deserialize value from cache: {e}");
226 None
227 }
228 },
229 Ok(None) => {
230 debug!("Key not found in cache");
231 None
232 }
233 Err(e) => {
234 if e.to_string().contains("response was nil") {
235 debug!("Got nil response from Redis");
236 return None;
237 }
238 error!("Failed to get value from cache: {e}");
239 None
240 }
241 }
242 }
243
244 pub async fn invalidate<K>(&self, key: K) -> bool
248 where
249 K: ToRedisArgs + Send + Sync,
250 {
251 if !self.initial_connection_successful() {
252 return false;
253 }
254
255 let mut connection = match self.connection_manager.get() {
256 Some(conn) => conn.clone(),
257 None => return false,
258 };
259
260 match connection.del::<_, i64>(key).await {
261 Ok(1) => true,
262 Ok(0) => false,
263 Ok(_) => true,
264 Err(e) => {
265 error!("Failed to invalidate cache key: {e}");
266 false
267 }
268 }
269 }
270
271 pub fn initial_connection_successful(&self) -> bool {
273 self.initial_connection_successful
274 .load(std::sync::atomic::Ordering::SeqCst)
275 }
276
277 #[cfg(test)]
281 pub async fn wait_for_initial_connection(&self, timeout: Duration) -> bool {
282 let start = std::time::Instant::now();
283 while !self.initial_connection_successful() || self.connection_manager.get().is_none() {
284 if start.elapsed() >= timeout {
285 error!(
286 "Timed out waiting for Redis connection after {:?}. Flag is: {}, Connection manager: {}",
287 timeout,
288 self.initial_connection_successful(),
289 self.connection_manager.get().is_some()
290 );
291 return false;
292 }
293 tokio::time::sleep(Duration::from_millis(100)).await;
294 }
295 true
296 }
297}
298
299#[cfg(test)]
300mod test {
301 use super::*;
302 use serde::Deserialize;
303
304 #[tokio::test]
305 async fn caches() {
306 tracing_subscriber::fmt().init();
307 let redis_url = std::env::var("REDIS_URL")
308 .unwrap_or("redis://redis.default.svc.cluster.local/1".to_string());
309 info!("Redis URL: {redis_url}");
310
311 #[derive(Deserialize, Serialize, Debug, PartialEq)]
312 struct S {
313 field: String,
314 }
315
316 let cache = Cache::new(&redis_url).unwrap();
317 let value = S {
318 field: "value".to_string(),
319 };
320
321 assert!(
323 cache
324 .wait_for_initial_connection(Duration::from_secs(10))
325 .await,
326 "Failed to connect to Redis within timeout"
327 );
328
329 let _cache_result = cache
331 .cache_json("key", &value, Duration::from_secs(10))
332 .await;
333
334 let retrieved = cache.get_json::<S, _>("key").await;
335 assert_eq!(
336 retrieved,
337 Some(S {
338 field: "value".to_string()
339 })
340 );
341
342 let result = cache
344 .get_or_set("test_key", Duration::from_secs(10), || async {
345 Ok(S {
346 field: "computed".to_string(),
347 })
348 })
349 .await
350 .expect("get_or_set failed");
351
352 assert_eq!(
353 result,
354 S {
355 field: "computed".to_string()
356 }
357 );
358
359 assert!(cache.invalidate("key").await);
361 let retrieved = cache.get_json::<S, _>("key").await;
362 assert_eq!(retrieved, None);
363 }
364}