redis/aio/
pubsub.rs

1use crate::types::{RedisResult, Value};
2use crate::{
3    aio::Runtime, cmd, errors::closed_connection_error, errors::RedisError, from_redis_value,
4    parser::ValueCodec, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs,
5};
6use ::tokio::{
7    io::{AsyncRead, AsyncWrite},
8    sync::oneshot,
9};
10use futures_util::{
11    future::{Future, FutureExt},
12    ready,
13    sink::Sink,
14    stream::{self, Stream, StreamExt},
15};
16use pin_project_lite::pin_project;
17use std::collections::VecDeque;
18use std::pin::Pin;
19use std::task::{self, Poll};
20use tokio::sync::mpsc::unbounded_channel;
21use tokio::sync::mpsc::UnboundedSender;
22use tokio_util::codec::Decoder;
23
24use super::{setup_connection, SharedHandleContainer};
25
26// A signal that a un/subscribe request has completed.
27type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
28
29// A single message sent through the pipeline
30struct PipelineMessage {
31    input: Vec<u8>,
32    output: RequestResultSender,
33}
34
35/// The sink part of a split async Pubsub.
36///
37/// The sink is used to subscribe and unsubscribe from
38/// channels.
39/// The stream part is independent from the sink,
40/// and dropping the sink doesn't cause the stream part to
41/// stop working.
42/// The sink isn't independent from the stream - dropping
43/// the stream will cause the sink to return errors on requests.
44#[derive(Clone)]
45pub struct PubSubSink {
46    sender: UnboundedSender<PipelineMessage>,
47}
48
49pin_project! {
50    /// The stream part of a split async Pubsub.
51    ///
52    /// The sink is used to subscribe and unsubscribe from
53    /// channels.
54    /// The stream part is independent from the sink,
55    /// and dropping the sink doesn't cause the stream part to
56    /// stop working.
57    /// The sink isn't independent from the stream - dropping
58    /// the stream will cause the sink to return errors on requests.
59    pub struct PubSubStream {
60        #[pin]
61        receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
62        // This handle ensures that once the stream will be dropped, the underlying task will stop.
63        _task_handle: Option<SharedHandleContainer>,
64    }
65}
66
67pin_project! {
68    struct PipelineSink<T> {
69        // The `Sink + Stream` that sends requests and receives values from the server.
70        #[pin]
71        sink_stream: T,
72        // The requests that were sent and are awaiting a response.
73        in_flight: VecDeque<RequestResultSender>,
74        // A sender for the push messages received from the server.
75        sender: UnboundedSender<Msg>,
76    }
77}
78
79impl<T> PipelineSink<T>
80where
81    T: Stream<Item = RedisResult<Value>> + 'static,
82{
83    fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
84    where
85        T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
86    {
87        PipelineSink {
88            sink_stream,
89            in_flight: VecDeque::new(),
90            sender,
91        }
92    }
93
94    // Read messages from the stream and handle them.
95    fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
96        loop {
97            let self_ = self.as_mut().project();
98            if self_.sender.is_closed() {
99                return Poll::Ready(Err(()));
100            }
101
102            let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
103                Some(result) => result,
104                // The redis response stream is not going to produce any more items so we `Err`
105                // to break out of the `forward` combinator and stop handling requests
106                None => return Poll::Ready(Err(())),
107            };
108            self.as_mut().handle_message(item)?;
109        }
110    }
111
112    fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
113        let self_ = self.project();
114
115        match result {
116            Ok(Value::Array(value)) => {
117                if let Some(Value::BulkString(kind)) = value.first() {
118                    if matches!(
119                        kind.as_slice(),
120                        b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
121                    ) {
122                        if let Some(entry) = self_.in_flight.pop_front() {
123                            let _ = entry.send(Ok(Value::Array(value)));
124                        };
125                        return Ok(());
126                    }
127                }
128
129                if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
130                    let _ = self_.sender.send(msg);
131                    Ok(())
132                } else {
133                    Err(())
134                }
135            }
136
137            Ok(Value::Push { kind, data }) => {
138                if kind.has_reply() {
139                    if let Some(entry) = self_.in_flight.pop_front() {
140                        let _ = entry.send(Ok(Value::Push { kind, data }));
141                    };
142                    return Ok(());
143                }
144
145                if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
146                    let _ = self_.sender.send(msg);
147                    Ok(())
148                } else {
149                    Err(())
150                }
151            }
152
153            Err(err) if err.is_unrecoverable_error() => Err(()),
154
155            _ => {
156                if let Some(entry) = self_.in_flight.pop_front() {
157                    let _ = entry.send(result);
158                    Ok(())
159                } else {
160                    Err(())
161                }
162            }
163        }
164    }
165}
166
167impl<T> Sink<PipelineMessage> for PipelineSink<T>
168where
169    T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
170{
171    type Error = ();
172
173    // Retrieve incoming messages and write them to the sink
174    fn poll_ready(
175        mut self: Pin<&mut Self>,
176        cx: &mut task::Context,
177    ) -> Poll<Result<(), Self::Error>> {
178        self.as_mut()
179            .project()
180            .sink_stream
181            .poll_ready(cx)
182            .map_err(|_| ())
183    }
184
185    fn start_send(
186        mut self: Pin<&mut Self>,
187        PipelineMessage { input, output }: PipelineMessage,
188    ) -> Result<(), Self::Error> {
189        let self_ = self.as_mut().project();
190
191        match self_.sink_stream.start_send(input) {
192            Ok(()) => {
193                self_.in_flight.push_back(output);
194                Ok(())
195            }
196            Err(err) => {
197                let _ = output.send(Err(err));
198                Err(())
199            }
200        }
201    }
202
203    fn poll_flush(
204        mut self: Pin<&mut Self>,
205        cx: &mut task::Context,
206    ) -> Poll<Result<(), Self::Error>> {
207        ready!(self
208            .as_mut()
209            .project()
210            .sink_stream
211            .poll_flush(cx)
212            .map_err(|err| {
213                let _ = self.as_mut().handle_message(Err(err));
214            }))?;
215        self.poll_read(cx)
216    }
217
218    fn poll_close(
219        mut self: Pin<&mut Self>,
220        cx: &mut task::Context,
221    ) -> Poll<Result<(), Self::Error>> {
222        // No new requests will come in after the first call to `close` but we need to complete any
223        // in progress requests before closing
224        if !self.in_flight.is_empty() {
225            ready!(self.as_mut().poll_flush(cx))?;
226        }
227        let this = self.as_mut().project();
228
229        if this.sender.is_closed() {
230            return Poll::Ready(Ok(()));
231        }
232
233        match ready!(this.sink_stream.poll_next(cx)) {
234            Some(result) => {
235                let _ = self.handle_message(result);
236                Poll::Pending
237            }
238            None => Poll::Ready(Ok(())),
239        }
240    }
241}
242
243impl PubSubSink {
244    fn new<T>(
245        sink_stream: T,
246        messages_sender: UnboundedSender<Msg>,
247    ) -> (Self, impl Future<Output = ()>)
248    where
249        T: Sink<Vec<u8>, Error = RedisError>,
250        T: Stream<Item = RedisResult<Value>>,
251        T: Unpin + Send + 'static,
252    {
253        let (sender, mut receiver) = unbounded_channel();
254        let sink = PipelineSink::new(sink_stream, messages_sender);
255        let f = stream::poll_fn(move |cx| {
256            let res = receiver.poll_recv(cx);
257            match res {
258                // We don't want to stop the backing task for the stream, even if the sink was closed.
259                Poll::Ready(None) => Poll::Pending,
260                _ => res,
261            }
262        })
263        .map(Ok)
264        .forward(sink)
265        .map(|_| ());
266        (PubSubSink { sender }, f)
267    }
268
269    async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
270        let (sender, receiver) = oneshot::channel();
271
272        self.sender
273            .send(PipelineMessage {
274                input,
275                output: sender,
276            })
277            .map_err(|_| closed_connection_error())?;
278        match receiver.await {
279            Ok(result) => result,
280            Err(_) => Err(closed_connection_error()),
281        }
282    }
283
284    /// Subscribes to a new channel(s).
285    ///
286    /// ```rust,no_run
287    /// # #[cfg(feature = "aio")]
288    /// # async fn do_something() -> redis::RedisResult<()> {
289    /// let client = redis::Client::open("redis://127.0.0.1/")?;
290    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
291    /// sink.subscribe("channel_1").await?;
292    /// sink.subscribe(&["channel_2", "channel_3"]).await?;
293    /// # Ok(())
294    /// # }
295    /// ```
296    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
297        let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
298        self.send_recv(cmd).await.map(|_| ())
299    }
300
301    /// Unsubscribes from channel(s).
302    ///
303    /// ```rust,no_run
304    /// # #[cfg(feature = "aio")]
305    /// # async fn do_something() -> redis::RedisResult<()> {
306    /// let client = redis::Client::open("redis://127.0.0.1/")?;
307    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
308    /// sink.subscribe(&["channel_1", "channel_2"]).await?;
309    /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
310    /// # Ok(())
311    /// # }
312    /// ```
313    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
314        let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
315        self.send_recv(cmd).await.map(|_| ())
316    }
317
318    /// Subscribes to new channel(s) with pattern(s).
319    ///
320    /// ```rust,no_run
321    /// # #[cfg(feature = "aio")]
322    /// # async fn do_something() -> redis::RedisResult<()> {
323    /// let client = redis::Client::open("redis://127.0.0.1/")?;
324    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
325    /// sink.psubscribe("channel*_1").await?;
326    /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
327    /// # Ok(())
328    /// # }
329    /// ```
330    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
331        let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
332        self.send_recv(cmd).await.map(|_| ())
333    }
334
335    /// Unsubscribes from channel pattern(s).
336    ///
337    /// ```rust,no_run
338    /// # #[cfg(feature = "aio")]
339    /// # async fn do_something() -> redis::RedisResult<()> {
340    /// let client = redis::Client::open("redis://127.0.0.1/")?;
341    /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
342    /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
343    /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
344    /// # Ok(())
345    /// # }
346    /// ```
347    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
348        let cmd = cmd("PUNSUBSCRIBE")
349            .arg(channel_pattern)
350            .get_packed_command();
351        self.send_recv(cmd).await.map(|_| ())
352    }
353
354    /// Sends a ping with a message to the server
355    pub async fn ping_message<T: FromRedisValue>(
356        &mut self,
357        message: impl ToRedisArgs,
358    ) -> RedisResult<T> {
359        let cmd = cmd("PING").arg(message).get_packed_command();
360        let response = self.send_recv(cmd).await?;
361        Ok(from_redis_value(response)?)
362    }
363
364    /// Sends a ping to the server
365    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
366        let cmd = cmd("PING").get_packed_command();
367        let response = self.send_recv(cmd).await?;
368        Ok(from_redis_value(response)?)
369    }
370}
371
372/// A connection dedicated to RESP2 pubsub messages.
373///
374/// If you're using a DB that supports RESP3, consider using a regular connection and setting a [crate::aio::AsyncPushSender] on it using [crate::client::AsyncConnectionConfig::set_push_sender].
375pub struct PubSub {
376    sink: PubSubSink,
377    stream: PubSubStream,
378}
379
380impl PubSub {
381    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
382    /// and a `ConnectionInfo`
383    pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
384    where
385        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
386    {
387        let mut codec = ValueCodec::default().framed(stream);
388        setup_connection(
389            &mut codec,
390            connection_info,
391            #[cfg(feature = "cache-aio")]
392            None,
393        )
394        .await?;
395        let (sender, receiver) = unbounded_channel();
396        let (sink, driver) = PubSubSink::new(codec, sender);
397        let handle = Runtime::locate().spawn(driver);
398        let _task_handle = Some(SharedHandleContainer::new(handle));
399        let stream = PubSubStream {
400            receiver,
401            _task_handle,
402        };
403        let con = PubSub { sink, stream };
404        Ok(con)
405    }
406
407    /// Subscribes to a new channel(s).
408    ///
409    /// ```rust,no_run
410    /// # #[cfg(feature = "aio")]
411    /// # #[cfg(feature = "aio")]
412    /// # async fn do_something() -> redis::RedisResult<()> {
413    /// let client = redis::Client::open("redis://127.0.0.1/")?;
414    /// let mut pubsub = client.get_async_pubsub().await?;
415    /// pubsub.subscribe("channel_1").await?;
416    /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
417    /// # Ok(())
418    /// # }
419    /// ```
420    pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
421        self.sink.subscribe(channel_name).await
422    }
423
424    /// Unsubscribes from channel(s).
425    ///
426    /// ```rust,no_run
427    /// # #[cfg(feature = "aio")]
428    /// # #[cfg(feature = "aio")]
429    /// # async fn do_something() -> redis::RedisResult<()> {
430    /// let client = redis::Client::open("redis://127.0.0.1/")?;
431    /// let mut pubsub = client.get_async_pubsub().await?;
432    /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
433    /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
434    /// # Ok(())
435    /// # }
436    /// ```
437    pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
438        self.sink.unsubscribe(channel_name).await
439    }
440
441    /// Subscribes to new channel(s) with pattern(s).
442    ///
443    /// ```rust,no_run
444    /// # #[cfg(feature = "aio")]
445    /// # async fn do_something() -> redis::RedisResult<()> {
446    /// let client = redis::Client::open("redis://127.0.0.1/")?;
447    /// let mut pubsub = client.get_async_pubsub().await?;
448    /// pubsub.psubscribe("channel*_1").await?;
449    /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
450    /// # Ok(())
451    /// # }
452    /// ```
453    pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
454        self.sink.psubscribe(channel_pattern).await
455    }
456
457    /// Unsubscribes from channel pattern(s).
458    ///
459    /// ```rust,no_run
460    /// # #[cfg(feature = "aio")]
461    /// # async fn do_something() -> redis::RedisResult<()> {
462    /// let client = redis::Client::open("redis://127.0.0.1/")?;
463    /// let mut pubsub = client.get_async_pubsub().await?;
464    /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
465    /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
466    /// # Ok(())
467    /// # }
468    /// ```
469    pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
470        self.sink.punsubscribe(channel_pattern).await
471    }
472
473    /// Sends a ping to the server
474    pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
475        self.sink.ping().await
476    }
477
478    /// Sends a ping with a message to the server
479    pub async fn ping_message<T: FromRedisValue>(
480        &mut self,
481        message: impl ToRedisArgs,
482    ) -> RedisResult<T> {
483        self.sink.ping_message(message).await
484    }
485
486    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
487    ///
488    /// The message itself is still generic and can be converted into an appropriate type through
489    /// the helper methods on it.
490    pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
491        &mut self.stream
492    }
493
494    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
495    ///
496    /// The message itself is still generic and can be converted into an appropriate type through
497    /// the helper methods on it.
498    /// This can be useful in cases where the stream needs to be returned or held by something other
499    /// than the [`PubSub`].
500    pub fn into_on_message(self) -> PubSubStream {
501        self.stream
502    }
503
504    /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
505    /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
506    pub fn split(self) -> (PubSubSink, PubSubStream) {
507        (self.sink, self.stream)
508    }
509}
510
511impl Stream for PubSubStream {
512    type Item = Msg;
513
514    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
515        self.project().receiver.poll_recv(cx)
516    }
517}