redis/
script.rs

1#![cfg(feature = "script")]
2use sha1_smol::Sha1;
3
4use crate::{
5    cmd::cmd,
6    connection::ConnectionLike,
7    types::{FromRedisValue, RedisResult, ToRedisArgs},
8    Cmd, ErrorKind,
9};
10
11/// Represents a lua script.
12#[derive(Debug, Clone)]
13pub struct Script {
14    code: String,
15    hash: String,
16}
17
18/// The script object represents a lua script that can be executed on the
19/// redis server.  The object itself takes care of automatic uploading and
20/// execution.  The script object itself can be shared and is immutable.
21///
22/// Example:
23///
24/// ```rust,no_run
25/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
26/// # let mut con = client.get_connection().unwrap();
27/// let script = redis::Script::new(r"
28///     return tonumber(ARGV[1]) + tonumber(ARGV[2]);
29/// ");
30/// let result = script.arg(1).arg(2).invoke(&mut con);
31/// assert_eq!(result, Ok(3));
32/// ```
33impl Script {
34    /// Creates a new script object.
35    pub fn new(code: &str) -> Script {
36        let mut hash = Sha1::new();
37        hash.update(code.as_bytes());
38        Script {
39            code: code.to_string(),
40            hash: hash.digest().to_string(),
41        }
42    }
43
44    /// Returns the script's SHA1 hash in hexadecimal format.
45    pub fn get_hash(&self) -> &str {
46        &self.hash
47    }
48
49    /// Returns a command to load the script.
50    pub(crate) fn load_cmd(&self) -> Cmd {
51        let mut cmd = cmd("SCRIPT");
52        cmd.arg("LOAD").arg(self.code.as_bytes());
53        cmd
54    }
55
56    /// Loads the script and returns the SHA1 of it.
57    #[inline]
58    pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
59        let hash: String = self.load_cmd().query(con)?;
60
61        debug_assert_eq!(hash, self.hash);
62
63        Ok(hash)
64    }
65
66    /// Asynchronously loads the script and returns the SHA1 of it.
67    #[inline]
68    #[cfg(feature = "aio")]
69    pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
70    where
71        C: crate::aio::ConnectionLike,
72    {
73        let hash: String = self.load_cmd().query_async(con).await?;
74
75        debug_assert_eq!(hash, self.hash);
76
77        Ok(hash)
78    }
79
80    /// Creates a script invocation object with a key filled in.
81    #[inline]
82    pub fn key<T: ToRedisArgs>(&self, key: T) -> ScriptInvocation<'_> {
83        ScriptInvocation {
84            script: self,
85            args: vec![],
86            keys: key.to_redis_args(),
87        }
88    }
89
90    /// Creates a script invocation object with an argument filled in.
91    #[inline]
92    pub fn arg<T: ToRedisArgs>(&self, arg: T) -> ScriptInvocation<'_> {
93        ScriptInvocation {
94            script: self,
95            args: arg.to_redis_args(),
96            keys: vec![],
97        }
98    }
99
100    /// Returns an empty script invocation object.  This is primarily useful
101    /// for programmatically adding arguments and keys because the type will
102    /// not change.  Normally you can use `arg` and `key` directly.
103    #[inline]
104    pub fn prepare_invoke(&self) -> ScriptInvocation<'_> {
105        ScriptInvocation {
106            script: self,
107            args: vec![],
108            keys: vec![],
109        }
110    }
111
112    /// Invokes the script directly without arguments.
113    #[inline]
114    pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
115        ScriptInvocation {
116            script: self,
117            args: vec![],
118            keys: vec![],
119        }
120        .invoke(con)
121    }
122
123    /// Asynchronously invokes the script without arguments.
124    #[inline]
125    #[cfg(feature = "aio")]
126    pub async fn invoke_async<C, T>(&self, con: &mut C) -> RedisResult<T>
127    where
128        C: crate::aio::ConnectionLike,
129        T: FromRedisValue,
130    {
131        ScriptInvocation {
132            script: self,
133            args: vec![],
134            keys: vec![],
135        }
136        .invoke_async(con)
137        .await
138    }
139}
140
141/// Represents a prepared script call.
142pub struct ScriptInvocation<'a> {
143    script: &'a Script,
144    args: Vec<Vec<u8>>,
145    keys: Vec<Vec<u8>>,
146}
147
148/// This type collects keys and other arguments for the script so that it
149/// can be then invoked.  While the `Script` type itself holds the script,
150/// the `ScriptInvocation` holds the arguments that should be invoked until
151/// it's sent to the server.
152impl<'a> ScriptInvocation<'a> {
153    /// Adds a regular argument to the invocation.  This ends up as `ARGV[i]`
154    /// in the script.
155    #[inline]
156    pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a>
157    where
158        'a: 'b,
159    {
160        arg.write_redis_args(&mut self.args);
161        self
162    }
163
164    /// Adds a key argument to the invocation.  This ends up as `KEYS[i]`
165    /// in the script.
166    #[inline]
167    pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a>
168    where
169        'a: 'b,
170    {
171        key.write_redis_args(&mut self.keys);
172        self
173    }
174
175    /// Invokes the script and returns the result.
176    #[inline]
177    pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
178        let eval_cmd = self.eval_cmd();
179        match eval_cmd.query(con) {
180            Ok(val) => Ok(val),
181            Err(err) => {
182                if err.kind() == ErrorKind::Server(crate::ServerErrorKind::NoScript) {
183                    self.load(con)?;
184                    eval_cmd.query(con)
185                } else {
186                    Err(err)
187                }
188            }
189        }
190    }
191
192    /// Asynchronously invokes the script and returns the result.
193    #[inline]
194    #[cfg(feature = "aio")]
195    pub async fn invoke_async<T: FromRedisValue>(
196        &self,
197        con: &mut impl crate::aio::ConnectionLike,
198    ) -> RedisResult<T> {
199        let eval_cmd = self.eval_cmd();
200        match eval_cmd.query_async(con).await {
201            Ok(val) => {
202                // Return the value from the script evaluation
203                Ok(val)
204            }
205            Err(err) => {
206                // Load the script into Redis if the script hash wasn't there already
207                if err.kind() == ErrorKind::Server(crate::ServerErrorKind::NoScript) {
208                    self.load_async(con).await?;
209                    eval_cmd.query_async(con).await
210                } else {
211                    Err(err)
212                }
213            }
214        }
215    }
216
217    /// Loads the script and returns the SHA1 of it.
218    #[inline]
219    pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
220        self.script.load(con)
221    }
222
223    /// Asynchronously loads the script and returns the SHA1 of it.
224    #[inline]
225    #[cfg(feature = "aio")]
226    pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
227    where
228        C: crate::aio::ConnectionLike,
229    {
230        self.script.load_async(con).await
231    }
232
233    fn estimate_buflen(&self) -> usize {
234        self
235            .keys
236            .iter()
237            .chain(self.args.iter())
238            .fold(0, |acc, e| acc + e.len())
239            + 7 /* "EVALSHA".len() */
240            + self.script.hash.len()
241            + 4 /* Slots reserved for the length of keys. */
242    }
243
244    /// Returns a command to evaluate the script.
245    pub(crate) fn eval_cmd(&self) -> Cmd {
246        let args_len = 3 + self.keys.len() + self.args.len();
247        let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen());
248        cmd.arg("EVALSHA")
249            .arg(self.script.hash.as_bytes())
250            .arg(self.keys.len())
251            .arg(&*self.keys)
252            .arg(&*self.args);
253        cmd
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::Script;
260
261    #[test]
262    fn script_eval_should_work() {
263        let script = Script::new("return KEYS[1]");
264        let invocation = script.key("dummy");
265        let estimated_buflen = invocation.estimate_buflen();
266        let cmd = invocation.eval_cmd();
267        assert!(estimated_buflen >= cmd.capacity().1);
268        let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n";
269        assert_eq!(
270            expected,
271            std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap()
272        );
273    }
274}