1use super::{AsyncPushSender, ConnectionLike, Runtime, SharedHandleContainer, TaskHandle};
2use crate::aio::{check_resp3, setup_connection};
3#[cfg(feature = "cache-aio")]
4use crate::caching::{CacheManager, CacheStatistics, PrepareCacheResult};
5use crate::cmd::Cmd;
6use crate::parser::ValueCodec;
7use crate::types::{closed_connection_error, RedisError, RedisFuture, RedisResult, Value};
8use crate::{
9 cmd, AsyncConnectionConfig, ProtocolVersion, PushInfo, RedisConnectionInfo, ToRedisArgs,
10};
11use ::tokio::{
12 io::{AsyncRead, AsyncWrite},
13 sync::{mpsc, oneshot},
14};
15use futures_util::{
16 future::{Future, FutureExt},
17 ready,
18 sink::Sink,
19 stream::{self, Stream, StreamExt},
20};
21use pin_project_lite::pin_project;
22use std::collections::VecDeque;
23use std::fmt;
24use std::fmt::Debug;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::task::{self, Poll};
28use std::time::Duration;
29use tokio_util::codec::Decoder;
30
31type PipelineOutput = oneshot::Sender<RedisResult<Value>>;
33
34enum ResponseAggregate {
35 SingleCommand,
36 Pipeline {
37 buffer: Vec<Value>,
38 first_err: Option<RedisError>,
39 expectation: PipelineResponseExpectation,
40 },
41}
42
43struct PipelineResponseExpectation {
45 skipped_response_count: usize,
47 expected_response_count: usize,
49 is_transaction: bool,
51}
52
53impl ResponseAggregate {
54 fn new(expectation: Option<PipelineResponseExpectation>) -> Self {
55 match expectation {
56 Some(expectation) => ResponseAggregate::Pipeline {
57 buffer: Vec::new(),
58 first_err: None,
59 expectation,
60 },
61 None => ResponseAggregate::SingleCommand,
62 }
63 }
64}
65
66struct InFlight {
67 output: PipelineOutput,
68 response_aggregate: ResponseAggregate,
69}
70
71struct PipelineMessage {
73 input: Vec<u8>,
74 output: PipelineOutput,
75 expectation: Option<PipelineResponseExpectation>,
79}
80
81#[derive(Clone)]
86struct Pipeline {
87 sender: mpsc::Sender<PipelineMessage>,
88}
89
90impl Debug for Pipeline {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_tuple("Pipeline").field(&self.sender).finish()
93 }
94}
95
96#[cfg(feature = "cache-aio")]
97pin_project! {
98 struct PipelineSink<T> {
99 #[pin]
100 sink_stream: T,
101 in_flight: VecDeque<InFlight>,
102 error: Option<RedisError>,
103 push_sender: Option<Arc<dyn AsyncPushSender>>,
104 cache_manager: Option<CacheManager>,
105 }
106}
107
108#[cfg(not(feature = "cache-aio"))]
109pin_project! {
110 struct PipelineSink<T> {
111 #[pin]
112 sink_stream: T,
113 in_flight: VecDeque<InFlight>,
114 error: Option<RedisError>,
115 push_sender: Option<Arc<dyn AsyncPushSender>>,
116 }
117}
118
119fn send_push(push_sender: &Option<Arc<dyn AsyncPushSender>>, info: PushInfo) {
120 if let Some(sender) = push_sender {
121 let _ = sender.send(info);
122 };
123}
124
125pub(crate) fn send_disconnect(push_sender: &Option<Arc<dyn AsyncPushSender>>) {
126 send_push(push_sender, PushInfo::disconnect());
127}
128
129impl<T> PipelineSink<T>
130where
131 T: Stream<Item = RedisResult<Value>> + 'static,
132{
133 fn new(
134 sink_stream: T,
135 push_sender: Option<Arc<dyn AsyncPushSender>>,
136 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
137 ) -> Self
138 where
139 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
140 {
141 PipelineSink {
142 sink_stream,
143 in_flight: VecDeque::new(),
144 error: None,
145 push_sender,
146 #[cfg(feature = "cache-aio")]
147 cache_manager,
148 }
149 }
150
151 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
153 loop {
154 let item = ready!(self.as_mut().project().sink_stream.poll_next(cx));
155 let item = match item {
156 Some(result) => result,
157 None => Err(closed_connection_error()),
159 };
160
161 let is_unrecoverable = item.as_ref().is_err_and(|err| err.is_unrecoverable_error());
162 self.as_mut().send_result(item);
163 if is_unrecoverable {
164 let self_ = self.project();
165 send_disconnect(self_.push_sender);
166 return Poll::Ready(Err(()));
167 }
168 }
169 }
170
171 fn send_result(self: Pin<&mut Self>, result: RedisResult<Value>) {
172 let self_ = self.project();
173 let result = match result {
174 Ok(Value::Push { kind, data }) if !kind.has_reply() => {
176 #[cfg(feature = "cache-aio")]
177 if let Some(cache_manager) = &self_.cache_manager {
178 cache_manager.handle_push_value(&kind, &data);
179 }
180 send_push(self_.push_sender, PushInfo { kind, data });
181
182 return;
183 }
184 Ok(Value::Push { kind, data }) if kind.has_reply() => {
186 send_push(
187 self_.push_sender,
188 PushInfo {
189 kind: kind.clone(),
190 data: data.clone(),
191 },
192 );
193 Ok(Value::Push { kind, data })
194 }
195 _ => result,
196 };
197
198 let mut entry = match self_.in_flight.pop_front() {
199 Some(entry) => entry,
200 None => return,
201 };
202
203 match &mut entry.response_aggregate {
204 ResponseAggregate::SingleCommand => {
205 entry.output.send(result).ok();
206 }
207 ResponseAggregate::Pipeline {
208 buffer,
209 first_err,
210 expectation:
211 PipelineResponseExpectation {
212 expected_response_count,
213 skipped_response_count,
214 is_transaction,
215 },
216 } => {
217 if *skipped_response_count > 0 {
218 if first_err.is_none() && *is_transaction {
221 *first_err = result.and_then(Value::extract_error).err();
222 }
223
224 *skipped_response_count -= 1;
225 self_.in_flight.push_front(entry);
226 return;
227 }
228
229 match result {
230 Ok(item) => {
231 buffer.push(item);
232 }
233 Err(err) => {
234 if first_err.is_none() {
235 *first_err = Some(err);
236 }
237 }
238 }
239
240 if buffer.len() < *expected_response_count {
241 self_.in_flight.push_front(entry);
243 return;
244 }
245
246 let response = match first_err.take() {
247 Some(err) => Err(err),
248 None => Ok(Value::Array(std::mem::take(buffer))),
249 };
250
251 entry.output.send(response).ok();
255 }
256 }
257 }
258}
259
260impl<T> Sink<PipelineMessage> for PipelineSink<T>
261where
262 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
263{
264 type Error = ();
265
266 fn poll_ready(
268 mut self: Pin<&mut Self>,
269 cx: &mut task::Context,
270 ) -> Poll<Result<(), Self::Error>> {
271 match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
272 Ok(()) => Ok(()).into(),
273 Err(err) => {
274 *self.project().error = Some(err);
275 Ok(()).into()
276 }
277 }
278 }
279
280 fn start_send(
281 mut self: Pin<&mut Self>,
282 PipelineMessage {
283 input,
284 output,
285 expectation,
286 }: PipelineMessage,
287 ) -> Result<(), Self::Error> {
288 if output.is_closed() {
292 return Ok(());
293 }
294
295 let self_ = self.as_mut().project();
296
297 if let Some(err) = self_.error.take() {
298 let _ = output.send(Err(err));
299 return Err(());
300 }
301
302 match self_.sink_stream.start_send(input) {
303 Ok(()) => {
304 let response_aggregate = ResponseAggregate::new(expectation);
305 let entry = InFlight {
306 output,
307 response_aggregate,
308 };
309
310 self_.in_flight.push_back(entry);
311 Ok(())
312 }
313 Err(err) => {
314 let _ = output.send(Err(err));
315 Err(())
316 }
317 }
318 }
319
320 fn poll_flush(
321 mut self: Pin<&mut Self>,
322 cx: &mut task::Context,
323 ) -> Poll<Result<(), Self::Error>> {
324 ready!(self
325 .as_mut()
326 .project()
327 .sink_stream
328 .poll_flush(cx)
329 .map_err(|err| {
330 self.as_mut().send_result(Err(err));
331 }))?;
332 self.poll_read(cx)
333 }
334
335 fn poll_close(
336 mut self: Pin<&mut Self>,
337 cx: &mut task::Context,
338 ) -> Poll<Result<(), Self::Error>> {
339 if !self.in_flight.is_empty() {
342 ready!(self.as_mut().poll_flush(cx))?;
343 }
344 let this = self.as_mut().project();
345 this.sink_stream.poll_close(cx).map_err(|err| {
346 self.send_result(Err(err));
347 })
348 }
349}
350
351impl Pipeline {
352 fn new<T>(
353 sink_stream: T,
354 push_sender: Option<Arc<dyn AsyncPushSender>>,
355 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
356 ) -> (Self, impl Future<Output = ()>)
357 where
358 T: Sink<Vec<u8>, Error = RedisError>,
359 T: Stream<Item = RedisResult<Value>>,
360 T: Unpin + Send + 'static,
361 {
362 const BUFFER_SIZE: usize = 50;
363 let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
364
365 let sink = PipelineSink::new(
366 sink_stream,
367 push_sender,
368 #[cfg(feature = "cache-aio")]
369 cache_manager,
370 );
371 let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
372 .map(Ok)
373 .forward(sink)
374 .map(|_| ());
375 (Pipeline { sender }, f)
376 }
377
378 async fn send_recv(
379 &mut self,
380 input: Vec<u8>,
381 expectation: Option<PipelineResponseExpectation>,
384 timeout: Option<Duration>,
385 ) -> Result<Value, RedisError> {
386 let (sender, receiver) = oneshot::channel();
387
388 let request = async {
389 self.sender
390 .send(PipelineMessage {
391 input,
392 expectation,
393 output: sender,
394 })
395 .await
396 .map_err(|_| None)?;
397
398 receiver.await
399 .map_err(|_| None)
402 .and_then(|res| res.map_err(Some))
403 };
404
405 match timeout {
406 Some(timeout) => match Runtime::locate().timeout(timeout, request).await {
407 Ok(res) => res,
408 Err(elapsed) => Err(Some(elapsed.into())),
409 },
410 None => request.await,
411 }
412 .map_err(|err| err.unwrap_or_else(closed_connection_error))
413 }
414}
415
416#[derive(Clone)]
429pub struct MultiplexedConnection {
430 pipeline: Pipeline,
431 db: i64,
432 response_timeout: Option<Duration>,
433 protocol: ProtocolVersion,
434 _task_handle: Option<SharedHandleContainer>,
438 #[cfg(feature = "cache-aio")]
439 pub(crate) cache_manager: Option<CacheManager>,
440}
441
442impl Debug for MultiplexedConnection {
443 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
444 f.debug_struct("MultiplexedConnection")
445 .field("pipeline", &self.pipeline)
446 .field("db", &self.db)
447 .finish()
448 }
449}
450
451impl MultiplexedConnection {
452 pub async fn new<C>(
455 connection_info: &RedisConnectionInfo,
456 stream: C,
457 ) -> RedisResult<(Self, impl Future<Output = ()>)>
458 where
459 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
460 {
461 Self::new_with_response_timeout(connection_info, stream, None).await
462 }
463
464 pub async fn new_with_response_timeout<C>(
467 connection_info: &RedisConnectionInfo,
468 stream: C,
469 response_timeout: Option<std::time::Duration>,
470 ) -> RedisResult<(Self, impl Future<Output = ()>)>
471 where
472 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
473 {
474 Self::new_with_config(
475 connection_info,
476 stream,
477 AsyncConnectionConfig {
478 response_timeout,
479 ..Default::default()
480 },
481 )
482 .await
483 }
484
485 pub async fn new_with_config<C>(
488 connection_info: &RedisConnectionInfo,
489 stream: C,
490 config: AsyncConnectionConfig,
491 ) -> RedisResult<(Self, impl Future<Output = ()>)>
492 where
493 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
494 {
495 let mut codec = ValueCodec::default().framed(stream);
496 if config.push_sender.is_some() {
497 check_resp3!(
498 connection_info.protocol,
499 "Can only pass push sender to a connection using RESP3"
500 );
501 }
502
503 #[cfg(feature = "cache-aio")]
504 let cache_config = config.cache.as_ref().map(|cache| match cache {
505 crate::client::Cache::Config(cache_config) => *cache_config,
506 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
507 crate::client::Cache::Manager(cache_manager) => cache_manager.cache_config,
508 });
509 #[cfg(feature = "cache-aio")]
510 let cache_manager_opt = config
511 .cache
512 .map(|cache| {
513 check_resp3!(
514 connection_info.protocol,
515 "Can only enable client side caching in a connection using RESP3"
516 );
517 match cache {
518 crate::client::Cache::Config(cache_config) => {
519 Ok(CacheManager::new(cache_config))
520 }
521 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
522 crate::client::Cache::Manager(cache_manager) => Ok(cache_manager),
523 }
524 })
525 .transpose()?;
526
527 setup_connection(
528 &mut codec,
529 connection_info,
530 #[cfg(feature = "cache-aio")]
531 cache_config,
532 )
533 .await?;
534 if config.push_sender.is_some() {
535 check_resp3!(
536 connection_info.protocol,
537 "Can only pass push sender to a connection using RESP3"
538 );
539 }
540
541 let (pipeline, driver) = Pipeline::new(
542 codec,
543 config.push_sender,
544 #[cfg(feature = "cache-aio")]
545 cache_manager_opt.clone(),
546 );
547 let con = MultiplexedConnection {
548 pipeline,
549 db: connection_info.db,
550 response_timeout: config.response_timeout,
551 protocol: connection_info.protocol,
552 _task_handle: None,
553 #[cfg(feature = "cache-aio")]
554 cache_manager: cache_manager_opt,
555 };
556
557 Ok((con, driver))
558 }
559
560 pub(crate) fn set_task_handle(&mut self, handle: TaskHandle) {
563 self._task_handle = Some(SharedHandleContainer::new(handle));
564 }
565
566 pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
568 self.response_timeout = Some(timeout);
569 }
570
571 pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
574 #[cfg(feature = "cache-aio")]
575 if let Some(cache_manager) = &self.cache_manager {
576 match cache_manager.get_cached_cmd(cmd) {
577 PrepareCacheResult::Cached(value) => return Ok(value),
578 PrepareCacheResult::NotCached(cacheable_command) => {
579 let mut pipeline = crate::Pipeline::new();
580 cacheable_command.pack_command(cache_manager, &mut pipeline);
581
582 let result = self
583 .pipeline
584 .send_recv(
585 pipeline.get_packed_pipeline(),
586 Some(PipelineResponseExpectation {
587 skipped_response_count: 0,
588 expected_response_count: pipeline.commands.len(),
589 is_transaction: false,
590 }),
591 self.response_timeout,
592 )
593 .await?;
594 let replies: Vec<Value> = crate::types::from_owned_redis_value(result)?;
595 return cacheable_command.resolve(cache_manager, replies.into_iter());
596 }
597 _ => (),
598 }
599 }
600 self.pipeline
601 .send_recv(cmd.get_packed_command(), None, self.response_timeout)
602 .await
603 }
604
605 pub async fn send_packed_commands(
609 &mut self,
610 cmd: &crate::Pipeline,
611 offset: usize,
612 count: usize,
613 ) -> RedisResult<Vec<Value>> {
614 #[cfg(feature = "cache-aio")]
615 if let Some(cache_manager) = &self.cache_manager {
616 let (cacheable_pipeline, pipeline, (skipped_response_count, expected_response_count)) =
617 cache_manager.get_cached_pipeline(cmd);
618 let result = self
619 .pipeline
620 .send_recv(
621 pipeline.get_packed_pipeline(),
622 Some(PipelineResponseExpectation {
623 skipped_response_count,
624 expected_response_count,
625 is_transaction: cacheable_pipeline.transaction_mode,
626 }),
627 self.response_timeout,
628 )
629 .await?;
630
631 return cacheable_pipeline.resolve(cache_manager, result);
632 }
633 let value = self
634 .pipeline
635 .send_recv(
636 cmd.get_packed_pipeline(),
637 Some(PipelineResponseExpectation {
638 skipped_response_count: offset,
639 expected_response_count: count,
640 is_transaction: cmd.is_transaction(),
641 }),
642 self.response_timeout,
643 )
644 .await?;
645 match value {
646 Value::Array(values) => Ok(values),
647 _ => Ok(vec![value]),
648 }
649 }
650
651 #[cfg(feature = "cache-aio")]
653 #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
654 pub fn get_cache_statistics(&self) -> Option<CacheStatistics> {
655 self.cache_manager.as_ref().map(|cm| cm.statistics())
656 }
657}
658
659impl ConnectionLike for MultiplexedConnection {
660 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
661 (async move { self.send_packed_command(cmd).await }).boxed()
662 }
663
664 fn req_packed_commands<'a>(
665 &'a mut self,
666 cmd: &'a crate::Pipeline,
667 offset: usize,
668 count: usize,
669 ) -> RedisFuture<'a, Vec<Value>> {
670 (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
671 }
672
673 fn get_db(&self) -> i64 {
674 self.db
675 }
676}
677
678impl MultiplexedConnection {
679 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
696 check_resp3!(self.protocol);
697 let mut cmd = cmd("SUBSCRIBE");
698 cmd.arg(channel_name);
699 cmd.exec_async(self).await?;
700 Ok(())
701 }
702
703 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
718 check_resp3!(self.protocol);
719 let mut cmd = cmd("UNSUBSCRIBE");
720 cmd.arg(channel_name);
721 cmd.exec_async(self).await?;
722 Ok(())
723 }
724
725 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
743 check_resp3!(self.protocol);
744 let mut cmd = cmd("PSUBSCRIBE");
745 cmd.arg(channel_pattern);
746 cmd.exec_async(self).await?;
747 Ok(())
748 }
749
750 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
754 check_resp3!(self.protocol);
755 let mut cmd = cmd("PUNSUBSCRIBE");
756 cmd.arg(channel_pattern);
757 cmd.exec_async(self).await?;
758 Ok(())
759 }
760}