redis/
connection.rs

1use std::borrow::Cow;
2use std::collections::VecDeque;
3use std::fmt;
4use std::io::{self, Write};
5use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
6use std::ops::DerefMut;
7use std::path::PathBuf;
8use std::str::{from_utf8, FromStr};
9use std::time::{Duration, Instant};
10
11use crate::cmd::{cmd, pipe, Cmd};
12use crate::io::tcp::{stream_with_settings, TcpSettings};
13use crate::parser::Parser;
14use crate::pipeline::Pipeline;
15use crate::types::{
16    from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult,
17    ServerError, ServerErrorKind, SyncPushSender, ToRedisArgs, Value,
18};
19use crate::{from_owned_redis_value, ProtocolVersion};
20
21#[cfg(unix)]
22use std::os::unix::net::UnixStream;
23
24use crate::commands::resp3_hello;
25#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
26use native_tls::{TlsConnector, TlsStream};
27
28#[cfg(feature = "tls-rustls")]
29use rustls::{RootCertStore, StreamOwned};
30#[cfg(feature = "tls-rustls")]
31use std::sync::Arc;
32
33use crate::PushInfo;
34
35#[cfg(all(
36    feature = "tls-rustls",
37    not(feature = "tls-native-tls"),
38    not(feature = "tls-rustls-webpki-roots")
39))]
40use rustls_native_certs::load_native_certs;
41
42#[cfg(feature = "tls-rustls")]
43use crate::tls::ClientTlsParams;
44
45// Non-exhaustive to prevent construction outside this crate
46#[derive(Clone, Debug)]
47#[non_exhaustive]
48pub struct TlsConnParams {
49    #[cfg(feature = "tls-rustls")]
50    pub(crate) client_tls_params: Option<ClientTlsParams>,
51    #[cfg(feature = "tls-rustls")]
52    pub(crate) root_cert_store: Option<RootCertStore>,
53    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
54    pub(crate) danger_accept_invalid_hostnames: bool,
55}
56
57static DEFAULT_PORT: u16 = 6379;
58
59#[inline(always)]
60fn connect_tcp(addr: (&str, u16)) -> io::Result<TcpStream> {
61    let socket = TcpStream::connect(addr)?;
62    stream_with_settings(socket, &TcpSettings::default())
63}
64
65#[inline(always)]
66fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
67    let socket = TcpStream::connect_timeout(addr, timeout)?;
68    stream_with_settings(socket, &TcpSettings::default())
69}
70
71/// This function takes a redis URL string and parses it into a URL
72/// as used by rust-url.
73///
74/// This is necessary as the default parser does not understand how redis URLs function.
75pub fn parse_redis_url(input: &str) -> Option<url::Url> {
76    match url::Url::parse(input) {
77        Ok(result) => match result.scheme() {
78            "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
79                Some(result)
80            }
81            _ => None,
82        },
83        Err(_) => None,
84    }
85}
86
87/// TlsMode indicates use or do not use verification of certification.
88///
89/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more.
90#[derive(Clone, Copy, PartialEq)]
91pub enum TlsMode {
92    /// Secure verify certification.
93    Secure,
94    /// Insecure do not verify certification.
95    Insecure,
96}
97
98/// Defines the connection address.
99///
100/// Not all connection addresses are supported on all platforms.  For instance
101/// to connect to a unix socket you need to run this on an operating system
102/// that supports them.
103#[derive(Clone, Debug)]
104pub enum ConnectionAddr {
105    /// Format for this is `(host, port)`.
106    Tcp(String, u16),
107    /// Format for this is `(host, port)`.
108    TcpTls {
109        /// Hostname
110        host: String,
111        /// Port
112        port: u16,
113        /// Disable hostname verification when connecting.
114        ///
115        /// # Warning
116        ///
117        /// You should think very carefully before you use this method. If hostname
118        /// verification is not used, any valid certificate for any site will be
119        /// trusted for use from any other. This introduces a significant
120        /// vulnerability to man-in-the-middle attacks.
121        insecure: bool,
122
123        /// TLS certificates and client key.
124        tls_params: Option<TlsConnParams>,
125    },
126    /// Format for this is the path to the unix socket.
127    Unix(PathBuf),
128}
129
130impl PartialEq for ConnectionAddr {
131    fn eq(&self, other: &Self) -> bool {
132        match (self, other) {
133            (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
134                host1 == host2 && port1 == port2
135            }
136            (
137                ConnectionAddr::TcpTls {
138                    host: host1,
139                    port: port1,
140                    insecure: insecure1,
141                    tls_params: _,
142                },
143                ConnectionAddr::TcpTls {
144                    host: host2,
145                    port: port2,
146                    insecure: insecure2,
147                    tls_params: _,
148                },
149            ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
150            (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
151            _ => false,
152        }
153    }
154}
155
156impl Eq for ConnectionAddr {}
157
158impl ConnectionAddr {
159    /// Checks if this address is supported.
160    ///
161    /// Because not all platforms support all connection addresses this is a
162    /// quick way to figure out if a connection method is supported. Currently
163    /// this affects:
164    ///
165    /// - Unix socket addresses, which are supported only on Unix
166    ///
167    /// - TLS addresses, which are supported only if a TLS feature is enabled
168    ///   (either `tls-native-tls` or `tls-rustls`).
169    pub fn is_supported(&self) -> bool {
170        match *self {
171            ConnectionAddr::Tcp(_, _) => true,
172            ConnectionAddr::TcpTls { .. } => {
173                cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
174            }
175            ConnectionAddr::Unix(_) => cfg!(unix),
176        }
177    }
178
179    /// Configure this address to connect without checking certificate hostnames.
180    ///
181    /// # Warning
182    ///
183    /// You should think very carefully before you use this method. If hostname
184    /// verification is not used, any valid certificate for any site will be
185    /// trusted for use from any other. This introduces a significant
186    /// vulnerability to man-in-the-middle attacks.
187    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
188    pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
189        if let ConnectionAddr::TcpTls { tls_params, .. } = self {
190            if let Some(ref mut params) = tls_params {
191                params.danger_accept_invalid_hostnames = insecure;
192            } else if insecure {
193                *tls_params = Some(TlsConnParams {
194                    #[cfg(feature = "tls-rustls")]
195                    client_tls_params: None,
196                    #[cfg(feature = "tls-rustls")]
197                    root_cert_store: None,
198                    danger_accept_invalid_hostnames: insecure,
199                });
200            }
201        }
202    }
203
204    #[cfg(feature = "cluster")]
205    pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
206        match self {
207            ConnectionAddr::TcpTls { insecure, .. } => {
208                if *insecure {
209                    Some(TlsMode::Insecure)
210                } else {
211                    Some(TlsMode::Secure)
212                }
213            }
214            _ => None,
215        }
216    }
217}
218
219impl fmt::Display for ConnectionAddr {
220    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221        // Cluster::get_connection_info depends on the return value from this function
222        match *self {
223            ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
224            ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
225            ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
226        }
227    }
228}
229
230/// Holds the connection information that redis should use for connecting.
231#[derive(Clone, Debug)]
232pub struct ConnectionInfo {
233    /// A connection address for where to connect to.
234    pub addr: ConnectionAddr,
235
236    /// A redis connection info for how to handshake with redis.
237    pub redis: RedisConnectionInfo,
238}
239
240/// Redis specific/connection independent information used to establish a connection to redis.
241#[derive(Clone, Debug, Default)]
242pub struct RedisConnectionInfo {
243    /// The database number to use.  This is usually `0`.
244    pub db: i64,
245    /// Optionally a username that should be used for connection.
246    pub username: Option<String>,
247    /// Optionally a password that should be used for connection.
248    pub password: Option<String>,
249    /// Version of the protocol to use.
250    pub protocol: ProtocolVersion,
251}
252
253impl FromStr for ConnectionInfo {
254    type Err = RedisError;
255
256    fn from_str(s: &str) -> Result<Self, Self::Err> {
257        s.into_connection_info()
258    }
259}
260
261/// Converts an object into a connection info struct.  This allows the
262/// constructor of the client to accept connection information in a
263/// range of different formats.
264pub trait IntoConnectionInfo {
265    /// Converts the object into a connection info object.
266    fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
267}
268
269impl IntoConnectionInfo for ConnectionInfo {
270    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
271        Ok(self)
272    }
273}
274
275/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
276///
277/// - Basic: `redis://127.0.0.1:6379`
278/// - Username & Password: `redis://user:password@127.0.0.1:6379`
279/// - Password only: `redis://:password@127.0.0.1:6379`
280/// - Specifying DB: `redis://127.0.0.1:6379/0`
281/// - Enabling TLS: `rediss://127.0.0.1:6379`
282/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
283/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
284impl IntoConnectionInfo for &str {
285    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
286        match parse_redis_url(self) {
287            Some(u) => u.into_connection_info(),
288            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
289        }
290    }
291}
292
293impl<T> IntoConnectionInfo for (T, u16)
294where
295    T: Into<String>,
296{
297    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
298        Ok(ConnectionInfo {
299            addr: ConnectionAddr::Tcp(self.0.into(), self.1),
300            redis: RedisConnectionInfo::default(),
301        })
302    }
303}
304
305/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
306///
307/// - Basic: `redis://127.0.0.1:6379`
308/// - Username & Password: `redis://user:password@127.0.0.1:6379`
309/// - Password only: `redis://:password@127.0.0.1:6379`
310/// - Specifying DB: `redis://127.0.0.1:6379/0`
311/// - Enabling TLS: `rediss://127.0.0.1:6379`
312/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
313/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
314impl IntoConnectionInfo for String {
315    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
316        match parse_redis_url(&self) {
317            Some(u) => u.into_connection_info(),
318            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
319        }
320    }
321}
322
323fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
324    Ok(match query.get("protocol") {
325        Some(protocol) => {
326            if protocol == "2" || protocol == "resp2" {
327                ProtocolVersion::RESP2
328            } else if protocol == "3" || protocol == "resp3" {
329                ProtocolVersion::RESP3
330            } else {
331                fail!((
332                    ErrorKind::InvalidClientConfig,
333                    "Invalid protocol version",
334                    protocol.to_string()
335                ))
336            }
337        }
338        None => ProtocolVersion::RESP2,
339    })
340}
341
342fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
343    let host = match url.host() {
344        Some(host) => {
345            // Here we manually match host's enum arms and call their to_string().
346            // Because url.host().to_string() will add `[` and `]` for ipv6:
347            // https://docs.rs/url/latest/src/url/host.rs.html#170
348            // And these brackets will break host.parse::<Ipv6Addr>() when
349            // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`:
350            // https://doc.rust-lang.org/src/std/net/addr.rs.html#963
351            // https://doc.rust-lang.org/src/std/net/parser.rs.html#158
352            // IpAddr string with brackets can ONLY parse to SocketAddrV6:
353            // https://doc.rust-lang.org/src/std/net/parser.rs.html#255
354            // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets:
355            // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755
356            match host {
357                url::Host::Domain(path) => path.to_string(),
358                url::Host::Ipv4(v4) => v4.to_string(),
359                url::Host::Ipv6(v6) => v6.to_string(),
360            }
361        }
362        None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
363    };
364    let port = url.port().unwrap_or(DEFAULT_PORT);
365    let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
366        #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
367        {
368            match url.fragment() {
369                Some("insecure") => ConnectionAddr::TcpTls {
370                    host,
371                    port,
372                    insecure: true,
373                    tls_params: None,
374                },
375                Some(_) => fail!((
376                    ErrorKind::InvalidClientConfig,
377                    "only #insecure is supported as URL fragment"
378                )),
379                _ => ConnectionAddr::TcpTls {
380                    host,
381                    port,
382                    insecure: false,
383                    tls_params: None,
384                },
385            }
386        }
387
388        #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
389        fail!((
390            ErrorKind::InvalidClientConfig,
391            "can't connect with TLS, the feature is not enabled"
392        ));
393    } else {
394        ConnectionAddr::Tcp(host, port)
395    };
396    let query: HashMap<_, _> = url.query_pairs().collect();
397    Ok(ConnectionInfo {
398        addr,
399        redis: RedisConnectionInfo {
400            db: match url.path().trim_matches('/') {
401                "" => 0,
402                path => path.parse::<i64>().map_err(|_| -> RedisError {
403                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
404                })?,
405            },
406            username: if url.username().is_empty() {
407                None
408            } else {
409                match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
410                    Ok(decoded) => Some(decoded.into_owned()),
411                    Err(_) => fail!((
412                        ErrorKind::InvalidClientConfig,
413                        "Username is not valid UTF-8 string"
414                    )),
415                }
416            },
417            password: match url.password() {
418                Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
419                    Ok(decoded) => Some(decoded.into_owned()),
420                    Err(_) => fail!((
421                        ErrorKind::InvalidClientConfig,
422                        "Password is not valid UTF-8 string"
423                    )),
424                },
425                None => None,
426            },
427            protocol: parse_protocol(&query)?,
428        },
429    })
430}
431
432#[cfg(unix)]
433fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
434    let query: HashMap<_, _> = url.query_pairs().collect();
435    Ok(ConnectionInfo {
436        addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
437            (ErrorKind::InvalidClientConfig, "Missing path").into()
438        })?),
439        redis: RedisConnectionInfo {
440            db: match query.get("db") {
441                Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
442                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
443                })?,
444
445                None => 0,
446            },
447            username: query.get("user").map(|username| username.to_string()),
448            password: query.get("pass").map(|password| password.to_string()),
449            protocol: parse_protocol(&query)?,
450        },
451    })
452}
453
454#[cfg(not(unix))]
455fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
456    fail!((
457        ErrorKind::InvalidClientConfig,
458        "Unix sockets are not available on this platform."
459    ));
460}
461
462impl IntoConnectionInfo for url::Url {
463    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
464        match self.scheme() {
465            "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
466            "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
467            _ => fail!((
468                ErrorKind::InvalidClientConfig,
469                "URL provided is not a redis URL"
470            )),
471        }
472    }
473}
474
475struct TcpConnection {
476    reader: TcpStream,
477    open: bool,
478}
479
480#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
481struct TcpNativeTlsConnection {
482    reader: TlsStream<TcpStream>,
483    open: bool,
484}
485
486#[cfg(feature = "tls-rustls")]
487struct TcpRustlsConnection {
488    reader: StreamOwned<rustls::ClientConnection, TcpStream>,
489    open: bool,
490}
491
492#[cfg(unix)]
493struct UnixConnection {
494    sock: UnixStream,
495    open: bool,
496}
497
498enum ActualConnection {
499    Tcp(TcpConnection),
500    #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
501    TcpNativeTls(Box<TcpNativeTlsConnection>),
502    #[cfg(feature = "tls-rustls")]
503    TcpRustls(Box<TcpRustlsConnection>),
504    #[cfg(unix)]
505    Unix(UnixConnection),
506}
507
508#[cfg(feature = "tls-rustls-insecure")]
509struct NoCertificateVerification {
510    supported: rustls::crypto::WebPkiSupportedAlgorithms,
511}
512
513#[cfg(feature = "tls-rustls-insecure")]
514impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
515    fn verify_server_cert(
516        &self,
517        _end_entity: &rustls::pki_types::CertificateDer<'_>,
518        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
519        _server_name: &rustls::pki_types::ServerName<'_>,
520        _ocsp_response: &[u8],
521        _now: rustls::pki_types::UnixTime,
522    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
523        Ok(rustls::client::danger::ServerCertVerified::assertion())
524    }
525
526    fn verify_tls12_signature(
527        &self,
528        _message: &[u8],
529        _cert: &rustls::pki_types::CertificateDer<'_>,
530        _dss: &rustls::DigitallySignedStruct,
531    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
532        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
533    }
534
535    fn verify_tls13_signature(
536        &self,
537        _message: &[u8],
538        _cert: &rustls::pki_types::CertificateDer<'_>,
539        _dss: &rustls::DigitallySignedStruct,
540    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
541        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
542    }
543
544    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
545        self.supported.supported_schemes()
546    }
547}
548
549#[cfg(feature = "tls-rustls-insecure")]
550impl fmt::Debug for NoCertificateVerification {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        f.debug_struct("NoCertificateVerification").finish()
553    }
554}
555
556/// Insecure `ServerCertVerifier` for rustls that implements `danger_accept_invalid_hostnames`.
557#[cfg(feature = "tls-rustls-insecure")]
558#[derive(Debug)]
559struct AcceptInvalidHostnamesCertVerifier {
560    inner: Arc<rustls::client::WebPkiServerVerifier>,
561}
562
563#[cfg(feature = "tls-rustls-insecure")]
564fn is_hostname_error(err: &rustls::Error) -> bool {
565    matches!(
566        err,
567        rustls::Error::InvalidCertificate(
568            rustls::CertificateError::NotValidForName
569                | rustls::CertificateError::NotValidForNameContext { .. }
570        )
571    )
572}
573
574#[cfg(feature = "tls-rustls-insecure")]
575impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
576    fn verify_server_cert(
577        &self,
578        end_entity: &rustls::pki_types::CertificateDer<'_>,
579        intermediates: &[rustls::pki_types::CertificateDer<'_>],
580        server_name: &rustls::pki_types::ServerName<'_>,
581        ocsp_response: &[u8],
582        now: rustls::pki_types::UnixTime,
583    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
584        self.inner
585            .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
586            .or_else(|err| {
587                if is_hostname_error(&err) {
588                    Ok(rustls::client::danger::ServerCertVerified::assertion())
589                } else {
590                    Err(err)
591                }
592            })
593    }
594
595    fn verify_tls12_signature(
596        &self,
597        message: &[u8],
598        cert: &rustls::pki_types::CertificateDer<'_>,
599        dss: &rustls::DigitallySignedStruct,
600    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
601        self.inner
602            .verify_tls12_signature(message, cert, dss)
603            .or_else(|err| {
604                if is_hostname_error(&err) {
605                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
606                } else {
607                    Err(err)
608                }
609            })
610    }
611
612    fn verify_tls13_signature(
613        &self,
614        message: &[u8],
615        cert: &rustls::pki_types::CertificateDer<'_>,
616        dss: &rustls::DigitallySignedStruct,
617    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
618        self.inner
619            .verify_tls13_signature(message, cert, dss)
620            .or_else(|err| {
621                if is_hostname_error(&err) {
622                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
623                } else {
624                    Err(err)
625                }
626            })
627    }
628
629    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
630        self.inner.supported_verify_schemes()
631    }
632}
633
634/// Represents a stateful redis TCP connection.
635pub struct Connection {
636    con: ActualConnection,
637    parser: Parser,
638    db: i64,
639
640    /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
641    ///
642    /// This flag is checked when attempting to send a command, and if it's raised, we attempt to
643    /// exit the pubsub state before executing the new request.
644    pubsub: bool,
645
646    // Field indicating which protocol to use for server communications.
647    protocol: ProtocolVersion,
648
649    /// This is used to manage Push messages in RESP3 mode.
650    push_sender: Option<SyncPushSender>,
651
652    /// The number of messages that are expected to be returned from the server,
653    /// but the user no longer waits for - answers for requests that already returned a transient error.
654    messages_to_skip: usize,
655}
656
657/// Represents a pubsub connection.
658pub struct PubSub<'a> {
659    con: &'a mut Connection,
660    waiting_messages: VecDeque<Msg>,
661}
662
663/// Represents a pubsub message.
664#[derive(Debug, Clone)]
665pub struct Msg {
666    payload: Value,
667    channel: Value,
668    pattern: Option<Value>,
669}
670
671impl ActualConnection {
672    pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
673        Ok(match *addr {
674            ConnectionAddr::Tcp(ref host, ref port) => {
675                let addr = (host.as_str(), *port);
676                let tcp = match timeout {
677                    None => connect_tcp(addr)?,
678                    Some(timeout) => {
679                        let mut tcp = None;
680                        let mut last_error = None;
681                        for addr in addr.to_socket_addrs()? {
682                            match connect_tcp_timeout(&addr, timeout) {
683                                Ok(l) => {
684                                    tcp = Some(l);
685                                    break;
686                                }
687                                Err(e) => {
688                                    last_error = Some(e);
689                                }
690                            };
691                        }
692                        match (tcp, last_error) {
693                            (Some(tcp), _) => tcp,
694                            (None, Some(e)) => {
695                                fail!(e);
696                            }
697                            (None, None) => {
698                                fail!((
699                                    ErrorKind::InvalidClientConfig,
700                                    "could not resolve to any addresses"
701                                ));
702                            }
703                        }
704                    }
705                };
706                ActualConnection::Tcp(TcpConnection {
707                    reader: tcp,
708                    open: true,
709                })
710            }
711            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
712            ConnectionAddr::TcpTls {
713                ref host,
714                port,
715                insecure,
716                ref tls_params,
717            } => {
718                let tls_connector = if insecure {
719                    TlsConnector::builder()
720                        .danger_accept_invalid_certs(true)
721                        .danger_accept_invalid_hostnames(true)
722                        .use_sni(false)
723                        .build()?
724                } else if let Some(params) = tls_params {
725                    TlsConnector::builder()
726                        .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
727                        .build()?
728                } else {
729                    TlsConnector::new()?
730                };
731                let addr = (host.as_str(), port);
732                let tls = match timeout {
733                    None => {
734                        let tcp = connect_tcp(addr)?;
735                        match tls_connector.connect(host, tcp) {
736                            Ok(res) => res,
737                            Err(e) => {
738                                fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
739                            }
740                        }
741                    }
742                    Some(timeout) => {
743                        let mut tcp = None;
744                        let mut last_error = None;
745                        for addr in (host.as_str(), port).to_socket_addrs()? {
746                            match connect_tcp_timeout(&addr, timeout) {
747                                Ok(l) => {
748                                    tcp = Some(l);
749                                    break;
750                                }
751                                Err(e) => {
752                                    last_error = Some(e);
753                                }
754                            };
755                        }
756                        match (tcp, last_error) {
757                            (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
758                            (None, Some(e)) => {
759                                fail!(e);
760                            }
761                            (None, None) => {
762                                fail!((
763                                    ErrorKind::InvalidClientConfig,
764                                    "could not resolve to any addresses"
765                                ));
766                            }
767                        }
768                    }
769                };
770                ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
771                    reader: tls,
772                    open: true,
773                }))
774            }
775            #[cfg(feature = "tls-rustls")]
776            ConnectionAddr::TcpTls {
777                ref host,
778                port,
779                insecure,
780                ref tls_params,
781            } => {
782                let host: &str = host;
783                let config = create_rustls_config(insecure, tls_params.clone())?;
784                let conn = rustls::ClientConnection::new(
785                    Arc::new(config),
786                    rustls::pki_types::ServerName::try_from(host)?.to_owned(),
787                )?;
788                let reader = match timeout {
789                    None => {
790                        let tcp = connect_tcp((host, port))?;
791                        StreamOwned::new(conn, tcp)
792                    }
793                    Some(timeout) => {
794                        let mut tcp = None;
795                        let mut last_error = None;
796                        for addr in (host, port).to_socket_addrs()? {
797                            match connect_tcp_timeout(&addr, timeout) {
798                                Ok(l) => {
799                                    tcp = Some(l);
800                                    break;
801                                }
802                                Err(e) => {
803                                    last_error = Some(e);
804                                }
805                            };
806                        }
807                        match (tcp, last_error) {
808                            (Some(tcp), _) => StreamOwned::new(conn, tcp),
809                            (None, Some(e)) => {
810                                fail!(e);
811                            }
812                            (None, None) => {
813                                fail!((
814                                    ErrorKind::InvalidClientConfig,
815                                    "could not resolve to any addresses"
816                                ));
817                            }
818                        }
819                    }
820                };
821
822                ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
823            }
824            #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
825            ConnectionAddr::TcpTls { .. } => {
826                fail!((
827                    ErrorKind::InvalidClientConfig,
828                    "Cannot connect to TCP with TLS without the tls feature"
829                ));
830            }
831            #[cfg(unix)]
832            ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
833                sock: UnixStream::connect(path)?,
834                open: true,
835            }),
836            #[cfg(not(unix))]
837            ConnectionAddr::Unix(ref _path) => {
838                fail!((
839                    ErrorKind::InvalidClientConfig,
840                    "Cannot connect to unix sockets \
841                     on this platform"
842                ));
843            }
844        })
845    }
846
847    pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
848        match *self {
849            ActualConnection::Tcp(ref mut connection) => {
850                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
851                match res {
852                    Err(e) => {
853                        if e.is_unrecoverable_error() {
854                            connection.open = false;
855                        }
856                        Err(e)
857                    }
858                    Ok(_) => Ok(Value::Okay),
859                }
860            }
861            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
862            ActualConnection::TcpNativeTls(ref mut connection) => {
863                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
864                match res {
865                    Err(e) => {
866                        if e.is_unrecoverable_error() {
867                            connection.open = false;
868                        }
869                        Err(e)
870                    }
871                    Ok(_) => Ok(Value::Okay),
872                }
873            }
874            #[cfg(feature = "tls-rustls")]
875            ActualConnection::TcpRustls(ref mut connection) => {
876                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
877                match res {
878                    Err(e) => {
879                        if e.is_unrecoverable_error() {
880                            connection.open = false;
881                        }
882                        Err(e)
883                    }
884                    Ok(_) => Ok(Value::Okay),
885                }
886            }
887            #[cfg(unix)]
888            ActualConnection::Unix(ref mut connection) => {
889                let result = connection.sock.write_all(bytes).map_err(RedisError::from);
890                match result {
891                    Err(e) => {
892                        if e.is_unrecoverable_error() {
893                            connection.open = false;
894                        }
895                        Err(e)
896                    }
897                    Ok(_) => Ok(Value::Okay),
898                }
899            }
900        }
901    }
902
903    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
904        match *self {
905            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
906                reader.set_write_timeout(dur)?;
907            }
908            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
909            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
910                let reader = &(boxed_tls_connection.reader);
911                reader.get_ref().set_write_timeout(dur)?;
912            }
913            #[cfg(feature = "tls-rustls")]
914            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
915                let reader = &(boxed_tls_connection.reader);
916                reader.get_ref().set_write_timeout(dur)?;
917            }
918            #[cfg(unix)]
919            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
920                sock.set_write_timeout(dur)?;
921            }
922        }
923        Ok(())
924    }
925
926    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
927        match *self {
928            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
929                reader.set_read_timeout(dur)?;
930            }
931            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
932            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
933                let reader = &(boxed_tls_connection.reader);
934                reader.get_ref().set_read_timeout(dur)?;
935            }
936            #[cfg(feature = "tls-rustls")]
937            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
938                let reader = &(boxed_tls_connection.reader);
939                reader.get_ref().set_read_timeout(dur)?;
940            }
941            #[cfg(unix)]
942            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
943                sock.set_read_timeout(dur)?;
944            }
945        }
946        Ok(())
947    }
948
949    pub fn is_open(&self) -> bool {
950        match *self {
951            ActualConnection::Tcp(TcpConnection { open, .. }) => open,
952            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
953            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
954            #[cfg(feature = "tls-rustls")]
955            ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
956            #[cfg(unix)]
957            ActualConnection::Unix(UnixConnection { open, .. }) => open,
958        }
959    }
960}
961
962#[cfg(feature = "tls-rustls")]
963pub(crate) fn create_rustls_config(
964    insecure: bool,
965    tls_params: Option<TlsConnParams>,
966) -> RedisResult<rustls::ClientConfig> {
967    #[allow(unused_mut)]
968    let mut root_store = RootCertStore::empty();
969    #[cfg(feature = "tls-rustls-webpki-roots")]
970    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
971    #[cfg(all(
972        feature = "tls-rustls",
973        not(feature = "tls-native-tls"),
974        not(feature = "tls-rustls-webpki-roots")
975    ))]
976    {
977        let mut certificate_result = load_native_certs();
978        if let Some(error) = certificate_result.errors.pop() {
979            return Err(error.into());
980        }
981        for cert in certificate_result.certs {
982            root_store.add(cert)?;
983        }
984    }
985
986    let config = rustls::ClientConfig::builder();
987    let config = if let Some(tls_params) = tls_params {
988        let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
989        let config_builder = config.with_root_certificates(root_cert_store.clone());
990
991        let config_builder = if let Some(ClientTlsParams {
992            client_cert_chain: client_cert,
993            client_key,
994        }) = tls_params.client_tls_params
995        {
996            config_builder
997                .with_client_auth_cert(client_cert, client_key)
998                .map_err(|err| {
999                    RedisError::from((
1000                        ErrorKind::InvalidClientConfig,
1001                        "Unable to build client with TLS parameters provided.",
1002                        err.to_string(),
1003                    ))
1004                })?
1005        } else {
1006            config_builder.with_no_client_auth()
1007        };
1008
1009        // Implement `danger_accept_invalid_hostnames`.
1010        //
1011        // The strange cfg here is to handle a specific unusual combination of features: if
1012        // `tls-native-tls` and `tls-rustls` are enabled, but `tls-rustls-insecure` is not, and the
1013        // application tries to use the danger flag.
1014        #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
1015        let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
1016            #[cfg(not(feature = "tls-rustls-insecure"))]
1017            {
1018                // This code should not enable an insecure mode if the `insecure` feature is not
1019                // set, but it shouldn't silently ignore the flag either. So return an error.
1020                fail!((
1021                    ErrorKind::InvalidClientConfig,
1022                    "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
1023                ));
1024            }
1025
1026            #[cfg(feature = "tls-rustls-insecure")]
1027            {
1028                let mut config = config_builder;
1029                config.dangerous().set_certificate_verifier(Arc::new(
1030                    AcceptInvalidHostnamesCertVerifier {
1031                        inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
1032                            root_cert_store,
1033                        ))
1034                        .build()
1035                        .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
1036                    },
1037                ));
1038                config
1039            }
1040        } else {
1041            config_builder
1042        };
1043
1044        config_builder
1045    } else {
1046        config
1047            .with_root_certificates(root_store)
1048            .with_no_client_auth()
1049    };
1050
1051    match (insecure, cfg!(feature = "tls-rustls-insecure")) {
1052        #[cfg(feature = "tls-rustls-insecure")]
1053        (true, true) => {
1054            let mut config = config;
1055            config.enable_sni = false;
1056            let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
1057                return Err(RedisError::from((
1058                    ErrorKind::InvalidClientConfig,
1059                    "No crypto provider available for rustls",
1060                )));
1061            };
1062            config
1063                .dangerous()
1064                .set_certificate_verifier(Arc::new(NoCertificateVerification {
1065                    supported: crypto_provider.signature_verification_algorithms,
1066                }));
1067
1068            Ok(config)
1069        }
1070        (true, false) => {
1071            fail!((
1072                ErrorKind::InvalidClientConfig,
1073                "Cannot create insecure client without tls-rustls-insecure feature"
1074            ));
1075        }
1076        _ => Ok(config),
1077    }
1078}
1079
1080fn authenticate_cmd(
1081    connection_info: &RedisConnectionInfo,
1082    check_username: bool,
1083    password: &str,
1084) -> Cmd {
1085    let mut command = cmd("AUTH");
1086    if check_username {
1087        if let Some(username) = &connection_info.username {
1088            command.arg(username);
1089        }
1090    }
1091    command.arg(password);
1092    command
1093}
1094
1095pub fn connect(
1096    connection_info: &ConnectionInfo,
1097    timeout: Option<Duration>,
1098) -> RedisResult<Connection> {
1099    let start = Instant::now();
1100    let con: ActualConnection = ActualConnection::new(&connection_info.addr, timeout)?;
1101
1102    // we temporarily set the timeout, and will remove it after finishing setup.
1103    let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
1104    // TLS could run logic that doesn't contain a timeout, and should fail if it takes too long.
1105    if timeout.is_some() && remaining_timeout.is_none() {
1106        return Err(RedisError::from(std::io::Error::new(
1107            std::io::ErrorKind::TimedOut,
1108            "Connection timed out",
1109        )));
1110    }
1111    con.set_read_timeout(remaining_timeout)?;
1112    con.set_write_timeout(remaining_timeout)?;
1113
1114    let con = setup_connection(
1115        con,
1116        &connection_info.redis,
1117        #[cfg(feature = "cache-aio")]
1118        None,
1119    )?;
1120
1121    // remove the temporary timeout.
1122    con.set_read_timeout(None)?;
1123    con.set_write_timeout(None)?;
1124
1125    Ok(con)
1126}
1127
1128pub(crate) struct ConnectionSetupComponents {
1129    resp3_auth_cmd_idx: Option<usize>,
1130    resp2_auth_cmd_idx: Option<usize>,
1131    select_cmd_idx: Option<usize>,
1132    #[cfg(feature = "cache-aio")]
1133    cache_cmd_idx: Option<usize>,
1134}
1135
1136pub(crate) fn connection_setup_pipeline(
1137    connection_info: &RedisConnectionInfo,
1138    check_username: bool,
1139    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1140) -> (crate::Pipeline, ConnectionSetupComponents) {
1141    let mut last_cmd_index = 0;
1142
1143    let mut get_next_command_index = |condition| {
1144        if condition {
1145            last_cmd_index += 1;
1146            Some(last_cmd_index - 1)
1147        } else {
1148            None
1149        }
1150    };
1151
1152    let authenticate_with_resp3_cmd_index =
1153        get_next_command_index(connection_info.protocol != ProtocolVersion::RESP2);
1154    let authenticate_with_resp2_cmd_index = get_next_command_index(
1155        authenticate_with_resp3_cmd_index.is_none() && connection_info.password.is_some(),
1156    );
1157    let select_db_cmd_index = get_next_command_index(connection_info.db != 0);
1158    #[cfg(feature = "cache-aio")]
1159    let cache_cmd_index = get_next_command_index(
1160        connection_info.protocol != ProtocolVersion::RESP2 && cache_config.is_some(),
1161    );
1162
1163    let mut pipeline = pipe();
1164
1165    if authenticate_with_resp3_cmd_index.is_some() {
1166        pipeline.add_command(resp3_hello(connection_info));
1167    } else if authenticate_with_resp2_cmd_index.is_some() {
1168        pipeline.add_command(authenticate_cmd(
1169            connection_info,
1170            check_username,
1171            connection_info.password.as_ref().unwrap(),
1172        ));
1173    }
1174
1175    if select_db_cmd_index.is_some() {
1176        pipeline.cmd("SELECT").arg(connection_info.db);
1177    }
1178
1179    // result is ignored, as per the command's instructions.
1180    // https://redis.io/commands/client-setinfo/
1181    #[cfg(not(feature = "disable-client-setinfo"))]
1182    pipeline
1183        .cmd("CLIENT")
1184        .arg("SETINFO")
1185        .arg("LIB-NAME")
1186        .arg("redis-rs")
1187        .ignore();
1188    #[cfg(not(feature = "disable-client-setinfo"))]
1189    pipeline
1190        .cmd("CLIENT")
1191        .arg("SETINFO")
1192        .arg("LIB-VER")
1193        .arg(env!("CARGO_PKG_VERSION"))
1194        .ignore();
1195
1196    #[cfg(feature = "cache-aio")]
1197    if cache_cmd_index.is_some() {
1198        let cache_config = cache_config.expect(
1199            "It's expected to have cache_config if cache_cmd_index is Some, please create an issue about this.",
1200        );
1201        pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
1202        match cache_config.mode {
1203            crate::caching::CacheMode::All => {}
1204            crate::caching::CacheMode::OptIn => {
1205                pipeline.arg("OPTIN");
1206            }
1207        }
1208    }
1209
1210    (
1211        pipeline,
1212        ConnectionSetupComponents {
1213            resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
1214            resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
1215            select_cmd_idx: select_db_cmd_index,
1216            #[cfg(feature = "cache-aio")]
1217            cache_cmd_idx: cache_cmd_index,
1218        },
1219    )
1220}
1221
1222fn check_resp3_auth(result: &Value) -> RedisResult<()> {
1223    if let Value::ServerError(err) = result {
1224        return Err(get_resp3_hello_command_error(err.clone().into()));
1225    }
1226    Ok(())
1227}
1228
1229#[derive(PartialEq)]
1230pub(crate) enum AuthResult {
1231    Succeeded,
1232    ShouldRetryWithoutUsername,
1233}
1234
1235fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
1236    let err = match result {
1237        Value::Okay => {
1238            return Ok(AuthResult::Succeeded);
1239        }
1240        Value::ServerError(err) => err,
1241        _ => {
1242            return Err((
1243                ErrorKind::ResponseError,
1244                "Redis server refused to authenticate, returns Ok() != Value::Okay",
1245            )
1246                .into());
1247        }
1248    };
1249
1250    let err_msg = err.details().ok_or((
1251        ErrorKind::AuthenticationFailed,
1252        "Password authentication failed",
1253    ))?;
1254    if !err_msg.contains("wrong number of arguments for 'auth' command") {
1255        return Err((
1256            ErrorKind::AuthenticationFailed,
1257            "Password authentication failed",
1258        )
1259            .into());
1260    }
1261    Ok(AuthResult::ShouldRetryWithoutUsername)
1262}
1263
1264fn check_db_select(value: &Value) -> RedisResult<()> {
1265    let Value::ServerError(err) = value else {
1266        return Ok(());
1267    };
1268
1269    match err.details() {
1270        Some(err_msg) => Err((
1271            ErrorKind::ResponseError,
1272            "Redis server refused to switch database",
1273            err_msg.to_string(),
1274        )
1275            .into()),
1276        None => Err((
1277            ErrorKind::ResponseError,
1278            "Redis server refused to switch database",
1279        )
1280            .into()),
1281    }
1282}
1283
1284#[cfg(feature = "cache-aio")]
1285fn check_caching(result: &Value) -> RedisResult<()> {
1286    match result {
1287        Value::Okay => Ok(()),
1288        _ => Err((
1289            ErrorKind::ResponseError,
1290            "Client-side caching returned unknown response",
1291        )
1292            .into()),
1293    }
1294}
1295
1296pub(crate) fn check_connection_setup(
1297    results: Vec<Value>,
1298    ConnectionSetupComponents {
1299        resp3_auth_cmd_idx,
1300        resp2_auth_cmd_idx,
1301        select_cmd_idx,
1302        #[cfg(feature = "cache-aio")]
1303        cache_cmd_idx,
1304    }: ConnectionSetupComponents,
1305) -> RedisResult<AuthResult> {
1306    // can't have both values set
1307    assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
1308
1309    if let Some(index) = resp3_auth_cmd_idx {
1310        let Some(value) = results.get(index) else {
1311            return Err((ErrorKind::ClientError, "Missing RESP3 auth response").into());
1312        };
1313        check_resp3_auth(value)?;
1314    } else if let Some(index) = resp2_auth_cmd_idx {
1315        let Some(value) = results.get(index) else {
1316            return Err((ErrorKind::ClientError, "Missing RESP2 auth response").into());
1317        };
1318        if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
1319            return Ok(AuthResult::ShouldRetryWithoutUsername);
1320        }
1321    }
1322
1323    if let Some(index) = select_cmd_idx {
1324        let Some(value) = results.get(index) else {
1325            return Err((ErrorKind::ClientError, "Missing SELECT DB response").into());
1326        };
1327        check_db_select(value)?;
1328    }
1329
1330    #[cfg(feature = "cache-aio")]
1331    if let Some(index) = cache_cmd_idx {
1332        let Some(value) = results.get(index) else {
1333            return Err((ErrorKind::ClientError, "Missing Caching response").into());
1334        };
1335        check_caching(value)?;
1336    }
1337
1338    Ok(AuthResult::Succeeded)
1339}
1340
1341fn execute_connection_pipeline(
1342    rv: &mut Connection,
1343    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
1344) -> RedisResult<AuthResult> {
1345    if pipeline.is_empty() {
1346        return Ok(AuthResult::Succeeded);
1347    }
1348    let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
1349
1350    check_connection_setup(results, instructions)
1351}
1352
1353fn setup_connection(
1354    con: ActualConnection,
1355    connection_info: &RedisConnectionInfo,
1356    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1357) -> RedisResult<Connection> {
1358    let mut rv = Connection {
1359        con,
1360        parser: Parser::new(),
1361        db: connection_info.db,
1362        pubsub: false,
1363        protocol: connection_info.protocol,
1364        push_sender: None,
1365        messages_to_skip: 0,
1366    };
1367
1368    if execute_connection_pipeline(
1369        &mut rv,
1370        connection_setup_pipeline(
1371            connection_info,
1372            true,
1373            #[cfg(feature = "cache-aio")]
1374            cache_config,
1375        ),
1376    )? == AuthResult::ShouldRetryWithoutUsername
1377    {
1378        execute_connection_pipeline(
1379            &mut rv,
1380            connection_setup_pipeline(
1381                connection_info,
1382                false,
1383                #[cfg(feature = "cache-aio")]
1384                cache_config,
1385            ),
1386        )?;
1387    }
1388
1389    Ok(rv)
1390}
1391
1392/// Implements the "stateless" part of the connection interface that is used by the
1393/// different objects in redis-rs.
1394///
1395/// Primarily it obviously applies to `Connection` object but also some other objects
1396///  implement the interface (for instance whole clients or certain redis results).
1397///
1398/// Generally clients and connections (as well as redis results of those) implement
1399/// this trait.  Actual connections provide more functionality which can be used
1400/// to implement things like `PubSub` but they also can modify the intrinsic
1401/// state of the TCP connection.  This is not possible with `ConnectionLike`
1402/// implementors because that functionality is not exposed.
1403pub trait ConnectionLike {
1404    /// Sends an already encoded (packed) command into the TCP socket and
1405    /// reads the single response from it.
1406    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
1407
1408    /// Sends multiple already encoded (packed) command into the TCP socket
1409    /// and reads `count` responses from it.  This is used to implement
1410    /// pipelining.
1411    /// Important - this function is meant for internal usage, since it's
1412    /// easy to pass incorrect `offset` & `count` parameters, which might
1413    /// cause the connection to enter an erroneous state. Users shouldn't
1414    /// call it, instead using the Pipeline::query function.
1415    #[doc(hidden)]
1416    fn req_packed_commands(
1417        &mut self,
1418        cmd: &[u8],
1419        offset: usize,
1420        count: usize,
1421    ) -> RedisResult<Vec<Value>>;
1422
1423    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1424    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1425        let pcmd = cmd.get_packed_command();
1426        self.req_packed_command(&pcmd)
1427    }
1428
1429    /// Returns the database this connection is bound to.  Note that this
1430    /// information might be unreliable because it's initially cached and
1431    /// also might be incorrect if the connection like object is not
1432    /// actually connected.
1433    fn get_db(&self) -> i64;
1434
1435    /// Does this connection support pipelining?
1436    #[doc(hidden)]
1437    fn supports_pipelining(&self) -> bool {
1438        true
1439    }
1440
1441    /// Check that all connections it has are available (`PING` internally).
1442    fn check_connection(&mut self) -> bool;
1443
1444    /// Returns the connection status.
1445    ///
1446    /// The connection is open until any `read` call received an
1447    /// invalid response from the server (most likely a closed or dropped
1448    /// connection, otherwise a Redis protocol error). When using unix
1449    /// sockets the connection is open until writing a command failed with a
1450    /// `BrokenPipe` error.
1451    fn is_open(&self) -> bool;
1452}
1453
1454/// A connection is an object that represents a single redis connection.  It
1455/// provides basic support for sending encoded commands into a redis connection
1456/// and to read a response from it.  It's bound to a single database and can
1457/// only be created from the client.
1458///
1459/// You generally do not much with this object other than passing it to
1460/// `Cmd` objects.
1461impl Connection {
1462    /// Sends an already encoded (packed) command into the TCP socket and
1463    /// does not read a response.  This is useful for commands like
1464    /// `MONITOR` which yield multiple items.  This needs to be used with
1465    /// care because it changes the state of the connection.
1466    pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1467        self.send_bytes(cmd)?;
1468        Ok(())
1469    }
1470
1471    /// Fetches a single response from the connection.  This is useful
1472    /// if used in combination with `send_packed_command`.
1473    pub fn recv_response(&mut self) -> RedisResult<Value> {
1474        self.read(true)
1475    }
1476
1477    /// Sets the write timeout for the connection.
1478    ///
1479    /// If the provided value is `None`, then `send_packed_command` call will
1480    /// block indefinitely. It is an error to pass the zero `Duration` to this
1481    /// method.
1482    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1483        self.con.set_write_timeout(dur)
1484    }
1485
1486    /// Sets the read timeout for the connection.
1487    ///
1488    /// If the provided value is `None`, then `recv_response` call will
1489    /// block indefinitely. It is an error to pass the zero `Duration` to this
1490    /// method.
1491    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1492        self.con.set_read_timeout(dur)
1493    }
1494
1495    /// Creates a [`PubSub`] instance for this connection.
1496    pub fn as_pubsub(&mut self) -> PubSub<'_> {
1497        // NOTE: The pubsub flag is intentionally not raised at this time since
1498        // running commands within the pubsub state should not try and exit from
1499        // the pubsub state.
1500        PubSub::new(self)
1501    }
1502
1503    fn exit_pubsub(&mut self) -> RedisResult<()> {
1504        let res = self.clear_active_subscriptions();
1505        if res.is_ok() {
1506            self.pubsub = false;
1507        } else {
1508            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
1509            self.pubsub = true;
1510        }
1511
1512        res
1513    }
1514
1515    /// Get the inner connection out of a PubSub
1516    ///
1517    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
1518    /// dropped.
1519    fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1520        // Responses to unsubscribe commands return in a 3-tuple with values
1521        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
1522        // The "count of remaining subs" includes both pattern subscriptions and non pattern
1523        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
1524        // server, both commands need to be executed at once.
1525        {
1526            // Prepare both unsubscribe commands
1527            let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1528            let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1529
1530            // Execute commands
1531            self.send_bytes(&unsubscribe)?;
1532            self.send_bytes(&punsubscribe)?;
1533        }
1534
1535        // Receive responses
1536        //
1537        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
1538        // commands. There may be more responses if there are active subscriptions. In this case,
1539        // messages are received until the _subscription count_ in the responses reach zero.
1540        let mut received_unsub = false;
1541        let mut received_punsub = false;
1542
1543        loop {
1544            let resp = self.recv_response()?;
1545
1546            match resp {
1547                Value::Push { kind, data } => {
1548                    if data.len() >= 2 {
1549                        if let Value::Int(num) = data[1] {
1550                            if resp3_is_pub_sub_state_cleared(
1551                                &mut received_unsub,
1552                                &mut received_punsub,
1553                                &kind,
1554                                num as isize,
1555                            ) {
1556                                break;
1557                            }
1558                        }
1559                    }
1560                }
1561                Value::ServerError(err) => {
1562                    // a new error behavior, introduced in valkey 8.
1563                    // https://github.com/valkey-io/valkey/pull/759
1564                    if err.kind() == Some(ServerErrorKind::NoSub) {
1565                        if no_sub_err_is_pub_sub_state_cleared(
1566                            &mut received_unsub,
1567                            &mut received_punsub,
1568                            &err,
1569                        ) {
1570                            break;
1571                        } else {
1572                            continue;
1573                        }
1574                    }
1575
1576                    return Err(err.into());
1577                }
1578                Value::Array(vec) => {
1579                    let res: (Vec<u8>, (), isize) = from_owned_redis_value(Value::Array(vec))?;
1580                    if resp2_is_pub_sub_state_cleared(
1581                        &mut received_unsub,
1582                        &mut received_punsub,
1583                        &res.0,
1584                        res.2,
1585                    ) {
1586                        break;
1587                    }
1588                }
1589                _ => {
1590                    return Err((
1591                        ErrorKind::ClientError,
1592                        "Unexpected unsubscribe response",
1593                        format!("{resp:?}"),
1594                    )
1595                        .into())
1596                }
1597            }
1598        }
1599
1600        // Finally, the connection is back in its normal state since all subscriptions were
1601        // cancelled *and* all unsubscribe messages were received.
1602        Ok(())
1603    }
1604
1605    fn send_push(&self, push: PushInfo) {
1606        if let Some(sender) = &self.push_sender {
1607            let _ = sender.send(push);
1608        }
1609    }
1610
1611    fn try_send(&self, value: &RedisResult<Value>) {
1612        if let Ok(Value::Push { kind, data }) = value {
1613            self.send_push(PushInfo {
1614                kind: kind.clone(),
1615                data: data.clone(),
1616            });
1617        }
1618    }
1619
1620    fn send_disconnect(&self) {
1621        self.send_push(PushInfo::disconnect())
1622    }
1623
1624    fn close_connection(&mut self) {
1625        // Notify the PushManager that the connection was lost
1626        self.send_disconnect();
1627        match self.con {
1628            ActualConnection::Tcp(ref mut connection) => {
1629                let _ = connection.reader.shutdown(net::Shutdown::Both);
1630                connection.open = false;
1631            }
1632            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1633            ActualConnection::TcpNativeTls(ref mut connection) => {
1634                let _ = connection.reader.shutdown();
1635                connection.open = false;
1636            }
1637            #[cfg(feature = "tls-rustls")]
1638            ActualConnection::TcpRustls(ref mut connection) => {
1639                let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1640                connection.open = false;
1641            }
1642            #[cfg(unix)]
1643            ActualConnection::Unix(ref mut connection) => {
1644                let _ = connection.sock.shutdown(net::Shutdown::Both);
1645                connection.open = false;
1646            }
1647        }
1648    }
1649
1650    /// Fetches a single message from the connection. If the message is a response,
1651    /// increment `messages_to_skip` if it wasn't received before a timeout.
1652    fn read(&mut self, is_response: bool) -> RedisResult<Value> {
1653        loop {
1654            let result = match self.con {
1655                ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1656                    self.parser.parse_value(reader)
1657                }
1658                #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1659                ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1660                    let reader = &mut boxed_tls_connection.reader;
1661                    self.parser.parse_value(reader)
1662                }
1663                #[cfg(feature = "tls-rustls")]
1664                ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1665                    let reader = &mut boxed_tls_connection.reader;
1666                    self.parser.parse_value(reader)
1667                }
1668                #[cfg(unix)]
1669                ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1670                    self.parser.parse_value(sock)
1671                }
1672            };
1673            self.try_send(&result);
1674
1675            let Err(err) = &result else {
1676                if self.messages_to_skip > 0 {
1677                    self.messages_to_skip -= 1;
1678                    continue;
1679                }
1680                return result;
1681            };
1682            let Some(io_error) = err.as_io_error() else {
1683                if self.messages_to_skip > 0 {
1684                    self.messages_to_skip -= 1;
1685                    continue;
1686                }
1687                return result;
1688            };
1689            // shutdown connection on protocol error
1690            if io_error.kind() == io::ErrorKind::UnexpectedEof {
1691                self.close_connection();
1692            } else if is_response {
1693                self.messages_to_skip += 1;
1694            }
1695
1696            return result;
1697        }
1698    }
1699
1700    /// Sets sender channel for push values.
1701    pub fn set_push_sender(&mut self, sender: SyncPushSender) {
1702        self.push_sender = Some(sender);
1703    }
1704
1705    fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1706        let result = self.con.send_bytes(bytes);
1707        if self.protocol != ProtocolVersion::RESP2 {
1708            if let Err(e) = &result {
1709                if e.is_connection_dropped() {
1710                    self.send_disconnect();
1711                }
1712            }
1713        }
1714        result
1715    }
1716}
1717
1718impl ConnectionLike for Connection {
1719    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1720    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1721        let pcmd = cmd.get_packed_command();
1722        if self.pubsub {
1723            self.exit_pubsub()?;
1724        }
1725
1726        self.send_bytes(&pcmd)?;
1727        if cmd.is_no_response() {
1728            return Ok(Value::Nil);
1729        }
1730        loop {
1731            match self.read(true)? {
1732                Value::Push {
1733                    kind: _kind,
1734                    data: _data,
1735                } => continue,
1736                val => return Ok(val),
1737            }
1738        }
1739    }
1740    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1741        if self.pubsub {
1742            self.exit_pubsub()?;
1743        }
1744
1745        self.send_bytes(cmd)?;
1746        loop {
1747            match self.read(true)? {
1748                Value::Push {
1749                    kind: _kind,
1750                    data: _data,
1751                } => continue,
1752                val => return Ok(val),
1753            }
1754        }
1755    }
1756
1757    fn req_packed_commands(
1758        &mut self,
1759        cmd: &[u8],
1760        offset: usize,
1761        count: usize,
1762    ) -> RedisResult<Vec<Value>> {
1763        if self.pubsub {
1764            self.exit_pubsub()?;
1765        }
1766        self.send_bytes(cmd)?;
1767        let mut rv = vec![];
1768        let mut first_err = None;
1769        let mut count = count;
1770        let mut idx = 0;
1771        while idx < (offset + count) {
1772            // When processing a transaction, some responses may be errors.
1773            // We need to keep processing the rest of the responses in that case,
1774            // so bailing early with `?` would not be correct.
1775            // See: https://github.com/redis-rs/redis-rs/issues/436
1776            let response = self.read(true);
1777            match response {
1778                Ok(Value::ServerError(err)) => {
1779                    if idx < offset {
1780                        if first_err.is_none() {
1781                            first_err = Some(err.into());
1782                        }
1783                    } else {
1784                        rv.push(Value::ServerError(err));
1785                    }
1786                }
1787                Ok(item) => {
1788                    // RESP3 can insert push data between command replies
1789                    if let Value::Push {
1790                        kind: _kind,
1791                        data: _data,
1792                    } = item
1793                    {
1794                        // if that is the case we have to extend the loop and handle push data
1795                        count += 1;
1796                    } else if idx >= offset {
1797                        rv.push(item);
1798                    }
1799                }
1800                Err(err) => {
1801                    if first_err.is_none() {
1802                        first_err = Some(err);
1803                    }
1804                }
1805            }
1806            idx += 1;
1807        }
1808
1809        first_err.map_or(Ok(rv), Err)
1810    }
1811
1812    fn get_db(&self) -> i64 {
1813        self.db
1814    }
1815
1816    fn check_connection(&mut self) -> bool {
1817        cmd("PING").query::<String>(self).is_ok()
1818    }
1819
1820    fn is_open(&self) -> bool {
1821        self.con.is_open()
1822    }
1823}
1824
1825impl<C, T> ConnectionLike for T
1826where
1827    C: ConnectionLike,
1828    T: DerefMut<Target = C>,
1829{
1830    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1831        self.deref_mut().req_packed_command(cmd)
1832    }
1833
1834    fn req_packed_commands(
1835        &mut self,
1836        cmd: &[u8],
1837        offset: usize,
1838        count: usize,
1839    ) -> RedisResult<Vec<Value>> {
1840        self.deref_mut().req_packed_commands(cmd, offset, count)
1841    }
1842
1843    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1844        self.deref_mut().req_command(cmd)
1845    }
1846
1847    fn get_db(&self) -> i64 {
1848        self.deref().get_db()
1849    }
1850
1851    fn supports_pipelining(&self) -> bool {
1852        self.deref().supports_pipelining()
1853    }
1854
1855    fn check_connection(&mut self) -> bool {
1856        self.deref_mut().check_connection()
1857    }
1858
1859    fn is_open(&self) -> bool {
1860        self.deref().is_open()
1861    }
1862}
1863
1864/// The pubsub object provides convenient access to the redis pubsub
1865/// system.  Once created you can subscribe and unsubscribe from channels
1866/// and listen in on messages.
1867///
1868/// Example:
1869///
1870/// ```rust,no_run
1871/// # fn do_something() -> redis::RedisResult<()> {
1872/// let client = redis::Client::open("redis://127.0.0.1/")?;
1873/// let mut con = client.get_connection()?;
1874/// let mut pubsub = con.as_pubsub();
1875/// pubsub.subscribe("channel_1")?;
1876/// pubsub.subscribe("channel_2")?;
1877///
1878/// loop {
1879///     let msg = pubsub.get_message()?;
1880///     let payload : String = msg.get_payload()?;
1881///     println!("channel '{}': {}", msg.get_channel_name(), payload);
1882/// }
1883/// # }
1884/// ```
1885impl<'a> PubSub<'a> {
1886    fn new(con: &'a mut Connection) -> Self {
1887        Self {
1888            con,
1889            waiting_messages: VecDeque::new(),
1890        }
1891    }
1892
1893    fn cache_messages_until_received_response(
1894        &mut self,
1895        cmd: &mut Cmd,
1896        is_sub_unsub: bool,
1897    ) -> RedisResult<Value> {
1898        let ignore_response = self.con.protocol != ProtocolVersion::RESP2 && is_sub_unsub;
1899        cmd.set_no_response(ignore_response);
1900
1901        self.con.send_packed_command(&cmd.get_packed_command())?;
1902
1903        loop {
1904            let response = self.con.recv_response()?;
1905            if let Some(msg) = Msg::from_value(&response) {
1906                self.waiting_messages.push_back(msg);
1907            } else {
1908                return Ok(response);
1909            }
1910        }
1911    }
1912
1913    /// Subscribes to a new channel(s).    
1914    pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1915        self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
1916        Ok(())
1917    }
1918
1919    /// Subscribes to new channel(s) with pattern(s).
1920    pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1921        self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
1922        Ok(())
1923    }
1924
1925    /// Unsubscribes from a channel(s).
1926    pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1927        self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
1928        Ok(())
1929    }
1930
1931    /// Unsubscribes from channel pattern(s).
1932    pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1933        self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
1934        Ok(())
1935    }
1936
1937    /// Sends a ping with a message to the server
1938    pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
1939        from_owned_redis_value(
1940            self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
1941        )
1942    }
1943    /// Sends a ping to the server
1944    pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
1945        from_owned_redis_value(
1946            self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
1947        )
1948    }
1949
1950    /// Fetches the next message from the pubsub connection.  Blocks until
1951    /// a message becomes available.  This currently does not provide a
1952    /// wait not to block :(
1953    ///
1954    /// The message itself is still generic and can be converted into an
1955    /// appropriate type through the helper methods on it.
1956    pub fn get_message(&mut self) -> RedisResult<Msg> {
1957        if let Some(msg) = self.waiting_messages.pop_front() {
1958            return Ok(msg);
1959        }
1960        loop {
1961            if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
1962                return Ok(msg);
1963            } else {
1964                continue;
1965            }
1966        }
1967    }
1968
1969    /// Sets the read timeout for the connection.
1970    ///
1971    /// If the provided value is `None`, then `get_message` call will
1972    /// block indefinitely. It is an error to pass the zero `Duration` to this
1973    /// method.
1974    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1975        self.con.set_read_timeout(dur)
1976    }
1977}
1978
1979impl Drop for PubSub<'_> {
1980    fn drop(&mut self) {
1981        let _ = self.con.exit_pubsub();
1982    }
1983}
1984
1985/// This holds the data that comes from listening to a pubsub
1986/// connection.  It only contains actual message data.
1987impl Msg {
1988    /// Tries to convert provided [`Value`] into [`Msg`].
1989    pub fn from_value(value: &Value) -> Option<Self> {
1990        Self::from_owned_value(value.clone())
1991    }
1992
1993    /// Tries to convert provided [`Value`] into [`Msg`].
1994    pub fn from_owned_value(value: Value) -> Option<Self> {
1995        let mut pattern = None;
1996        let payload;
1997        let channel;
1998
1999        if let Value::Push { kind, data } = value {
2000            return Self::from_push_info(PushInfo { kind, data });
2001        } else {
2002            let raw_msg: Vec<Value> = from_owned_redis_value(value).ok()?;
2003            let mut iter = raw_msg.into_iter();
2004            let msg_type: String = from_owned_redis_value(iter.next()?).ok()?;
2005            if msg_type == "message" {
2006                channel = iter.next()?;
2007                payload = iter.next()?;
2008            } else if msg_type == "pmessage" {
2009                pattern = Some(iter.next()?);
2010                channel = iter.next()?;
2011                payload = iter.next()?;
2012            } else {
2013                return None;
2014            }
2015        };
2016        Some(Msg {
2017            payload,
2018            channel,
2019            pattern,
2020        })
2021    }
2022
2023    /// Tries to convert provided [`PushInfo`] into [`Msg`].
2024    pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
2025        let mut pattern = None;
2026        let payload;
2027        let channel;
2028
2029        let mut iter = push_info.data.into_iter();
2030        if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
2031            channel = iter.next()?;
2032            payload = iter.next()?;
2033        } else if push_info.kind == PushKind::PMessage {
2034            pattern = Some(iter.next()?);
2035            channel = iter.next()?;
2036            payload = iter.next()?;
2037        } else {
2038            return None;
2039        }
2040
2041        Some(Msg {
2042            payload,
2043            channel,
2044            pattern,
2045        })
2046    }
2047
2048    /// Returns the channel this message came on.
2049    pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
2050        from_redis_value(&self.channel)
2051    }
2052
2053    /// Convenience method to get a string version of the channel.  Unless
2054    /// your channel contains non utf-8 bytes you can always use this
2055    /// method.  If the channel is not a valid string (which really should
2056    /// not happen) then the return value is `"?"`.
2057    pub fn get_channel_name(&self) -> &str {
2058        match self.channel {
2059            Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
2060            _ => "?",
2061        }
2062    }
2063
2064    /// Returns the message's payload in a specific format.
2065    pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
2066        from_redis_value(&self.payload)
2067    }
2068
2069    /// Returns the bytes that are the message's payload.  This can be used
2070    /// as an alternative to the `get_payload` function if you are interested
2071    /// in the raw bytes in it.
2072    pub fn get_payload_bytes(&self) -> &[u8] {
2073        match self.payload {
2074            Value::BulkString(ref bytes) => bytes,
2075            _ => b"",
2076        }
2077    }
2078
2079    /// Returns true if the message was constructed from a pattern
2080    /// subscription.
2081    #[allow(clippy::wrong_self_convention)]
2082    pub fn from_pattern(&self) -> bool {
2083        self.pattern.is_some()
2084    }
2085
2086    /// If the message was constructed from a message pattern this can be
2087    /// used to find out which one.  It's recommended to match against
2088    /// an `Option<String>` so that you do not need to use `from_pattern`
2089    /// to figure out if a pattern was set.
2090    pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
2091        match self.pattern {
2092            None => from_redis_value(&Value::Nil),
2093            Some(ref x) => from_redis_value(x),
2094        }
2095    }
2096}
2097
2098/// This function simplifies transaction management slightly.  What it
2099/// does is automatically watching keys and then going into a transaction
2100/// loop util it succeeds.  Once it goes through the results are
2101/// returned.
2102///
2103/// To use the transaction two pieces of information are needed: a list
2104/// of all the keys that need to be watched for modifications and a
2105/// closure with the code that should be execute in the context of the
2106/// transaction.  The closure is invoked with a fresh pipeline in atomic
2107/// mode.  To use the transaction the function needs to return the result
2108/// from querying the pipeline with the connection.
2109///
2110/// The end result of the transaction is then available as the return
2111/// value from the function call.
2112///
2113/// Example:
2114///
2115/// ```rust,no_run
2116/// use redis::Commands;
2117/// # fn do_something() -> redis::RedisResult<()> {
2118/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
2119/// # let mut con = client.get_connection().unwrap();
2120/// let key = "the_key";
2121/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| {
2122///     let old_val : isize = con.get(key)?;
2123///     pipe
2124///         .set(key, old_val + 1).ignore()
2125///         .get(key).query(con)
2126/// })?;
2127/// println!("The incremented number is: {}", new_val);
2128/// # Ok(()) }
2129/// ```
2130pub fn transaction<
2131    C: ConnectionLike,
2132    K: ToRedisArgs,
2133    T,
2134    F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
2135>(
2136    con: &mut C,
2137    keys: &[K],
2138    func: F,
2139) -> RedisResult<T> {
2140    let mut func = func;
2141    loop {
2142        cmd("WATCH").arg(keys).exec(con)?;
2143        let mut p = pipe();
2144        let response: Option<T> = func(con, p.atomic())?;
2145        match response {
2146            None => {
2147                continue;
2148            }
2149            Some(response) => {
2150                // make sure no watch is left in the connection, even if
2151                // someone forgot to use the pipeline.
2152                cmd("UNWATCH").exec(con)?;
2153                return Ok(response);
2154            }
2155        }
2156    }
2157}
2158//TODO: for both clearing logic support sharded channels.
2159
2160/// Common logic for clearing subscriptions in RESP2 async/sync
2161pub fn resp2_is_pub_sub_state_cleared(
2162    received_unsub: &mut bool,
2163    received_punsub: &mut bool,
2164    kind: &[u8],
2165    num: isize,
2166) -> bool {
2167    match kind.first() {
2168        Some(&b'u') => *received_unsub = true,
2169        Some(&b'p') => *received_punsub = true,
2170        _ => (),
2171    };
2172    *received_unsub && *received_punsub && num == 0
2173}
2174
2175/// Common logic for clearing subscriptions in RESP3 async/sync
2176pub fn resp3_is_pub_sub_state_cleared(
2177    received_unsub: &mut bool,
2178    received_punsub: &mut bool,
2179    kind: &PushKind,
2180    num: isize,
2181) -> bool {
2182    match kind {
2183        PushKind::Unsubscribe => *received_unsub = true,
2184        PushKind::PUnsubscribe => *received_punsub = true,
2185        _ => (),
2186    };
2187    *received_unsub && *received_punsub && num == 0
2188}
2189
2190pub fn no_sub_err_is_pub_sub_state_cleared(
2191    received_unsub: &mut bool,
2192    received_punsub: &mut bool,
2193    err: &ServerError,
2194) -> bool {
2195    let details = err.details();
2196    *received_unsub = *received_unsub
2197        || details
2198            .map(|details| details.starts_with("'unsub"))
2199            .unwrap_or_default();
2200    *received_punsub = *received_punsub
2201        || details
2202            .map(|details| details.starts_with("'punsub"))
2203            .unwrap_or_default();
2204    *received_unsub && *received_punsub
2205}
2206
2207/// Common logic for checking real cause of hello3 command error
2208pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
2209    if let Some(detail) = err.detail() {
2210        if detail.starts_with("unknown command `HELLO`") {
2211            return (
2212                ErrorKind::RESP3NotSupported,
2213                "Redis Server doesn't support HELLO command therefore resp3 cannot be used",
2214            )
2215                .into();
2216        }
2217    }
2218    err
2219}
2220
2221#[cfg(test)]
2222mod tests {
2223    use super::*;
2224
2225    #[test]
2226    fn test_parse_redis_url() {
2227        let cases = vec![
2228            ("redis://127.0.0.1", true),
2229            ("redis://[::1]", true),
2230            ("rediss://127.0.0.1", true),
2231            ("rediss://[::1]", true),
2232            ("valkey://127.0.0.1", true),
2233            ("valkey://[::1]", true),
2234            ("valkeys://127.0.0.1", true),
2235            ("valkeys://[::1]", true),
2236            ("redis+unix:///run/redis.sock", true),
2237            ("valkey+unix:///run/valkey.sock", true),
2238            ("unix:///run/redis.sock", true),
2239            ("http://127.0.0.1", false),
2240            ("tcp://127.0.0.1", false),
2241        ];
2242        for (url, expected) in cases.into_iter() {
2243            let res = parse_redis_url(url);
2244            assert_eq!(
2245                res.is_some(),
2246                expected,
2247                "Parsed result of `{url}` is not expected",
2248            );
2249        }
2250    }
2251
2252    #[test]
2253    fn test_url_to_tcp_connection_info() {
2254        let cases = vec![
2255            (
2256                url::Url::parse("redis://127.0.0.1").unwrap(),
2257                ConnectionInfo {
2258                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2259                    redis: Default::default(),
2260                },
2261            ),
2262            (
2263                url::Url::parse("redis://[::1]").unwrap(),
2264                ConnectionInfo {
2265                    addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
2266                    redis: Default::default(),
2267                },
2268            ),
2269            (
2270                url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
2271                ConnectionInfo {
2272                    addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
2273                    redis: RedisConnectionInfo {
2274                        db: 2,
2275                        username: Some("%johndoe%".to_string()),
2276                        password: Some("#@<>$".to_string()),
2277                        ..Default::default()
2278                    },
2279                },
2280            ),
2281            (
2282                url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
2283                ConnectionInfo {
2284                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2285                    redis: Default::default(),
2286                },
2287            ),
2288            (
2289                url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
2290                ConnectionInfo {
2291                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2292                    redis: RedisConnectionInfo {
2293                        protocol: ProtocolVersion::RESP3,
2294                        ..Default::default()
2295                    },
2296                },
2297            ),
2298        ];
2299        for (url, expected) in cases.into_iter() {
2300            let res = url_to_tcp_connection_info(url.clone()).unwrap();
2301            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2302            assert_eq!(
2303                res.redis.db, expected.redis.db,
2304                "db of {url} is not expected",
2305            );
2306            assert_eq!(
2307                res.redis.username, expected.redis.username,
2308                "username of {url} is not expected",
2309            );
2310            assert_eq!(
2311                res.redis.password, expected.redis.password,
2312                "password of {url} is not expected",
2313            );
2314        }
2315    }
2316
2317    #[test]
2318    fn test_url_to_tcp_connection_info_failed() {
2319        let cases = vec![
2320            (
2321                url::Url::parse("redis://").unwrap(),
2322                "Missing hostname",
2323                None,
2324            ),
2325            (
2326                url::Url::parse("redis://127.0.0.1/db").unwrap(),
2327                "Invalid database number",
2328                None,
2329            ),
2330            (
2331                url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
2332                "Username is not valid UTF-8 string",
2333                None,
2334            ),
2335            (
2336                url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
2337                "Password is not valid UTF-8 string",
2338                None,
2339            ),
2340            (
2341                url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
2342                "Invalid protocol version",
2343                Some("4"),
2344            ),
2345        ];
2346        for (url, expected, detail) in cases.into_iter() {
2347            let res = url_to_tcp_connection_info(url).unwrap_err();
2348            assert_eq!(
2349                res.kind(),
2350                crate::ErrorKind::InvalidClientConfig,
2351                "{}",
2352                &res,
2353            );
2354            #[allow(deprecated)]
2355            let desc = std::error::Error::description(&res);
2356            assert_eq!(desc, expected, "{}", &res);
2357            assert_eq!(res.detail(), detail, "{}", &res);
2358        }
2359    }
2360
2361    #[test]
2362    #[cfg(unix)]
2363    fn test_url_to_unix_connection_info() {
2364        let cases = vec![
2365            (
2366                url::Url::parse("unix:///var/run/redis.sock").unwrap(),
2367                ConnectionInfo {
2368                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2369                    redis: RedisConnectionInfo {
2370                        db: 0,
2371                        username: None,
2372                        password: None,
2373                        protocol: ProtocolVersion::RESP2,
2374                    },
2375                },
2376            ),
2377            (
2378                url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
2379                ConnectionInfo {
2380                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2381                    redis: RedisConnectionInfo {
2382                        db: 1,
2383                        ..Default::default()
2384                    },
2385                },
2386            ),
2387            (
2388                url::Url::parse(
2389                    "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
2390                )
2391                .unwrap(),
2392                ConnectionInfo {
2393                    addr: ConnectionAddr::Unix("/example.sock".into()),
2394                    redis: RedisConnectionInfo {
2395                        db: 2,
2396                        username: Some("%johndoe%".to_string()),
2397                        password: Some("#@<>$".to_string()),
2398                        ..Default::default()
2399                    },
2400                },
2401            ),
2402            (
2403                url::Url::parse(
2404                    "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
2405                )
2406                .unwrap(),
2407                ConnectionInfo {
2408                    addr: ConnectionAddr::Unix("/example.sock".into()),
2409                    redis: RedisConnectionInfo {
2410                        db: 2,
2411                        username: Some("%johndoe%".to_string()),
2412                        password: Some("&?= *+".to_string()),
2413                        ..Default::default()
2414                    },
2415                },
2416            ),
2417            (
2418                url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
2419                ConnectionInfo {
2420                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2421                    redis: RedisConnectionInfo {
2422                        protocol: ProtocolVersion::RESP3,
2423                        ..Default::default()
2424                    },
2425                },
2426            ),
2427        ];
2428        for (url, expected) in cases.into_iter() {
2429            assert_eq!(
2430                ConnectionAddr::Unix(url.to_file_path().unwrap()),
2431                expected.addr,
2432                "addr of {url} is not expected",
2433            );
2434            let res = url_to_unix_connection_info(url.clone()).unwrap();
2435            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2436            assert_eq!(
2437                res.redis.db, expected.redis.db,
2438                "db of {url} is not expected",
2439            );
2440            assert_eq!(
2441                res.redis.username, expected.redis.username,
2442                "username of {url} is not expected",
2443            );
2444            assert_eq!(
2445                res.redis.password, expected.redis.password,
2446                "password of {url} is not expected",
2447            );
2448        }
2449    }
2450}