use super::ConnectionLike;
use crate::aio::setup_connection;
use crate::cmd::Cmd;
use crate::connection::RedisConnectionInfo;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use crate::parser::ValueCodec;
use crate::types::{RedisError, RedisFuture, RedisResult, Value};
use ::tokio::{
io::{AsyncRead, AsyncWrite},
sync::{mpsc, oneshot},
};
use futures_util::{
future::{Future, FutureExt},
ready,
sink::Sink,
stream::{self, Stream, StreamExt, TryStreamExt as _},
};
use pin_project_lite::pin_project;
use std::collections::VecDeque;
use std::fmt;
use std::fmt::Debug;
use std::io;
use std::pin::Pin;
use std::task::{self, Poll};
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use tokio_util::codec::Decoder;
type PipelineOutput<O, E> = oneshot::Sender<Result<Vec<O>, E>>;
struct InFlight<O, E> {
output: PipelineOutput<O, E>,
expected_response_count: usize,
current_response_count: usize,
buffer: Vec<O>,
first_err: Option<E>,
}
impl<O, E> InFlight<O, E> {
fn new(output: PipelineOutput<O, E>, expected_response_count: usize) -> Self {
Self {
output,
expected_response_count,
current_response_count: 0,
buffer: Vec::new(),
first_err: None,
}
}
}
struct PipelineMessage<S, I, E> {
input: S,
output: PipelineOutput<I, E>,
response_count: usize,
}
struct Pipeline<SinkItem, I, E>(mpsc::Sender<PipelineMessage<SinkItem, I, E>>);
impl<SinkItem, I, E> Clone for Pipeline<SinkItem, I, E> {
fn clone(&self) -> Self {
Pipeline(self.0.clone())
}
}
impl<SinkItem, I, E> Debug for Pipeline<SinkItem, I, E>
where
SinkItem: Debug,
I: Debug,
E: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Pipeline").field(&self.0).finish()
}
}
pin_project! {
struct PipelineSink<T, I, E> {
#[pin]
sink_stream: T,
in_flight: VecDeque<InFlight<I, E>>,
error: Option<E>,
}
}
impl<T, I, E> PipelineSink<T, I, E>
where
T: Stream<Item = Result<I, E>> + 'static,
{
fn new<SinkItem>(sink_stream: T) -> Self
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
{
PipelineSink {
sink_stream,
in_flight: VecDeque::new(),
error: None,
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
loop {
if self.in_flight.is_empty() {
return Poll::Ready(Ok(()));
}
let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
Some(result) => result,
None => return Poll::Ready(Err(())),
};
self.as_mut().send_result(item);
}
}
fn send_result(self: Pin<&mut Self>, result: Result<I, E>) {
let self_ = self.project();
{
let entry = match self_.in_flight.front_mut() {
Some(entry) => entry,
None => return,
};
match result {
Ok(item) => {
entry.buffer.push(item);
}
Err(err) => {
if entry.first_err.is_none() {
entry.first_err = Some(err);
}
}
}
entry.current_response_count += 1;
if entry.current_response_count < entry.expected_response_count {
return;
}
}
let entry = self_.in_flight.pop_front().unwrap();
let response = match entry.first_err {
Some(err) => Err(err),
None => Ok(entry.buffer),
};
entry.output.send(response).ok();
}
}
impl<SinkItem, T, I, E> Sink<PipelineMessage<SinkItem, I, E>> for PipelineSink<T, I, E>
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
Ok(()) => Ok(()).into(),
Err(err) => {
*self.project().error = Some(err);
Ok(()).into()
}
}
}
fn start_send(
mut self: Pin<&mut Self>,
PipelineMessage {
input,
output,
response_count,
}: PipelineMessage<SinkItem, I, E>,
) -> Result<(), Self::Error> {
if output.is_closed() {
return Ok(());
}
let self_ = self.as_mut().project();
if let Some(err) = self_.error.take() {
let _ = output.send(Err(err));
return Err(());
}
match self_.sink_stream.start_send(input) {
Ok(()) => {
self_
.in_flight
.push_back(InFlight::new(output, response_count));
Ok(())
}
Err(err) => {
let _ = output.send(Err(err));
Err(())
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
ready!(self
.as_mut()
.project()
.sink_stream
.poll_flush(cx)
.map_err(|err| {
self.as_mut().send_result(Err(err));
}))?;
self.poll_read(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
if !self.in_flight.is_empty() {
ready!(self.as_mut().poll_flush(cx))?;
}
let this = self.as_mut().project();
this.sink_stream.poll_close(cx).map_err(|err| {
self.send_result(Err(err));
})
}
}
impl<SinkItem, I, E> Pipeline<SinkItem, I, E>
where
SinkItem: Send + 'static,
I: Send + 'static,
E: Send + 'static,
{
fn new<T>(sink_stream: T) -> (Self, impl Future<Output = ()>)
where
T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
T: Send + 'static,
T::Item: Send,
T::Error: Send,
T::Error: ::std::fmt::Debug,
{
const BUFFER_SIZE: usize = 50;
let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(PipelineSink::new::<SinkItem>(sink_stream))
.map(|_| ());
(Pipeline(sender), f)
}
async fn send(&mut self, item: SinkItem) -> Result<I, Option<E>> {
self.send_recv_multiple(item, 1)
.await
.map(|mut item| item.pop().unwrap())
}
async fn send_recv_multiple(
&mut self,
input: SinkItem,
count: usize,
) -> Result<Vec<I>, Option<E>> {
let (sender, receiver) = oneshot::channel();
self.0
.send(PipelineMessage {
input,
response_count: count,
output: sender,
})
.await
.map_err(|_| None)?;
match receiver.await {
Ok(result) => result.map_err(Some),
Err(_) => {
Err(None)
}
}
}
}
#[derive(Clone)]
pub struct MultiplexedConnection {
pipeline: Pipeline<Vec<u8>, Value, RedisError>,
db: i64,
}
impl Debug for MultiplexedConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MultiplexedConnection")
.field("pipeline", &self.pipeline)
.field("db", &self.db)
.finish()
}
}
impl MultiplexedConnection {
pub async fn new<C>(
connection_info: &RedisConnectionInfo,
stream: C,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
fn boxed(
f: impl Future<Output = ()> + Send + 'static,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(f)
}
#[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
compile_error!("tokio-comp or async-std-comp features required for aio feature");
let codec = ValueCodec::default()
.framed(stream)
.and_then(|msg| async move { msg });
let (pipeline, driver) = Pipeline::new(codec);
let driver = boxed(driver);
let mut con = MultiplexedConnection {
pipeline,
db: connection_info.db,
};
let driver = {
let auth = setup_connection(connection_info, &mut con);
futures_util::pin_mut!(auth);
match futures_util::future::select(auth, driver).await {
futures_util::future::Either::Left((result, driver)) => {
result?;
driver
}
futures_util::future::Either::Right(((), _)) => {
unreachable!("Multiplexed connection driver unexpectedly terminated")
}
}
};
Ok((con, driver))
}
pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
let value = self
.pipeline
.send(cmd.get_packed_command())
.await
.map_err(|err| {
err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))
})?;
Ok(value)
}
pub async fn send_packed_commands(
&mut self,
cmd: &crate::Pipeline,
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>> {
let mut value = self
.pipeline
.send_recv_multiple(cmd.get_packed_pipeline(), offset + count)
.await
.map_err(|err| {
err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))
})?;
value.drain(..offset);
Ok(value)
}
}
impl ConnectionLike for MultiplexedConnection {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
self.db
}
}