1use super::{AsyncPushSender, HandleContainer, RedisFuture};
2#[cfg(feature = "cache-aio")]
3use crate::caching::CacheManager;
4use crate::{
5 aio::{check_resp3, ConnectionLike, MultiplexedConnection, Runtime},
6 cmd,
7 subscription_tracker::{SubscriptionAction, SubscriptionTracker},
8 types::{RedisError, RedisResult, Value},
9 AsyncConnectionConfig, Client, Cmd, Pipeline, ProtocolVersion, PushInfo, PushKind, ToRedisArgs,
10};
11use arc_swap::ArcSwap;
12use backon::{ExponentialBuilder, Retryable};
13use futures_channel::oneshot;
14use futures_util::future::{self, BoxFuture, FutureExt, Shared};
15use std::sync::{Arc, Weak};
16use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
17use tokio::sync::Mutex;
18
19type OptionalPushSender = Option<Arc<dyn AsyncPushSender>>;
20
21#[derive(Clone)]
23pub struct ConnectionManagerConfig {
24 exponent_base: u64,
27 factor: u64,
31 number_of_retries: usize,
33 max_delay: Option<u64>,
35 response_timeout: Option<std::time::Duration>,
37 connection_timeout: Option<std::time::Duration>,
39 push_sender: Option<Arc<dyn AsyncPushSender>>,
41 resubscribe_automatically: bool,
43 tcp_settings: crate::io::tcp::TcpSettings,
44 #[cfg(feature = "cache-aio")]
45 pub(crate) cache_config: Option<crate::caching::CacheConfig>,
46}
47
48impl std::fmt::Debug for ConnectionManagerConfig {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
50 let &Self {
51 exponent_base,
52 factor,
53 number_of_retries,
54 max_delay,
55 response_timeout,
56 connection_timeout,
57 push_sender,
58 resubscribe_automatically,
59 tcp_settings,
60 #[cfg(feature = "cache-aio")]
61 cache_config,
62 } = &self;
63 let mut str = f.debug_struct("ConnectionManagerConfig");
64 str.field("exponent_base", &exponent_base)
65 .field("factor", &factor)
66 .field("number_of_retries", &number_of_retries)
67 .field("max_delay", &max_delay)
68 .field("response_timeout", &response_timeout)
69 .field("connection_timeout", &connection_timeout)
70 .field("resubscribe_automatically", &resubscribe_automatically)
71 .field(
72 "push_sender",
73 if push_sender.is_some() {
74 &"set"
75 } else {
76 &"not set"
77 },
78 )
79 .field("tcp_settings", &tcp_settings);
80
81 #[cfg(feature = "cache-aio")]
82 str.field("cache_config", &cache_config);
83
84 str.finish()
85 }
86}
87
88impl ConnectionManagerConfig {
89 const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: u64 = 2;
90 const DEFAULT_CONNECTION_RETRY_FACTOR: u64 = 100;
91 const DEFAULT_NUMBER_OF_CONNECTION_RETRIES: usize = 6;
92 const DEFAULT_RESPONSE_TIMEOUT: Option<std::time::Duration> = None;
93 const DEFAULT_CONNECTION_TIMEOUT: Option<std::time::Duration> = None;
94
95 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn set_factor(mut self, factor: u64) -> ConnectionManagerConfig {
104 self.factor = factor;
105 self
106 }
107
108 pub fn set_max_delay(mut self, time: u64) -> ConnectionManagerConfig {
110 self.max_delay = Some(time);
111 self
112 }
113
114 pub fn set_exponent_base(mut self, base: u64) -> ConnectionManagerConfig {
117 self.exponent_base = base;
118 self
119 }
120
121 pub fn set_number_of_retries(mut self, amount: usize) -> ConnectionManagerConfig {
123 self.number_of_retries = amount;
124 self
125 }
126
127 pub fn set_response_timeout(
129 mut self,
130 duration: std::time::Duration,
131 ) -> ConnectionManagerConfig {
132 self.response_timeout = Some(duration);
133 self
134 }
135
136 pub fn set_connection_timeout(
138 mut self,
139 duration: std::time::Duration,
140 ) -> ConnectionManagerConfig {
141 self.connection_timeout = Some(duration);
142 self
143 }
144
145 pub fn set_push_sender(mut self, sender: impl AsyncPushSender) -> Self {
171 self.push_sender = Some(Arc::new(sender));
172 self
173 }
174
175 pub fn set_automatic_resubscription(mut self) -> Self {
177 self.resubscribe_automatically = true;
178 self
179 }
180
181 pub fn set_tcp_settings(self, tcp_settings: crate::io::tcp::TcpSettings) -> Self {
183 Self {
184 tcp_settings,
185 ..self
186 }
187 }
188
189 #[cfg(feature = "cache-aio")]
191 pub fn set_cache_config(self, cache_config: crate::caching::CacheConfig) -> Self {
192 Self {
193 cache_config: Some(cache_config),
194 ..self
195 }
196 }
197}
198
199impl Default for ConnectionManagerConfig {
200 fn default() -> Self {
201 Self {
202 exponent_base: Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE,
203 factor: Self::DEFAULT_CONNECTION_RETRY_FACTOR,
204 number_of_retries: Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIES,
205 response_timeout: Self::DEFAULT_RESPONSE_TIMEOUT,
206 connection_timeout: Self::DEFAULT_CONNECTION_TIMEOUT,
207 max_delay: None,
208 push_sender: None,
209 resubscribe_automatically: false,
210 tcp_settings: Default::default(),
211 #[cfg(feature = "cache-aio")]
212 cache_config: None,
213 }
214 }
215}
216
217struct Internals {
218 client: Client,
220 connection: ArcSwap<SharedRedisFuture<MultiplexedConnection>>,
225
226 runtime: Runtime,
227 retry_strategy: ExponentialBuilder,
228 connection_config: AsyncConnectionConfig,
229 subscription_tracker: Option<Mutex<SubscriptionTracker>>,
230 #[cfg(feature = "cache-aio")]
231 cache_manager: Option<CacheManager>,
232 _task_handle: HandleContainer,
233}
234
235#[derive(Clone)]
264pub struct ConnectionManager(Arc<Internals>);
265
266type CloneableRedisResult<T> = Result<T, Arc<RedisError>>;
268
269type SharedRedisFuture<T> = Shared<BoxFuture<'static, CloneableRedisResult<T>>>;
271
272macro_rules! reconnect_if_dropped {
274 ($self:expr, $result:expr, $current:expr) => {
275 if let Err(ref e) = $result {
276 if e.is_unrecoverable_error() {
277 Self::reconnect(Arc::downgrade(&$self.0), $current);
278 }
279 }
280 };
281}
282
283macro_rules! reconnect_if_io_error {
286 ($self:expr, $result:expr, $current:expr) => {
287 if let Err(e) = $result {
288 if e.is_io_error() {
289 Self::reconnect(Arc::downgrade(&$self.0), $current);
290 }
291 return Err(e);
292 }
293 };
294}
295
296impl ConnectionManager {
297 pub async fn new(client: Client) -> RedisResult<Self> {
302 let config = ConnectionManagerConfig::new();
303
304 Self::new_with_config(client, config).await
305 }
306
307 #[deprecated(note = "Use `new_with_config`")]
316 pub async fn new_with_backoff(
317 client: Client,
318 exponent_base: u64,
319 factor: u64,
320 number_of_retries: usize,
321 ) -> RedisResult<Self> {
322 let config = ConnectionManagerConfig::new()
323 .set_exponent_base(exponent_base)
324 .set_factor(factor)
325 .set_number_of_retries(number_of_retries);
326 Self::new_with_config(client, config).await
327 }
328
329 #[deprecated(note = "Use `new_with_config`")]
341 pub async fn new_with_backoff_and_timeouts(
342 client: Client,
343 exponent_base: u64,
344 factor: u64,
345 number_of_retries: usize,
346 response_timeout: std::time::Duration,
347 connection_timeout: std::time::Duration,
348 ) -> RedisResult<Self> {
349 let config = ConnectionManagerConfig::new()
350 .set_exponent_base(exponent_base)
351 .set_factor(factor)
352 .set_number_of_retries(number_of_retries)
353 .set_response_timeout(response_timeout)
354 .set_connection_timeout(connection_timeout);
355
356 Self::new_with_config(client, config).await
357 }
358
359 pub async fn new_with_config(
373 client: Client,
374 config: ConnectionManagerConfig,
375 ) -> RedisResult<Self> {
376 let runtime = Runtime::locate();
378
379 if config.resubscribe_automatically && config.push_sender.is_none() {
380 return Err((crate::ErrorKind::ClientError, "Cannot set resubscribe_automatically without setting a push sender to receive messages.").into());
381 }
382
383 let mut retry_strategy = ExponentialBuilder::default()
384 .with_factor(config.factor as f32)
385 .with_max_times(config.number_of_retries)
386 .with_jitter();
387 if let Some(max_delay) = config.max_delay {
388 retry_strategy =
389 retry_strategy.with_max_delay(std::time::Duration::from_millis(max_delay));
390 }
391
392 let mut connection_config = AsyncConnectionConfig::new();
393 if let Some(connection_timeout) = config.connection_timeout {
394 connection_config = connection_config.set_connection_timeout(connection_timeout);
395 }
396 if let Some(response_timeout) = config.response_timeout {
397 connection_config = connection_config.set_response_timeout(response_timeout);
398 }
399 connection_config = connection_config.set_tcp_settings(config.tcp_settings);
400 #[cfg(feature = "cache-aio")]
401 let cache_manager = config
402 .cache_config
403 .as_ref()
404 .map(|cache_config| CacheManager::new(*cache_config));
405 #[cfg(feature = "cache-aio")]
406 if let Some(cache_manager) = cache_manager.as_ref() {
407 connection_config = connection_config.set_cache_manager(cache_manager.clone());
408 }
409
410 let (oneshot_sender, oneshot_receiver) = oneshot::channel();
411 let _task_handle = HandleContainer::new(
412 runtime.spawn(Self::check_for_disconnect_pushes(oneshot_receiver)),
413 );
414
415 let mut components_for_reconnection_on_push = None;
416 if let Some(push_sender) = config.push_sender.clone() {
417 check_resp3!(
418 client.connection_info.redis.protocol,
419 "Can only pass push sender to a connection using RESP3"
420 );
421
422 let (internal_sender, internal_receiver) = unbounded_channel();
423 components_for_reconnection_on_push = Some((internal_receiver, Some(push_sender)));
424
425 connection_config =
426 connection_config.set_push_sender_internal(Arc::new(internal_sender));
427 } else if client.connection_info.redis.protocol != ProtocolVersion::RESP2 {
428 let (internal_sender, internal_receiver) = unbounded_channel();
429 components_for_reconnection_on_push = Some((internal_receiver, None));
430
431 connection_config =
432 connection_config.set_push_sender_internal(Arc::new(internal_sender));
433 }
434
435 let connection =
436 Self::new_connection(&client, retry_strategy, &connection_config, None).await?;
437 let subscription_tracker = if config.resubscribe_automatically {
438 Some(Mutex::new(SubscriptionTracker::default()))
439 } else {
440 None
441 };
442
443 let new_self = Self(Arc::new(Internals {
444 client,
445 connection: ArcSwap::from_pointee(future::ok(connection).boxed().shared()),
446 runtime,
447 retry_strategy,
448 connection_config,
449 subscription_tracker,
450 #[cfg(feature = "cache-aio")]
451 cache_manager,
452 _task_handle,
453 }));
454
455 if let Some((internal_receiver, external_sender)) = components_for_reconnection_on_push {
456 oneshot_sender
457 .send((
458 Arc::downgrade(&new_self.0),
459 internal_receiver,
460 external_sender,
461 ))
462 .map_err(|_| {
463 crate::RedisError::from((
464 crate::ErrorKind::ClientError,
465 "Failed to set automatic resubscription",
466 ))
467 })?;
468 };
469
470 Ok(new_self)
471 }
472
473 async fn new_connection(
474 client: &Client,
475 exponential_backoff: ExponentialBuilder,
476 connection_config: &AsyncConnectionConfig,
477 additional_commands: Option<Pipeline>,
478 ) -> RedisResult<MultiplexedConnection> {
479 let connection_config = connection_config.clone();
480 let get_conn = || async {
481 client
482 .get_multiplexed_async_connection_with_config(&connection_config)
483 .await
484 };
485 let mut conn = get_conn
486 .retry(exponential_backoff)
487 .sleep(|duration| async move { Runtime::locate().sleep(duration).await })
488 .await?;
489 if let Some(pipeline) = additional_commands {
490 let _ = pipeline.exec_async(&mut conn).await;
492 }
493 Ok(conn)
494 }
495
496 fn reconnect(
501 internals: Weak<Internals>,
502 current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>,
503 ) {
504 let Some(internals) = internals.upgrade() else {
505 return;
506 };
507 let internals_clone = internals.clone();
508 #[cfg(not(feature = "cache-aio"))]
509 let connection_config = internals.connection_config.clone();
510 #[cfg(feature = "cache-aio")]
511 let mut connection_config = internals.connection_config.clone();
512 #[cfg(feature = "cache-aio")]
513 if let Some(manager) = internals.cache_manager.as_ref() {
514 let new_cache_manager = manager.clone_and_increase_epoch();
515 connection_config = connection_config.set_cache_manager(new_cache_manager);
516 }
517 let new_connection: SharedRedisFuture<MultiplexedConnection> = async move {
518 let additional_commands = match &internals_clone.subscription_tracker {
519 Some(subscription_tracker) => Some(
520 subscription_tracker
521 .lock()
522 .await
523 .get_subscription_pipeline(),
524 ),
525 None => None,
526 };
527
528 let con = Self::new_connection(
529 &internals_clone.client,
530 internals_clone.retry_strategy,
531 &connection_config,
532 additional_commands,
533 )
534 .await?;
535 Ok(con)
536 }
537 .boxed()
538 .shared();
539
540 let new_connection_arc = Arc::new(new_connection.clone());
542 let prev = internals
543 .connection
544 .compare_and_swap(¤t, new_connection_arc);
545
546 if Arc::ptr_eq(&prev, ¤t) {
548 internals.runtime.spawn(new_connection.map(|_| ())).detach();
550 }
551 }
552
553 async fn check_for_disconnect_pushes(
554 receiver: oneshot::Receiver<(
555 Weak<Internals>,
556 UnboundedReceiver<PushInfo>,
557 OptionalPushSender,
558 )>,
559 ) {
560 let Ok((this, mut internal_receiver, external_sender)) = receiver.await else {
561 return;
562 };
563 while let Some(push_info) = internal_receiver.recv().await {
564 if push_info.kind == PushKind::Disconnection {
565 let Some(internals) = this.upgrade() else {
566 return;
567 };
568 Self::reconnect(Arc::downgrade(&internals), internals.connection.load());
569 }
570 if let Some(sender) = external_sender.as_ref() {
571 let _ = sender.send(push_info);
572 }
573 }
574 }
575
576 pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
579 let guard = self.0.connection.load();
581 let connection_result = (**guard)
582 .clone()
583 .await
584 .map_err(|e| e.clone_mostly("Reconnecting failed"));
585 reconnect_if_io_error!(self, connection_result, guard);
586 let result = connection_result?.send_packed_command(cmd).await;
587 reconnect_if_dropped!(self, &result, guard);
588 result
589 }
590
591 pub async fn send_packed_commands(
595 &mut self,
596 cmd: &crate::Pipeline,
597 offset: usize,
598 count: usize,
599 ) -> RedisResult<Vec<Value>> {
600 let guard = self.0.connection.load();
602 let connection_result = (**guard)
603 .clone()
604 .await
605 .map_err(|e| e.clone_mostly("Reconnecting failed"));
606 reconnect_if_io_error!(self, connection_result, guard);
607 let result = connection_result?
608 .send_packed_commands(cmd, offset, count)
609 .await;
610 reconnect_if_dropped!(self, &result, guard);
611 result
612 }
613
614 async fn update_subscription_tracker(
615 &self,
616 action: SubscriptionAction,
617 args: impl ToRedisArgs,
618 ) {
619 let Some(subscription_tracker) = &self.0.subscription_tracker else {
620 return;
621 };
622 let args = args.to_redis_args().into_iter();
623 subscription_tracker
624 .lock()
625 .await
626 .update_with_request(action, args);
627 }
628
629 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
638 check_resp3!(self.0.client.connection_info.redis.protocol);
639 let mut cmd = cmd("SUBSCRIBE");
640 cmd.arg(&channel_name);
641 cmd.exec_async(self).await?;
642 self.update_subscription_tracker(SubscriptionAction::Subscribe, channel_name)
643 .await;
644
645 Ok(())
646 }
647
648 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
652 check_resp3!(self.0.client.connection_info.redis.protocol);
653 let mut cmd = cmd("UNSUBSCRIBE");
654 cmd.arg(&channel_name);
655 cmd.exec_async(self).await?;
656 self.update_subscription_tracker(SubscriptionAction::Unsubscribe, channel_name)
657 .await;
658 Ok(())
659 }
660
661 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
670 check_resp3!(self.0.client.connection_info.redis.protocol);
671 let mut cmd = cmd("PSUBSCRIBE");
672 cmd.arg(&channel_pattern);
673 cmd.exec_async(self).await?;
674 self.update_subscription_tracker(SubscriptionAction::PSubscribe, channel_pattern)
675 .await;
676 Ok(())
677 }
678
679 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
683 check_resp3!(self.0.client.connection_info.redis.protocol);
684 let mut cmd = cmd("PUNSUBSCRIBE");
685 cmd.arg(&channel_pattern);
686 cmd.exec_async(self).await?;
687 self.update_subscription_tracker(SubscriptionAction::PUnsubscribe, channel_pattern)
688 .await;
689 Ok(())
690 }
691
692 #[cfg(feature = "cache-aio")]
694 #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
695 pub fn get_cache_statistics(&self) -> Option<crate::caching::CacheStatistics> {
696 self.0.cache_manager.as_ref().map(|cm| cm.statistics())
697 }
698}
699
700impl ConnectionLike for ConnectionManager {
701 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
702 (async move { self.send_packed_command(cmd).await }).boxed()
703 }
704
705 fn req_packed_commands<'a>(
706 &'a mut self,
707 cmd: &'a crate::Pipeline,
708 offset: usize,
709 count: usize,
710 ) -> RedisFuture<'a, Vec<Value>> {
711 (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
712 }
713
714 fn get_db(&self) -> i64 {
715 self.0.client.connection_info().redis.db
716 }
717}