redis/aio/pubsub.rs
1use crate::types::{RedisResult, Value};
2use crate::{
3 aio::Runtime, cmd, errors::closed_connection_error, errors::RedisError, from_redis_value,
4 parser::ValueCodec, FromRedisValue, Msg, RedisConnectionInfo, ToRedisArgs,
5};
6use ::tokio::{
7 io::{AsyncRead, AsyncWrite},
8 sync::oneshot,
9};
10use futures_util::{
11 future::{Future, FutureExt},
12 ready,
13 sink::Sink,
14 stream::{self, Stream, StreamExt},
15};
16use pin_project_lite::pin_project;
17use std::collections::VecDeque;
18use std::pin::Pin;
19use std::task::{self, Poll};
20use tokio::sync::mpsc::unbounded_channel;
21use tokio::sync::mpsc::UnboundedSender;
22use tokio_util::codec::Decoder;
23
24use super::{setup_connection, SharedHandleContainer};
25
26// A signal that a un/subscribe request has completed.
27type RequestResultSender = oneshot::Sender<RedisResult<Value>>;
28
29// A single message sent through the pipeline
30struct PipelineMessage {
31 input: Vec<u8>,
32 output: RequestResultSender,
33}
34
35/// The sink part of a split async Pubsub.
36///
37/// The sink is used to subscribe and unsubscribe from
38/// channels.
39/// The stream part is independent from the sink,
40/// and dropping the sink doesn't cause the stream part to
41/// stop working.
42/// The sink isn't independent from the stream - dropping
43/// the stream will cause the sink to return errors on requests.
44#[derive(Clone)]
45pub struct PubSubSink {
46 sender: UnboundedSender<PipelineMessage>,
47}
48
49pin_project! {
50 /// The stream part of a split async Pubsub.
51 ///
52 /// The sink is used to subscribe and unsubscribe from
53 /// channels.
54 /// The stream part is independent from the sink,
55 /// and dropping the sink doesn't cause the stream part to
56 /// stop working.
57 /// The sink isn't independent from the stream - dropping
58 /// the stream will cause the sink to return errors on requests.
59 pub struct PubSubStream {
60 #[pin]
61 receiver: tokio::sync::mpsc::UnboundedReceiver<Msg>,
62 // This handle ensures that once the stream will be dropped, the underlying task will stop.
63 _task_handle: Option<SharedHandleContainer>,
64 }
65}
66
67pin_project! {
68 struct PipelineSink<T> {
69 // The `Sink + Stream` that sends requests and receives values from the server.
70 #[pin]
71 sink_stream: T,
72 // The requests that were sent and are awaiting a response.
73 in_flight: VecDeque<RequestResultSender>,
74 // A sender for the push messages received from the server.
75 sender: UnboundedSender<Msg>,
76 }
77}
78
79impl<T> PipelineSink<T>
80where
81 T: Stream<Item = RedisResult<Value>> + 'static,
82{
83 fn new(sink_stream: T, sender: UnboundedSender<Msg>) -> Self
84 where
85 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
86 {
87 PipelineSink {
88 sink_stream,
89 in_flight: VecDeque::new(),
90 sender,
91 }
92 }
93
94 // Read messages from the stream and handle them.
95 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
96 loop {
97 let self_ = self.as_mut().project();
98 if self_.sender.is_closed() {
99 return Poll::Ready(Err(()));
100 }
101
102 let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
103 Some(result) => result,
104 // The redis response stream is not going to produce any more items so we `Err`
105 // to break out of the `forward` combinator and stop handling requests
106 None => return Poll::Ready(Err(())),
107 };
108 self.as_mut().handle_message(item)?;
109 }
110 }
111
112 fn handle_message(self: Pin<&mut Self>, result: RedisResult<Value>) -> Result<(), ()> {
113 let self_ = self.project();
114
115 match result {
116 Ok(Value::Array(value)) => {
117 if let Some(Value::BulkString(kind)) = value.first() {
118 if matches!(
119 kind.as_slice(),
120 b"subscribe" | b"psubscribe" | b"unsubscribe" | b"punsubscribe" | b"pong"
121 ) {
122 if let Some(entry) = self_.in_flight.pop_front() {
123 let _ = entry.send(Ok(Value::Array(value)));
124 };
125 return Ok(());
126 }
127 }
128
129 if let Some(msg) = Msg::from_owned_value(Value::Array(value)) {
130 let _ = self_.sender.send(msg);
131 Ok(())
132 } else {
133 Err(())
134 }
135 }
136
137 Ok(Value::Push { kind, data }) => {
138 if kind.has_reply() {
139 if let Some(entry) = self_.in_flight.pop_front() {
140 let _ = entry.send(Ok(Value::Push { kind, data }));
141 };
142 return Ok(());
143 }
144
145 if let Some(msg) = Msg::from_push_info(crate::PushInfo { kind, data }) {
146 let _ = self_.sender.send(msg);
147 Ok(())
148 } else {
149 Err(())
150 }
151 }
152
153 Err(err) if err.is_unrecoverable_error() => Err(()),
154
155 _ => {
156 if let Some(entry) = self_.in_flight.pop_front() {
157 let _ = entry.send(result);
158 Ok(())
159 } else {
160 Err(())
161 }
162 }
163 }
164 }
165}
166
167impl<T> Sink<PipelineMessage> for PipelineSink<T>
168where
169 T: Sink<Vec<u8>, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
170{
171 type Error = ();
172
173 // Retrieve incoming messages and write them to the sink
174 fn poll_ready(
175 mut self: Pin<&mut Self>,
176 cx: &mut task::Context,
177 ) -> Poll<Result<(), Self::Error>> {
178 self.as_mut()
179 .project()
180 .sink_stream
181 .poll_ready(cx)
182 .map_err(|_| ())
183 }
184
185 fn start_send(
186 mut self: Pin<&mut Self>,
187 PipelineMessage { input, output }: PipelineMessage,
188 ) -> Result<(), Self::Error> {
189 let self_ = self.as_mut().project();
190
191 match self_.sink_stream.start_send(input) {
192 Ok(()) => {
193 self_.in_flight.push_back(output);
194 Ok(())
195 }
196 Err(err) => {
197 let _ = output.send(Err(err));
198 Err(())
199 }
200 }
201 }
202
203 fn poll_flush(
204 mut self: Pin<&mut Self>,
205 cx: &mut task::Context,
206 ) -> Poll<Result<(), Self::Error>> {
207 ready!(self
208 .as_mut()
209 .project()
210 .sink_stream
211 .poll_flush(cx)
212 .map_err(|err| {
213 let _ = self.as_mut().handle_message(Err(err));
214 }))?;
215 self.poll_read(cx)
216 }
217
218 fn poll_close(
219 mut self: Pin<&mut Self>,
220 cx: &mut task::Context,
221 ) -> Poll<Result<(), Self::Error>> {
222 // No new requests will come in after the first call to `close` but we need to complete any
223 // in progress requests before closing
224 if !self.in_flight.is_empty() {
225 ready!(self.as_mut().poll_flush(cx))?;
226 }
227 let this = self.as_mut().project();
228
229 if this.sender.is_closed() {
230 return Poll::Ready(Ok(()));
231 }
232
233 match ready!(this.sink_stream.poll_next(cx)) {
234 Some(result) => {
235 let _ = self.handle_message(result);
236 Poll::Pending
237 }
238 None => Poll::Ready(Ok(())),
239 }
240 }
241}
242
243impl PubSubSink {
244 fn new<T>(
245 sink_stream: T,
246 messages_sender: UnboundedSender<Msg>,
247 ) -> (Self, impl Future<Output = ()>)
248 where
249 T: Sink<Vec<u8>, Error = RedisError>,
250 T: Stream<Item = RedisResult<Value>>,
251 T: Unpin + Send + 'static,
252 {
253 let (sender, mut receiver) = unbounded_channel();
254 let sink = PipelineSink::new(sink_stream, messages_sender);
255 let f = stream::poll_fn(move |cx| {
256 let res = receiver.poll_recv(cx);
257 match res {
258 // We don't want to stop the backing task for the stream, even if the sink was closed.
259 Poll::Ready(None) => Poll::Pending,
260 _ => res,
261 }
262 })
263 .map(Ok)
264 .forward(sink)
265 .map(|_| ());
266 (PubSubSink { sender }, f)
267 }
268
269 async fn send_recv(&mut self, input: Vec<u8>) -> Result<Value, RedisError> {
270 let (sender, receiver) = oneshot::channel();
271
272 self.sender
273 .send(PipelineMessage {
274 input,
275 output: sender,
276 })
277 .map_err(|_| closed_connection_error())?;
278 match receiver.await {
279 Ok(result) => result,
280 Err(_) => Err(closed_connection_error()),
281 }
282 }
283
284 /// Subscribes to a new channel(s).
285 ///
286 /// ```rust,no_run
287 /// # #[cfg(feature = "aio")]
288 /// # async fn do_something() -> redis::RedisResult<()> {
289 /// let client = redis::Client::open("redis://127.0.0.1/")?;
290 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
291 /// sink.subscribe("channel_1").await?;
292 /// sink.subscribe(&["channel_2", "channel_3"]).await?;
293 /// # Ok(())
294 /// # }
295 /// ```
296 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
297 let cmd = cmd("SUBSCRIBE").arg(channel_name).get_packed_command();
298 self.send_recv(cmd).await.map(|_| ())
299 }
300
301 /// Unsubscribes from channel(s).
302 ///
303 /// ```rust,no_run
304 /// # #[cfg(feature = "aio")]
305 /// # async fn do_something() -> redis::RedisResult<()> {
306 /// let client = redis::Client::open("redis://127.0.0.1/")?;
307 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
308 /// sink.subscribe(&["channel_1", "channel_2"]).await?;
309 /// sink.unsubscribe(&["channel_1", "channel_2"]).await?;
310 /// # Ok(())
311 /// # }
312 /// ```
313 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
314 let cmd = cmd("UNSUBSCRIBE").arg(channel_name).get_packed_command();
315 self.send_recv(cmd).await.map(|_| ())
316 }
317
318 /// Subscribes to new channel(s) with pattern(s).
319 ///
320 /// ```rust,no_run
321 /// # #[cfg(feature = "aio")]
322 /// # async fn do_something() -> redis::RedisResult<()> {
323 /// let client = redis::Client::open("redis://127.0.0.1/")?;
324 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
325 /// sink.psubscribe("channel*_1").await?;
326 /// sink.psubscribe(&["channel*_2", "channel*_3"]).await?;
327 /// # Ok(())
328 /// # }
329 /// ```
330 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
331 let cmd = cmd("PSUBSCRIBE").arg(channel_pattern).get_packed_command();
332 self.send_recv(cmd).await.map(|_| ())
333 }
334
335 /// Unsubscribes from channel pattern(s).
336 ///
337 /// ```rust,no_run
338 /// # #[cfg(feature = "aio")]
339 /// # async fn do_something() -> redis::RedisResult<()> {
340 /// let client = redis::Client::open("redis://127.0.0.1/")?;
341 /// let (mut sink, _stream) = client.get_async_pubsub().await?.split();
342 /// sink.psubscribe(&["channel_1", "channel_2"]).await?;
343 /// sink.punsubscribe(&["channel_1", "channel_2"]).await?;
344 /// # Ok(())
345 /// # }
346 /// ```
347 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
348 let cmd = cmd("PUNSUBSCRIBE")
349 .arg(channel_pattern)
350 .get_packed_command();
351 self.send_recv(cmd).await.map(|_| ())
352 }
353
354 /// Sends a ping with a message to the server
355 pub async fn ping_message<T: FromRedisValue>(
356 &mut self,
357 message: impl ToRedisArgs,
358 ) -> RedisResult<T> {
359 let cmd = cmd("PING").arg(message).get_packed_command();
360 let response = self.send_recv(cmd).await?;
361 Ok(from_redis_value(response)?)
362 }
363
364 /// Sends a ping to the server
365 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
366 let cmd = cmd("PING").get_packed_command();
367 let response = self.send_recv(cmd).await?;
368 Ok(from_redis_value(response)?)
369 }
370}
371
372/// A connection dedicated to RESP2 pubsub messages.
373///
374/// If you're using a DB that supports RESP3, consider using a regular connection and setting a [crate::aio::AsyncPushSender] on it using [crate::client::AsyncConnectionConfig::set_push_sender].
375pub struct PubSub {
376 sink: PubSubSink,
377 stream: PubSubStream,
378}
379
380impl PubSub {
381 /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
382 /// and a `ConnectionInfo`
383 pub async fn new<C>(connection_info: &RedisConnectionInfo, stream: C) -> RedisResult<Self>
384 where
385 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
386 {
387 let mut codec = ValueCodec::default().framed(stream);
388 setup_connection(
389 &mut codec,
390 connection_info,
391 #[cfg(feature = "cache-aio")]
392 None,
393 )
394 .await?;
395 let (sender, receiver) = unbounded_channel();
396 let (sink, driver) = PubSubSink::new(codec, sender);
397 let handle = Runtime::locate().spawn(driver);
398 let _task_handle = Some(SharedHandleContainer::new(handle));
399 let stream = PubSubStream {
400 receiver,
401 _task_handle,
402 };
403 let con = PubSub { sink, stream };
404 Ok(con)
405 }
406
407 /// Subscribes to a new channel(s).
408 ///
409 /// ```rust,no_run
410 /// # #[cfg(feature = "aio")]
411 /// # #[cfg(feature = "aio")]
412 /// # async fn do_something() -> redis::RedisResult<()> {
413 /// let client = redis::Client::open("redis://127.0.0.1/")?;
414 /// let mut pubsub = client.get_async_pubsub().await?;
415 /// pubsub.subscribe("channel_1").await?;
416 /// pubsub.subscribe(&["channel_2", "channel_3"]).await?;
417 /// # Ok(())
418 /// # }
419 /// ```
420 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
421 self.sink.subscribe(channel_name).await
422 }
423
424 /// Unsubscribes from channel(s).
425 ///
426 /// ```rust,no_run
427 /// # #[cfg(feature = "aio")]
428 /// # #[cfg(feature = "aio")]
429 /// # async fn do_something() -> redis::RedisResult<()> {
430 /// let client = redis::Client::open("redis://127.0.0.1/")?;
431 /// let mut pubsub = client.get_async_pubsub().await?;
432 /// pubsub.subscribe(&["channel_1", "channel_2"]).await?;
433 /// pubsub.unsubscribe(&["channel_1", "channel_2"]).await?;
434 /// # Ok(())
435 /// # }
436 /// ```
437 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
438 self.sink.unsubscribe(channel_name).await
439 }
440
441 /// Subscribes to new channel(s) with pattern(s).
442 ///
443 /// ```rust,no_run
444 /// # #[cfg(feature = "aio")]
445 /// # async fn do_something() -> redis::RedisResult<()> {
446 /// let client = redis::Client::open("redis://127.0.0.1/")?;
447 /// let mut pubsub = client.get_async_pubsub().await?;
448 /// pubsub.psubscribe("channel*_1").await?;
449 /// pubsub.psubscribe(&["channel*_2", "channel*_3"]).await?;
450 /// # Ok(())
451 /// # }
452 /// ```
453 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
454 self.sink.psubscribe(channel_pattern).await
455 }
456
457 /// Unsubscribes from channel pattern(s).
458 ///
459 /// ```rust,no_run
460 /// # #[cfg(feature = "aio")]
461 /// # async fn do_something() -> redis::RedisResult<()> {
462 /// let client = redis::Client::open("redis://127.0.0.1/")?;
463 /// let mut pubsub = client.get_async_pubsub().await?;
464 /// pubsub.psubscribe(&["channel_1", "channel_2"]).await?;
465 /// pubsub.punsubscribe(&["channel_1", "channel_2"]).await?;
466 /// # Ok(())
467 /// # }
468 /// ```
469 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
470 self.sink.punsubscribe(channel_pattern).await
471 }
472
473 /// Sends a ping to the server
474 pub async fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
475 self.sink.ping().await
476 }
477
478 /// Sends a ping with a message to the server
479 pub async fn ping_message<T: FromRedisValue>(
480 &mut self,
481 message: impl ToRedisArgs,
482 ) -> RedisResult<T> {
483 self.sink.ping_message(message).await
484 }
485
486 /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
487 ///
488 /// The message itself is still generic and can be converted into an appropriate type through
489 /// the helper methods on it.
490 pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
491 &mut self.stream
492 }
493
494 /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
495 ///
496 /// The message itself is still generic and can be converted into an appropriate type through
497 /// the helper methods on it.
498 /// This can be useful in cases where the stream needs to be returned or held by something other
499 /// than the [`PubSub`].
500 pub fn into_on_message(self) -> PubSubStream {
501 self.stream
502 }
503
504 /// Splits the PubSub into separate sink and stream components, so that subscriptions could be
505 /// updated through the `Sink` while concurrently waiting for new messages on the `Stream`.
506 pub fn split(self) -> (PubSubSink, PubSubStream) {
507 (self.sink, self.stream)
508 }
509}
510
511impl Stream for PubSubStream {
512 type Item = Msg;
513
514 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
515 self.project().receiver.poll_recv(cx)
516 }
517}