redis/aio/
mod.rs

1//! Adds async IO support to redis.
2use crate::cmd::Cmd;
3use crate::connection::{
4    check_connection_setup, connection_setup_pipeline, AuthResult, ConnectionSetupComponents,
5    RedisConnectionInfo,
6};
7use crate::io::AsyncDNSResolver;
8use crate::types::{closed_connection_error, RedisFuture, RedisResult, Value};
9use crate::{ErrorKind, PushInfo, RedisError};
10use ::tokio::io::{AsyncRead, AsyncWrite};
11use futures_util::{
12    future::{Future, FutureExt},
13    sink::{Sink, SinkExt},
14    stream::{Stream, StreamExt},
15};
16pub use monitor::Monitor;
17use std::net::SocketAddr;
18#[cfg(unix)]
19use std::path::Path;
20use std::pin::Pin;
21
22mod monitor;
23
24/// Enables the async_std compatibility
25#[cfg(feature = "async-std-comp")]
26#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))]
27pub mod async_std;
28
29#[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
30use crate::connection::TlsConnParams;
31
32/// Enables the smol compatibility
33#[cfg(feature = "smol-comp")]
34#[cfg_attr(docsrs, doc(cfg(feature = "smol-comp")))]
35pub mod smol;
36/// Enables the tokio compatibility
37#[cfg(feature = "tokio-comp")]
38#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))]
39pub mod tokio;
40
41mod pubsub;
42pub use pubsub::{PubSub, PubSubSink, PubSubStream};
43
44/// Represents the ability of connecting via TCP or via Unix socket
45pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static {
46    /// Performs a TCP connection
47    async fn connect_tcp(
48        socket_addr: SocketAddr,
49        tcp_settings: &crate::io::tcp::TcpSettings,
50    ) -> RedisResult<Self>;
51
52    // Performs a TCP TLS connection
53    #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
54    async fn connect_tcp_tls(
55        hostname: &str,
56        socket_addr: SocketAddr,
57        insecure: bool,
58        tls_params: &Option<TlsConnParams>,
59        tcp_settings: &crate::io::tcp::TcpSettings,
60    ) -> RedisResult<Self>;
61
62    /// Performs a UNIX connection
63    #[cfg(unix)]
64    async fn connect_unix(path: &Path) -> RedisResult<Self>;
65
66    fn spawn(f: impl Future<Output = ()> + Send + 'static) -> TaskHandle;
67
68    fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
69        Box::pin(self)
70    }
71}
72
73/// Trait for objects that implements `AsyncRead` and `AsyncWrite`
74pub trait AsyncStream: AsyncRead + AsyncWrite {}
75impl<S> AsyncStream for S where S: AsyncRead + AsyncWrite {}
76
77/// An async abstraction over connections.
78pub trait ConnectionLike {
79    /// Sends an already encoded (packed) command into the TCP socket and
80    /// reads the single response from it.
81    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>;
82
83    /// Sends multiple already encoded (packed) command into the TCP socket
84    /// and reads `count` responses from it.  This is used to implement
85    /// pipelining.
86    /// Important - this function is meant for internal usage, since it's
87    /// easy to pass incorrect `offset` & `count` parameters, which might
88    /// cause the connection to enter an erroneous state. Users shouldn't
89    /// call it, instead using the Pipeline::query_async function.
90    #[doc(hidden)]
91    fn req_packed_commands<'a>(
92        &'a mut self,
93        cmd: &'a crate::Pipeline,
94        offset: usize,
95        count: usize,
96    ) -> RedisFuture<'a, Vec<Value>>;
97
98    /// Returns the database this connection is bound to.  Note that this
99    /// information might be unreliable because it's initially cached and
100    /// also might be incorrect if the connection like object is not
101    /// actually connected.
102    fn get_db(&self) -> i64;
103}
104
105async fn execute_connection_pipeline<T>(
106    codec: &mut T,
107    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
108) -> RedisResult<AuthResult>
109where
110    T: Sink<Vec<u8>, Error = RedisError>,
111    T: Stream<Item = RedisResult<Value>>,
112    T: Unpin + Send + 'static,
113{
114    let count = pipeline.len();
115    if count == 0 {
116        return Ok(AuthResult::Succeeded);
117    }
118    codec.send(pipeline.get_packed_pipeline()).await?;
119
120    let mut results = Vec::with_capacity(count);
121    for _ in 0..count {
122        let value = codec.next().await.ok_or_else(closed_connection_error)??;
123        results.push(value);
124    }
125
126    check_connection_setup(results, instructions)
127}
128
129pub(super) async fn setup_connection<T>(
130    codec: &mut T,
131    connection_info: &RedisConnectionInfo,
132    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
133) -> RedisResult<()>
134where
135    T: Sink<Vec<u8>, Error = RedisError>,
136    T: Stream<Item = RedisResult<Value>>,
137    T: Unpin + Send + 'static,
138{
139    if execute_connection_pipeline(
140        codec,
141        connection_setup_pipeline(
142            connection_info,
143            true,
144            #[cfg(feature = "cache-aio")]
145            cache_config,
146        ),
147    )
148    .await?
149        == AuthResult::ShouldRetryWithoutUsername
150    {
151        execute_connection_pipeline(
152            codec,
153            connection_setup_pipeline(
154                connection_info,
155                false,
156                #[cfg(feature = "cache-aio")]
157                cache_config,
158            ),
159        )
160        .await?;
161    }
162
163    Ok(())
164}
165
166mod connection;
167pub(crate) use connection::connect_simple;
168mod multiplexed_connection;
169pub use multiplexed_connection::*;
170#[cfg(feature = "connection-manager")]
171mod connection_manager;
172#[cfg(feature = "connection-manager")]
173#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))]
174pub use connection_manager::*;
175mod runtime;
176#[cfg(all(
177    feature = "async-std-comp",
178    any(feature = "smol-comp", feature = "tokio-comp")
179))]
180pub use runtime::prefer_async_std;
181#[cfg(all(
182    feature = "smol-comp",
183    any(feature = "async-std-comp", feature = "tokio-comp")
184))]
185pub use runtime::prefer_smol;
186#[cfg(all(
187    feature = "tokio-comp",
188    any(feature = "async-std-comp", feature = "smol-comp")
189))]
190pub use runtime::prefer_tokio;
191pub(super) use runtime::*;
192
193macro_rules! check_resp3 {
194    ($protocol: expr) => {
195        use crate::types::ProtocolVersion;
196        if $protocol == ProtocolVersion::RESP2 {
197            return Err(RedisError::from((
198                crate::ErrorKind::InvalidClientConfig,
199                "RESP3 is required for this command",
200            )));
201        }
202    };
203
204    ($protocol: expr, $message: expr) => {
205        use crate::types::ProtocolVersion;
206        if $protocol == ProtocolVersion::RESP2 {
207            return Err(RedisError::from((
208                crate::ErrorKind::InvalidClientConfig,
209                $message,
210            )));
211        }
212    };
213}
214
215pub(crate) use check_resp3;
216
217/// An error showing that the receiver
218pub struct SendError;
219
220/// A trait for sender parts of a channel that can be used for sending push messages from async
221/// connection.
222pub trait AsyncPushSender: Send + Sync + 'static {
223    /// The sender must send without blocking, otherwise it will block the sending connection.
224    fn send(&self, info: PushInfo) -> Result<(), SendError>;
225}
226
227impl AsyncPushSender for ::tokio::sync::mpsc::UnboundedSender<PushInfo> {
228    fn send(&self, info: PushInfo) -> Result<(), SendError> {
229        match self.send(info) {
230            Ok(_) => Ok(()),
231            Err(_) => Err(SendError),
232        }
233    }
234}
235
236impl AsyncPushSender for ::tokio::sync::broadcast::Sender<PushInfo> {
237    fn send(&self, info: PushInfo) -> Result<(), SendError> {
238        match self.send(info) {
239            Ok(_) => Ok(()),
240            Err(_) => Err(SendError),
241        }
242    }
243}
244
245impl<T, Func: Fn(PushInfo) -> Result<(), T> + Send + Sync + 'static> AsyncPushSender for Func {
246    fn send(&self, info: PushInfo) -> Result<(), SendError> {
247        match self(info) {
248            Ok(_) => Ok(()),
249            Err(_) => Err(SendError),
250        }
251    }
252}
253
254impl AsyncPushSender for std::sync::mpsc::Sender<PushInfo> {
255    fn send(&self, info: PushInfo) -> Result<(), SendError> {
256        match self.send(info) {
257            Ok(_) => Ok(()),
258            Err(_) => Err(SendError),
259        }
260    }
261}
262
263impl<T> AsyncPushSender for std::sync::Arc<T>
264where
265    T: AsyncPushSender,
266{
267    fn send(&self, info: PushInfo) -> Result<(), SendError> {
268        self.as_ref().send(info)
269    }
270}
271
272/// Default DNS resolver which uses the system's DNS resolver.
273#[derive(Clone)]
274pub(crate) struct DefaultAsyncDNSResolver;
275
276impl AsyncDNSResolver for DefaultAsyncDNSResolver {
277    fn resolve<'a, 'b: 'a>(
278        &'a self,
279        host: &'b str,
280        port: u16,
281    ) -> RedisFuture<'a, Box<dyn Iterator<Item = SocketAddr> + Send + 'a>> {
282        Box::pin(get_socket_addrs(host, port).map(|vec| {
283            Ok(Box::new(vec?.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
284        }))
285    }
286}
287
288async fn get_socket_addrs(host: &str, port: u16) -> RedisResult<Vec<SocketAddr>> {
289    let socket_addrs: Vec<_> = match Runtime::locate() {
290        #[cfg(feature = "tokio-comp")]
291        Runtime::Tokio => ::tokio::net::lookup_host((host, port))
292            .await
293            .map_err(RedisError::from)
294            .map(|iter| iter.collect()),
295        #[cfg(feature = "async-std-comp")]
296        Runtime::AsyncStd => Ok::<_, RedisError>(
297            ::async_std::net::ToSocketAddrs::to_socket_addrs(&(host, port))
298                .await
299                .map(|iter| iter.collect())?,
300        ),
301        #[cfg(feature = "smol-comp")]
302        Runtime::Smol => ::smol::net::resolve((host, port))
303            .await
304            .map_err(RedisError::from),
305    }?;
306
307    if socket_addrs.is_empty() {
308        Err(RedisError::from((
309            ErrorKind::InvalidClientConfig,
310            "No address found for host",
311        )))
312    } else {
313        Ok(socket_addrs)
314    }
315}