1use std::{
2 cell::RefCell,
3 collections::VecDeque,
4 io,
5 marker::PhantomData,
6 ops::DerefMut,
7 pin::Pin,
8 rc::Rc,
9 task,
10 task::{Context, Poll},
11};
12
13use bitflags::bitflags;
14use bytes::BytesMut;
15use futures_sink::Sink;
16use tokio::io::{AsyncWrite, AsyncWriteExt};
17use tokio_util::codec::Encoder;
18
19use crate::{
20 actor::{Actor, ActorContext, AsyncContext, Running, SpawnHandle},
21 fut::ActorFuture,
22};
23
24#[allow(unused_variables)]
29pub trait WriteHandler<E>
30where
31 Self: Actor,
32 Self::Context: ActorContext,
33{
34 fn error(&mut self, err: E, ctx: &mut Self::Context) -> Running {
39 Running::Stop
40 }
41
42 fn finished(&mut self, ctx: &mut Self::Context) {
46 ctx.stop()
47 }
48}
49
50bitflags! {
51 struct Flags: u8 {
52 const CLOSING = 0b0000_0001;
53 const CLOSED = 0b0000_0010;
54 }
55}
56
57const LOW_WATERMARK: usize = 4 * 1024;
58const HIGH_WATERMARK: usize = 4 * LOW_WATERMARK;
59
60pub struct Writer<T: AsyncWrite, E: From<io::Error>> {
62 inner: UnsafeWriter<T, E>,
63}
64
65struct UnsafeWriter<T: AsyncWrite, E: From<io::Error>>(Rc<RefCell<InnerWriter<E>>>, Rc<RefCell<T>>);
66
67impl<T: AsyncWrite, E: From<io::Error>> Clone for UnsafeWriter<T, E> {
68 fn clone(&self) -> Self {
69 UnsafeWriter(self.0.clone(), self.1.clone())
70 }
71}
72
73struct InnerWriter<E: From<io::Error>> {
74 flags: Flags,
75 buffer: BytesMut,
76 error: Option<E>,
77 low: usize,
78 high: usize,
79 handle: SpawnHandle,
80 task: Option<task::Waker>,
81}
82
83impl<T: AsyncWrite, E: From<io::Error> + 'static> Writer<T, E> {
84 pub fn new<A, C>(io: T, ctx: &mut C) -> Self
85 where
86 A: Actor<Context = C> + WriteHandler<E>,
87 C: AsyncContext<A>,
88 T: Unpin + 'static,
89 {
90 let inner = UnsafeWriter(
91 Rc::new(RefCell::new(InnerWriter {
92 flags: Flags::empty(),
93 buffer: BytesMut::new(),
94 error: None,
95 low: LOW_WATERMARK,
96 high: HIGH_WATERMARK,
97 handle: SpawnHandle::default(),
98 task: None,
99 })),
100 Rc::new(RefCell::new(io)),
101 );
102 let h = ctx.spawn(WriterFut {
103 inner: inner.clone(),
104 });
105
106 let writer = Self { inner };
107 writer.inner.0.borrow_mut().handle = h;
108 writer
109 }
110
111 pub fn close(&mut self) {
115 self.inner.0.borrow_mut().flags.insert(Flags::CLOSING);
116 }
117
118 pub fn closed(&self) -> bool {
120 self.inner.0.borrow().flags.contains(Flags::CLOSED)
121 }
122
123 pub fn set_buffer_capacity(&mut self, low_watermark: usize, high_watermark: usize) {
125 let mut inner = self.inner.0.borrow_mut();
126 inner.low = low_watermark;
127 inner.high = high_watermark;
128 }
129
130 pub fn write(&mut self, msg: &[u8]) {
132 let mut inner = self.inner.0.borrow_mut();
133 inner.buffer.extend_from_slice(msg);
134 if let Some(task) = inner.task.take() {
135 task.wake_by_ref();
136 }
137 }
138
139 pub fn handle(&self) -> SpawnHandle {
141 self.inner.0.borrow().handle
142 }
143}
144
145struct WriterFut<T, E>
146where
147 T: AsyncWrite + Unpin,
148 E: From<io::Error>,
149{
150 inner: UnsafeWriter<T, E>,
151}
152
153impl<T: 'static, E: 'static, A> ActorFuture<A> for WriterFut<T, E>
154where
155 T: AsyncWrite + Unpin,
156 E: From<io::Error>,
157 A: Actor + WriteHandler<E>,
158 A::Context: AsyncContext<A>,
159{
160 type Output = ();
161
162 fn poll(
163 self: Pin<&mut Self>,
164 act: &mut A,
165 ctx: &mut A::Context,
166 task: &mut Context<'_>,
167 ) -> Poll<Self::Output> {
168 let this = self.get_mut();
169 let mut inner = this.inner.0.borrow_mut();
170 if let Some(err) = inner.error.take() {
171 if act.error(err, ctx) == Running::Stop {
172 act.finished(ctx);
173 return Poll::Ready(());
174 }
175 }
176
177 let mut io = this.inner.1.borrow_mut();
178 inner.task = None;
179 while !inner.buffer.is_empty() {
180 match Pin::new(io.deref_mut()).poll_write(task, &inner.buffer) {
181 Poll::Ready(Ok(n)) => {
182 if n == 0
183 && act.error(
184 io::Error::new(
185 io::ErrorKind::WriteZero,
186 "failed to write frame to transport",
187 )
188 .into(),
189 ctx,
190 ) == Running::Stop
191 {
192 act.finished(ctx);
193 return Poll::Ready(());
194 }
195 let _ = inner.buffer.split_to(n);
196 }
197 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
198 if inner.buffer.len() > inner.high {
199 ctx.wait(WriterDrain {
200 inner: this.inner.clone(),
201 });
202 }
203 return Poll::Pending;
204 }
205 Poll::Ready(Err(e)) => {
206 if act.error(e.into(), ctx) == Running::Stop {
207 act.finished(ctx);
208 return Poll::Ready(());
209 }
210 }
211 Poll::Pending => return Poll::Pending,
212 }
213 }
214
215 match Pin::new(io.deref_mut()).poll_flush(task) {
217 Poll::Ready(Ok(_)) => (),
218 Poll::Pending => return Poll::Pending,
219 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
220 return Poll::Pending;
221 }
222 Poll::Ready(Err(e)) => {
223 if act.error(e.into(), ctx) == Running::Stop {
224 act.finished(ctx);
225 return Poll::Ready(());
226 }
227 }
228 }
229
230 if inner.flags.contains(Flags::CLOSING) {
232 inner.flags |= Flags::CLOSED;
233 act.finished(ctx);
234 Poll::Ready(())
235 } else {
236 inner.task = Some(task.waker().clone());
237 Poll::Pending
238 }
239 }
240}
241
242struct WriterDrain<T, E>
243where
244 T: AsyncWrite + Unpin,
245 E: From<io::Error>,
246{
247 inner: UnsafeWriter<T, E>,
248}
249
250impl<T, E, A> ActorFuture<A> for WriterDrain<T, E>
251where
252 T: AsyncWrite + Unpin,
253 E: From<io::Error>,
254 A: Actor,
255 A::Context: AsyncContext<A>,
256{
257 type Output = ();
258
259 fn poll(
260 self: Pin<&mut Self>,
261 _: &mut A,
262 _: &mut A::Context,
263 task: &mut Context<'_>,
264 ) -> Poll<Self::Output> {
265 let this = self.get_mut();
266 let mut inner = this.inner.0.borrow_mut();
267 if inner.error.is_some() {
268 return Poll::Ready(());
269 }
270 let mut io = this.inner.1.borrow_mut();
271 while !inner.buffer.is_empty() {
272 match Pin::new(io.deref_mut()).poll_write(task, &inner.buffer) {
273 Poll::Ready(Ok(n)) => {
274 if n == 0 {
275 inner.error = Some(
276 io::Error::new(
277 io::ErrorKind::WriteZero,
278 "failed to write frame to transport",
279 )
280 .into(),
281 );
282 return Poll::Ready(());
283 }
284 let _ = inner.buffer.split_to(n);
285 }
286 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::WouldBlock => {
287 return if inner.buffer.len() < inner.low {
288 Poll::Ready(())
289 } else {
290 Poll::Pending
291 };
292 }
293 Poll::Ready(Err(e)) => {
294 inner.error = Some(e.into());
295 return Poll::Ready(());
296 }
297 Poll::Pending => return Poll::Pending,
298 }
299 }
300 Poll::Ready(())
301 }
302}
303
304pub struct FramedWrite<I, T: AsyncWrite + Unpin, U: Encoder<I>> {
307 enc: U,
308 inner: UnsafeWriter<T, U::Error>,
309}
310
311impl<I, T: AsyncWrite + Unpin, U: Encoder<I>> FramedWrite<I, T, U> {
312 pub fn new<A, C>(io: T, enc: U, ctx: &mut C) -> Self
313 where
314 A: Actor<Context = C> + WriteHandler<U::Error>,
315 C: AsyncContext<A>,
316 U::Error: 'static,
317 T: Unpin + 'static,
318 {
319 let inner = UnsafeWriter(
320 Rc::new(RefCell::new(InnerWriter {
321 flags: Flags::empty(),
322 buffer: BytesMut::new(),
323 error: None,
324 low: LOW_WATERMARK,
325 high: HIGH_WATERMARK,
326 handle: SpawnHandle::default(),
327 task: None,
328 })),
329 Rc::new(RefCell::new(io)),
330 );
331 let h = ctx.spawn(WriterFut {
332 inner: inner.clone(),
333 });
334
335 let writer = Self { enc, inner };
336 writer.inner.0.borrow_mut().handle = h;
337 writer
338 }
339
340 pub fn from_buffer<A, C>(io: T, enc: U, buffer: BytesMut, ctx: &mut C) -> Self
341 where
342 A: Actor<Context = C> + WriteHandler<U::Error>,
343 C: AsyncContext<A>,
344 U::Error: 'static,
345 T: Unpin + 'static,
346 {
347 let inner = UnsafeWriter(
348 Rc::new(RefCell::new(InnerWriter {
349 buffer,
350 flags: Flags::empty(),
351 error: None,
352 low: LOW_WATERMARK,
353 high: HIGH_WATERMARK,
354 handle: SpawnHandle::default(),
355 task: None,
356 })),
357 Rc::new(RefCell::new(io)),
358 );
359 let h = ctx.spawn(WriterFut {
360 inner: inner.clone(),
361 });
362
363 let writer = Self { enc, inner };
364 writer.inner.0.borrow_mut().handle = h;
365 writer
366 }
367
368 pub fn close(&mut self) {
372 self.inner.0.borrow_mut().flags.insert(Flags::CLOSING);
373 }
374
375 pub fn closed(&self) -> bool {
377 self.inner.0.borrow().flags.contains(Flags::CLOSED)
378 }
379
380 pub fn set_buffer_capacity(&mut self, low: usize, high: usize) {
382 let mut inner = self.inner.0.borrow_mut();
383 inner.low = low;
384 inner.high = high;
385 }
386
387 pub fn write(&mut self, item: I) {
389 let mut inner = self.inner.0.borrow_mut();
390 let _ = self.enc.encode(item, &mut inner.buffer).map_err(|e| {
391 inner.error = Some(e);
392 });
393 if let Some(task) = inner.task.take() {
394 task.wake_by_ref();
395 }
396 }
397
398 pub fn handle(&self) -> SpawnHandle {
400 self.inner.0.borrow().handle
401 }
402}
403
404impl<I, T: AsyncWrite + Unpin, U: Encoder<I>> Drop for FramedWrite<I, T, U> {
405 fn drop(&mut self) {
406 let mut async_writer = self.inner.1.borrow_mut();
408 let inner = self.inner.0.borrow_mut();
409 if !inner.buffer.is_empty() {
410 drop(async_writer.write(&inner.buffer));
412 drop(async_writer.flush());
413 }
414 }
415}
416
417pub struct SinkWrite<I, S: Sink<I> + Unpin> {
419 inner: Rc<RefCell<InnerSinkWrite<I, S>>>,
420}
421
422impl<I: 'static, S: Sink<I> + Unpin + 'static> SinkWrite<I, S> {
423 pub fn new<A, C>(sink: S, ctxt: &mut C) -> Self
424 where
425 A: Actor<Context = C> + WriteHandler<S::Error>,
426 C: AsyncContext<A>,
427 {
428 let inner = Rc::new(RefCell::new(InnerSinkWrite {
429 _i: PhantomData,
430 closing_flag: Flags::empty(),
431 sink,
432 task: None,
433 handle: SpawnHandle::default(),
434 buffer: VecDeque::new(),
435 }));
436
437 let handle = ctxt.spawn(SinkWriteFuture {
438 inner: inner.clone(),
439 });
440
441 inner.borrow_mut().handle = handle;
442 SinkWrite { inner }
443 }
444
445 pub fn write(&mut self, item: I) -> Result<(), I> {
449 if self.inner.borrow().closing_flag.is_empty() {
450 self.inner.borrow_mut().buffer.push_back(item);
451 self.notify_task();
452 Ok(())
453 } else {
454 Err(item)
455 }
456 }
457
458 pub fn close(&mut self) {
462 self.inner.borrow_mut().closing_flag.insert(Flags::CLOSING);
463 self.notify_task();
464 }
465
466 pub fn closed(&self) -> bool {
468 self.inner.borrow_mut().closing_flag.contains(Flags::CLOSED)
469 }
470
471 fn notify_task(&self) {
472 if let Some(task) = &self.inner.borrow().task {
473 task.wake_by_ref()
474 }
475 }
476
477 pub fn handle(&self) -> SpawnHandle {
479 self.inner.borrow().handle
480 }
481}
482
483struct InnerSinkWrite<I, S: Sink<I>> {
484 _i: PhantomData<I>,
485 closing_flag: Flags,
486 sink: S,
487 task: Option<task::Waker>,
488 handle: SpawnHandle,
489
490 buffer: VecDeque<I>,
493}
494
495struct SinkWriteFuture<I: 'static, S: Sink<I>> {
496 inner: Rc<RefCell<InnerSinkWrite<I, S>>>,
497}
498
499impl<I: 'static, S: Sink<I>, A> ActorFuture<A> for SinkWriteFuture<I, S>
500where
501 S: Sink<I> + Unpin,
502 A: Actor + WriteHandler<S::Error>,
503 A::Context: AsyncContext<A>,
504{
505 type Output = ();
506
507 fn poll(
508 self: Pin<&mut Self>,
509 act: &mut A,
510 ctxt: &mut A::Context,
511 cx: &mut Context<'_>,
512 ) -> Poll<Self::Output> {
513 let this = self.get_mut();
514 let inner = &mut this.inner.borrow_mut();
515
516 loop {
519 match Pin::new(&mut inner.sink).poll_ready(cx) {
521 Poll::Ready(Ok(())) => {
522 if let Some(item) = inner.buffer.pop_front() {
523 let _ = Pin::new(&mut inner.sink).start_send(item);
525 } else {
526 break;
527 }
528 }
529 Poll::Ready(Err(_err)) => {
530 break;
531 }
532 Poll::Pending => {
533 break;
534 }
535 }
536 }
537
538 if !inner.closing_flag.contains(Flags::CLOSING) {
539 match Pin::new(&mut inner.sink).poll_flush(cx) {
540 Poll::Ready(Err(e)) => {
541 if act.error(e, ctxt) == Running::Stop {
542 act.finished(ctxt);
543 return Poll::Ready(());
544 }
545 }
546 Poll::Ready(Ok(())) => {}
547 Poll::Pending => {}
548 }
549 } else {
550 assert!(!inner.closing_flag.contains(Flags::CLOSED));
551 match Pin::new(&mut inner.sink).poll_close(cx) {
552 Poll::Ready(Err(e)) => {
553 if act.error(e, ctxt) == Running::Stop {
554 act.finished(ctxt);
555 return Poll::Ready(());
556 }
557 }
558 Poll::Ready(Ok(())) => {
559 if inner.buffer.is_empty() {
561 inner.closing_flag |= Flags::CLOSED;
562 act.finished(ctxt);
563 return Poll::Ready(());
564 }
565 }
566 Poll::Pending => {}
567 }
568 }
569
570 inner.task.replace(cx.waker().clone());
571
572 Poll::Pending
573 }
574}