1use super::{AsyncPushSender, ConnectionLike, Runtime, SharedHandleContainer, TaskHandle};
2#[cfg(feature = "cache-aio")]
3use crate::caching::{CacheManager, CacheStatistics, PrepareCacheResult};
4use crate::{
5 aio::setup_connection,
6 check_resp3, cmd,
7 cmd::Cmd,
8 errors::{closed_connection_error, RedisError},
9 parser::ValueCodec,
10 types::{RedisFuture, RedisResult, Value},
11 AsyncConnectionConfig, ProtocolVersion, PushInfo, RedisConnectionInfo, ServerError,
12 ToRedisArgs,
13};
14use ::tokio::{
15 io::{AsyncRead, AsyncWrite},
16 sync::{mpsc, oneshot},
17};
18use futures_util::{
19 future::{Future, FutureExt},
20 ready,
21 sink::Sink,
22 stream::{self, Stream, StreamExt},
23};
24use pin_project_lite::pin_project;
25use std::collections::VecDeque;
26use std::fmt;
27use std::fmt::Debug;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{self, Poll};
31use std::time::Duration;
32use tokio_util::codec::Decoder;
33
34type PipelineOutput = oneshot::Sender<RedisResult<Value>>;
36
37enum ErrorOrErrors {
38 Errors(Vec<(usize, ServerError)>),
39 FirstError(RedisError),
41}
42
43enum ResponseAggregate {
44 SingleCommand,
45 Pipeline {
46 buffer: Vec<Value>,
47 error_or_errors: ErrorOrErrors,
48 expectation: PipelineResponseExpectation,
49 },
50}
51
52struct PipelineResponseExpectation {
54 skipped_response_count: usize,
56 expected_response_count: usize,
58 is_transaction: bool,
60 seen_responses: usize,
61}
62
63impl ResponseAggregate {
64 fn new(expectation: Option<PipelineResponseExpectation>) -> Self {
65 match expectation {
66 Some(expectation) => ResponseAggregate::Pipeline {
67 buffer: Vec::new(),
68 error_or_errors: ErrorOrErrors::Errors(Vec::new()),
69 expectation,
70 },
71 None => ResponseAggregate::SingleCommand,
72 }
73 }
74}
75
76struct InFlight {
77 output: Option<PipelineOutput>,
78 response_aggregate: ResponseAggregate,
79}
80
81struct PipelineMessage {
83 input: Vec<u8>,
84 output: Option<PipelineOutput>,
86 expectation: Option<PipelineResponseExpectation>,
90}
91
92#[derive(Clone)]
97struct Pipeline {
98 sender: mpsc::Sender<PipelineMessage>,
99}
100
101impl Debug for Pipeline {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 f.debug_tuple("Pipeline").field(&self.sender).finish()
104 }
105}
106
107#[cfg(feature = "cache-aio")]
108pin_project! {
109 struct PipelineSink<T> {
110 #[pin]
111 sink_stream: T,
112 in_flight: VecDeque<InFlight>,
113 error: Option<RedisError>,
114 push_sender: Option<Arc<dyn AsyncPushSender>>,
115 cache_manager: Option<CacheManager>,
116 }
117}
118
119#[cfg(not(feature = "cache-aio"))]
120pin_project! {
121 struct PipelineSink<T> {
122 #[pin]
123 sink_stream: T,
124 in_flight: VecDeque<InFlight>,
125 error: Option<RedisError>,
126 push_sender: Option<Arc<dyn AsyncPushSender>>,
127 }
128}
129
130fn send_push(push_sender: &Option<Arc<dyn AsyncPushSender>>, info: PushInfo) {
131 if let Some(sender) = push_sender {
132 let _ = sender.send(info);
133 };
134}
135
136pub(crate) fn send_disconnect(push_sender: &Option<Arc<dyn AsyncPushSender>>) {
137 send_push(push_sender, PushInfo::disconnect());
138}
139
140impl<T> PipelineSink<T>
141where
142 T: Stream<Item = RedisResult<Value>> + 'static,
143{
144 fn new(
145 sink_stream: T,
146 push_sender: Option<Arc<dyn AsyncPushSender>>,
147 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
148 ) -> Self
149 where
150 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
151 {
152 PipelineSink {
153 sink_stream,
154 in_flight: VecDeque::new(),
155 error: None,
156 push_sender,
157 #[cfg(feature = "cache-aio")]
158 cache_manager,
159 }
160 }
161
162 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
164 loop {
165 let item = ready!(self.as_mut().project().sink_stream.poll_next(cx));
166 let item = match item {
167 Some(result) => result,
168 None => Err(closed_connection_error()),
170 };
171
172 let is_unrecoverable = item.as_ref().is_err_and(|err| err.is_unrecoverable_error());
173 self.as_mut().send_result(item);
174 if is_unrecoverable {
175 let self_ = self.project();
176 send_disconnect(self_.push_sender);
177 return Poll::Ready(Err(()));
178 }
179 }
180 }
181
182 fn send_result(self: Pin<&mut Self>, result: RedisResult<Value>) {
183 let self_ = self.project();
184 let result = match result {
185 Ok(Value::Push { kind, data }) if !kind.has_reply() => {
187 #[cfg(feature = "cache-aio")]
188 if let Some(cache_manager) = &self_.cache_manager {
189 cache_manager.handle_push_value(&kind, &data);
190 }
191 send_push(self_.push_sender, PushInfo { kind, data });
192
193 return;
194 }
195 Ok(Value::Push { kind, data }) if kind.has_reply() => {
197 send_push(
198 self_.push_sender,
199 PushInfo {
200 kind: kind.clone(),
201 data: data.clone(),
202 },
203 );
204 Ok(Value::Push { kind, data })
205 }
206 _ => result,
207 };
208
209 let mut entry = match self_.in_flight.pop_front() {
210 Some(entry) => entry,
211 None => return,
212 };
213
214 match &mut entry.response_aggregate {
215 ResponseAggregate::SingleCommand => {
216 if let Some(output) = entry.output.take() {
217 _ = output.send(result);
218 }
219 }
220 ResponseAggregate::Pipeline {
221 buffer,
222 error_or_errors,
223 expectation:
224 PipelineResponseExpectation {
225 expected_response_count,
226 skipped_response_count,
227 is_transaction,
228 seen_responses,
229 },
230 } => {
231 *seen_responses += 1;
232 if *skipped_response_count > 0 {
233 if *is_transaction {
236 if let ErrorOrErrors::Errors(errs) = error_or_errors {
237 match result {
238 Ok(Value::ServerError(err)) => {
239 errs.push((*seen_responses - 2, err)); }
241 Err(err) => *error_or_errors = ErrorOrErrors::FirstError(err),
242 _ => {}
243 }
244 }
245 }
246
247 *skipped_response_count -= 1;
248 self_.in_flight.push_front(entry);
249 return;
250 }
251
252 match result {
253 Ok(item) => {
254 buffer.push(item);
255 }
256 Err(err) => {
257 if matches!(error_or_errors, ErrorOrErrors::Errors(_)) {
258 *error_or_errors = ErrorOrErrors::FirstError(err)
259 }
260 }
261 }
262
263 if buffer.len() < *expected_response_count {
264 self_.in_flight.push_front(entry);
266 return;
267 }
268
269 let response =
270 match std::mem::replace(error_or_errors, ErrorOrErrors::Errors(Vec::new())) {
271 ErrorOrErrors::Errors(errors) => {
272 if errors.is_empty() {
273 Ok(Value::Array(std::mem::take(buffer)))
274 } else {
275 Err(RedisError::make_aborted_transaction(errors))
276 }
277 }
278 ErrorOrErrors::FirstError(redis_error) => Err(redis_error),
279 };
280
281 if let Some(output) = entry.output.take() {
285 _ = output.send(response);
286 }
287 }
288 }
289 }
290}
291
292impl<T> Sink<PipelineMessage> for PipelineSink<T>
293where
294 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
295{
296 type Error = ();
297
298 fn poll_ready(
300 mut self: Pin<&mut Self>,
301 cx: &mut task::Context,
302 ) -> Poll<Result<(), Self::Error>> {
303 match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
304 Ok(()) => Ok(()).into(),
305 Err(err) => {
306 *self.project().error = Some(err);
307 Ok(()).into()
308 }
309 }
310 }
311
312 fn start_send(
313 mut self: Pin<&mut Self>,
314 PipelineMessage {
315 input,
316 mut output,
317 expectation,
318 }: PipelineMessage,
319 ) -> Result<(), Self::Error> {
320 if output.as_ref().is_some_and(|output| output.is_closed()) {
324 return Ok(());
325 }
326
327 let self_ = self.as_mut().project();
328
329 if let Some(err) = self_.error.take() {
330 if let Some(output) = output.take() {
331 _ = output.send(Err(err));
332 }
333 return Err(());
334 }
335
336 match self_.sink_stream.start_send(input) {
337 Ok(()) => {
338 let response_aggregate = ResponseAggregate::new(expectation);
339 let entry = InFlight {
340 output,
341 response_aggregate,
342 };
343
344 self_.in_flight.push_back(entry);
345 Ok(())
346 }
347 Err(err) => {
348 if let Some(output) = output.take() {
349 _ = output.send(Err(err));
350 }
351 Err(())
352 }
353 }
354 }
355
356 fn poll_flush(
357 mut self: Pin<&mut Self>,
358 cx: &mut task::Context,
359 ) -> Poll<Result<(), Self::Error>> {
360 ready!(self
361 .as_mut()
362 .project()
363 .sink_stream
364 .poll_flush(cx)
365 .map_err(|err| {
366 self.as_mut().send_result(Err(err));
367 }))?;
368 self.poll_read(cx)
369 }
370
371 fn poll_close(
372 mut self: Pin<&mut Self>,
373 cx: &mut task::Context,
374 ) -> Poll<Result<(), Self::Error>> {
375 if !self.in_flight.is_empty() {
378 ready!(self.as_mut().poll_flush(cx))?;
379 }
380 let this = self.as_mut().project();
381 this.sink_stream.poll_close(cx).map_err(|err| {
382 self.send_result(Err(err));
383 })
384 }
385}
386
387impl Pipeline {
388 fn new<T>(
389 sink_stream: T,
390 push_sender: Option<Arc<dyn AsyncPushSender>>,
391 #[cfg(feature = "cache-aio")] cache_manager: Option<CacheManager>,
392 ) -> (Self, impl Future<Output = ()>)
393 where
394 T: Sink<Vec<u8>, Error = RedisError>,
395 T: Stream<Item = RedisResult<Value>>,
396 T: Unpin + Send + 'static,
397 {
398 const BUFFER_SIZE: usize = 50;
399 let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
400
401 let sink = PipelineSink::new(
402 sink_stream,
403 push_sender,
404 #[cfg(feature = "cache-aio")]
405 cache_manager,
406 );
407 let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
408 .map(Ok)
409 .forward(sink)
410 .map(|_| ());
411 (Pipeline { sender }, f)
412 }
413
414 async fn send_recv(
415 &mut self,
416 input: Vec<u8>,
417 expectation: Option<PipelineResponseExpectation>,
420 timeout: Option<Duration>,
421 skip_response: bool,
422 ) -> Result<Value, RedisError> {
423 if input.is_empty() {
424 return Err(RedisError::make_empty_command());
425 }
426
427 let request = async {
428 if skip_response {
429 self.sender
430 .send(PipelineMessage {
431 input,
432 expectation,
433 output: None,
434 })
435 .await
436 .map_err(|_| None)?;
437
438 return Ok(Value::Nil);
439 }
440
441 let (sender, receiver) = oneshot::channel();
442
443 self.sender
444 .send(PipelineMessage {
445 input,
446 expectation,
447 output: Some(sender),
448 })
449 .await
450 .map_err(|_| None)?;
451
452 receiver.await
453 .map_err(|_| None)
456 .and_then(|res| res.map_err(Some))
457 };
458
459 match timeout {
460 Some(timeout) => match Runtime::locate().timeout(timeout, request).await {
461 Ok(res) => res,
462 Err(elapsed) => Err(Some(elapsed.into())),
463 },
464 None => request.await,
465 }
466 .map_err(|err| err.unwrap_or_else(closed_connection_error))
467 }
468}
469
470#[derive(Clone)]
483pub struct MultiplexedConnection {
484 pipeline: Pipeline,
485 db: i64,
486 response_timeout: Option<Duration>,
487 protocol: ProtocolVersion,
488 _task_handle: Option<SharedHandleContainer>,
492 #[cfg(feature = "cache-aio")]
493 pub(crate) cache_manager: Option<CacheManager>,
494}
495
496impl Debug for MultiplexedConnection {
497 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
498 f.debug_struct("MultiplexedConnection")
499 .field("pipeline", &self.pipeline)
500 .field("db", &self.db)
501 .finish()
502 }
503}
504
505impl MultiplexedConnection {
506 pub async fn new<C>(
509 connection_info: &RedisConnectionInfo,
510 stream: C,
511 ) -> RedisResult<(Self, impl Future<Output = ()>)>
512 where
513 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
514 {
515 Self::new_with_config(connection_info, stream, AsyncConnectionConfig::default()).await
516 }
517
518 pub async fn new_with_config<C>(
521 connection_info: &RedisConnectionInfo,
522 stream: C,
523 config: AsyncConnectionConfig,
524 ) -> RedisResult<(Self, impl Future<Output = ()>)>
525 where
526 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
527 {
528 let mut codec = ValueCodec::default().framed(stream);
529 if config.push_sender.is_some() {
530 check_resp3!(
531 connection_info.protocol,
532 "Can only pass push sender to a connection using RESP3"
533 );
534 }
535
536 #[cfg(feature = "cache-aio")]
537 let cache_config = config.cache.as_ref().map(|cache| match cache {
538 crate::client::Cache::Config(cache_config) => *cache_config,
539 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
540 crate::client::Cache::Manager(cache_manager) => cache_manager.cache_config,
541 });
542 #[cfg(feature = "cache-aio")]
543 let cache_manager_opt = config
544 .cache
545 .map(|cache| {
546 check_resp3!(
547 connection_info.protocol,
548 "Can only enable client side caching in a connection using RESP3"
549 );
550 match cache {
551 crate::client::Cache::Config(cache_config) => {
552 Ok(CacheManager::new(cache_config))
553 }
554 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
555 crate::client::Cache::Manager(cache_manager) => Ok(cache_manager),
556 }
557 })
558 .transpose()?;
559
560 setup_connection(
561 &mut codec,
562 connection_info,
563 #[cfg(feature = "cache-aio")]
564 cache_config,
565 )
566 .await?;
567 if config.push_sender.is_some() {
568 check_resp3!(
569 connection_info.protocol,
570 "Can only pass push sender to a connection using RESP3"
571 );
572 }
573
574 let (pipeline, driver) = Pipeline::new(
575 codec,
576 config.push_sender,
577 #[cfg(feature = "cache-aio")]
578 cache_manager_opt.clone(),
579 );
580 let con = MultiplexedConnection {
581 pipeline,
582 db: connection_info.db,
583 response_timeout: config.response_timeout,
584 protocol: connection_info.protocol,
585 _task_handle: None,
586 #[cfg(feature = "cache-aio")]
587 cache_manager: cache_manager_opt,
588 };
589
590 Ok((con, driver))
591 }
592
593 pub(crate) fn set_task_handle(&mut self, handle: TaskHandle) {
596 self._task_handle = Some(SharedHandleContainer::new(handle));
597 }
598
599 pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
601 self.response_timeout = Some(timeout);
602 }
603
604 pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
607 #[cfg(feature = "cache-aio")]
608 if let Some(cache_manager) = &self.cache_manager {
609 match cache_manager.get_cached_cmd(cmd) {
610 PrepareCacheResult::Cached(value) => return Ok(value),
611 PrepareCacheResult::NotCached(cacheable_command) => {
612 let mut pipeline = crate::Pipeline::new();
613 cacheable_command.pack_command(cache_manager, &mut pipeline);
614
615 let result = self
616 .pipeline
617 .send_recv(
618 pipeline.get_packed_pipeline(),
619 Some(PipelineResponseExpectation {
620 skipped_response_count: 0,
621 expected_response_count: pipeline.commands.len(),
622 is_transaction: false,
623 seen_responses: 0,
624 }),
625 self.response_timeout,
626 cmd.is_no_response(),
627 )
628 .await?;
629 let replies: Vec<Value> = crate::types::from_redis_value(result)?;
630 return cacheable_command.resolve(cache_manager, replies.into_iter());
631 }
632 _ => (),
633 }
634 }
635 self.pipeline
636 .send_recv(
637 cmd.get_packed_command(),
638 None,
639 self.response_timeout,
640 cmd.is_no_response(),
641 )
642 .await
643 }
644
645 pub async fn send_packed_commands(
649 &mut self,
650 cmd: &crate::Pipeline,
651 offset: usize,
652 count: usize,
653 ) -> RedisResult<Vec<Value>> {
654 #[cfg(feature = "cache-aio")]
655 if let Some(cache_manager) = &self.cache_manager {
656 let (cacheable_pipeline, pipeline, (skipped_response_count, expected_response_count)) =
657 cache_manager.get_cached_pipeline(cmd);
658 if pipeline.is_empty() {
659 return cacheable_pipeline.resolve(cache_manager, Value::Array(Vec::new()));
660 }
661 let result = self
662 .pipeline
663 .send_recv(
664 pipeline.get_packed_pipeline(),
665 Some(PipelineResponseExpectation {
666 skipped_response_count,
667 expected_response_count,
668 is_transaction: cacheable_pipeline.transaction_mode,
669 seen_responses: 0,
670 }),
671 self.response_timeout,
672 false,
673 )
674 .await?;
675
676 return cacheable_pipeline.resolve(cache_manager, result);
677 }
678 let value = self
679 .pipeline
680 .send_recv(
681 cmd.get_packed_pipeline(),
682 Some(PipelineResponseExpectation {
683 skipped_response_count: offset,
684 expected_response_count: count,
685 is_transaction: cmd.is_transaction(),
686 seen_responses: 0,
687 }),
688 self.response_timeout,
689 false,
690 )
691 .await?;
692 match value {
693 Value::Array(values) => Ok(values),
694 _ => Ok(vec![value]),
695 }
696 }
697
698 #[cfg(feature = "cache-aio")]
700 #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
701 pub fn get_cache_statistics(&self) -> Option<CacheStatistics> {
702 self.cache_manager.as_ref().map(|cm| cm.statistics())
703 }
704}
705
706impl ConnectionLike for MultiplexedConnection {
707 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
708 (async move { self.send_packed_command(cmd).await }).boxed()
709 }
710
711 fn req_packed_commands<'a>(
712 &'a mut self,
713 cmd: &'a crate::Pipeline,
714 offset: usize,
715 count: usize,
716 ) -> RedisFuture<'a, Vec<Value>> {
717 (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
718 }
719
720 fn get_db(&self) -> i64 {
721 self.db
722 }
723}
724
725impl MultiplexedConnection {
726 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
743 check_resp3!(self.protocol);
744 let mut cmd = cmd("SUBSCRIBE");
745 cmd.arg(channel_name);
746 cmd.exec_async(self).await?;
747 Ok(())
748 }
749
750 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
765 check_resp3!(self.protocol);
766 let mut cmd = cmd("UNSUBSCRIBE");
767 cmd.arg(channel_name);
768 cmd.exec_async(self).await?;
769 Ok(())
770 }
771
772 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
791 check_resp3!(self.protocol);
792 let mut cmd = cmd("PSUBSCRIBE");
793 cmd.arg(channel_pattern);
794 cmd.exec_async(self).await?;
795 Ok(())
796 }
797
798 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
802 check_resp3!(self.protocol);
803 let mut cmd = cmd("PUNSUBSCRIBE");
804 cmd.arg(channel_pattern);
805 cmd.exec_async(self).await?;
806 Ok(())
807 }
808}