redis/aio/
mod.rs

1//! Adds async IO support to redis.
2use crate::cmd::Cmd;
3use crate::connection::{
4    check_connection_setup, connection_setup_pipeline, AuthResult, ConnectionSetupComponents,
5    RedisConnectionInfo,
6};
7use crate::io::AsyncDNSResolver;
8use crate::types::{RedisFuture, RedisResult, Value};
9use crate::{errors::closed_connection_error, ErrorKind, PushInfo, RedisError};
10use ::tokio::io::{AsyncRead, AsyncWrite};
11use futures_util::{
12    future::{Future, FutureExt},
13    sink::{Sink, SinkExt},
14    stream::{Stream, StreamExt},
15};
16pub use monitor::Monitor;
17use std::net::SocketAddr;
18#[cfg(unix)]
19use std::path::Path;
20use std::pin::Pin;
21
22mod monitor;
23
24#[cfg(any(feature = "tls-rustls", feature = "tls-native-tls"))]
25use crate::connection::TlsConnParams;
26
27/// Enables the smol compatibility
28#[cfg(feature = "smol-comp")]
29#[cfg_attr(docsrs, doc(cfg(feature = "smol-comp")))]
30pub mod smol;
31/// Enables the tokio compatibility
32#[cfg(feature = "tokio-comp")]
33#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))]
34pub mod tokio;
35
36mod pubsub;
37pub use pubsub::{PubSub, PubSubSink, PubSubStream};
38
39/// Represents the ability of connecting via TCP or via Unix socket
40pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static {
41    /// Performs a TCP connection
42    async fn connect_tcp(
43        socket_addr: SocketAddr,
44        tcp_settings: &crate::io::tcp::TcpSettings,
45    ) -> RedisResult<Self>;
46
47    // Performs a TCP TLS connection
48    #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
49    async fn connect_tcp_tls(
50        hostname: &str,
51        socket_addr: SocketAddr,
52        insecure: bool,
53        tls_params: &Option<TlsConnParams>,
54        tcp_settings: &crate::io::tcp::TcpSettings,
55    ) -> RedisResult<Self>;
56
57    /// Performs a UNIX connection
58    #[cfg(unix)]
59    async fn connect_unix(path: &Path) -> RedisResult<Self>;
60
61    fn spawn(f: impl Future<Output = ()> + Send + 'static) -> TaskHandle;
62
63    fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
64        Box::pin(self)
65    }
66}
67
68/// Trait for objects that implements `AsyncRead` and `AsyncWrite`
69pub trait AsyncStream: AsyncRead + AsyncWrite {}
70impl<S> AsyncStream for S where S: AsyncRead + AsyncWrite {}
71
72/// An async abstraction over connections.
73pub trait ConnectionLike {
74    /// Sends an already encoded (packed) command into the TCP socket and
75    /// reads the single response from it.
76    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>;
77
78    /// Sends multiple already encoded (packed) command into the TCP socket
79    /// and reads `count` responses from it.  This is used to implement
80    /// pipelining.
81    /// Important - this function is meant for internal usage, since it's
82    /// easy to pass incorrect `offset` & `count` parameters, which might
83    /// cause the connection to enter an erroneous state. Users shouldn't
84    /// call it, instead using the Pipeline::query_async function.
85    #[doc(hidden)]
86    fn req_packed_commands<'a>(
87        &'a mut self,
88        cmd: &'a crate::Pipeline,
89        offset: usize,
90        count: usize,
91    ) -> RedisFuture<'a, Vec<Value>>;
92
93    /// Returns the database this connection is bound to.  Note that this
94    /// information might be unreliable because it's initially cached and
95    /// also might be incorrect if the connection like object is not
96    /// actually connected.
97    fn get_db(&self) -> i64;
98}
99
100async fn execute_connection_pipeline<T>(
101    codec: &mut T,
102    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
103) -> RedisResult<AuthResult>
104where
105    T: Sink<Vec<u8>, Error = RedisError>,
106    T: Stream<Item = RedisResult<Value>>,
107    T: Unpin + Send + 'static,
108{
109    let count = pipeline.len();
110    if count == 0 {
111        return Ok(AuthResult::Succeeded);
112    }
113    codec.send(pipeline.get_packed_pipeline()).await?;
114
115    let mut results = Vec::with_capacity(count);
116    for _ in 0..count {
117        let value = codec.next().await.ok_or_else(closed_connection_error)??;
118        results.push(value);
119    }
120
121    check_connection_setup(results, instructions)
122}
123
124pub(super) async fn setup_connection<T>(
125    codec: &mut T,
126    connection_info: &RedisConnectionInfo,
127    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
128) -> RedisResult<()>
129where
130    T: Sink<Vec<u8>, Error = RedisError>,
131    T: Stream<Item = RedisResult<Value>>,
132    T: Unpin + Send + 'static,
133{
134    if execute_connection_pipeline(
135        codec,
136        connection_setup_pipeline(
137            connection_info,
138            true,
139            #[cfg(feature = "cache-aio")]
140            cache_config,
141        ),
142    )
143    .await?
144        == AuthResult::ShouldRetryWithoutUsername
145    {
146        execute_connection_pipeline(
147            codec,
148            connection_setup_pipeline(
149                connection_info,
150                false,
151                #[cfg(feature = "cache-aio")]
152                cache_config,
153            ),
154        )
155        .await?;
156    }
157
158    Ok(())
159}
160
161mod connection;
162pub(crate) use connection::connect_simple;
163mod multiplexed_connection;
164pub use multiplexed_connection::*;
165#[cfg(feature = "connection-manager")]
166mod connection_manager;
167#[cfg(feature = "connection-manager")]
168#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))]
169pub use connection_manager::*;
170mod runtime;
171#[cfg(all(feature = "smol-comp", feature = "tokio-comp"))]
172pub use runtime::prefer_smol;
173#[cfg(all(feature = "tokio-comp", feature = "smol-comp"))]
174pub use runtime::prefer_tokio;
175pub(super) use runtime::*;
176
177/// An error showing that the receiver
178pub struct SendError;
179
180/// A trait for sender parts of a channel that can be used for sending push messages from async
181/// connection.
182pub trait AsyncPushSender: Send + Sync + 'static {
183    /// The sender must send without blocking, otherwise it will block the sending connection.
184    fn send(&self, info: PushInfo) -> Result<(), SendError>;
185}
186
187impl AsyncPushSender for ::tokio::sync::mpsc::UnboundedSender<PushInfo> {
188    fn send(&self, info: PushInfo) -> Result<(), SendError> {
189        match self.send(info) {
190            Ok(_) => Ok(()),
191            Err(_) => Err(SendError),
192        }
193    }
194}
195
196impl AsyncPushSender for ::tokio::sync::broadcast::Sender<PushInfo> {
197    fn send(&self, info: PushInfo) -> Result<(), SendError> {
198        match self.send(info) {
199            Ok(_) => Ok(()),
200            Err(_) => Err(SendError),
201        }
202    }
203}
204
205impl<T, Func: Fn(PushInfo) -> Result<(), T> + Send + Sync + 'static> AsyncPushSender for Func {
206    fn send(&self, info: PushInfo) -> Result<(), SendError> {
207        match self(info) {
208            Ok(_) => Ok(()),
209            Err(_) => Err(SendError),
210        }
211    }
212}
213
214impl AsyncPushSender for std::sync::mpsc::Sender<PushInfo> {
215    fn send(&self, info: PushInfo) -> Result<(), SendError> {
216        match self.send(info) {
217            Ok(_) => Ok(()),
218            Err(_) => Err(SendError),
219        }
220    }
221}
222
223impl<T> AsyncPushSender for std::sync::Arc<T>
224where
225    T: AsyncPushSender,
226{
227    fn send(&self, info: PushInfo) -> Result<(), SendError> {
228        self.as_ref().send(info)
229    }
230}
231
232/// Default DNS resolver which uses the system's DNS resolver.
233#[derive(Clone)]
234pub(crate) struct DefaultAsyncDNSResolver;
235
236impl AsyncDNSResolver for DefaultAsyncDNSResolver {
237    fn resolve<'a, 'b: 'a>(
238        &'a self,
239        host: &'b str,
240        port: u16,
241    ) -> RedisFuture<'a, Box<dyn Iterator<Item = SocketAddr> + Send + 'a>> {
242        Box::pin(get_socket_addrs(host, port).map(|vec| {
243            Ok(Box::new(vec?.into_iter()) as Box<dyn Iterator<Item = SocketAddr> + Send>)
244        }))
245    }
246}
247
248async fn get_socket_addrs(host: &str, port: u16) -> RedisResult<Vec<SocketAddr>> {
249    let socket_addrs: Vec<_> = match Runtime::locate() {
250        #[cfg(feature = "tokio-comp")]
251        Runtime::Tokio => ::tokio::net::lookup_host((host, port))
252            .await
253            .map_err(RedisError::from)
254            .map(|iter| iter.collect()),
255
256        #[cfg(feature = "smol-comp")]
257        Runtime::Smol => ::smol::net::resolve((host, port))
258            .await
259            .map_err(RedisError::from),
260    }?;
261
262    if socket_addrs.is_empty() {
263        Err(RedisError::from((
264            ErrorKind::InvalidClientConfig,
265            "No address found for host",
266        )))
267    } else {
268        Ok(socket_addrs)
269    }
270}