headless_lms_utils/
cache.rs

1//! Redis cache wrapper.
2
3use crate::prelude::*;
4use redis::{AsyncCommands, Client, ToRedisArgs, aio::ConnectionManager};
5use serde::{Serialize, de::DeserializeOwned};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::OnceCell;
9
10/// Wrapper for accessing a redis cache.
11/// Operations are non-blocking and fail gracefully when Redis is unavailable.
12pub struct Cache {
13    connection_manager: Arc<OnceCell<ConnectionManager>>,
14    initial_connection_successful: Arc<std::sync::atomic::AtomicBool>,
15}
16
17impl Clone for Cache {
18    fn clone(&self) -> Self {
19        Self {
20            connection_manager: self.connection_manager.clone(),
21            initial_connection_successful: self.initial_connection_successful.clone(),
22        }
23    }
24}
25
26impl std::fmt::Debug for Cache {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("Cache")
29            .field("client", &"<Client>")
30            .field("connection", &"<ConnectionManager>")
31            .field(
32                "initial_connection_successful",
33                &self
34                    .initial_connection_successful
35                    .load(std::sync::atomic::Ordering::SeqCst),
36            )
37            .finish()
38    }
39}
40
41impl Cache {
42    /// Creates a new Redis cache instance.
43    ///
44    /// This will succeed even if Redis is unavailable.
45    /// Cache operations will be no-ops if Redis cannot be connected to.
46    /// Will retry connecting in the background if initial connection fails.
47    pub fn new(redis_url: &str) -> UtilResult<Self> {
48        let client = Client::open(redis_url).map_err(|e| {
49            UtilError::new(
50                UtilErrorType::Other,
51                format!("Failed to create Redis client: {e}"),
52                Some(e.into()),
53            )
54        })?;
55
56        let cache = Self {
57            connection_manager: Arc::new(OnceCell::new()),
58            initial_connection_successful: Arc::new(std::sync::atomic::AtomicBool::new(false)),
59        };
60
61        let client_clone = client.clone();
62        let cache_clone = cache.clone();
63
64        tokio::spawn(async move {
65            let mut backoff = Duration::from_secs(1);
66            const MAX_BACKOFF: Duration = Duration::from_secs(6000);
67            const MAX_ATTEMPTS: usize = 1000;
68            let mut attempt = 0;
69
70            while attempt < MAX_ATTEMPTS {
71                attempt += 1;
72                info!("Attempting to establish Redis connection... (attempt {attempt})");
73                let config = redis::aio::ConnectionManagerConfig::new()
74                    .set_connection_timeout(Duration::from_secs(5))
75                    .set_response_timeout(Duration::from_secs(2))
76                    .set_number_of_retries(3);
77
78                match ConnectionManager::new_with_config(client_clone.clone(), config).await {
79                    Ok(conn_manager) => {
80                        info!("Successfully established Redis connection");
81                        match cache_clone.connection_manager.set(conn_manager) {
82                            Ok(_) => {
83                                info!("Connection manager set successfully");
84                                cache_clone
85                                    .initial_connection_successful
86                                    .store(true, std::sync::atomic::Ordering::SeqCst);
87                                trace!("Connection flag set to true");
88                                break;
89                            }
90                            Err(_) => {
91                                warn!("Failed to set connection manager - will retry");
92                                continue;
93                            }
94                        }
95                    }
96                    Err(e) => {
97                        warn!(
98                            "Failed to establish Redis connection: {e}. Will retry in {:?}.",
99                            backoff
100                        );
101                        tokio::time::sleep(backoff).await;
102                        backoff = std::cmp::min(backoff * 2, MAX_BACKOFF);
103                    }
104                }
105            }
106
107            if attempt >= MAX_ATTEMPTS {
108                error!("Failed to establish Redis connection after {MAX_ATTEMPTS} attempts");
109                cache_clone
110                    .initial_connection_successful
111                    .store(false, std::sync::atomic::Ordering::SeqCst);
112            }
113        });
114
115        Ok(cache)
116    }
117
118    /// Retrieves a value from cache, or executes the provided function to generate and cache the value.
119    ///
120    /// First checks if the key exists in the cache. If it does, returns the cached value.
121    /// If not, executes the provided async function, caches its result, and returns it.
122    /// Operations are non-blocking - if Redis is unavailable, the function is executed immediately.
123    pub async fn get_or_set<V, F, Fut, K>(
124        &self,
125        key: K,
126        expires_in: Duration,
127        f: F,
128    ) -> UtilResult<V>
129    where
130        V: DeserializeOwned + Serialize,
131        F: FnOnce() -> Fut,
132        Fut: std::future::Future<Output = UtilResult<V>>,
133        K: ToRedisArgs + Send + Sync + Clone + std::fmt::Debug,
134    {
135        if let Some(cached) = self.get_json::<V, K>(key.clone()).await {
136            info!("Cache hit for key: {:?}", key);
137            return Ok(cached);
138        }
139
140        info!("Cache miss for key: {:?}", key);
141
142        // If not in cache or connection not available, execute the function
143        let start = std::time::Instant::now();
144        let value = f().await?;
145        let duration = start.elapsed();
146        info!("Generated value for key {:?} in {:?}", key, duration);
147
148        self.cache_json(key.clone(), &value, expires_in).await;
149
150        Ok(value)
151    }
152
153    /// Stores the given value in the redis cache as JSON.
154    /// If Redis is unavailable, this function silently does nothing.
155    /// This is a non-blocking operation.
156    pub async fn cache_json<V, K>(&self, key: K, value: &V, expires_in: Duration) -> bool
157    where
158        V: Serialize,
159        K: ToRedisArgs + Send + Sync,
160    {
161        if !self.initial_connection_successful() {
162            warn!("Skipping cache_json because initial connection not successful");
163            return false;
164        }
165
166        let mut connection = match self.connection_manager.get() {
167            Some(conn) => conn.clone(),
168            None => {
169                warn!("Skipping cache_json because no connection manager");
170                return false;
171            }
172        };
173
174        let value = match serde_json::to_vec(value) {
175            Ok(v) => v,
176            Err(e) => {
177                error!("Failed to serialize value to JSON: {e}");
178                return false;
179            }
180        };
181
182        match connection
183            .set_ex::<_, _, ()>(key, value, expires_in.as_secs())
184            .await
185        {
186            Ok(_) => {
187                debug!("Successfully cached value");
188                true
189            }
190            Err(e) => {
191                error!("Failed to cache value: {e}");
192                false
193            }
194        }
195    }
196
197    /// Retrieves and deserializes a value from cache.
198    /// If Redis is unavailable, returns None.
199    /// This is a non-blocking operation.
200    pub async fn get_json<V, K>(&self, key: K) -> Option<V>
201    where
202        V: DeserializeOwned,
203        K: ToRedisArgs + Send + Sync,
204    {
205        if !self.initial_connection_successful() {
206            warn!("Skipping get_json because initial connection not successful");
207            return None;
208        }
209
210        let mut connection = match self.connection_manager.get() {
211            Some(conn) => conn.clone(),
212            None => {
213                warn!("Skipping get_json because no connection manager");
214                return None;
215            }
216        };
217
218        match connection.get::<_, Option<Vec<u8>>>(key).await {
219            Ok(Some(bytes)) => match serde_json::from_slice(&bytes) {
220                Ok(value) => {
221                    debug!("Successfully retrieved and deserialized value from cache");
222                    Some(value)
223                }
224                Err(e) => {
225                    error!("Failed to deserialize value from cache: {e}");
226                    None
227                }
228            },
229            Ok(None) => {
230                debug!("Key not found in cache");
231                None
232            }
233            Err(e) => {
234                if e.to_string().contains("response was nil") {
235                    debug!("Got nil response from Redis");
236                    return None;
237                }
238                error!("Failed to get value from cache: {e}");
239                None
240            }
241        }
242    }
243
244    /// Delete a key from the cache.
245    /// Returns true if the key was deleted, false otherwise.
246    /// This is a non-blocking operation.
247    pub async fn invalidate<K>(&self, key: K) -> bool
248    where
249        K: ToRedisArgs + Send + Sync,
250    {
251        if !self.initial_connection_successful() {
252            return false;
253        }
254
255        let mut connection = match self.connection_manager.get() {
256            Some(conn) => conn.clone(),
257            None => return false,
258        };
259
260        match connection.del::<_, i64>(key).await {
261            Ok(1) => true,
262            Ok(0) => false,
263            Ok(_) => true,
264            Err(e) => {
265                error!("Failed to invalidate cache key: {e}");
266                false
267            }
268        }
269    }
270
271    /// Returns whether the initial connection was successful
272    pub fn initial_connection_successful(&self) -> bool {
273        self.initial_connection_successful
274            .load(std::sync::atomic::Ordering::SeqCst)
275    }
276
277    /// Waits for the initial connection to be established, with a timeout.
278    /// Returns true if connected, false if timed out.
279    /// This is primarily intended for testing.
280    #[cfg(test)]
281    pub async fn wait_for_initial_connection(&self, timeout: Duration) -> bool {
282        let start = std::time::Instant::now();
283        while !self.initial_connection_successful() || self.connection_manager.get().is_none() {
284            if start.elapsed() >= timeout {
285                error!(
286                    "Timed out waiting for Redis connection after {:?}. Flag is: {}, Connection manager: {}",
287                    timeout,
288                    self.initial_connection_successful(),
289                    self.connection_manager.get().is_some()
290                );
291                return false;
292            }
293            tokio::time::sleep(Duration::from_millis(100)).await;
294        }
295        true
296    }
297}
298
299#[cfg(test)]
300mod test {
301    use super::*;
302    use serde::Deserialize;
303
304    #[tokio::test]
305    async fn caches() {
306        tracing_subscriber::fmt().init();
307        let redis_url = std::env::var("REDIS_URL")
308            .unwrap_or("redis://redis.default.svc.cluster.local/1".to_string());
309        info!("Redis URL: {redis_url}");
310
311        #[derive(Deserialize, Serialize, Debug, PartialEq)]
312        struct S {
313            field: String,
314        }
315
316        let cache = Cache::new(&redis_url).unwrap();
317        let value = S {
318            field: "value".to_string(),
319        };
320
321        // Wait for connection to be established
322        assert!(
323            cache
324                .wait_for_initial_connection(Duration::from_secs(10))
325                .await,
326            "Failed to connect to Redis within timeout"
327        );
328
329        // Test cache_json and get_json
330        let _cache_result = cache
331            .cache_json("key", &value, Duration::from_secs(10))
332            .await;
333
334        let retrieved = cache.get_json::<S, _>("key").await;
335        assert_eq!(
336            retrieved,
337            Some(S {
338                field: "value".to_string()
339            })
340        );
341
342        // Test get_or_set
343        let result = cache
344            .get_or_set("test_key", Duration::from_secs(10), || async {
345                Ok(S {
346                    field: "computed".to_string(),
347                })
348            })
349            .await
350            .expect("get_or_set failed");
351
352        assert_eq!(
353            result,
354            S {
355                field: "computed".to_string()
356            }
357        );
358
359        // Test invalidate
360        assert!(cache.invalidate("key").await);
361        let retrieved = cache.get_json::<S, _>("key").await;
362        assert_eq!(retrieved, None);
363    }
364}