1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//! Redis cache wrapper.

use redis::{AsyncCommands, Client, ToRedisArgs};
use serde::{de::DeserializeOwned, Serialize};
use std::time::Duration;

/// Wrapper for accessing a redis cache.
pub struct Cache {
    client: Client,
}

impl Cache {
    /// # Panics
    /// If the URL is malformed.
    pub async fn new(redis_url: &str) -> Self {
        let client =
            Client::open(redis_url).unwrap_or_else(|_| panic!("Malformed url: {redis_url}"));
        Self { client }
    }

    /// Stores the given value in the redis cache as JSON (`Vec<u8>`).
    pub async fn cache_json<V>(
        &self,
        key: impl ToRedisArgs + Send + Sync,
        value: &V,
        expires_in: Duration,
    ) -> bool
    where
        V: Serialize,
    {
        match self.client.get_async_connection().await {
            Ok(mut conn) => {
                let Ok(value) = serde_json::to_vec(value) else {
                    return false;
                };
                match conn
                    .set_ex::<_, _, ()>(key, value, expires_in.as_secs())
                    .await
                {
                    Ok(_) => true,
                    Err(err) => {
                        error!("Error caching json: {err:#}");
                        false
                    }
                }
            }
            Err(err) => {
                error!("Error connecting to redis: {err:#}");
                false
            }
        }
    }

    /// Retrieves and deserializes the corresponding value for the key stored with `cache_json`.
    pub async fn get_json<V>(&self, key: impl ToRedisArgs + Send + Sync) -> Option<V>
    where
        V: DeserializeOwned,
    {
        match self.client.get_async_connection().await {
            Ok(mut conn) => match conn.get::<_, Vec<u8>>(key).await {
                Ok(bytes) => {
                    let value = serde_json::from_slice(bytes.as_slice()).ok()?;
                    Some(value)
                }
                Err(err) => {
                    error!("Error fetching json from cache: {err:#}");
                    None
                }
            },
            Err(err) => {
                error!("Error connecting to redis: {err:#}");
                None
            }
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use serde::Deserialize;

    #[tokio::test]
    async fn caches() {
        tracing_subscriber::fmt().init();
        let redis_url = std::env::var("REDIS_URL")
            .unwrap_or("redis://redis.default.svc.cluster.local/1".to_string());
        info!("Redis URL: {redis_url}");

        #[derive(Deserialize, Serialize)]
        struct S {
            field: String,
        }

        let cache = Cache::new(&redis_url).await;
        let value = S {
            field: "value".to_string(),
        };
        assert!(
            cache
                .cache_json("key", &value, Duration::from_secs(10))
                .await
        );
        let value = cache.get_json::<S>("key").await.unwrap();
        assert_eq!(value.field, "value")
    }
}