1use super::{AsyncPushSender, HandleContainer, RedisFuture};
2#[cfg(feature = "cache-aio")]
3use crate::caching::CacheManager;
4use crate::{
5 aio::{ConnectionLike, MultiplexedConnection, Runtime},
6 check_resp3,
7 client::{DEFAULT_CONNECTION_TIMEOUT, DEFAULT_RESPONSE_TIMEOUT},
8 cmd,
9 errors::RedisError,
10 subscription_tracker::{SubscriptionAction, SubscriptionTracker},
11 types::{RedisResult, Value},
12 AsyncConnectionConfig, Client, Cmd, Pipeline, PushInfo, PushKind, ToRedisArgs,
13};
14use arc_swap::ArcSwap;
15use backon::{ExponentialBuilder, Retryable};
16use futures_channel::oneshot;
17use futures_util::future::{self, BoxFuture, FutureExt, Shared};
18use std::sync::{Arc, Weak};
19use std::time::Duration;
20use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
21use tokio::sync::Mutex;
22
23type OptionalPushSender = Option<Arc<dyn AsyncPushSender>>;
24
25#[derive(Clone)]
27pub struct ConnectionManagerConfig {
28 exponent_base: f32,
31 min_delay: Duration,
33 max_delay: Option<Duration>,
35 number_of_retries: usize,
37 response_timeout: Option<Duration>,
39 connection_timeout: Option<Duration>,
41 push_sender: Option<Arc<dyn AsyncPushSender>>,
43 resubscribe_automatically: bool,
45 #[cfg(feature = "cache-aio")]
46 pub(crate) cache_config: Option<crate::caching::CacheConfig>,
47}
48
49impl std::fmt::Debug for ConnectionManagerConfig {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
51 let &Self {
52 exponent_base,
53 min_delay,
54 number_of_retries,
55 max_delay,
56 response_timeout,
57 connection_timeout,
58 push_sender,
59 resubscribe_automatically,
60 #[cfg(feature = "cache-aio")]
61 cache_config,
62 } = &self;
63 let mut str = f.debug_struct("ConnectionManagerConfig");
64 str.field("exponent_base", &exponent_base)
65 .field("min_delay", &min_delay)
66 .field("max_delay", &max_delay)
67 .field("number_of_retries", &number_of_retries)
68 .field("response_timeout", &response_timeout)
69 .field("connection_timeout", &connection_timeout)
70 .field("resubscribe_automatically", &resubscribe_automatically)
71 .field(
72 "push_sender",
73 if push_sender.is_some() {
74 &"set"
75 } else {
76 &"not set"
77 },
78 );
79
80 #[cfg(feature = "cache-aio")]
81 str.field("cache_config", &cache_config);
82
83 str.finish()
84 }
85}
86
87impl ConnectionManagerConfig {
88 const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: f32 = 2.0;
89 const DEFAULT_CONNECTION_RETRY_MIN_DELAY: Duration = Duration::from_millis(100);
90 const DEFAULT_NUMBER_OF_CONNECTION_RETRIES: usize = 6;
91
92 pub fn new() -> Self {
94 Self::default()
95 }
96
97 pub fn min_delay(&self) -> Duration {
99 self.min_delay
100 }
101
102 pub fn max_delay(&self) -> Option<Duration> {
104 self.max_delay
105 }
106
107 pub fn exponent_base(&self) -> f32 {
109 self.exponent_base
110 }
111
112 pub fn number_of_retries(&self) -> usize {
114 self.number_of_retries
115 }
116
117 pub fn response_timeout(&self) -> Option<Duration> {
121 self.response_timeout
122 }
123
124 pub fn connection_timeout(&self) -> Option<Duration> {
128 self.connection_timeout
129 }
130
131 pub fn automatic_resubscription(&self) -> bool {
133 self.resubscribe_automatically
134 }
135
136 #[cfg(feature = "cache-aio")]
138 pub fn cache_config(&self) -> Option<&crate::caching::CacheConfig> {
139 self.cache_config.as_ref()
140 }
141
142 pub fn set_min_delay(mut self, min_delay: Duration) -> ConnectionManagerConfig {
144 self.min_delay = min_delay;
145 self
146 }
147
148 pub fn set_max_delay(mut self, time: Duration) -> ConnectionManagerConfig {
150 self.max_delay = Some(time);
151 self
152 }
153
154 pub fn set_exponent_base(mut self, base: f32) -> ConnectionManagerConfig {
157 self.exponent_base = base;
158 self
159 }
160
161 pub fn set_number_of_retries(mut self, amount: usize) -> ConnectionManagerConfig {
163 self.number_of_retries = amount;
164 self
165 }
166
167 pub fn set_response_timeout(mut self, duration: Option<Duration>) -> ConnectionManagerConfig {
171 self.response_timeout = duration;
172 self
173 }
174
175 pub fn set_connection_timeout(mut self, duration: Option<Duration>) -> ConnectionManagerConfig {
179 self.connection_timeout = duration;
180 self
181 }
182
183 pub fn set_push_sender(mut self, sender: impl AsyncPushSender) -> Self {
209 self.push_sender = Some(Arc::new(sender));
210 self
211 }
212
213 pub fn set_automatic_resubscription(mut self) -> Self {
215 self.resubscribe_automatically = true;
216 self
217 }
218
219 #[cfg(feature = "cache-aio")]
221 pub fn set_cache_config(self, cache_config: crate::caching::CacheConfig) -> Self {
222 Self {
223 cache_config: Some(cache_config),
224 ..self
225 }
226 }
227}
228
229impl Default for ConnectionManagerConfig {
230 fn default() -> Self {
231 Self {
232 exponent_base: Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE,
233 min_delay: Self::DEFAULT_CONNECTION_RETRY_MIN_DELAY,
234 max_delay: None,
235 number_of_retries: Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIES,
236 response_timeout: DEFAULT_RESPONSE_TIMEOUT,
237 connection_timeout: DEFAULT_CONNECTION_TIMEOUT,
238 push_sender: None,
239 resubscribe_automatically: false,
240 #[cfg(feature = "cache-aio")]
241 cache_config: None,
242 }
243 }
244}
245
246struct Internals {
247 client: Client,
249 connection: ArcSwap<SharedRedisFuture<MultiplexedConnection>>,
254
255 runtime: Runtime,
256 retry_strategy: ExponentialBuilder,
257 connection_config: AsyncConnectionConfig,
258 subscription_tracker: Option<Mutex<SubscriptionTracker>>,
259 #[cfg(feature = "cache-aio")]
260 cache_manager: Option<CacheManager>,
261 _task_handle: HandleContainer,
262}
263
264#[derive(Clone)]
293pub struct ConnectionManager(Arc<Internals>);
294
295impl std::fmt::Debug for ConnectionManager {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 f.debug_struct("ConnectionManager")
298 .field("client", &self.0.client)
299 .field("retry_strategy", &self.0.retry_strategy)
300 .finish()
301 }
302}
303
304type SharedRedisFuture<T> = Shared<BoxFuture<'static, RedisResult<T>>>;
306
307macro_rules! reconnect_if_dropped {
309 ($self:expr, $result:expr, $current:expr) => {
310 if let Err(ref e) = $result {
311 if e.is_unrecoverable_error() {
312 Self::reconnect(Arc::downgrade(&$self.0), $current);
313 }
314 }
315 };
316}
317
318macro_rules! reconnect_if_io_error {
321 ($self:expr, $result:expr, $current:expr) => {
322 if let Err(e) = $result {
323 if e.is_io_error() {
324 Self::reconnect(Arc::downgrade(&$self.0), $current);
325 }
326 return Err(e);
327 }
328 };
329}
330
331impl ConnectionManager {
332 pub async fn new(client: Client) -> RedisResult<Self> {
337 let config = ConnectionManagerConfig::new();
338
339 Self::new_with_config(client, config).await
340 }
341
342 pub async fn new_with_config(
354 client: Client,
355 config: ConnectionManagerConfig,
356 ) -> RedisResult<Self> {
357 let runtime = Runtime::locate();
359
360 if config.resubscribe_automatically && config.push_sender.is_none() {
361 return Err((crate::ErrorKind::Client, "Cannot set resubscribe_automatically without setting a push sender to receive messages.").into());
362 }
363
364 let mut retry_strategy = ExponentialBuilder::default()
365 .with_factor(config.exponent_base)
366 .with_min_delay(config.min_delay)
367 .with_max_times(config.number_of_retries)
368 .with_jitter();
369 if let Some(max_delay) = config.max_delay {
370 retry_strategy = retry_strategy.with_max_delay(max_delay);
371 }
372
373 let mut connection_config = AsyncConnectionConfig::new()
374 .set_connection_timeout(config.connection_timeout)
375 .set_response_timeout(config.response_timeout);
376
377 #[cfg(feature = "cache-aio")]
378 let cache_manager = config
379 .cache_config
380 .as_ref()
381 .map(|cache_config| CacheManager::new(*cache_config));
382 #[cfg(feature = "cache-aio")]
383 if let Some(cache_manager) = cache_manager.as_ref() {
384 connection_config = connection_config.set_cache_manager(cache_manager.clone());
385 }
386
387 let (oneshot_sender, oneshot_receiver) = oneshot::channel();
388 let _task_handle = HandleContainer::new(
389 runtime.spawn(Self::check_for_disconnect_pushes(oneshot_receiver)),
390 );
391
392 let mut components_for_reconnection_on_push = None;
393 if let Some(push_sender) = config.push_sender.clone() {
394 check_resp3!(
395 client.connection_info.redis.protocol,
396 "Can only pass push sender to a connection using RESP3"
397 );
398
399 let (internal_sender, internal_receiver) = unbounded_channel();
400 components_for_reconnection_on_push = Some((internal_receiver, Some(push_sender)));
401
402 connection_config =
403 connection_config.set_push_sender_internal(Arc::new(internal_sender));
404 } else if client.connection_info.redis.protocol.supports_resp3() {
405 let (internal_sender, internal_receiver) = unbounded_channel();
406 components_for_reconnection_on_push = Some((internal_receiver, None));
407
408 connection_config =
409 connection_config.set_push_sender_internal(Arc::new(internal_sender));
410 }
411
412 let connection =
413 Self::new_connection(&client, retry_strategy, &connection_config, None).await?;
414 let subscription_tracker = if config.resubscribe_automatically {
415 Some(Mutex::new(SubscriptionTracker::default()))
416 } else {
417 None
418 };
419
420 let new_self = Self(Arc::new(Internals {
421 client,
422 connection: ArcSwap::from_pointee(future::ok(connection).boxed().shared()),
423 runtime,
424 retry_strategy,
425 connection_config,
426 subscription_tracker,
427 #[cfg(feature = "cache-aio")]
428 cache_manager,
429 _task_handle,
430 }));
431
432 if let Some((internal_receiver, external_sender)) = components_for_reconnection_on_push {
433 oneshot_sender
434 .send((
435 Arc::downgrade(&new_self.0),
436 internal_receiver,
437 external_sender,
438 ))
439 .map_err(|_| {
440 crate::RedisError::from((
441 crate::ErrorKind::Client,
442 "Failed to set automatic resubscription",
443 ))
444 })?;
445 };
446
447 Ok(new_self)
448 }
449
450 async fn new_connection(
451 client: &Client,
452 exponential_backoff: ExponentialBuilder,
453 connection_config: &AsyncConnectionConfig,
454 additional_commands: Option<Pipeline>,
455 ) -> RedisResult<MultiplexedConnection> {
456 let connection_config = connection_config.clone();
457 let get_conn = || async {
458 client
459 .get_multiplexed_async_connection_with_config(&connection_config)
460 .await
461 };
462 let mut conn = get_conn
463 .retry(exponential_backoff)
464 .sleep(|duration| async move { Runtime::locate().sleep(duration).await })
465 .await?;
466 if let Some(pipeline) = additional_commands {
467 let _ = pipeline.exec_async(&mut conn).await;
469 }
470 Ok(conn)
471 }
472
473 fn reconnect(
478 internals: Weak<Internals>,
479 current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>,
480 ) {
481 let Some(internals) = internals.upgrade() else {
482 return;
483 };
484 let internals_clone = internals.clone();
485 #[cfg(not(feature = "cache-aio"))]
486 let connection_config = internals.connection_config.clone();
487 #[cfg(feature = "cache-aio")]
488 let mut connection_config = internals.connection_config.clone();
489 #[cfg(feature = "cache-aio")]
490 if let Some(manager) = internals.cache_manager.as_ref() {
491 let new_cache_manager = manager.clone_and_increase_epoch();
492 connection_config = connection_config.set_cache_manager(new_cache_manager);
493 }
494 let new_connection: SharedRedisFuture<MultiplexedConnection> = async move {
495 let additional_commands = match &internals_clone.subscription_tracker {
496 Some(subscription_tracker) => Some(
497 subscription_tracker
498 .lock()
499 .await
500 .get_subscription_pipeline(),
501 ),
502 None => None,
503 };
504
505 let con = Self::new_connection(
506 &internals_clone.client,
507 internals_clone.retry_strategy,
508 &connection_config,
509 additional_commands,
510 )
511 .await?;
512 Ok(con)
513 }
514 .boxed()
515 .shared();
516
517 let new_connection_arc = Arc::new(new_connection.clone());
519 let prev = internals
520 .connection
521 .compare_and_swap(¤t, new_connection_arc);
522
523 if Arc::ptr_eq(&prev, ¤t) {
525 internals.runtime.spawn(new_connection.map(|_| ())).detach();
527 }
528 }
529
530 async fn check_for_disconnect_pushes(
531 receiver: oneshot::Receiver<(
532 Weak<Internals>,
533 UnboundedReceiver<PushInfo>,
534 OptionalPushSender,
535 )>,
536 ) {
537 let Ok((this, mut internal_receiver, external_sender)) = receiver.await else {
538 return;
539 };
540 while let Some(push_info) = internal_receiver.recv().await {
541 if push_info.kind == PushKind::Disconnection {
542 let Some(internals) = this.upgrade() else {
543 return;
544 };
545 Self::reconnect(Arc::downgrade(&internals), internals.connection.load());
546 }
547 if let Some(sender) = external_sender.as_ref() {
548 let _ = sender.send(push_info);
549 }
550 }
551 }
552
553 pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
556 let guard = self.0.connection.load();
558 let connection_result = (**guard).clone().await.map_err(|e| e.clone());
559 reconnect_if_io_error!(self, connection_result, guard);
560 let result = connection_result?.send_packed_command(cmd).await;
561 reconnect_if_dropped!(self, &result, guard);
562 result
563 }
564
565 pub async fn send_packed_commands(
569 &mut self,
570 cmd: &crate::Pipeline,
571 offset: usize,
572 count: usize,
573 ) -> RedisResult<Vec<Value>> {
574 let guard = self.0.connection.load();
576 let connection_result = (**guard).clone().await.map_err(|e| e.clone());
577 reconnect_if_io_error!(self, connection_result, guard);
578 let result = connection_result?
579 .send_packed_commands(cmd, offset, count)
580 .await;
581 reconnect_if_dropped!(self, &result, guard);
582 result
583 }
584
585 async fn update_subscription_tracker(
586 &self,
587 action: SubscriptionAction,
588 args: impl ToRedisArgs,
589 ) {
590 let Some(subscription_tracker) = &self.0.subscription_tracker else {
591 return;
592 };
593 let args = args.to_redis_args().into_iter();
594 subscription_tracker
595 .lock()
596 .await
597 .update_with_request(action, args);
598 }
599
600 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
621 check_resp3!(self.0.client.connection_info.redis.protocol);
622 let mut cmd = cmd("SUBSCRIBE");
623 cmd.arg(&channel_name);
624 cmd.exec_async(self).await?;
625 self.update_subscription_tracker(SubscriptionAction::Subscribe, channel_name)
626 .await;
627
628 Ok(())
629 }
630
631 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
635 check_resp3!(self.0.client.connection_info.redis.protocol);
636 let mut cmd = cmd("UNSUBSCRIBE");
637 cmd.arg(&channel_name);
638 cmd.exec_async(self).await?;
639 self.update_subscription_tracker(SubscriptionAction::Unsubscribe, channel_name)
640 .await;
641 Ok(())
642 }
643
644 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
665 check_resp3!(self.0.client.connection_info.redis.protocol);
666 let mut cmd = cmd("PSUBSCRIBE");
667 cmd.arg(&channel_pattern);
668 cmd.exec_async(self).await?;
669 self.update_subscription_tracker(SubscriptionAction::PSubscribe, channel_pattern)
670 .await;
671 Ok(())
672 }
673
674 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
678 check_resp3!(self.0.client.connection_info.redis.protocol);
679 let mut cmd = cmd("PUNSUBSCRIBE");
680 cmd.arg(&channel_pattern);
681 cmd.exec_async(self).await?;
682 self.update_subscription_tracker(SubscriptionAction::PUnsubscribe, channel_pattern)
683 .await;
684 Ok(())
685 }
686
687 #[cfg(feature = "cache-aio")]
689 #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
690 pub fn get_cache_statistics(&self) -> Option<crate::caching::CacheStatistics> {
691 self.0.cache_manager.as_ref().map(|cm| cm.statistics())
692 }
693}
694
695impl ConnectionLike for ConnectionManager {
696 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
697 (async move { self.send_packed_command(cmd).await }).boxed()
698 }
699
700 fn req_packed_commands<'a>(
701 &'a mut self,
702 cmd: &'a crate::Pipeline,
703 offset: usize,
704 count: usize,
705 ) -> RedisFuture<'a, Vec<Value>> {
706 (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
707 }
708
709 fn get_db(&self) -> i64 {
710 self.0.client.connection_info().redis.db
711 }
712}