redis/aio/
pubsub.rs

1use crate::aio::Runtime;
2use crate::parser::ValueCodec;
3use crate::types::{closed_connection_error, RedisError, RedisResult, Value};
4use crate::{cmd, from_owned_redis_value, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs};
5use ::tokio::{
6    io::{AsyncRead, AsyncWrite},
7    sync::oneshot,
8};
9use futures_util::{
10    future::{Future, FutureExt},
11    ready,
12    sink::Sink,
13    stream::{self, Stream, StreamExt},
14};
15use pin_project_lite::pin_project;
16use std::collections::VecDeque;
17use std::pin::Pin;
18use std::task::{self, Poll};
19use tokio::sync::mpsc::unbounded_channel;
20use tokio::sync::mpsc::UnboundedSender;
21use tokio_util::codec::Decoder;
22
23use super::{setup_connection, SharedHandleContainer};
24
25// A signal that a un/subscribe request has completed.
26type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
27
28// A single message sent through the pipeline
29struct PipelineMessage {
30    input: Vec<u8>,
31    output: RequestResultSender,
32}
33
34/// The sink part of a split async Pubsub.
35///
36/// The sink is used to subscribe and unsubscribe from
37/// channels.
38/// The stream part is independent from the sink,
39/// and dropping the sink doesn't cause the stream part to
40/// stop working.
41/// The sink isn't independent from the stream - dropping
42/// the stream will cause the sink to return errors on requests.
43#[derive(Clone)]
44pub struct PubSubSink {
45    sender: UnboundedSender<PipelineMessage>,
46}
47
48pin_project! {
49    /// The stream part of a split async Pubsub.
50    ///
51    /// The sink is used to subscribe and unsubscribe from
52    /// channels.
53    /// The stream part is independent from the sink,
54    /// and dropping the sink doesn't cause the stream part to
55    /// stop working.
56    /// The sink isn't independent from the stream - dropping
57    /// the stream will cause the sink to return errors on requests.
58    pub struct PubSubStream {
59        #[pin]
60        receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
61        // This handle ensures that once the stream will be dropped, the underlying task will stop.
62        _task_handle: Option<SharedHandleContainer>,
63    }
64}
65
66pin_project! {
67    struct PipelineSink<T> {
68        // The `Sink + Stream` that sends requests and receives values from the server.
69        #[pin]
70        sink_stream: T,
71        // The requests that were sent and are awaiting a response.
72        in_flight: VecDeque<RequestResultSender>,
73        // A sender for the push messages received from the server.
74        sender: UnboundedSender<Msg>,
75    }
76}
77
78impl<T> PipelineSink<T>
79where
80    T: Stream<Item = RedisResult<Value>> + 'static,
81{
82    fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
83    where
84        T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
85    {
86        PipelineSink {
87            sink_stream,
88            in_flight: VecDeque::new(),
89            sender,
90        }
91    }
92
93    // Read messages from the stream and handle them.
94    fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
95        loop {
96            let self_ = self.as_mut().project();
97            if self_.sender.is_closed() {
98                return Poll::Ready(Err(()));
99            }
100
101            let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
102                Some(result) => result,
103                // The redis response stream is not going to produce any more items so we `Err`
104                // to break out of the `forward` combinator and stop handling requests
105                None => return Poll::Ready(Err(())),
106            };
107            self.as_mut().handle_message(item)?;
108        }
109    }
110
111    fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
112        let self_ = self.project();
113
114        match result {
115            Ok(Value::Array(value)) => {
116                if let Some(Value::BulkString(kind)) = value.first() {
117                    if matches!(
118                        kind.as_slice(),
119                        b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
120                    ) {
121                        if let Some(entry) = self_.in_flight.pop_front() {
122                            let _ = entry.send(Ok(Value::Array(value)));
123                        };
124                        return Ok(());
125                    }
126                }
127
128                if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
129                    let _ = self_.sender.send(msg);
130                    Ok(())
131                } else {
132                    Err(())
133                }
134            }
135
136            Ok(Value::Push { kind, data }) => {
137                if kind.has_reply() {
138                    if let Some(entry) = self_.in_flight.pop_front() {
139                        let _ = entry.send(Ok(Value::Push { kind, data }));
140                    };
141                    return Ok(());
142                }
143
144                if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
145                    let _ = self_.sender.send(msg);
146                    Ok(())
147                } else {
148                    Err(())
149                }
150            }
151
152            Err(err) if err.is_unrecoverable_error() => Err(()),
153
154            _ => {
155                if let Some(entry) = self_.in_flight.pop_front() {
156                    let _ = entry.send(result);
157                    Ok(())
158                } else {
159                    Err(())
160                }
161            }
162        }
163    }
164}
165
166impl<T> Sink<PipelineMessage> for PipelineSink<T>
167where
168    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
169{
170    type Error = ();
171
172    // Retrieve incoming messages and write them to the sink
173    fn poll_ready(
174        mut self: Pin<&mut Self>,
175        cx: &mut task::Context,
176    ) -> Poll<Result<(), Self::Error>> {
177        self.as_mut()
178            .project()
179            .sink_stream
180            .poll_ready(cx)
181            .map_err(|_| ())
182    }
183
184    fn start_send(
185        mut self: Pin<&mut Self>,
186        PipelineMessage { input, output }: PipelineMessage,
187    ) -> Result<(), Self::Error> {
188        let self_ = self.as_mut().project();
189
190        match self_.sink_stream.start_send(input) {
191            Ok(()) => {
192                self_.in_flight.push_back(output);
193                Ok(())
194            }
195            Err(err) => {
196                let _ = output.send(Err(err));
197                Err(())
198            }
199        }
200    }
201
202    fn poll_flush(
203        mut self: Pin<&mut Self>,
204        cx: &mut task::Context,
205    ) -> Poll<Result<(), Self::Error>> {
206        ready!(self
207            .as_mut()
208            .project()
209            .sink_stream
210            .poll_flush(cx)
211            .map_err(|err| {
212                let _ = self.as_mut().handle_message(Err(err));
213            }))?;
214        self.poll_read(cx)
215    }
216
217    fn poll_close(
218        mut self: Pin<&mut Self>,
219        cx: &mut task::Context,
220    ) -> Poll<Result<(), Self::Error>> {
221        // No new requests will come in after the first call to `close` but we need to complete any
222        // in progress requests before closing
223        if !self.in_flight.is_empty() {
224            ready!(self.as_mut().poll_flush(cx))?;
225        }
226        let this = self.as_mut().project();
227
228        if this.sender.is_closed() {
229            return Poll::Ready(Ok(()));
230        }
231
232        match ready!(this.sink_stream.poll_next(cx)) {
233            Some(result) => {
234                let _ = self.handle_message(result);
235                Poll::Pending
236            }
237            None => Poll::Ready(Ok(())),
238        }
239    }
240}
241
242impl PubSubSink {
243    fn new<T>(
244        sink_stream: T,
245        messages_sender: UnboundedSender<Msg>,
246    ) -> (Self, impl Future<Output = ()>)
247    where
248        T: Sink<Vec<u8>, Error = RedisError>,
249        T: Stream<Item = RedisResult<Value>>,
250        T: Unpin + Send + 'static,
251    {
252        let (sender, mut receiver) = unbounded_channel();
253        let sink = PipelineSink::new(sink_stream, messages_sender);
254        let f = stream::poll_fn(move |cx| {
255            let res = receiver.poll_recv(cx);
256            match res {
257                // We don't want to stop the backing task for the stream, even if the sink was closed.
258                Poll::Ready(None) => Poll::Pending,
259                _ => res,
260            }
261        })
262        .map(Ok)
263        .forward(sink)
264        .map(|_| ());
265        (PubSubSink { sender }, f)
266    }
267
268    async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
269        let (sender, receiver) = oneshot::channel();
270
271        self.sender
272            .send(PipelineMessage {
273                input,
274                output: sender,
275            })
276            .map_err(|_| closed_connection_error())?;
277        match receiver.await {
278            Ok(result) => result,
279            Err(_) => Err(closed_connection_error()),
280        }
281    }
282
283    /// Subscribes to a new channel(s).
284    ///
285    /// ```rust,no_run
286    /// # #[cfg(feature = "aio")]
287    /// # async fn do_something() -> redis::RedisResult<()> {
288    /// let client = redis::Client::open("redis://127.0.0.1/")?;
289    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
290    /// sink.subscribe("channel_1").await?;
291    /// sink.subscribe(&["channel_2", "channel_3"]).await?;
292    /// # Ok(())
293    /// # }
294    /// ```
295    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
296        let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
297        self.send_recv(cmd).await.map(|_| ())
298    }
299
300    /// Unsubscribes from channel(s).
301    ///
302    /// ```rust,no_run
303    /// # #[cfg(feature = "aio")]
304    /// # async fn do_something() -> redis::RedisResult<()> {
305    /// let client = redis::Client::open("redis://127.0.0.1/")?;
306    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
307    /// sink.subscribe(&["channel_1", "channel_2"]).await?;
308    /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
309    /// # Ok(())
310    /// # }
311    /// ```
312    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
313        let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
314        self.send_recv(cmd).await.map(|_| ())
315    }
316
317    /// Subscribes to new channel(s) with pattern(s).
318    ///
319    /// ```rust,no_run
320    /// # #[cfg(feature = "aio")]
321    /// # async fn do_something() -> redis::RedisResult<()> {
322    /// let client = redis::Client::open("redis://127.0.0.1/")?;
323    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
324    /// sink.psubscribe("channel*_1").await?;
325    /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
326    /// # Ok(())
327    /// # }
328    /// ```
329    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
330        let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
331        self.send_recv(cmd).await.map(|_| ())
332    }
333
334    /// Unsubscribes from channel pattern(s).
335    ///
336    /// ```rust,no_run
337    /// # #[cfg(feature = "aio")]
338    /// # async fn do_something() -> redis::RedisResult<()> {
339    /// let client = redis::Client::open("redis://127.0.0.1/")?;
340    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
341    /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
342    /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
343    /// # Ok(())
344    /// # }
345    /// ```
346    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
347        let cmd = cmd("PUNSUBSCRIBE")
348            .arg(channel_pattern)
349            .get_packed_command();
350        self.send_recv(cmd).await.map(|_| ())
351    }
352
353    /// Sends a ping with a message to the server
354    pub async fn ping_message<T: FromRedisValue>(
355        &mut self,
356        message: impl ToRedisArgs,
357    ) -> RedisResult<T> {
358        let cmd = cmd("PING").arg(message).get_packed_command();
359        let response = self.send_recv(cmd).await?;
360        from_owned_redis_value(response)
361    }
362
363    /// Sends a ping to the server
364    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
365        let cmd = cmd("PING").get_packed_command();
366        let response = self.send_recv(cmd).await?;
367        from_owned_redis_value(response)
368    }
369}
370
371/// A connection dedicated to pubsub messages.
372pub struct PubSub {
373    sink: PubSubSink,
374    stream: PubSubStream,
375}
376
377impl PubSub {
378    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
379    /// and a `ConnectionInfo`
380    pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
381    where
382        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
383    {
384        let mut codec = ValueCodec::default().framed(stream);
385        setup_connection(
386            &mut codec,
387            connection_info,
388            #[cfg(feature = "cache-aio")]
389            None,
390        )
391        .await?;
392        let (sender, receiver) = unbounded_channel();
393        let (sink, driver) = PubSubSink::new(codec, sender);
394        let handle = Runtime::locate().spawn(driver);
395        let _task_handle = Some(SharedHandleContainer::new(handle));
396        let stream = PubSubStream {
397            receiver,
398            _task_handle,
399        };
400        let con = PubSub { sink, stream };
401        Ok(con)
402    }
403
404    /// Subscribes to a new channel(s).
405    ///
406    /// ```rust,no_run
407    /// # #[cfg(feature = "aio")]
408    /// # #[cfg(feature = "aio")]
409    /// # async fn do_something() -> redis::RedisResult<()> {
410    /// let client = redis::Client::open("redis://127.0.0.1/")?;
411    /// let mut pubsub = client.get_async_pubsub().await?;
412    /// pubsub.subscribe("channel_1").await?;
413    /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
414    /// # Ok(())
415    /// # }
416    /// ```
417    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
418        self.sink.subscribe(channel_name).await
419    }
420
421    /// Unsubscribes from channel(s).
422    ///
423    /// ```rust,no_run
424    /// # #[cfg(feature = "aio")]
425    /// # #[cfg(feature = "aio")]
426    /// # async fn do_something() -> redis::RedisResult<()> {
427    /// let client = redis::Client::open("redis://127.0.0.1/")?;
428    /// let mut pubsub = client.get_async_pubsub().await?;
429    /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
430    /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
431    /// # Ok(())
432    /// # }
433    /// ```
434    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
435        self.sink.unsubscribe(channel_name).await
436    }
437
438    /// Subscribes to new channel(s) with pattern(s).
439    ///
440    /// ```rust,no_run
441    /// # #[cfg(feature = "aio")]
442    /// # async fn do_something() -> redis::RedisResult<()> {
443    /// let client = redis::Client::open("redis://127.0.0.1/")?;
444    /// let mut pubsub = client.get_async_pubsub().await?;
445    /// pubsub.psubscribe("channel*_1").await?;
446    /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
447    /// # Ok(())
448    /// # }
449    /// ```
450    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
451        self.sink.psubscribe(channel_pattern).await
452    }
453
454    /// Unsubscribes from channel pattern(s).
455    ///
456    /// ```rust,no_run
457    /// # #[cfg(feature = "aio")]
458    /// # async fn do_something() -> redis::RedisResult<()> {
459    /// let client = redis::Client::open("redis://127.0.0.1/")?;
460    /// let mut pubsub = client.get_async_pubsub().await?;
461    /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
462    /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
463    /// # Ok(())
464    /// # }
465    /// ```
466    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
467        self.sink.punsubscribe(channel_pattern).await
468    }
469
470    /// Sends a ping to the server
471    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
472        self.sink.ping().await
473    }
474
475    /// Sends a ping with a message to the server
476    pub async fn ping_message<T: FromRedisValue>(
477        &mut self,
478        message: impl ToRedisArgs,
479    ) -> RedisResult<T> {
480        self.sink.ping_message(message).await
481    }
482
483    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
484    ///
485    /// The message itself is still generic and can be converted into an appropriate type through
486    /// the helper methods on it.
487    pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
488        &mut self.stream
489    }
490
491    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
492    ///
493    /// The message itself is still generic and can be converted into an appropriate type through
494    /// the helper methods on it.
495    /// This can be useful in cases where the stream needs to be returned or held by something other
496    /// than the [`PubSub`].
497    pub fn into_on_message(self) -> PubSubStream {
498        self.stream
499    }
500
501    /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
502    /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
503    pub fn split(self) -> (PubSubSink, PubSubStream) {
504        (self.sink, self.stream)
505    }
506}
507
508impl Stream for PubSubStream {
509    type Item = Msg;
510
511    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
512        self.project().receiver.poll_recv(cx)
513    }
514}