1use crate::aio::Runtime;
2use crate::parser::ValueCodec;
3use crate::types::{closed_connection_error, RedisError, RedisResult, Value};
4use crate::{cmd, from_owned_redis_value, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs};
5use ::tokio::{
6 io::{AsyncRead, AsyncWrite},
7 sync::oneshot,
8};
9use futures_util::{
10 future::{Future, FutureExt},
11 ready,
12 sink::Sink,
13 stream::{self, Stream, StreamExt},
14};
15use pin_project_lite::pin_project;
16use std::collections::VecDeque;
17use std::pin::Pin;
18use std::task::{self, Poll};
19use tokio::sync::mpsc::unbounded_channel;
20use tokio::sync::mpsc::UnboundedSender;
21use tokio_util::codec::Decoder;
22
23use super::{setup_connection, SharedHandleContainer};
24
25type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
27
28struct PipelineMessage {
30 input: Vec<u8>,
31 output: RequestResultSender,
32}
33
34#[derive(Clone)]
44pub struct PubSubSink {
45 sender: UnboundedSender<PipelineMessage>,
46}
47
48pin_project! {
49 pub struct PubSubStream {
59 #[pin]
60 receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
61 _task_handle: Option<SharedHandleContainer>,
63 }
64}
65
66pin_project! {
67 struct PipelineSink<T> {
68 #[pin]
70 sink_stream: T,
71 in_flight: VecDeque<RequestResultSender>,
73 sender: UnboundedSender<Msg>,
75 }
76}
77
78impl<T> PipelineSink<T>
79where
80 T: Stream<Item = RedisResult<Value>> + 'static,
81{
82 fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
83 where
84 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
85 {
86 PipelineSink {
87 sink_stream,
88 in_flight: VecDeque::new(),
89 sender,
90 }
91 }
92
93 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
95 loop {
96 let self_ = self.as_mut().project();
97 if self_.sender.is_closed() {
98 return Poll::Ready(Err(()));
99 }
100
101 let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
102 Some(result) => result,
103 None => return Poll::Ready(Err(())),
106 };
107 self.as_mut().handle_message(item)?;
108 }
109 }
110
111 fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
112 let self_ = self.project();
113
114 match result {
115 Ok(Value::Array(value)) => {
116 if let Some(Value::BulkString(kind)) = value.first() {
117 if matches!(
118 kind.as_slice(),
119 b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
120 ) {
121 if let Some(entry) = self_.in_flight.pop_front() {
122 let _ = entry.send(Ok(Value::Array(value)));
123 };
124 return Ok(());
125 }
126 }
127
128 if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
129 let _ = self_.sender.send(msg);
130 Ok(())
131 } else {
132 Err(())
133 }
134 }
135
136 Ok(Value::Push { kind, data }) => {
137 if kind.has_reply() {
138 if let Some(entry) = self_.in_flight.pop_front() {
139 let _ = entry.send(Ok(Value::Push { kind, data }));
140 };
141 return Ok(());
142 }
143
144 if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
145 let _ = self_.sender.send(msg);
146 Ok(())
147 } else {
148 Err(())
149 }
150 }
151
152 Err(err) if err.is_unrecoverable_error() => Err(()),
153
154 _ => {
155 if let Some(entry) = self_.in_flight.pop_front() {
156 let _ = entry.send(result);
157 Ok(())
158 } else {
159 Err(())
160 }
161 }
162 }
163 }
164}
165
166impl<T> Sink<PipelineMessage> for PipelineSink<T>
167where
168 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
169{
170 type Error = ();
171
172 fn poll_ready(
174 mut self: Pin<&mut Self>,
175 cx: &mut task::Context,
176 ) -> Poll<Result<(), Self::Error>> {
177 self.as_mut()
178 .project()
179 .sink_stream
180 .poll_ready(cx)
181 .map_err(|_| ())
182 }
183
184 fn start_send(
185 mut self: Pin<&mut Self>,
186 PipelineMessage { input, output }: PipelineMessage,
187 ) -> Result<(), Self::Error> {
188 let self_ = self.as_mut().project();
189
190 match self_.sink_stream.start_send(input) {
191 Ok(()) => {
192 self_.in_flight.push_back(output);
193 Ok(())
194 }
195 Err(err) => {
196 let _ = output.send(Err(err));
197 Err(())
198 }
199 }
200 }
201
202 fn poll_flush(
203 mut self: Pin<&mut Self>,
204 cx: &mut task::Context,
205 ) -> Poll<Result<(), Self::Error>> {
206 ready!(self
207 .as_mut()
208 .project()
209 .sink_stream
210 .poll_flush(cx)
211 .map_err(|err| {
212 let _ = self.as_mut().handle_message(Err(err));
213 }))?;
214 self.poll_read(cx)
215 }
216
217 fn poll_close(
218 mut self: Pin<&mut Self>,
219 cx: &mut task::Context,
220 ) -> Poll<Result<(), Self::Error>> {
221 if !self.in_flight.is_empty() {
224 ready!(self.as_mut().poll_flush(cx))?;
225 }
226 let this = self.as_mut().project();
227
228 if this.sender.is_closed() {
229 return Poll::Ready(Ok(()));
230 }
231
232 match ready!(this.sink_stream.poll_next(cx)) {
233 Some(result) => {
234 let _ = self.handle_message(result);
235 Poll::Pending
236 }
237 None => Poll::Ready(Ok(())),
238 }
239 }
240}
241
242impl PubSubSink {
243 fn new<T>(
244 sink_stream: T,
245 messages_sender: UnboundedSender<Msg>,
246 ) -> (Self, impl Future<Output = ()>)
247 where
248 T: Sink<Vec<u8>, Error = RedisError>,
249 T: Stream<Item = RedisResult<Value>>,
250 T: Unpin + Send + 'static,
251 {
252 let (sender, mut receiver) = unbounded_channel();
253 let sink = PipelineSink::new(sink_stream, messages_sender);
254 let f = stream::poll_fn(move |cx| {
255 let res = receiver.poll_recv(cx);
256 match res {
257 Poll::Ready(None) => Poll::Pending,
259 _ => res,
260 }
261 })
262 .map(Ok)
263 .forward(sink)
264 .map(|_| ());
265 (PubSubSink { sender }, f)
266 }
267
268 async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
269 let (sender, receiver) = oneshot::channel();
270
271 self.sender
272 .send(PipelineMessage {
273 input,
274 output: sender,
275 })
276 .map_err(|_| closed_connection_error())?;
277 match receiver.await {
278 Ok(result) => result,
279 Err(_) => Err(closed_connection_error()),
280 }
281 }
282
283 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
296 let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
297 self.send_recv(cmd).await.map(|_| ())
298 }
299
300 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
313 let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
314 self.send_recv(cmd).await.map(|_| ())
315 }
316
317 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
330 let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
331 self.send_recv(cmd).await.map(|_| ())
332 }
333
334 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
347 let cmd = cmd("PUNSUBSCRIBE")
348 .arg(channel_pattern)
349 .get_packed_command();
350 self.send_recv(cmd).await.map(|_| ())
351 }
352
353 pub async fn ping_message<T: FromRedisValue>(
355 &mut self,
356 message: impl ToRedisArgs,
357 ) -> RedisResult<T> {
358 let cmd = cmd("PING").arg(message).get_packed_command();
359 let response = self.send_recv(cmd).await?;
360 from_owned_redis_value(response)
361 }
362
363 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
365 let cmd = cmd("PING").get_packed_command();
366 let response = self.send_recv(cmd).await?;
367 from_owned_redis_value(response)
368 }
369}
370
371pub struct PubSub {
373 sink: PubSubSink,
374 stream: PubSubStream,
375}
376
377impl PubSub {
378 pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
381 where
382 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
383 {
384 let mut codec = ValueCodec::default().framed(stream);
385 setup_connection(
386 &mut codec,
387 connection_info,
388 #[cfg(feature = "cache-aio")]
389 None,
390 )
391 .await?;
392 let (sender, receiver) = unbounded_channel();
393 let (sink, driver) = PubSubSink::new(codec, sender);
394 let handle = Runtime::locate().spawn(driver);
395 let _task_handle = Some(SharedHandleContainer::new(handle));
396 let stream = PubSubStream {
397 receiver,
398 _task_handle,
399 };
400 let con = PubSub { sink, stream };
401 Ok(con)
402 }
403
404 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
418 self.sink.subscribe(channel_name).await
419 }
420
421 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
435 self.sink.unsubscribe(channel_name).await
436 }
437
438 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
451 self.sink.psubscribe(channel_pattern).await
452 }
453
454 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
467 self.sink.punsubscribe(channel_pattern).await
468 }
469
470 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
472 self.sink.ping().await
473 }
474
475 pub async fn ping_message<T: FromRedisValue>(
477 &mut self,
478 message: impl ToRedisArgs,
479 ) -> RedisResult<T> {
480 self.sink.ping_message(message).await
481 }
482
483 pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
488 &mut self.stream
489 }
490
491 pub fn into_on_message(self) -> PubSubStream {
498 self.stream
499 }
500
501 pub fn split(self) -> (PubSubSink, PubSubStream) {
504 (self.sink, self.stream)
505 }
506}
507
508impl Stream for PubSubStream {
509 type Item = Msg;
510
511 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
512 self.project().receiver.poll_recv(cx)
513 }
514}