1#![cfg(feature = "script")]
2use sha1_smol::Sha1;
3
4use crate::{
5 cmd::cmd,
6 connection::ConnectionLike,
7 types::{FromRedisValue, RedisResult, ToRedisArgs},
8 Cmd, ErrorKind,
9};
10
11#[derive(Debug, Clone)]
13pub struct Script {
14 code: String,
15 hash: String,
16}
17
18impl Script {
34 pub fn new(code: &str) -> Script {
36 let mut hash = Sha1::new();
37 hash.update(code.as_bytes());
38 Script {
39 code: code.to_string(),
40 hash: hash.digest().to_string(),
41 }
42 }
43
44 pub fn get_hash(&self) -> &str {
46 &self.hash
47 }
48
49 pub(crate) fn load_cmd(&self) -> Cmd {
51 let mut cmd = cmd("SCRIPT");
52 cmd.arg("LOAD").arg(self.code.as_bytes());
53 cmd
54 }
55
56 #[inline]
58 pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
59 let hash: String = self.load_cmd().query(con)?;
60
61 debug_assert_eq!(hash, self.hash);
62
63 Ok(hash)
64 }
65
66 #[inline]
68 #[cfg(feature = "aio")]
69 pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
70 where
71 C: crate::aio::ConnectionLike,
72 {
73 let hash: String = self.load_cmd().query_async(con).await?;
74
75 debug_assert_eq!(hash, self.hash);
76
77 Ok(hash)
78 }
79
80 #[inline]
82 pub fn key<T: ToRedisArgs>(&self, key: T) -> ScriptInvocation<'_> {
83 ScriptInvocation {
84 script: self,
85 args: vec![],
86 keys: key.to_redis_args(),
87 }
88 }
89
90 #[inline]
92 pub fn arg<T: ToRedisArgs>(&self, arg: T) -> ScriptInvocation<'_> {
93 ScriptInvocation {
94 script: self,
95 args: arg.to_redis_args(),
96 keys: vec![],
97 }
98 }
99
100 #[inline]
104 pub fn prepare_invoke(&self) -> ScriptInvocation<'_> {
105 ScriptInvocation {
106 script: self,
107 args: vec![],
108 keys: vec![],
109 }
110 }
111
112 #[inline]
114 pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
115 ScriptInvocation {
116 script: self,
117 args: vec![],
118 keys: vec![],
119 }
120 .invoke(con)
121 }
122
123 #[inline]
125 #[cfg(feature = "aio")]
126 pub async fn invoke_async<C, T>(&self, con: &mut C) -> RedisResult<T>
127 where
128 C: crate::aio::ConnectionLike,
129 T: FromRedisValue,
130 {
131 ScriptInvocation {
132 script: self,
133 args: vec![],
134 keys: vec![],
135 }
136 .invoke_async(con)
137 .await
138 }
139}
140
141pub struct ScriptInvocation<'a> {
143 script: &'a Script,
144 args: Vec<Vec<u8>>,
145 keys: Vec<Vec<u8>>,
146}
147
148impl<'a> ScriptInvocation<'a> {
153 #[inline]
156 pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a>
157 where
158 'a: 'b,
159 {
160 arg.write_redis_args(&mut self.args);
161 self
162 }
163
164 #[inline]
167 pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a>
168 where
169 'a: 'b,
170 {
171 key.write_redis_args(&mut self.keys);
172 self
173 }
174
175 #[inline]
177 pub fn invoke<T: FromRedisValue>(&self, con: &mut dyn ConnectionLike) -> RedisResult<T> {
178 let eval_cmd = self.eval_cmd();
179 match eval_cmd.query(con) {
180 Ok(val) => Ok(val),
181 Err(err) => {
182 if err.kind() == ErrorKind::Server(crate::ServerErrorKind::NoScript) {
183 self.load(con)?;
184 eval_cmd.query(con)
185 } else {
186 Err(err)
187 }
188 }
189 }
190 }
191
192 #[inline]
194 #[cfg(feature = "aio")]
195 pub async fn invoke_async<T: FromRedisValue>(
196 &self,
197 con: &mut impl crate::aio::ConnectionLike,
198 ) -> RedisResult<T> {
199 let eval_cmd = self.eval_cmd();
200 match eval_cmd.query_async(con).await {
201 Ok(val) => {
202 Ok(val)
204 }
205 Err(err) => {
206 if err.kind() == ErrorKind::Server(crate::ServerErrorKind::NoScript) {
208 self.load_async(con).await?;
209 eval_cmd.query_async(con).await
210 } else {
211 Err(err)
212 }
213 }
214 }
215 }
216
217 #[inline]
219 pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult<String> {
220 self.script.load(con)
221 }
222
223 #[inline]
225 #[cfg(feature = "aio")]
226 pub async fn load_async<C>(&self, con: &mut C) -> RedisResult<String>
227 where
228 C: crate::aio::ConnectionLike,
229 {
230 self.script.load_async(con).await
231 }
232
233 fn estimate_buflen(&self) -> usize {
234 self
235 .keys
236 .iter()
237 .chain(self.args.iter())
238 .fold(0, |acc, e| acc + e.len())
239 + 7 + self.script.hash.len()
241 + 4 }
243
244 pub(crate) fn eval_cmd(&self) -> Cmd {
246 let args_len = 3 + self.keys.len() + self.args.len();
247 let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen());
248 cmd.arg("EVALSHA")
249 .arg(self.script.hash.as_bytes())
250 .arg(self.keys.len())
251 .arg(&*self.keys)
252 .arg(&*self.args);
253 cmd
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::Script;
260
261 #[test]
262 fn script_eval_should_work() {
263 let script = Script::new("return KEYS[1]");
264 let invocation = script.key("dummy");
265 let estimated_buflen = invocation.estimate_buflen();
266 let cmd = invocation.eval_cmd();
267 assert!(estimated_buflen >= cmd.capacity().1);
268 let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n";
269 assert_eq!(
270 expected,
271 std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap()
272 );
273 }
274}