actix/
io.rs

1use std::{
2    cell::RefCell,
3    collections::VecDeque,
4    io,
5    marker::PhantomData,
6    ops::DerefMut,
7    pin::Pin,
8    rc::Rc,
9    task,
10    task::{Context, Poll},
11};
12
13use bitflags::bitflags;
14use bytes::BytesMut;
15use futures_sink::Sink;
16use tokio::io::{AsyncWrite, AsyncWriteExt};
17use tokio_util::codec::Encoder;
18
19use crate::{
20    actor::{Actor, ActorContext, AsyncContext, Running, SpawnHandle},
21    fut::ActorFuture,
22};
23
24/// A helper trait for write handling.
25///
26/// `WriteHandler` is a helper for `AsyncWrite` types. Implementation
27/// of this trait is required for `Writer` and `FramedWrite` support.
28#[allow(unused_variables)]
29pub trait WriteHandler<E>
30where
31    Self: Actor,
32    Self::Context: ActorContext,
33{
34    /// Called when the writer emits error.
35    ///
36    /// If this method returns `ErrorAction::Continue` writer processing
37    /// continues otherwise stream processing stops.
38    fn error(&mut self, err: E, ctx: &mut Self::Context) -> Running {
39        Running::Stop
40    }
41
42    /// Called when the writer finishes.
43    ///
44    /// By default this method stops actor's `Context`.
45    fn finished(&mut self, ctx: &mut Self::Context) {
46        ctx.stop()
47    }
48}
49
50bitflags! {
51    struct Flags: u8 {
52        const CLOSING = 0b0000_0001;
53        const CLOSED = 0b0000_0010;
54    }
55}
56
57const LOW_WATERMARK: usize = 4 * 1024;
58const HIGH_WATERMARK: usize = 4 * LOW_WATERMARK;
59
60/// A wrapper for `AsyncWrite` types.
61pub struct Writer<T: AsyncWrite, E: From<io::Error>> {
62    inner: UnsafeWriter<T, E>,
63}
64
65struct UnsafeWriter<T: AsyncWrite, E: From<io::Error>>(Rc<RefCell<InnerWriter<E>>>, Rc<RefCell<T>>);
66
67impl<T: AsyncWrite, E: From<io::Error>> Clone for UnsafeWriter<T, E> {
68    fn clone(&self) -> Self {
69        UnsafeWriter(self.0.clone(), self.1.clone())
70    }
71}
72
73struct InnerWriter<E: From<io::Error>> {
74    flags: Flags,
75    buffer: BytesMut,
76    error: Option<E>,
77    low: usize,
78    high: usize,
79    handle: SpawnHandle,
80    task: Option<task::Waker>,
81}
82
83impl<T: AsyncWrite, E: From<io::Error> + 'static> Writer<T, E> {
84    pub fn new<A, C>(io: T, ctx: &mut C) -> Self
85    where
86        A: Actor<Context = C> + WriteHandler<E>,
87        C: AsyncContext<A>,
88        T: Unpin + 'static,
89    {
90        let inner = UnsafeWriter(
91            Rc::new(RefCell::new(InnerWriter {
92                flags: Flags::empty(),
93                buffer: BytesMut::new(),
94                error: None,
95                low: LOW_WATERMARK,
96                high: HIGH_WATERMARK,
97                handle: SpawnHandle::default(),
98                task: None,
99            })),
100            Rc::new(RefCell::new(io)),
101        );
102        let h = ctx.spawn(WriterFut {
103            inner: inner.clone(),
104        });
105
106        let writer = Self { inner };
107        writer.inner.0.borrow_mut().handle = h;
108        writer
109    }
110
111    /// Gracefully closes the sink.
112    ///
113    /// The closing happens asynchronously.
114    pub fn close(&mut self) {
115        self.inner.0.borrow_mut().flags.insert(Flags::CLOSING);
116    }
117
118    /// Checks if the sink is closed.
119    pub fn closed(&self) -> bool {
120        self.inner.0.borrow().flags.contains(Flags::CLOSED)
121    }
122
123    /// Sets the write buffer capacity.
124    pub fn set_buffer_capacity(&mut self, low_watermark: usize, high_watermark: usize) {
125        let mut inner = self.inner.0.borrow_mut();
126        inner.low = low_watermark;
127        inner.high = high_watermark;
128    }
129
130    /// Sends an item to the sink.
131    pub fn write(&mut self, msg: &[u8]) {
132        let mut inner = self.inner.0.borrow_mut();
133        inner.buffer.extend_from_slice(msg);
134        if let Some(task) = inner.task.take() {
135            task.wake_by_ref();
136        }
137    }
138
139    /// Returns the `SpawnHandle` for this writer.
140    pub fn handle(&self) -> SpawnHandle {
141        self.inner.0.borrow().handle
142    }
143}
144
145struct WriterFut<T, E>
146where
147    T: AsyncWrite + Unpin,
148    E: From<io::Error>,
149{
150    inner: UnsafeWriter<T, E>,
151}
152
153impl<T: 'static, E: 'static, A> ActorFuture<A> for WriterFut<T, E>
154where
155    T: AsyncWrite + Unpin,
156    E: From<io::Error>,
157    A: Actor + WriteHandler<E>,
158    A::Context: AsyncContext<A>,
159{
160    type Output = ();
161
162    fn poll(
163        self: Pin<&mut Self>,
164        act: &mut A,
165        ctx: &mut A::Context,
166        task: &mut Context<'_>,
167    ) -> Poll<Self::Output> {
168        let this = self.get_mut();
169        let mut inner = this.inner.0.borrow_mut();
170        if let Some(err) = inner.error.take() {
171            if act.error(err, ctx) == Running::Stop {
172                act.finished(ctx);
173                return Poll::Ready(());
174            }
175        }
176
177        let mut io = this.inner.1.borrow_mut();
178        inner.task = None;
179        while !inner.buffer.is_empty() {
180            match Pin::new(io.deref_mut()).poll_write(task, &inner.buffer) {
181                Poll::Ready(Ok(n)) => {
182                    if n == 0
183                        && act.error(
184                            io::Error::new(
185                                io::ErrorKind::WriteZero,
186                                "failed to write frame to transport",
187                            )
188                            .into(),
189                            ctx,
190                        ) == Running::Stop
191                    {
192                        act.finished(ctx);
193                        return Poll::Ready(());
194                    }
195                    let _ = inner.buffer.split_to(n);
196                }
197                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
198                    if inner.buffer.len() > inner.high {
199                        ctx.wait(WriterDrain {
200                            inner: this.inner.clone(),
201                        });
202                    }
203                    return Poll::Pending;
204                }
205                Poll::Ready(Err(e)) => {
206                    if act.error(e.into(), ctx) == Running::Stop {
207                        act.finished(ctx);
208                        return Poll::Ready(());
209                    }
210                }
211                Poll::Pending => return Poll::Pending,
212            }
213        }
214
215        // Try flushing the underlying IO
216        match Pin::new(io.deref_mut()).poll_flush(task) {
217            Poll::Ready(Ok(_)) => (),
218            Poll::Pending => return Poll::Pending,
219            Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
220                return Poll::Pending;
221            }
222            Poll::Ready(Err(e)) => {
223                if act.error(e.into(), ctx) == Running::Stop {
224                    act.finished(ctx);
225                    return Poll::Ready(());
226                }
227            }
228        }
229
230        // close if closing and we don't need to flush any data
231        if inner.flags.contains(Flags::CLOSING) {
232            inner.flags |= Flags::CLOSED;
233            act.finished(ctx);
234            Poll::Ready(())
235        } else {
236            inner.task = Some(task.waker().clone());
237            Poll::Pending
238        }
239    }
240}
241
242struct WriterDrain<T, E>
243where
244    T: AsyncWrite + Unpin,
245    E: From<io::Error>,
246{
247    inner: UnsafeWriter<T, E>,
248}
249
250impl<T, E, A> ActorFuture<A> for WriterDrain<T, E>
251where
252    T: AsyncWrite + Unpin,
253    E: From<io::Error>,
254    A: Actor,
255    A::Context: AsyncContext<A>,
256{
257    type Output = ();
258
259    fn poll(
260        self: Pin<&mut Self>,
261        _: &mut A,
262        _: &mut A::Context,
263        task: &mut Context<'_>,
264    ) -> Poll<Self::Output> {
265        let this = self.get_mut();
266        let mut inner = this.inner.0.borrow_mut();
267        if inner.error.is_some() {
268            return Poll::Ready(());
269        }
270        let mut io = this.inner.1.borrow_mut();
271        while !inner.buffer.is_empty() {
272            match Pin::new(io.deref_mut()).poll_write(task, &inner.buffer) {
273                Poll::Ready(Ok(n)) => {
274                    if n == 0 {
275                        inner.error = Some(
276                            io::Error::new(
277                                io::ErrorKind::WriteZero,
278                                "failed to write frame to transport",
279                            )
280                            .into(),
281                        );
282                        return Poll::Ready(());
283                    }
284                    let _ = inner.buffer.split_to(n);
285                }
286                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
287                    return if inner.buffer.len() < inner.low {
288                        Poll::Ready(())
289                    } else {
290                        Poll::Pending
291                    };
292                }
293                Poll::Ready(Err(e)) => {
294                    inner.error = Some(e.into());
295                    return Poll::Ready(());
296                }
297                Poll::Pending => return Poll::Pending,
298            }
299        }
300        Poll::Ready(())
301    }
302}
303
304/// A wrapper for the `AsyncWrite` and `Encoder` types. The [`AsyncWrite`] will be flushed when this
305/// struct is dropped.
306pub struct FramedWrite<I, T: AsyncWrite + Unpin, U: Encoder<I>> {
307    enc: U,
308    inner: UnsafeWriter<T, U::Error>,
309}
310
311impl<I, T: AsyncWrite + Unpin, U: Encoder<I>> FramedWrite<I, T, U> {
312    pub fn new<A, C>(io: T, enc: U, ctx: &mut C) -> Self
313    where
314        A: Actor<Context = C> + WriteHandler<U::Error>,
315        C: AsyncContext<A>,
316        U::Error: 'static,
317        T: Unpin + 'static,
318    {
319        let inner = UnsafeWriter(
320            Rc::new(RefCell::new(InnerWriter {
321                flags: Flags::empty(),
322                buffer: BytesMut::new(),
323                error: None,
324                low: LOW_WATERMARK,
325                high: HIGH_WATERMARK,
326                handle: SpawnHandle::default(),
327                task: None,
328            })),
329            Rc::new(RefCell::new(io)),
330        );
331        let h = ctx.spawn(WriterFut {
332            inner: inner.clone(),
333        });
334
335        let writer = Self { enc, inner };
336        writer.inner.0.borrow_mut().handle = h;
337        writer
338    }
339
340    pub fn from_buffer<A, C>(io: T, enc: U, buffer: BytesMut, ctx: &mut C) -> Self
341    where
342        A: Actor<Context = C> + WriteHandler<U::Error>,
343        C: AsyncContext<A>,
344        U::Error: 'static,
345        T: Unpin + 'static,
346    {
347        let inner = UnsafeWriter(
348            Rc::new(RefCell::new(InnerWriter {
349                buffer,
350                flags: Flags::empty(),
351                error: None,
352                low: LOW_WATERMARK,
353                high: HIGH_WATERMARK,
354                handle: SpawnHandle::default(),
355                task: None,
356            })),
357            Rc::new(RefCell::new(io)),
358        );
359        let h = ctx.spawn(WriterFut {
360            inner: inner.clone(),
361        });
362
363        let writer = Self { enc, inner };
364        writer.inner.0.borrow_mut().handle = h;
365        writer
366    }
367
368    /// Gracefully closes the sink.
369    ///
370    /// The closing happens asynchronously.
371    pub fn close(&mut self) {
372        self.inner.0.borrow_mut().flags.insert(Flags::CLOSING);
373    }
374
375    /// Checks if the sink is closed.
376    pub fn closed(&self) -> bool {
377        self.inner.0.borrow().flags.contains(Flags::CLOSED)
378    }
379
380    /// Sets the write buffer capacity.
381    pub fn set_buffer_capacity(&mut self, low: usize, high: usize) {
382        let mut inner = self.inner.0.borrow_mut();
383        inner.low = low;
384        inner.high = high;
385    }
386
387    /// Writes an item to the sink.
388    pub fn write(&mut self, item: I) {
389        let mut inner = self.inner.0.borrow_mut();
390        let _ = self.enc.encode(item, &mut inner.buffer).map_err(|e| {
391            inner.error = Some(e);
392        });
393        if let Some(task) = inner.task.take() {
394            task.wake_by_ref();
395        }
396    }
397
398    /// Returns the `SpawnHandle` for this writer.
399    pub fn handle(&self) -> SpawnHandle {
400        self.inner.0.borrow().handle
401    }
402}
403
404impl<I, T: AsyncWrite + Unpin, U: Encoder<I>> Drop for FramedWrite<I, T, U> {
405    fn drop(&mut self) {
406        // Attempts to write any remaining bytes to the stream and flush it
407        let mut async_writer = self.inner.1.borrow_mut();
408        let inner = self.inner.0.borrow_mut();
409        if !inner.buffer.is_empty() {
410            // Results must be ignored during drop, as the errors cannot be handled meaningfully
411            drop(async_writer.write(&inner.buffer));
412            drop(async_writer.flush());
413        }
414    }
415}
416
417/// A wrapper for the `Sink` type.
418pub struct SinkWrite<I, S: Sink<I> + Unpin> {
419    inner: Rc<RefCell<InnerSinkWrite<I, S>>>,
420}
421
422impl<I: 'static, S: Sink<I> + Unpin + 'static> SinkWrite<I, S> {
423    pub fn new<A, C>(sink: S, ctxt: &mut C) -> Self
424    where
425        A: Actor<Context = C> + WriteHandler<S::Error>,
426        C: AsyncContext<A>,
427    {
428        let inner = Rc::new(RefCell::new(InnerSinkWrite {
429            _i: PhantomData,
430            closing_flag: Flags::empty(),
431            sink,
432            task: None,
433            handle: SpawnHandle::default(),
434            buffer: VecDeque::new(),
435        }));
436
437        let handle = ctxt.spawn(SinkWriteFuture {
438            inner: inner.clone(),
439        });
440
441        inner.borrow_mut().handle = handle;
442        SinkWrite { inner }
443    }
444
445    /// Queues an item to be sent to the sink.
446    ///
447    /// Returns unsent item if sink is closing or closed.
448    pub fn write(&mut self, item: I) -> Result<(), I> {
449        if self.inner.borrow().closing_flag.is_empty() {
450            self.inner.borrow_mut().buffer.push_back(item);
451            self.notify_task();
452            Ok(())
453        } else {
454            Err(item)
455        }
456    }
457
458    /// Gracefully closes the sink.
459    ///
460    /// The closing happens asynchronously.
461    pub fn close(&mut self) {
462        self.inner.borrow_mut().closing_flag.insert(Flags::CLOSING);
463        self.notify_task();
464    }
465
466    /// Checks if the sink is closed.
467    pub fn closed(&self) -> bool {
468        self.inner.borrow_mut().closing_flag.contains(Flags::CLOSED)
469    }
470
471    fn notify_task(&self) {
472        if let Some(task) = &self.inner.borrow().task {
473            task.wake_by_ref()
474        }
475    }
476
477    /// Returns the `SpawnHandle` for this writer.
478    pub fn handle(&self) -> SpawnHandle {
479        self.inner.borrow().handle
480    }
481}
482
483struct InnerSinkWrite<I, S: Sink<I>> {
484    _i: PhantomData<I>,
485    closing_flag: Flags,
486    sink: S,
487    task: Option<task::Waker>,
488    handle: SpawnHandle,
489
490    // buffer of items to be sent so that multiple
491    // calls to start_send don't silently skip items
492    buffer: VecDeque<I>,
493}
494
495struct SinkWriteFuture<I: 'static, S: Sink<I>> {
496    inner: Rc<RefCell<InnerSinkWrite<I, S>>>,
497}
498
499impl<I: 'static, S: Sink<I>, A> ActorFuture<A> for SinkWriteFuture<I, S>
500where
501    S: Sink<I> + Unpin,
502    A: Actor + WriteHandler<S::Error>,
503    A::Context: AsyncContext<A>,
504{
505    type Output = ();
506
507    fn poll(
508        self: Pin<&mut Self>,
509        act: &mut A,
510        ctxt: &mut A::Context,
511        cx: &mut Context<'_>,
512    ) -> Poll<Self::Output> {
513        let this = self.get_mut();
514        let inner = &mut this.inner.borrow_mut();
515
516        // Loop to ensure we either process all items in the buffer, or trigger the inner sink to be pending
517        // and wake this task later.
518        loop {
519            // ensure sink is ready to receive next item
520            match Pin::new(&mut inner.sink).poll_ready(cx) {
521                Poll::Ready(Ok(())) => {
522                    if let Some(item) = inner.buffer.pop_front() {
523                        // send front of buffer to sink
524                        let _ = Pin::new(&mut inner.sink).start_send(item);
525                    } else {
526                        break;
527                    }
528                }
529                Poll::Ready(Err(_err)) => {
530                    break;
531                }
532                Poll::Pending => {
533                    break;
534                }
535            }
536        }
537
538        if !inner.closing_flag.contains(Flags::CLOSING) {
539            match Pin::new(&mut inner.sink).poll_flush(cx) {
540                Poll::Ready(Err(e)) => {
541                    if act.error(e, ctxt) == Running::Stop {
542                        act.finished(ctxt);
543                        return Poll::Ready(());
544                    }
545                }
546                Poll::Ready(Ok(())) => {}
547                Poll::Pending => {}
548            }
549        } else {
550            assert!(!inner.closing_flag.contains(Flags::CLOSED));
551            match Pin::new(&mut inner.sink).poll_close(cx) {
552                Poll::Ready(Err(e)) => {
553                    if act.error(e, ctxt) == Running::Stop {
554                        act.finished(ctxt);
555                        return Poll::Ready(());
556                    }
557                }
558                Poll::Ready(Ok(())) => {
559                    // ensure all items in buffer have been sent before closing
560                    if inner.buffer.is_empty() {
561                        inner.closing_flag |= Flags::CLOSED;
562                        act.finished(ctxt);
563                        return Poll::Ready(());
564                    }
565                }
566                Poll::Pending => {}
567            }
568        }
569
570        inner.task.replace(cx.waker().clone());
571
572        Poll::Pending
573    }
574}