1use std::borrow::Cow;
2use std::collections::VecDeque;
3use std::fmt;
4use std::io::{self, Write};
5use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
6use std::ops::DerefMut;
7use std::path::PathBuf;
8use std::str::{from_utf8, FromStr};
9use std::time::{Duration, Instant};
10
11use crate::cmd::{cmd, pipe, Cmd};
12use crate::io::tcp::{stream_with_settings, TcpSettings};
13use crate::parser::Parser;
14use crate::pipeline::Pipeline;
15use crate::types::{
16 from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult,
17 ServerError, ServerErrorKind, SyncPushSender, ToRedisArgs, Value,
18};
19use crate::{from_owned_redis_value, ProtocolVersion};
20
21#[cfg(unix)]
22use std::os::unix::net::UnixStream;
23
24use crate::commands::resp3_hello;
25#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
26use native_tls::{TlsConnector, TlsStream};
27
28#[cfg(feature = "tls-rustls")]
29use rustls::{RootCertStore, StreamOwned};
30#[cfg(feature = "tls-rustls")]
31use std::sync::Arc;
32
33use crate::PushInfo;
34
35#[cfg(all(
36 feature = "tls-rustls",
37 not(feature = "tls-native-tls"),
38 not(feature = "tls-rustls-webpki-roots")
39))]
40use rustls_native_certs::load_native_certs;
41
42#[cfg(feature = "tls-rustls")]
43use crate::tls::ClientTlsParams;
44
45#[derive(Clone, Debug)]
47#[non_exhaustive]
48pub struct TlsConnParams {
49 #[cfg(feature = "tls-rustls")]
50 pub(crate) client_tls_params: Option<ClientTlsParams>,
51 #[cfg(feature = "tls-rustls")]
52 pub(crate) root_cert_store: Option<RootCertStore>,
53 #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
54 pub(crate) danger_accept_invalid_hostnames: bool,
55}
56
57static DEFAULT_PORT: u16 = 6379;
58
59#[inline(always)]
60fn connect_tcp(addr: (&str, u16)) -> io::Result<TcpStream> {
61 let socket = TcpStream::connect(addr)?;
62 stream_with_settings(socket, &TcpSettings::default())
63}
64
65#[inline(always)]
66fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
67 let socket = TcpStream::connect_timeout(addr, timeout)?;
68 stream_with_settings(socket, &TcpSettings::default())
69}
70
71pub fn parse_redis_url(input: &str) -> Option<url::Url> {
76 match url::Url::parse(input) {
77 Ok(result) => match result.scheme() {
78 "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
79 Some(result)
80 }
81 _ => None,
82 },
83 Err(_) => None,
84 }
85}
86
87#[derive(Clone, Copy, PartialEq)]
91pub enum TlsMode {
92 Secure,
94 Insecure,
96}
97
98#[derive(Clone, Debug)]
104pub enum ConnectionAddr {
105 Tcp(String, u16),
107 TcpTls {
109 host: String,
111 port: u16,
113 insecure: bool,
122
123 tls_params: Option<TlsConnParams>,
125 },
126 Unix(PathBuf),
128}
129
130impl PartialEq for ConnectionAddr {
131 fn eq(&self, other: &Self) -> bool {
132 match (self, other) {
133 (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
134 host1 == host2 && port1 == port2
135 }
136 (
137 ConnectionAddr::TcpTls {
138 host: host1,
139 port: port1,
140 insecure: insecure1,
141 tls_params: _,
142 },
143 ConnectionAddr::TcpTls {
144 host: host2,
145 port: port2,
146 insecure: insecure2,
147 tls_params: _,
148 },
149 ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
150 (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
151 _ => false,
152 }
153 }
154}
155
156impl Eq for ConnectionAddr {}
157
158impl ConnectionAddr {
159 pub fn is_supported(&self) -> bool {
170 match *self {
171 ConnectionAddr::Tcp(_, _) => true,
172 ConnectionAddr::TcpTls { .. } => {
173 cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
174 }
175 ConnectionAddr::Unix(_) => cfg!(unix),
176 }
177 }
178
179 #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
188 pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
189 if let ConnectionAddr::TcpTls { tls_params, .. } = self {
190 if let Some(ref mut params) = tls_params {
191 params.danger_accept_invalid_hostnames = insecure;
192 } else if insecure {
193 *tls_params = Some(TlsConnParams {
194 #[cfg(feature = "tls-rustls")]
195 client_tls_params: None,
196 #[cfg(feature = "tls-rustls")]
197 root_cert_store: None,
198 danger_accept_invalid_hostnames: insecure,
199 });
200 }
201 }
202 }
203
204 #[cfg(feature = "cluster")]
205 pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
206 match self {
207 ConnectionAddr::TcpTls { insecure, .. } => {
208 if *insecure {
209 Some(TlsMode::Insecure)
210 } else {
211 Some(TlsMode::Secure)
212 }
213 }
214 _ => None,
215 }
216 }
217}
218
219impl fmt::Display for ConnectionAddr {
220 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221 match *self {
223 ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
224 ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
225 ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
226 }
227 }
228}
229
230#[derive(Clone, Debug)]
232pub struct ConnectionInfo {
233 pub addr: ConnectionAddr,
235
236 pub redis: RedisConnectionInfo,
238}
239
240#[derive(Clone, Debug, Default)]
242pub struct RedisConnectionInfo {
243 pub db: i64,
245 pub username: Option<String>,
247 pub password: Option<String>,
249 pub protocol: ProtocolVersion,
251}
252
253impl FromStr for ConnectionInfo {
254 type Err = RedisError;
255
256 fn from_str(s: &str) -> Result<Self, Self::Err> {
257 s.into_connection_info()
258 }
259}
260
261pub trait IntoConnectionInfo {
265 fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
267}
268
269impl IntoConnectionInfo for ConnectionInfo {
270 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
271 Ok(self)
272 }
273}
274
275impl IntoConnectionInfo for &str {
285 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
286 match parse_redis_url(self) {
287 Some(u) => u.into_connection_info(),
288 None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
289 }
290 }
291}
292
293impl<T> IntoConnectionInfo for (T, u16)
294where
295 T: Into<String>,
296{
297 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
298 Ok(ConnectionInfo {
299 addr: ConnectionAddr::Tcp(self.0.into(), self.1),
300 redis: RedisConnectionInfo::default(),
301 })
302 }
303}
304
305impl IntoConnectionInfo for String {
315 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
316 match parse_redis_url(&self) {
317 Some(u) => u.into_connection_info(),
318 None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
319 }
320 }
321}
322
323fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
324 Ok(match query.get("protocol") {
325 Some(protocol) => {
326 if protocol == "2" || protocol == "resp2" {
327 ProtocolVersion::RESP2
328 } else if protocol == "3" || protocol == "resp3" {
329 ProtocolVersion::RESP3
330 } else {
331 fail!((
332 ErrorKind::InvalidClientConfig,
333 "Invalid protocol version",
334 protocol.to_string()
335 ))
336 }
337 }
338 None => ProtocolVersion::RESP2,
339 })
340}
341
342fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
343 let host = match url.host() {
344 Some(host) => {
345 match host {
357 url::Host::Domain(path) => path.to_string(),
358 url::Host::Ipv4(v4) => v4.to_string(),
359 url::Host::Ipv6(v6) => v6.to_string(),
360 }
361 }
362 None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
363 };
364 let port = url.port().unwrap_or(DEFAULT_PORT);
365 let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
366 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
367 {
368 match url.fragment() {
369 Some("insecure") => ConnectionAddr::TcpTls {
370 host,
371 port,
372 insecure: true,
373 tls_params: None,
374 },
375 Some(_) => fail!((
376 ErrorKind::InvalidClientConfig,
377 "only #insecure is supported as URL fragment"
378 )),
379 _ => ConnectionAddr::TcpTls {
380 host,
381 port,
382 insecure: false,
383 tls_params: None,
384 },
385 }
386 }
387
388 #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
389 fail!((
390 ErrorKind::InvalidClientConfig,
391 "can't connect with TLS, the feature is not enabled"
392 ));
393 } else {
394 ConnectionAddr::Tcp(host, port)
395 };
396 let query: HashMap<_, _> = url.query_pairs().collect();
397 Ok(ConnectionInfo {
398 addr,
399 redis: RedisConnectionInfo {
400 db: match url.path().trim_matches('/') {
401 "" => 0,
402 path => path.parse::<i64>().map_err(|_| -> RedisError {
403 (ErrorKind::InvalidClientConfig, "Invalid database number").into()
404 })?,
405 },
406 username: if url.username().is_empty() {
407 None
408 } else {
409 match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
410 Ok(decoded) => Some(decoded.into_owned()),
411 Err(_) => fail!((
412 ErrorKind::InvalidClientConfig,
413 "Username is not valid UTF-8 string"
414 )),
415 }
416 },
417 password: match url.password() {
418 Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
419 Ok(decoded) => Some(decoded.into_owned()),
420 Err(_) => fail!((
421 ErrorKind::InvalidClientConfig,
422 "Password is not valid UTF-8 string"
423 )),
424 },
425 None => None,
426 },
427 protocol: parse_protocol(&query)?,
428 },
429 })
430}
431
432#[cfg(unix)]
433fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
434 let query: HashMap<_, _> = url.query_pairs().collect();
435 Ok(ConnectionInfo {
436 addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
437 (ErrorKind::InvalidClientConfig, "Missing path").into()
438 })?),
439 redis: RedisConnectionInfo {
440 db: match query.get("db") {
441 Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
442 (ErrorKind::InvalidClientConfig, "Invalid database number").into()
443 })?,
444
445 None => 0,
446 },
447 username: query.get("user").map(|username| username.to_string()),
448 password: query.get("pass").map(|password| password.to_string()),
449 protocol: parse_protocol(&query)?,
450 },
451 })
452}
453
454#[cfg(not(unix))]
455fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
456 fail!((
457 ErrorKind::InvalidClientConfig,
458 "Unix sockets are not available on this platform."
459 ));
460}
461
462impl IntoConnectionInfo for url::Url {
463 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
464 match self.scheme() {
465 "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
466 "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
467 _ => fail!((
468 ErrorKind::InvalidClientConfig,
469 "URL provided is not a redis URL"
470 )),
471 }
472 }
473}
474
475struct TcpConnection {
476 reader: TcpStream,
477 open: bool,
478}
479
480#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
481struct TcpNativeTlsConnection {
482 reader: TlsStream<TcpStream>,
483 open: bool,
484}
485
486#[cfg(feature = "tls-rustls")]
487struct TcpRustlsConnection {
488 reader: StreamOwned<rustls::ClientConnection, TcpStream>,
489 open: bool,
490}
491
492#[cfg(unix)]
493struct UnixConnection {
494 sock: UnixStream,
495 open: bool,
496}
497
498enum ActualConnection {
499 Tcp(TcpConnection),
500 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
501 TcpNativeTls(Box<TcpNativeTlsConnection>),
502 #[cfg(feature = "tls-rustls")]
503 TcpRustls(Box<TcpRustlsConnection>),
504 #[cfg(unix)]
505 Unix(UnixConnection),
506}
507
508#[cfg(feature = "tls-rustls-insecure")]
509struct NoCertificateVerification {
510 supported: rustls::crypto::WebPkiSupportedAlgorithms,
511}
512
513#[cfg(feature = "tls-rustls-insecure")]
514impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
515 fn verify_server_cert(
516 &self,
517 _end_entity: &rustls::pki_types::CertificateDer<'_>,
518 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
519 _server_name: &rustls::pki_types::ServerName<'_>,
520 _ocsp_response: &[u8],
521 _now: rustls::pki_types::UnixTime,
522 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
523 Ok(rustls::client::danger::ServerCertVerified::assertion())
524 }
525
526 fn verify_tls12_signature(
527 &self,
528 _message: &[u8],
529 _cert: &rustls::pki_types::CertificateDer<'_>,
530 _dss: &rustls::DigitallySignedStruct,
531 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
532 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
533 }
534
535 fn verify_tls13_signature(
536 &self,
537 _message: &[u8],
538 _cert: &rustls::pki_types::CertificateDer<'_>,
539 _dss: &rustls::DigitallySignedStruct,
540 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
541 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
542 }
543
544 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
545 self.supported.supported_schemes()
546 }
547}
548
549#[cfg(feature = "tls-rustls-insecure")]
550impl fmt::Debug for NoCertificateVerification {
551 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552 f.debug_struct("NoCertificateVerification").finish()
553 }
554}
555
556#[cfg(feature = "tls-rustls-insecure")]
558#[derive(Debug)]
559struct AcceptInvalidHostnamesCertVerifier {
560 inner: Arc<rustls::client::WebPkiServerVerifier>,
561}
562
563#[cfg(feature = "tls-rustls-insecure")]
564fn is_hostname_error(err: &rustls::Error) -> bool {
565 matches!(
566 err,
567 rustls::Error::InvalidCertificate(
568 rustls::CertificateError::NotValidForName
569 | rustls::CertificateError::NotValidForNameContext { .. }
570 )
571 )
572}
573
574#[cfg(feature = "tls-rustls-insecure")]
575impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
576 fn verify_server_cert(
577 &self,
578 end_entity: &rustls::pki_types::CertificateDer<'_>,
579 intermediates: &[rustls::pki_types::CertificateDer<'_>],
580 server_name: &rustls::pki_types::ServerName<'_>,
581 ocsp_response: &[u8],
582 now: rustls::pki_types::UnixTime,
583 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
584 self.inner
585 .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
586 .or_else(|err| {
587 if is_hostname_error(&err) {
588 Ok(rustls::client::danger::ServerCertVerified::assertion())
589 } else {
590 Err(err)
591 }
592 })
593 }
594
595 fn verify_tls12_signature(
596 &self,
597 message: &[u8],
598 cert: &rustls::pki_types::CertificateDer<'_>,
599 dss: &rustls::DigitallySignedStruct,
600 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
601 self.inner
602 .verify_tls12_signature(message, cert, dss)
603 .or_else(|err| {
604 if is_hostname_error(&err) {
605 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
606 } else {
607 Err(err)
608 }
609 })
610 }
611
612 fn verify_tls13_signature(
613 &self,
614 message: &[u8],
615 cert: &rustls::pki_types::CertificateDer<'_>,
616 dss: &rustls::DigitallySignedStruct,
617 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
618 self.inner
619 .verify_tls13_signature(message, cert, dss)
620 .or_else(|err| {
621 if is_hostname_error(&err) {
622 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
623 } else {
624 Err(err)
625 }
626 })
627 }
628
629 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
630 self.inner.supported_verify_schemes()
631 }
632}
633
634pub struct Connection {
636 con: ActualConnection,
637 parser: Parser,
638 db: i64,
639
640 pubsub: bool,
645
646 protocol: ProtocolVersion,
648
649 push_sender: Option<SyncPushSender>,
651
652 messages_to_skip: usize,
655}
656
657pub struct PubSub<'a> {
659 con: &'a mut Connection,
660 waiting_messages: VecDeque<Msg>,
661}
662
663#[derive(Debug, Clone)]
665pub struct Msg {
666 payload: Value,
667 channel: Value,
668 pattern: Option<Value>,
669}
670
671impl ActualConnection {
672 pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
673 Ok(match *addr {
674 ConnectionAddr::Tcp(ref host, ref port) => {
675 let addr = (host.as_str(), *port);
676 let tcp = match timeout {
677 None => connect_tcp(addr)?,
678 Some(timeout) => {
679 let mut tcp = None;
680 let mut last_error = None;
681 for addr in addr.to_socket_addrs()? {
682 match connect_tcp_timeout(&addr, timeout) {
683 Ok(l) => {
684 tcp = Some(l);
685 break;
686 }
687 Err(e) => {
688 last_error = Some(e);
689 }
690 };
691 }
692 match (tcp, last_error) {
693 (Some(tcp), _) => tcp,
694 (None, Some(e)) => {
695 fail!(e);
696 }
697 (None, None) => {
698 fail!((
699 ErrorKind::InvalidClientConfig,
700 "could not resolve to any addresses"
701 ));
702 }
703 }
704 }
705 };
706 ActualConnection::Tcp(TcpConnection {
707 reader: tcp,
708 open: true,
709 })
710 }
711 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
712 ConnectionAddr::TcpTls {
713 ref host,
714 port,
715 insecure,
716 ref tls_params,
717 } => {
718 let tls_connector = if insecure {
719 TlsConnector::builder()
720 .danger_accept_invalid_certs(true)
721 .danger_accept_invalid_hostnames(true)
722 .use_sni(false)
723 .build()?
724 } else if let Some(params) = tls_params {
725 TlsConnector::builder()
726 .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
727 .build()?
728 } else {
729 TlsConnector::new()?
730 };
731 let addr = (host.as_str(), port);
732 let tls = match timeout {
733 None => {
734 let tcp = connect_tcp(addr)?;
735 match tls_connector.connect(host, tcp) {
736 Ok(res) => res,
737 Err(e) => {
738 fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
739 }
740 }
741 }
742 Some(timeout) => {
743 let mut tcp = None;
744 let mut last_error = None;
745 for addr in (host.as_str(), port).to_socket_addrs()? {
746 match connect_tcp_timeout(&addr, timeout) {
747 Ok(l) => {
748 tcp = Some(l);
749 break;
750 }
751 Err(e) => {
752 last_error = Some(e);
753 }
754 };
755 }
756 match (tcp, last_error) {
757 (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
758 (None, Some(e)) => {
759 fail!(e);
760 }
761 (None, None) => {
762 fail!((
763 ErrorKind::InvalidClientConfig,
764 "could not resolve to any addresses"
765 ));
766 }
767 }
768 }
769 };
770 ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
771 reader: tls,
772 open: true,
773 }))
774 }
775 #[cfg(feature = "tls-rustls")]
776 ConnectionAddr::TcpTls {
777 ref host,
778 port,
779 insecure,
780 ref tls_params,
781 } => {
782 let host: &str = host;
783 let config = create_rustls_config(insecure, tls_params.clone())?;
784 let conn = rustls::ClientConnection::new(
785 Arc::new(config),
786 rustls::pki_types::ServerName::try_from(host)?.to_owned(),
787 )?;
788 let reader = match timeout {
789 None => {
790 let tcp = connect_tcp((host, port))?;
791 StreamOwned::new(conn, tcp)
792 }
793 Some(timeout) => {
794 let mut tcp = None;
795 let mut last_error = None;
796 for addr in (host, port).to_socket_addrs()? {
797 match connect_tcp_timeout(&addr, timeout) {
798 Ok(l) => {
799 tcp = Some(l);
800 break;
801 }
802 Err(e) => {
803 last_error = Some(e);
804 }
805 };
806 }
807 match (tcp, last_error) {
808 (Some(tcp), _) => StreamOwned::new(conn, tcp),
809 (None, Some(e)) => {
810 fail!(e);
811 }
812 (None, None) => {
813 fail!((
814 ErrorKind::InvalidClientConfig,
815 "could not resolve to any addresses"
816 ));
817 }
818 }
819 }
820 };
821
822 ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
823 }
824 #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
825 ConnectionAddr::TcpTls { .. } => {
826 fail!((
827 ErrorKind::InvalidClientConfig,
828 "Cannot connect to TCP with TLS without the tls feature"
829 ));
830 }
831 #[cfg(unix)]
832 ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
833 sock: UnixStream::connect(path)?,
834 open: true,
835 }),
836 #[cfg(not(unix))]
837 ConnectionAddr::Unix(ref _path) => {
838 fail!((
839 ErrorKind::InvalidClientConfig,
840 "Cannot connect to unix sockets \
841 on this platform"
842 ));
843 }
844 })
845 }
846
847 pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
848 match *self {
849 ActualConnection::Tcp(ref mut connection) => {
850 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
851 match res {
852 Err(e) => {
853 if e.is_unrecoverable_error() {
854 connection.open = false;
855 }
856 Err(e)
857 }
858 Ok(_) => Ok(Value::Okay),
859 }
860 }
861 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
862 ActualConnection::TcpNativeTls(ref mut connection) => {
863 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
864 match res {
865 Err(e) => {
866 if e.is_unrecoverable_error() {
867 connection.open = false;
868 }
869 Err(e)
870 }
871 Ok(_) => Ok(Value::Okay),
872 }
873 }
874 #[cfg(feature = "tls-rustls")]
875 ActualConnection::TcpRustls(ref mut connection) => {
876 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
877 match res {
878 Err(e) => {
879 if e.is_unrecoverable_error() {
880 connection.open = false;
881 }
882 Err(e)
883 }
884 Ok(_) => Ok(Value::Okay),
885 }
886 }
887 #[cfg(unix)]
888 ActualConnection::Unix(ref mut connection) => {
889 let result = connection.sock.write_all(bytes).map_err(RedisError::from);
890 match result {
891 Err(e) => {
892 if e.is_unrecoverable_error() {
893 connection.open = false;
894 }
895 Err(e)
896 }
897 Ok(_) => Ok(Value::Okay),
898 }
899 }
900 }
901 }
902
903 pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
904 match *self {
905 ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
906 reader.set_write_timeout(dur)?;
907 }
908 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
909 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
910 let reader = &(boxed_tls_connection.reader);
911 reader.get_ref().set_write_timeout(dur)?;
912 }
913 #[cfg(feature = "tls-rustls")]
914 ActualConnection::TcpRustls(ref boxed_tls_connection) => {
915 let reader = &(boxed_tls_connection.reader);
916 reader.get_ref().set_write_timeout(dur)?;
917 }
918 #[cfg(unix)]
919 ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
920 sock.set_write_timeout(dur)?;
921 }
922 }
923 Ok(())
924 }
925
926 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
927 match *self {
928 ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
929 reader.set_read_timeout(dur)?;
930 }
931 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
932 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
933 let reader = &(boxed_tls_connection.reader);
934 reader.get_ref().set_read_timeout(dur)?;
935 }
936 #[cfg(feature = "tls-rustls")]
937 ActualConnection::TcpRustls(ref boxed_tls_connection) => {
938 let reader = &(boxed_tls_connection.reader);
939 reader.get_ref().set_read_timeout(dur)?;
940 }
941 #[cfg(unix)]
942 ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
943 sock.set_read_timeout(dur)?;
944 }
945 }
946 Ok(())
947 }
948
949 pub fn is_open(&self) -> bool {
950 match *self {
951 ActualConnection::Tcp(TcpConnection { open, .. }) => open,
952 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
953 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
954 #[cfg(feature = "tls-rustls")]
955 ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
956 #[cfg(unix)]
957 ActualConnection::Unix(UnixConnection { open, .. }) => open,
958 }
959 }
960}
961
962#[cfg(feature = "tls-rustls")]
963pub(crate) fn create_rustls_config(
964 insecure: bool,
965 tls_params: Option<TlsConnParams>,
966) -> RedisResult<rustls::ClientConfig> {
967 #[allow(unused_mut)]
968 let mut root_store = RootCertStore::empty();
969 #[cfg(feature = "tls-rustls-webpki-roots")]
970 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
971 #[cfg(all(
972 feature = "tls-rustls",
973 not(feature = "tls-native-tls"),
974 not(feature = "tls-rustls-webpki-roots")
975 ))]
976 {
977 let mut certificate_result = load_native_certs();
978 if let Some(error) = certificate_result.errors.pop() {
979 return Err(error.into());
980 }
981 for cert in certificate_result.certs {
982 root_store.add(cert)?;
983 }
984 }
985
986 let config = rustls::ClientConfig::builder();
987 let config = if let Some(tls_params) = tls_params {
988 let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
989 let config_builder = config.with_root_certificates(root_cert_store.clone());
990
991 let config_builder = if let Some(ClientTlsParams {
992 client_cert_chain: client_cert,
993 client_key,
994 }) = tls_params.client_tls_params
995 {
996 config_builder
997 .with_client_auth_cert(client_cert, client_key)
998 .map_err(|err| {
999 RedisError::from((
1000 ErrorKind::InvalidClientConfig,
1001 "Unable to build client with TLS parameters provided.",
1002 err.to_string(),
1003 ))
1004 })?
1005 } else {
1006 config_builder.with_no_client_auth()
1007 };
1008
1009 #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
1015 let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
1016 #[cfg(not(feature = "tls-rustls-insecure"))]
1017 {
1018 fail!((
1021 ErrorKind::InvalidClientConfig,
1022 "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
1023 ));
1024 }
1025
1026 #[cfg(feature = "tls-rustls-insecure")]
1027 {
1028 let mut config = config_builder;
1029 config.dangerous().set_certificate_verifier(Arc::new(
1030 AcceptInvalidHostnamesCertVerifier {
1031 inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
1032 root_cert_store,
1033 ))
1034 .build()
1035 .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
1036 },
1037 ));
1038 config
1039 }
1040 } else {
1041 config_builder
1042 };
1043
1044 config_builder
1045 } else {
1046 config
1047 .with_root_certificates(root_store)
1048 .with_no_client_auth()
1049 };
1050
1051 match (insecure, cfg!(feature = "tls-rustls-insecure")) {
1052 #[cfg(feature = "tls-rustls-insecure")]
1053 (true, true) => {
1054 let mut config = config;
1055 config.enable_sni = false;
1056 let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
1057 return Err(RedisError::from((
1058 ErrorKind::InvalidClientConfig,
1059 "No crypto provider available for rustls",
1060 )));
1061 };
1062 config
1063 .dangerous()
1064 .set_certificate_verifier(Arc::new(NoCertificateVerification {
1065 supported: crypto_provider.signature_verification_algorithms,
1066 }));
1067
1068 Ok(config)
1069 }
1070 (true, false) => {
1071 fail!((
1072 ErrorKind::InvalidClientConfig,
1073 "Cannot create insecure client without tls-rustls-insecure feature"
1074 ));
1075 }
1076 _ => Ok(config),
1077 }
1078}
1079
1080fn authenticate_cmd(
1081 connection_info: &RedisConnectionInfo,
1082 check_username: bool,
1083 password: &str,
1084) -> Cmd {
1085 let mut command = cmd("AUTH");
1086 if check_username {
1087 if let Some(username) = &connection_info.username {
1088 command.arg(username);
1089 }
1090 }
1091 command.arg(password);
1092 command
1093}
1094
1095pub fn connect(
1096 connection_info: &ConnectionInfo,
1097 timeout: Option<Duration>,
1098) -> RedisResult<Connection> {
1099 let start = Instant::now();
1100 let con: ActualConnection = ActualConnection::new(&connection_info.addr, timeout)?;
1101
1102 let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
1104 if timeout.is_some() && remaining_timeout.is_none() {
1106 return Err(RedisError::from(std::io::Error::new(
1107 std::io::ErrorKind::TimedOut,
1108 "Connection timed out",
1109 )));
1110 }
1111 con.set_read_timeout(remaining_timeout)?;
1112 con.set_write_timeout(remaining_timeout)?;
1113
1114 let con = setup_connection(
1115 con,
1116 &connection_info.redis,
1117 #[cfg(feature = "cache-aio")]
1118 None,
1119 )?;
1120
1121 con.set_read_timeout(None)?;
1123 con.set_write_timeout(None)?;
1124
1125 Ok(con)
1126}
1127
1128pub(crate) struct ConnectionSetupComponents {
1129 resp3_auth_cmd_idx: Option<usize>,
1130 resp2_auth_cmd_idx: Option<usize>,
1131 select_cmd_idx: Option<usize>,
1132 #[cfg(feature = "cache-aio")]
1133 cache_cmd_idx: Option<usize>,
1134}
1135
1136pub(crate) fn connection_setup_pipeline(
1137 connection_info: &RedisConnectionInfo,
1138 check_username: bool,
1139 #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1140) -> (crate::Pipeline, ConnectionSetupComponents) {
1141 let mut last_cmd_index = 0;
1142
1143 let mut get_next_command_index = |condition| {
1144 if condition {
1145 last_cmd_index += 1;
1146 Some(last_cmd_index - 1)
1147 } else {
1148 None
1149 }
1150 };
1151
1152 let authenticate_with_resp3_cmd_index =
1153 get_next_command_index(connection_info.protocol != ProtocolVersion::RESP2);
1154 let authenticate_with_resp2_cmd_index = get_next_command_index(
1155 authenticate_with_resp3_cmd_index.is_none() && connection_info.password.is_some(),
1156 );
1157 let select_db_cmd_index = get_next_command_index(connection_info.db != 0);
1158 #[cfg(feature = "cache-aio")]
1159 let cache_cmd_index = get_next_command_index(
1160 connection_info.protocol != ProtocolVersion::RESP2 && cache_config.is_some(),
1161 );
1162
1163 let mut pipeline = pipe();
1164
1165 if authenticate_with_resp3_cmd_index.is_some() {
1166 pipeline.add_command(resp3_hello(connection_info));
1167 } else if authenticate_with_resp2_cmd_index.is_some() {
1168 pipeline.add_command(authenticate_cmd(
1169 connection_info,
1170 check_username,
1171 connection_info.password.as_ref().unwrap(),
1172 ));
1173 }
1174
1175 if select_db_cmd_index.is_some() {
1176 pipeline.cmd("SELECT").arg(connection_info.db);
1177 }
1178
1179 #[cfg(not(feature = "disable-client-setinfo"))]
1182 pipeline
1183 .cmd("CLIENT")
1184 .arg("SETINFO")
1185 .arg("LIB-NAME")
1186 .arg("redis-rs")
1187 .ignore();
1188 #[cfg(not(feature = "disable-client-setinfo"))]
1189 pipeline
1190 .cmd("CLIENT")
1191 .arg("SETINFO")
1192 .arg("LIB-VER")
1193 .arg(env!("CARGO_PKG_VERSION"))
1194 .ignore();
1195
1196 #[cfg(feature = "cache-aio")]
1197 if cache_cmd_index.is_some() {
1198 let cache_config = cache_config.expect(
1199 "It's expected to have cache_config if cache_cmd_index is Some, please create an issue about this.",
1200 );
1201 pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
1202 match cache_config.mode {
1203 crate::caching::CacheMode::All => {}
1204 crate::caching::CacheMode::OptIn => {
1205 pipeline.arg("OPTIN");
1206 }
1207 }
1208 }
1209
1210 (
1211 pipeline,
1212 ConnectionSetupComponents {
1213 resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
1214 resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
1215 select_cmd_idx: select_db_cmd_index,
1216 #[cfg(feature = "cache-aio")]
1217 cache_cmd_idx: cache_cmd_index,
1218 },
1219 )
1220}
1221
1222fn check_resp3_auth(result: &Value) -> RedisResult<()> {
1223 if let Value::ServerError(err) = result {
1224 return Err(get_resp3_hello_command_error(err.clone().into()));
1225 }
1226 Ok(())
1227}
1228
1229#[derive(PartialEq)]
1230pub(crate) enum AuthResult {
1231 Succeeded,
1232 ShouldRetryWithoutUsername,
1233}
1234
1235fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
1236 let err = match result {
1237 Value::Okay => {
1238 return Ok(AuthResult::Succeeded);
1239 }
1240 Value::ServerError(err) => err,
1241 _ => {
1242 return Err((
1243 ErrorKind::ResponseError,
1244 "Redis server refused to authenticate, returns Ok() != Value::Okay",
1245 )
1246 .into());
1247 }
1248 };
1249
1250 let err_msg = err.details().ok_or((
1251 ErrorKind::AuthenticationFailed,
1252 "Password authentication failed",
1253 ))?;
1254 if !err_msg.contains("wrong number of arguments for 'auth' command") {
1255 return Err((
1256 ErrorKind::AuthenticationFailed,
1257 "Password authentication failed",
1258 )
1259 .into());
1260 }
1261 Ok(AuthResult::ShouldRetryWithoutUsername)
1262}
1263
1264fn check_db_select(value: &Value) -> RedisResult<()> {
1265 let Value::ServerError(err) = value else {
1266 return Ok(());
1267 };
1268
1269 match err.details() {
1270 Some(err_msg) => Err((
1271 ErrorKind::ResponseError,
1272 "Redis server refused to switch database",
1273 err_msg.to_string(),
1274 )
1275 .into()),
1276 None => Err((
1277 ErrorKind::ResponseError,
1278 "Redis server refused to switch database",
1279 )
1280 .into()),
1281 }
1282}
1283
1284#[cfg(feature = "cache-aio")]
1285fn check_caching(result: &Value) -> RedisResult<()> {
1286 match result {
1287 Value::Okay => Ok(()),
1288 _ => Err((
1289 ErrorKind::ResponseError,
1290 "Client-side caching returned unknown response",
1291 )
1292 .into()),
1293 }
1294}
1295
1296pub(crate) fn check_connection_setup(
1297 results: Vec<Value>,
1298 ConnectionSetupComponents {
1299 resp3_auth_cmd_idx,
1300 resp2_auth_cmd_idx,
1301 select_cmd_idx,
1302 #[cfg(feature = "cache-aio")]
1303 cache_cmd_idx,
1304 }: ConnectionSetupComponents,
1305) -> RedisResult<AuthResult> {
1306 assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
1308
1309 if let Some(index) = resp3_auth_cmd_idx {
1310 let Some(value) = results.get(index) else {
1311 return Err((ErrorKind::ClientError, "Missing RESP3 auth response").into());
1312 };
1313 check_resp3_auth(value)?;
1314 } else if let Some(index) = resp2_auth_cmd_idx {
1315 let Some(value) = results.get(index) else {
1316 return Err((ErrorKind::ClientError, "Missing RESP2 auth response").into());
1317 };
1318 if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
1319 return Ok(AuthResult::ShouldRetryWithoutUsername);
1320 }
1321 }
1322
1323 if let Some(index) = select_cmd_idx {
1324 let Some(value) = results.get(index) else {
1325 return Err((ErrorKind::ClientError, "Missing SELECT DB response").into());
1326 };
1327 check_db_select(value)?;
1328 }
1329
1330 #[cfg(feature = "cache-aio")]
1331 if let Some(index) = cache_cmd_idx {
1332 let Some(value) = results.get(index) else {
1333 return Err((ErrorKind::ClientError, "Missing Caching response").into());
1334 };
1335 check_caching(value)?;
1336 }
1337
1338 Ok(AuthResult::Succeeded)
1339}
1340
1341fn execute_connection_pipeline(
1342 rv: &mut Connection,
1343 (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
1344) -> RedisResult<AuthResult> {
1345 if pipeline.is_empty() {
1346 return Ok(AuthResult::Succeeded);
1347 }
1348 let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
1349
1350 check_connection_setup(results, instructions)
1351}
1352
1353fn setup_connection(
1354 con: ActualConnection,
1355 connection_info: &RedisConnectionInfo,
1356 #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1357) -> RedisResult<Connection> {
1358 let mut rv = Connection {
1359 con,
1360 parser: Parser::new(),
1361 db: connection_info.db,
1362 pubsub: false,
1363 protocol: connection_info.protocol,
1364 push_sender: None,
1365 messages_to_skip: 0,
1366 };
1367
1368 if execute_connection_pipeline(
1369 &mut rv,
1370 connection_setup_pipeline(
1371 connection_info,
1372 true,
1373 #[cfg(feature = "cache-aio")]
1374 cache_config,
1375 ),
1376 )? == AuthResult::ShouldRetryWithoutUsername
1377 {
1378 execute_connection_pipeline(
1379 &mut rv,
1380 connection_setup_pipeline(
1381 connection_info,
1382 false,
1383 #[cfg(feature = "cache-aio")]
1384 cache_config,
1385 ),
1386 )?;
1387 }
1388
1389 Ok(rv)
1390}
1391
1392pub trait ConnectionLike {
1404 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
1407
1408 #[doc(hidden)]
1416 fn req_packed_commands(
1417 &mut self,
1418 cmd: &[u8],
1419 offset: usize,
1420 count: usize,
1421 ) -> RedisResult<Vec<Value>>;
1422
1423 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1425 let pcmd = cmd.get_packed_command();
1426 self.req_packed_command(&pcmd)
1427 }
1428
1429 fn get_db(&self) -> i64;
1434
1435 #[doc(hidden)]
1437 fn supports_pipelining(&self) -> bool {
1438 true
1439 }
1440
1441 fn check_connection(&mut self) -> bool;
1443
1444 fn is_open(&self) -> bool;
1452}
1453
1454impl Connection {
1462 pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1467 self.send_bytes(cmd)?;
1468 Ok(())
1469 }
1470
1471 pub fn recv_response(&mut self) -> RedisResult<Value> {
1474 self.read(true)
1475 }
1476
1477 pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1483 self.con.set_write_timeout(dur)
1484 }
1485
1486 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1492 self.con.set_read_timeout(dur)
1493 }
1494
1495 pub fn as_pubsub(&mut self) -> PubSub<'_> {
1497 PubSub::new(self)
1501 }
1502
1503 fn exit_pubsub(&mut self) -> RedisResult<()> {
1504 let res = self.clear_active_subscriptions();
1505 if res.is_ok() {
1506 self.pubsub = false;
1507 } else {
1508 self.pubsub = true;
1510 }
1511
1512 res
1513 }
1514
1515 fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1520 {
1526 let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1528 let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1529
1530 self.send_bytes(&unsubscribe)?;
1532 self.send_bytes(&punsubscribe)?;
1533 }
1534
1535 let mut received_unsub = false;
1541 let mut received_punsub = false;
1542
1543 loop {
1544 let resp = self.recv_response()?;
1545
1546 match resp {
1547 Value::Push { kind, data } => {
1548 if data.len() >= 2 {
1549 if let Value::Int(num) = data[1] {
1550 if resp3_is_pub_sub_state_cleared(
1551 &mut received_unsub,
1552 &mut received_punsub,
1553 &kind,
1554 num as isize,
1555 ) {
1556 break;
1557 }
1558 }
1559 }
1560 }
1561 Value::ServerError(err) => {
1562 if err.kind() == Some(ServerErrorKind::NoSub) {
1565 if no_sub_err_is_pub_sub_state_cleared(
1566 &mut received_unsub,
1567 &mut received_punsub,
1568 &err,
1569 ) {
1570 break;
1571 } else {
1572 continue;
1573 }
1574 }
1575
1576 return Err(err.into());
1577 }
1578 Value::Array(vec) => {
1579 let res: (Vec<u8>, (), isize) = from_owned_redis_value(Value::Array(vec))?;
1580 if resp2_is_pub_sub_state_cleared(
1581 &mut received_unsub,
1582 &mut received_punsub,
1583 &res.0,
1584 res.2,
1585 ) {
1586 break;
1587 }
1588 }
1589 _ => {
1590 return Err((
1591 ErrorKind::ClientError,
1592 "Unexpected unsubscribe response",
1593 format!("{resp:?}"),
1594 )
1595 .into())
1596 }
1597 }
1598 }
1599
1600 Ok(())
1603 }
1604
1605 fn send_push(&self, push: PushInfo) {
1606 if let Some(sender) = &self.push_sender {
1607 let _ = sender.send(push);
1608 }
1609 }
1610
1611 fn try_send(&self, value: &RedisResult<Value>) {
1612 if let Ok(Value::Push { kind, data }) = value {
1613 self.send_push(PushInfo {
1614 kind: kind.clone(),
1615 data: data.clone(),
1616 });
1617 }
1618 }
1619
1620 fn send_disconnect(&self) {
1621 self.send_push(PushInfo::disconnect())
1622 }
1623
1624 fn close_connection(&mut self) {
1625 self.send_disconnect();
1627 match self.con {
1628 ActualConnection::Tcp(ref mut connection) => {
1629 let _ = connection.reader.shutdown(net::Shutdown::Both);
1630 connection.open = false;
1631 }
1632 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1633 ActualConnection::TcpNativeTls(ref mut connection) => {
1634 let _ = connection.reader.shutdown();
1635 connection.open = false;
1636 }
1637 #[cfg(feature = "tls-rustls")]
1638 ActualConnection::TcpRustls(ref mut connection) => {
1639 let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1640 connection.open = false;
1641 }
1642 #[cfg(unix)]
1643 ActualConnection::Unix(ref mut connection) => {
1644 let _ = connection.sock.shutdown(net::Shutdown::Both);
1645 connection.open = false;
1646 }
1647 }
1648 }
1649
1650 fn read(&mut self, is_response: bool) -> RedisResult<Value> {
1653 loop {
1654 let result = match self.con {
1655 ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1656 self.parser.parse_value(reader)
1657 }
1658 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1659 ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1660 let reader = &mut boxed_tls_connection.reader;
1661 self.parser.parse_value(reader)
1662 }
1663 #[cfg(feature = "tls-rustls")]
1664 ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1665 let reader = &mut boxed_tls_connection.reader;
1666 self.parser.parse_value(reader)
1667 }
1668 #[cfg(unix)]
1669 ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1670 self.parser.parse_value(sock)
1671 }
1672 };
1673 self.try_send(&result);
1674
1675 let Err(err) = &result else {
1676 if self.messages_to_skip > 0 {
1677 self.messages_to_skip -= 1;
1678 continue;
1679 }
1680 return result;
1681 };
1682 let Some(io_error) = err.as_io_error() else {
1683 if self.messages_to_skip > 0 {
1684 self.messages_to_skip -= 1;
1685 continue;
1686 }
1687 return result;
1688 };
1689 if io_error.kind() == io::ErrorKind::UnexpectedEof {
1691 self.close_connection();
1692 } else if is_response {
1693 self.messages_to_skip += 1;
1694 }
1695
1696 return result;
1697 }
1698 }
1699
1700 pub fn set_push_sender(&mut self, sender: SyncPushSender) {
1702 self.push_sender = Some(sender);
1703 }
1704
1705 fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1706 let result = self.con.send_bytes(bytes);
1707 if self.protocol != ProtocolVersion::RESP2 {
1708 if let Err(e) = &result {
1709 if e.is_connection_dropped() {
1710 self.send_disconnect();
1711 }
1712 }
1713 }
1714 result
1715 }
1716}
1717
1718impl ConnectionLike for Connection {
1719 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1721 let pcmd = cmd.get_packed_command();
1722 if self.pubsub {
1723 self.exit_pubsub()?;
1724 }
1725
1726 self.send_bytes(&pcmd)?;
1727 if cmd.is_no_response() {
1728 return Ok(Value::Nil);
1729 }
1730 loop {
1731 match self.read(true)? {
1732 Value::Push {
1733 kind: _kind,
1734 data: _data,
1735 } => continue,
1736 val => return Ok(val),
1737 }
1738 }
1739 }
1740 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1741 if self.pubsub {
1742 self.exit_pubsub()?;
1743 }
1744
1745 self.send_bytes(cmd)?;
1746 loop {
1747 match self.read(true)? {
1748 Value::Push {
1749 kind: _kind,
1750 data: _data,
1751 } => continue,
1752 val => return Ok(val),
1753 }
1754 }
1755 }
1756
1757 fn req_packed_commands(
1758 &mut self,
1759 cmd: &[u8],
1760 offset: usize,
1761 count: usize,
1762 ) -> RedisResult<Vec<Value>> {
1763 if self.pubsub {
1764 self.exit_pubsub()?;
1765 }
1766 self.send_bytes(cmd)?;
1767 let mut rv = vec![];
1768 let mut first_err = None;
1769 let mut count = count;
1770 let mut idx = 0;
1771 while idx < (offset + count) {
1772 let response = self.read(true);
1777 match response {
1778 Ok(Value::ServerError(err)) => {
1779 if idx < offset {
1780 if first_err.is_none() {
1781 first_err = Some(err.into());
1782 }
1783 } else {
1784 rv.push(Value::ServerError(err));
1785 }
1786 }
1787 Ok(item) => {
1788 if let Value::Push {
1790 kind: _kind,
1791 data: _data,
1792 } = item
1793 {
1794 count += 1;
1796 } else if idx >= offset {
1797 rv.push(item);
1798 }
1799 }
1800 Err(err) => {
1801 if first_err.is_none() {
1802 first_err = Some(err);
1803 }
1804 }
1805 }
1806 idx += 1;
1807 }
1808
1809 first_err.map_or(Ok(rv), Err)
1810 }
1811
1812 fn get_db(&self) -> i64 {
1813 self.db
1814 }
1815
1816 fn check_connection(&mut self) -> bool {
1817 cmd("PING").query::<String>(self).is_ok()
1818 }
1819
1820 fn is_open(&self) -> bool {
1821 self.con.is_open()
1822 }
1823}
1824
1825impl<C, T> ConnectionLike for T
1826where
1827 C: ConnectionLike,
1828 T: DerefMut<Target = C>,
1829{
1830 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1831 self.deref_mut().req_packed_command(cmd)
1832 }
1833
1834 fn req_packed_commands(
1835 &mut self,
1836 cmd: &[u8],
1837 offset: usize,
1838 count: usize,
1839 ) -> RedisResult<Vec<Value>> {
1840 self.deref_mut().req_packed_commands(cmd, offset, count)
1841 }
1842
1843 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1844 self.deref_mut().req_command(cmd)
1845 }
1846
1847 fn get_db(&self) -> i64 {
1848 self.deref().get_db()
1849 }
1850
1851 fn supports_pipelining(&self) -> bool {
1852 self.deref().supports_pipelining()
1853 }
1854
1855 fn check_connection(&mut self) -> bool {
1856 self.deref_mut().check_connection()
1857 }
1858
1859 fn is_open(&self) -> bool {
1860 self.deref().is_open()
1861 }
1862}
1863
1864impl<'a> PubSub<'a> {
1886 fn new(con: &'a mut Connection) -> Self {
1887 Self {
1888 con,
1889 waiting_messages: VecDeque::new(),
1890 }
1891 }
1892
1893 fn cache_messages_until_received_response(
1894 &mut self,
1895 cmd: &mut Cmd,
1896 is_sub_unsub: bool,
1897 ) -> RedisResult<Value> {
1898 let ignore_response = self.con.protocol != ProtocolVersion::RESP2 && is_sub_unsub;
1899 cmd.set_no_response(ignore_response);
1900
1901 self.con.send_packed_command(&cmd.get_packed_command())?;
1902
1903 loop {
1904 let response = self.con.recv_response()?;
1905 if let Some(msg) = Msg::from_value(&response) {
1906 self.waiting_messages.push_back(msg);
1907 } else {
1908 return Ok(response);
1909 }
1910 }
1911 }
1912
1913 pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1915 self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
1916 Ok(())
1917 }
1918
1919 pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1921 self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
1922 Ok(())
1923 }
1924
1925 pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1927 self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
1928 Ok(())
1929 }
1930
1931 pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1933 self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
1934 Ok(())
1935 }
1936
1937 pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
1939 from_owned_redis_value(
1940 self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
1941 )
1942 }
1943 pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
1945 from_owned_redis_value(
1946 self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
1947 )
1948 }
1949
1950 pub fn get_message(&mut self) -> RedisResult<Msg> {
1957 if let Some(msg) = self.waiting_messages.pop_front() {
1958 return Ok(msg);
1959 }
1960 loop {
1961 if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
1962 return Ok(msg);
1963 } else {
1964 continue;
1965 }
1966 }
1967 }
1968
1969 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1975 self.con.set_read_timeout(dur)
1976 }
1977}
1978
1979impl Drop for PubSub<'_> {
1980 fn drop(&mut self) {
1981 let _ = self.con.exit_pubsub();
1982 }
1983}
1984
1985impl Msg {
1988 pub fn from_value(value: &Value) -> Option<Self> {
1990 Self::from_owned_value(value.clone())
1991 }
1992
1993 pub fn from_owned_value(value: Value) -> Option<Self> {
1995 let mut pattern = None;
1996 let payload;
1997 let channel;
1998
1999 if let Value::Push { kind, data } = value {
2000 return Self::from_push_info(PushInfo { kind, data });
2001 } else {
2002 let raw_msg: Vec<Value> = from_owned_redis_value(value).ok()?;
2003 let mut iter = raw_msg.into_iter();
2004 let msg_type: String = from_owned_redis_value(iter.next()?).ok()?;
2005 if msg_type == "message" {
2006 channel = iter.next()?;
2007 payload = iter.next()?;
2008 } else if msg_type == "pmessage" {
2009 pattern = Some(iter.next()?);
2010 channel = iter.next()?;
2011 payload = iter.next()?;
2012 } else {
2013 return None;
2014 }
2015 };
2016 Some(Msg {
2017 payload,
2018 channel,
2019 pattern,
2020 })
2021 }
2022
2023 pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
2025 let mut pattern = None;
2026 let payload;
2027 let channel;
2028
2029 let mut iter = push_info.data.into_iter();
2030 if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
2031 channel = iter.next()?;
2032 payload = iter.next()?;
2033 } else if push_info.kind == PushKind::PMessage {
2034 pattern = Some(iter.next()?);
2035 channel = iter.next()?;
2036 payload = iter.next()?;
2037 } else {
2038 return None;
2039 }
2040
2041 Some(Msg {
2042 payload,
2043 channel,
2044 pattern,
2045 })
2046 }
2047
2048 pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
2050 from_redis_value(&self.channel)
2051 }
2052
2053 pub fn get_channel_name(&self) -> &str {
2058 match self.channel {
2059 Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
2060 _ => "?",
2061 }
2062 }
2063
2064 pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
2066 from_redis_value(&self.payload)
2067 }
2068
2069 pub fn get_payload_bytes(&self) -> &[u8] {
2073 match self.payload {
2074 Value::BulkString(ref bytes) => bytes,
2075 _ => b"",
2076 }
2077 }
2078
2079 #[allow(clippy::wrong_self_convention)]
2082 pub fn from_pattern(&self) -> bool {
2083 self.pattern.is_some()
2084 }
2085
2086 pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
2091 match self.pattern {
2092 None => from_redis_value(&Value::Nil),
2093 Some(ref x) => from_redis_value(x),
2094 }
2095 }
2096}
2097
2098pub fn transaction<
2131 C: ConnectionLike,
2132 K: ToRedisArgs,
2133 T,
2134 F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
2135>(
2136 con: &mut C,
2137 keys: &[K],
2138 func: F,
2139) -> RedisResult<T> {
2140 let mut func = func;
2141 loop {
2142 cmd("WATCH").arg(keys).exec(con)?;
2143 let mut p = pipe();
2144 let response: Option<T> = func(con, p.atomic())?;
2145 match response {
2146 None => {
2147 continue;
2148 }
2149 Some(response) => {
2150 cmd("UNWATCH").exec(con)?;
2153 return Ok(response);
2154 }
2155 }
2156 }
2157}
2158pub fn resp2_is_pub_sub_state_cleared(
2162 received_unsub: &mut bool,
2163 received_punsub: &mut bool,
2164 kind: &[u8],
2165 num: isize,
2166) -> bool {
2167 match kind.first() {
2168 Some(&b'u') => *received_unsub = true,
2169 Some(&b'p') => *received_punsub = true,
2170 _ => (),
2171 };
2172 *received_unsub && *received_punsub && num == 0
2173}
2174
2175pub fn resp3_is_pub_sub_state_cleared(
2177 received_unsub: &mut bool,
2178 received_punsub: &mut bool,
2179 kind: &PushKind,
2180 num: isize,
2181) -> bool {
2182 match kind {
2183 PushKind::Unsubscribe => *received_unsub = true,
2184 PushKind::PUnsubscribe => *received_punsub = true,
2185 _ => (),
2186 };
2187 *received_unsub && *received_punsub && num == 0
2188}
2189
2190pub fn no_sub_err_is_pub_sub_state_cleared(
2191 received_unsub: &mut bool,
2192 received_punsub: &mut bool,
2193 err: &ServerError,
2194) -> bool {
2195 let details = err.details();
2196 *received_unsub = *received_unsub
2197 || details
2198 .map(|details| details.starts_with("'unsub"))
2199 .unwrap_or_default();
2200 *received_punsub = *received_punsub
2201 || details
2202 .map(|details| details.starts_with("'punsub"))
2203 .unwrap_or_default();
2204 *received_unsub && *received_punsub
2205}
2206
2207pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
2209 if let Some(detail) = err.detail() {
2210 if detail.starts_with("unknown command `HELLO`") {
2211 return (
2212 ErrorKind::RESP3NotSupported,
2213 "Redis Server doesn't support HELLO command therefore resp3 cannot be used",
2214 )
2215 .into();
2216 }
2217 }
2218 err
2219}
2220
2221#[cfg(test)]
2222mod tests {
2223 use super::*;
2224
2225 #[test]
2226 fn test_parse_redis_url() {
2227 let cases = vec![
2228 ("redis://127.0.0.1", true),
2229 ("redis://[::1]", true),
2230 ("rediss://127.0.0.1", true),
2231 ("rediss://[::1]", true),
2232 ("valkey://127.0.0.1", true),
2233 ("valkey://[::1]", true),
2234 ("valkeys://127.0.0.1", true),
2235 ("valkeys://[::1]", true),
2236 ("redis+unix:///run/redis.sock", true),
2237 ("valkey+unix:///run/valkey.sock", true),
2238 ("unix:///run/redis.sock", true),
2239 ("http://127.0.0.1", false),
2240 ("tcp://127.0.0.1", false),
2241 ];
2242 for (url, expected) in cases.into_iter() {
2243 let res = parse_redis_url(url);
2244 assert_eq!(
2245 res.is_some(),
2246 expected,
2247 "Parsed result of `{url}` is not expected",
2248 );
2249 }
2250 }
2251
2252 #[test]
2253 fn test_url_to_tcp_connection_info() {
2254 let cases = vec![
2255 (
2256 url::Url::parse("redis://127.0.0.1").unwrap(),
2257 ConnectionInfo {
2258 addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2259 redis: Default::default(),
2260 },
2261 ),
2262 (
2263 url::Url::parse("redis://[::1]").unwrap(),
2264 ConnectionInfo {
2265 addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
2266 redis: Default::default(),
2267 },
2268 ),
2269 (
2270 url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
2271 ConnectionInfo {
2272 addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
2273 redis: RedisConnectionInfo {
2274 db: 2,
2275 username: Some("%johndoe%".to_string()),
2276 password: Some("#@<>$".to_string()),
2277 ..Default::default()
2278 },
2279 },
2280 ),
2281 (
2282 url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
2283 ConnectionInfo {
2284 addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2285 redis: Default::default(),
2286 },
2287 ),
2288 (
2289 url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
2290 ConnectionInfo {
2291 addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2292 redis: RedisConnectionInfo {
2293 protocol: ProtocolVersion::RESP3,
2294 ..Default::default()
2295 },
2296 },
2297 ),
2298 ];
2299 for (url, expected) in cases.into_iter() {
2300 let res = url_to_tcp_connection_info(url.clone()).unwrap();
2301 assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2302 assert_eq!(
2303 res.redis.db, expected.redis.db,
2304 "db of {url} is not expected",
2305 );
2306 assert_eq!(
2307 res.redis.username, expected.redis.username,
2308 "username of {url} is not expected",
2309 );
2310 assert_eq!(
2311 res.redis.password, expected.redis.password,
2312 "password of {url} is not expected",
2313 );
2314 }
2315 }
2316
2317 #[test]
2318 fn test_url_to_tcp_connection_info_failed() {
2319 let cases = vec![
2320 (
2321 url::Url::parse("redis://").unwrap(),
2322 "Missing hostname",
2323 None,
2324 ),
2325 (
2326 url::Url::parse("redis://127.0.0.1/db").unwrap(),
2327 "Invalid database number",
2328 None,
2329 ),
2330 (
2331 url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
2332 "Username is not valid UTF-8 string",
2333 None,
2334 ),
2335 (
2336 url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
2337 "Password is not valid UTF-8 string",
2338 None,
2339 ),
2340 (
2341 url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
2342 "Invalid protocol version",
2343 Some("4"),
2344 ),
2345 ];
2346 for (url, expected, detail) in cases.into_iter() {
2347 let res = url_to_tcp_connection_info(url).unwrap_err();
2348 assert_eq!(
2349 res.kind(),
2350 crate::ErrorKind::InvalidClientConfig,
2351 "{}",
2352 &res,
2353 );
2354 #[allow(deprecated)]
2355 let desc = std::error::Error::description(&res);
2356 assert_eq!(desc, expected, "{}", &res);
2357 assert_eq!(res.detail(), detail, "{}", &res);
2358 }
2359 }
2360
2361 #[test]
2362 #[cfg(unix)]
2363 fn test_url_to_unix_connection_info() {
2364 let cases = vec![
2365 (
2366 url::Url::parse("unix:///var/run/redis.sock").unwrap(),
2367 ConnectionInfo {
2368 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2369 redis: RedisConnectionInfo {
2370 db: 0,
2371 username: None,
2372 password: None,
2373 protocol: ProtocolVersion::RESP2,
2374 },
2375 },
2376 ),
2377 (
2378 url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
2379 ConnectionInfo {
2380 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2381 redis: RedisConnectionInfo {
2382 db: 1,
2383 ..Default::default()
2384 },
2385 },
2386 ),
2387 (
2388 url::Url::parse(
2389 "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
2390 )
2391 .unwrap(),
2392 ConnectionInfo {
2393 addr: ConnectionAddr::Unix("/example.sock".into()),
2394 redis: RedisConnectionInfo {
2395 db: 2,
2396 username: Some("%johndoe%".to_string()),
2397 password: Some("#@<>$".to_string()),
2398 ..Default::default()
2399 },
2400 },
2401 ),
2402 (
2403 url::Url::parse(
2404 "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
2405 )
2406 .unwrap(),
2407 ConnectionInfo {
2408 addr: ConnectionAddr::Unix("/example.sock".into()),
2409 redis: RedisConnectionInfo {
2410 db: 2,
2411 username: Some("%johndoe%".to_string()),
2412 password: Some("&?= *+".to_string()),
2413 ..Default::default()
2414 },
2415 },
2416 ),
2417 (
2418 url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
2419 ConnectionInfo {
2420 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2421 redis: RedisConnectionInfo {
2422 protocol: ProtocolVersion::RESP3,
2423 ..Default::default()
2424 },
2425 },
2426 ),
2427 ];
2428 for (url, expected) in cases.into_iter() {
2429 assert_eq!(
2430 ConnectionAddr::Unix(url.to_file_path().unwrap()),
2431 expected.addr,
2432 "addr of {url} is not expected",
2433 );
2434 let res = url_to_unix_connection_info(url.clone()).unwrap();
2435 assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2436 assert_eq!(
2437 res.redis.db, expected.redis.db,
2438 "db of {url} is not expected",
2439 );
2440 assert_eq!(
2441 res.redis.username, expected.redis.username,
2442 "username of {url} is not expected",
2443 );
2444 assert_eq!(
2445 res.redis.password, expected.redis.password,
2446 "password of {url} is not expected",
2447 );
2448 }
2449 }
2450}