actix/address/
channel.rs

1//! This is copy of [sync/mpsc/](https://github.com/rust-lang/futures-rs)
2
3use std::{
4    fmt,
5    hash::{Hash, Hasher},
6    pin::Pin,
7    sync::{
8        atomic::{
9            AtomicBool, AtomicUsize,
10            Ordering::{Relaxed, SeqCst},
11        },
12        Arc, Weak,
13    },
14    task::{self, Poll},
15    thread,
16};
17
18use futures_core::{stream::Stream, task::__internal::AtomicWaker};
19use parking_lot::Mutex;
20use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver};
21
22use super::{
23    envelope::{Envelope, ToEnvelope},
24    queue::Queue,
25    SendError,
26};
27use crate::{
28    actor::Actor,
29    handler::{Handler, Message},
30};
31
32pub trait Sender<M>: Send
33where
34    M::Result: Send,
35    M: Message + Send,
36{
37    fn do_send(&self, msg: M) -> Result<(), SendError<M>>;
38
39    fn try_send(&self, msg: M) -> Result<(), SendError<M>>;
40
41    fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>;
42
43    fn boxed(&self) -> Box<dyn Sender<M> + Sync>;
44
45    fn hash(&self) -> usize;
46
47    fn connected(&self) -> bool;
48
49    /// Returns a downgraded sender, where the sender is downgraded into its weak counterpart.
50    fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync + 'static>;
51}
52
53impl<S, M> Sender<M> for Box<S>
54where
55    S: Sender<M> + ?Sized,
56    M::Result: Send,
57    M: Message + Send,
58{
59    fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
60        (**self).do_send(msg)
61    }
62
63    fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
64        (**self).try_send(msg)
65    }
66
67    fn send(&self, msg: M) -> Result<OneshotReceiver<<M as Message>::Result>, SendError<M>> {
68        (**self).send(msg)
69    }
70
71    fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
72        (**self).boxed()
73    }
74
75    fn hash(&self) -> usize {
76        (**self).hash()
77    }
78
79    fn connected(&self) -> bool {
80        (**self).connected()
81    }
82
83    fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync> {
84        (**self).downgrade()
85    }
86}
87
88pub trait WeakSender<M>: Send
89where
90    M::Result: Send,
91    M: Message + Send,
92{
93    /// Attempts to upgrade a `WeakAddressSender<A>` to a [`Sender<M>`]
94    ///
95    /// Returns [`None`] if the actor has since been dropped.
96    fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>>;
97
98    fn boxed(&self) -> Box<dyn WeakSender<M> + Sync>;
99}
100
101/// The transmission end of a channel which is used to send values.
102///
103/// This is created by the `channel` method.
104pub struct AddressSender<A: Actor> {
105    // Channel state shared between the sender and receiver.
106    inner: Arc<Inner<A>>,
107
108    // Handle to the task that is blocked on this sender. This handle is sent
109    // to the receiver half in order to be notified when the sender becomes
110    // unblocked.
111    sender_task: Arc<Mutex<SenderTask>>,
112
113    // True if the sender might be blocked. This is an optimization to avoid
114    // having to lock the mutex most of the time.
115    maybe_parked: Arc<AtomicBool>,
116}
117
118impl<A: Actor> fmt::Debug for AddressSender<A> {
119    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
120        fmt.debug_struct("AddressSender")
121            .field("sender_task", &self.sender_task)
122            .field("maybe_parked", &self.maybe_parked)
123            .finish()
124    }
125}
126
127/// A weakly referenced version of `AddressSender`.
128///
129/// This is created by the `AddressSender::downgrade` method.
130pub struct WeakAddressSender<A: Actor> {
131    inner: Weak<Inner<A>>,
132}
133
134impl<A: Actor> Clone for WeakAddressSender<A> {
135    fn clone(&self) -> WeakAddressSender<A> {
136        WeakAddressSender {
137            inner: self.inner.clone(),
138        }
139    }
140}
141
142impl<A: Actor> fmt::Debug for WeakAddressSender<A> {
143    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
144        fmt.debug_struct("WeakAddressSender").finish()
145    }
146}
147
148impl<A: Actor> PartialEq for WeakAddressSender<A> {
149    fn eq(&self, other: &Self) -> bool {
150        self.inner.ptr_eq(&other.inner)
151    }
152}
153
154impl<A: Actor> Eq for WeakAddressSender<A> {}
155
156#[allow(dead_code)]
157trait AssertKinds: Send + Sync + Clone {}
158
159/// The receiving end of a channel which implements the `Stream` trait.
160///
161/// This is a concrete implementation of a stream which can be used to represent
162/// a stream of values being computed elsewhere. This is created by the
163/// `channel` method.
164pub struct AddressReceiver<A: Actor> {
165    inner: Arc<Inner<A>>,
166}
167
168/// Generate `AddressSenders` for the channel
169pub struct AddressSenderProducer<A: Actor> {
170    inner: Arc<Inner<A>>,
171}
172
173struct Inner<A: Actor> {
174    // Max buffer size of the channel. If `0` then the channel is unbounded.
175    buffer: AtomicUsize,
176
177    // Internal channel state. Consists of the number of messages stored in the
178    // channel as well as a flag signalling that the channel is closed.
179    state: AtomicUsize,
180
181    // Atomic, FIFO queue used to send messages to the receiver.
182    message_queue: Queue<Envelope<A>>,
183
184    // Atomic, FIFO queue used to send parked task handles to the receiver.
185    parked_queue: Queue<Arc<Mutex<SenderTask>>>,
186
187    // Number of senders in existence.
188    num_senders: AtomicUsize,
189
190    // Handle to the receiver's task.
191    recv_task: AtomicWaker,
192}
193
194// Struct representation of `Inner::state`.
195#[derive(Debug, Clone, Copy)]
196struct State {
197    // `true` when the channel is open
198    is_open: bool,
199
200    // Number of messages in the channel
201    num_messages: usize,
202}
203
204impl State {
205    fn is_closed(&self) -> bool {
206        !self.is_open && self.num_messages == 0
207    }
208}
209
210// The `is_open` flag is stored in the left-most bit of `Inner::state`
211const OPEN_MASK: usize = usize::MAX - (usize::MAX >> 1);
212
213// When a new channel is created, it is created in the open state with no
214// pending messages.
215const INIT_STATE: usize = OPEN_MASK;
216
217// The maximum number of messages that a channel can track is `usize::MAX >> 1`
218const MAX_CAPACITY: usize = !(OPEN_MASK);
219
220// The maximum requested buffer size must be less than the maximum capacity of
221// a channel. This is because each sender gets a guaranteed slot.
222const MAX_BUFFER: usize = MAX_CAPACITY >> 1;
223
224// Sent to the consumer to wake up blocked producers
225#[derive(Debug)]
226struct SenderTask {
227    task: Option<task::Waker>,
228    is_parked: bool,
229}
230
231impl SenderTask {
232    fn new() -> Self {
233        SenderTask {
234            task: None,
235            is_parked: false,
236        }
237    }
238
239    fn notify(&mut self) -> bool {
240        self.is_parked = false;
241
242        if let Some(task) = self.task.take() {
243            task.wake();
244            true
245        } else {
246            false
247        }
248    }
249}
250
251/// Creates an in-memory channel implementation of the `Stream` trait with
252/// bounded capacity.
253///
254/// This method creates a concrete implementation of the `Stream` trait which
255/// can be used to send values across threads in a streaming fashion. This
256/// channel is unique in that it implements back pressure to ensure that the
257/// sender never outpaces the receiver. The channel capacity is equal to
258/// `buffer + num-senders`. In other words, each sender gets a guaranteed slot
259/// in the channel capacity, and on top of that there are `buffer` "first come,
260/// first serve" slots available to all senders.
261///
262/// The `Receiver` returned implements the `Stream` trait and has access to any
263/// number of the associated combinators for transforming the result.
264pub fn channel<A: Actor>(buffer: usize) -> (AddressSender<A>, AddressReceiver<A>) {
265    // Check that the requested buffer size does not exceed the maximum buffer
266    // size permitted by the system.
267    assert!(buffer < MAX_BUFFER, "requested buffer size too large");
268
269    let inner = Arc::new(Inner {
270        buffer: AtomicUsize::new(buffer),
271        state: AtomicUsize::new(INIT_STATE),
272        message_queue: Queue::new(),
273        parked_queue: Queue::new(),
274        num_senders: AtomicUsize::new(1),
275        recv_task: AtomicWaker::new(),
276    });
277
278    let tx = AddressSender {
279        inner: Arc::clone(&inner),
280        sender_task: Arc::new(Mutex::new(SenderTask::new())),
281        maybe_parked: Arc::new(AtomicBool::new(false)),
282    };
283
284    let rx = AddressReceiver { inner };
285
286    (tx, rx)
287}
288
289//
290//
291// ===== impl Sender =====
292//
293//
294impl<A: Actor> AddressSender<A> {
295    /// Is the channel still open
296    pub fn connected(&self) -> bool {
297        let curr = self.inner.state.load(SeqCst);
298        let state = decode_state(curr);
299
300        state.is_open
301    }
302
303    /// Attempts to send a message on this `Sender<A>` with blocking.
304    ///
305    /// This function must be called from inside of a task.
306    pub fn send<M>(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>
307    where
308        A: Handler<M>,
309        A::Context: ToEnvelope<A, M>,
310        M::Result: Send,
311        M: Message + Send,
312    {
313        // If the sender is currently blocked, reject the message
314        if !self.poll_unparked(false, None).is_ready() {
315            return Err(SendError::Full(msg));
316        }
317
318        // First, increment the number of messages contained by the channel.
319        // This operation will also atomically determine if the sender task
320        // should be parked.
321        //
322        // None is returned in the case that the channel has been closed by the
323        // receiver. This happens when `Receiver::close` is called or the
324        // receiver is dropped.
325        let park_self = match self.inc_num_messages() {
326            Some(num_messages) => {
327                // receiver is full
328                let buffer = self.inner.buffer.load(Relaxed);
329                buffer != 0 && num_messages >= buffer
330            }
331            None => return Err(SendError::Closed(msg)),
332        };
333
334        // If the channel has reached capacity, then the sender task needs to
335        // be parked. This will send the task handle on the parked task queue.
336        if park_self {
337            self.park();
338        }
339        let (tx, rx) = oneshot_channel();
340        let env = <A::Context as ToEnvelope<A, M>>::pack(msg, Some(tx));
341        self.queue_push_and_signal(env);
342        Ok(rx)
343    }
344
345    /// Attempts to send a message on this `Sender<A>` without blocking.
346    pub fn try_send<M>(&self, msg: M, park: bool) -> Result<(), SendError<M>>
347    where
348        A: Handler<M>,
349        <A as Actor>::Context: ToEnvelope<A, M>,
350        M::Result: Send,
351        M: Message + Send + 'static,
352    {
353        // If the sender is currently blocked, reject the message
354        if !self.poll_unparked(false, None).is_ready() {
355            return Err(SendError::Full(msg));
356        }
357
358        let park_self = match self.inc_num_messages() {
359            Some(num_messages) => {
360                // receiver is full
361                let buffer = self.inner.buffer.load(Relaxed);
362                buffer != 0 && num_messages >= buffer
363            }
364            None => return Err(SendError::Closed(msg)),
365        };
366
367        if park_self && park {
368            self.park();
369        }
370        let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
371        self.queue_push_and_signal(env);
372        Ok(())
373    }
374
375    /// Send a message on this `Sender<A>` without blocking.
376    ///
377    /// This function does not park current task.
378    pub fn do_send<M>(&self, msg: M) -> Result<(), SendError<M>>
379    where
380        A: Handler<M>,
381        <A as Actor>::Context: ToEnvelope<A, M>,
382        M::Result: Send,
383        M: Message + Send,
384    {
385        if self.inc_num_messages().is_none() {
386            Err(SendError::Closed(msg))
387        } else {
388            // If inc_num_messages returned Some(park_self), then the mailbox is still active.
389            // We ignore the boolean (indicating to park and wait) in the Some, and queue the
390            // message regardless.
391            let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
392            self.queue_push_and_signal(env);
393            Ok(())
394        }
395    }
396
397    /// Downgrade to `WeakAddressSender` which can later be upgraded
398    pub fn downgrade(&self) -> WeakAddressSender<A> {
399        WeakAddressSender {
400            inner: Arc::downgrade(&self.inner),
401        }
402    }
403
404    // Push message to the queue and signal to the receiver
405    fn queue_push_and_signal(&self, msg: Envelope<A>) {
406        // Push the message onto the message queue
407        self.inner.message_queue.push(msg);
408
409        // Signal to the receiver that a message has been enqueued. If the
410        // receiver is parked, this will unpark the task.
411        self.inner.recv_task.wake();
412    }
413
414    // Increment the number of queued messages. Returns if the sender should
415    // block.
416    fn inc_num_messages(&self) -> Option<usize> {
417        let mut curr = self.inner.state.load(SeqCst);
418        loop {
419            let mut state = decode_state(curr);
420            if !state.is_open {
421                return None;
422            }
423            state.num_messages += 1;
424
425            let next = encode_state(&state);
426            match self
427                .inner
428                .state
429                .compare_exchange(curr, next, SeqCst, SeqCst)
430            {
431                Ok(_) => {
432                    return Some(state.num_messages);
433                }
434                Err(actual) => curr = actual,
435            }
436        }
437    }
438
439    // TODO: Not sure about this one, I modified code to match the futures one, might still be buggy
440    fn park(&self) {
441        {
442            let mut sender = self.sender_task.lock();
443            sender.task = None;
444            sender.is_parked = true;
445        }
446
447        // Send handle over queue
448        self.inner.parked_queue.push(Arc::clone(&self.sender_task));
449
450        // Check to make sure we weren't closed after we sent our task on the queue
451        let state = decode_state(self.inner.state.load(SeqCst));
452        self.maybe_parked.store(state.is_open, Relaxed);
453    }
454
455    fn poll_unparked(&self, do_park: bool, cx: Option<&mut task::Context<'_>>) -> Poll<()> {
456        // First check the `maybe_parked` variable. This avoids acquiring the
457        // lock in most cases
458        if self.maybe_parked.load(Relaxed) {
459            // Get a lock on the task handle
460            let mut task = self.sender_task.lock();
461
462            if !task.is_parked {
463                self.maybe_parked.store(false, Relaxed);
464                return Poll::Ready(());
465            }
466
467            // At this point, an unpark request is pending, so there will be an
468            // unpark sometime in the future. We just need to make sure that
469            // the correct task will be notified.
470            //
471            // Update the task in case the `Sender` has been moved to another
472            // task
473            task.task = if do_park {
474                cx.map(|cx| cx.waker().clone())
475            } else {
476                None
477            };
478
479            Poll::Pending
480        } else {
481            Poll::Ready(())
482        }
483    }
484}
485
486impl<A, M> Sender<M> for AddressSender<A>
487where
488    A: Handler<M>,
489    A::Context: ToEnvelope<A, M>,
490    M::Result: Send,
491    M: Message + Send + 'static,
492{
493    fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
494        self.do_send(msg)
495    }
496    fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
497        self.try_send(msg, true)
498    }
499    fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>> {
500        self.send(msg)
501    }
502    fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
503        Box::new(self.clone())
504    }
505
506    fn hash(&self) -> usize {
507        let hash: *const _ = self.inner.as_ref();
508        hash as usize
509    }
510
511    fn connected(&self) -> bool {
512        self.connected()
513    }
514
515    fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync + 'static> {
516        Box::new(WeakAddressSender {
517            inner: Arc::downgrade(&self.inner),
518        })
519    }
520}
521
522impl<A: Actor> Clone for AddressSender<A> {
523    fn clone(&self) -> AddressSender<A> {
524        // Since this atomic op isn't actually guarding any memory and we don't
525        // care about any orderings besides the ordering on the single atomic
526        // variable, a relaxed ordering is acceptable.
527        let mut curr = self.inner.num_senders.load(SeqCst);
528
529        loop {
530            // If the maximum number of senders has been reached, then fail
531            if curr == self.inner.max_senders() {
532                panic!("cannot clone `Sender` -- too many outstanding senders");
533            }
534
535            debug_assert!(curr < self.inner.max_senders());
536
537            let next = curr + 1;
538            #[allow(deprecated)]
539            let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
540
541            // The ABA problem doesn't matter here. We only care that the
542            // number of senders never exceeds the maximum.
543            if actual == curr {
544                return AddressSender {
545                    inner: Arc::clone(&self.inner),
546                    sender_task: Arc::new(Mutex::new(SenderTask::new())),
547                    maybe_parked: Arc::new(AtomicBool::new(false)),
548                };
549            }
550
551            curr = actual;
552        }
553    }
554}
555
556impl<A: Actor> Drop for AddressSender<A> {
557    fn drop(&mut self) {
558        // Ordering between variables don't matter here
559        let prev = self.inner.num_senders.fetch_sub(1, SeqCst);
560        // last sender, notify receiver task
561        if prev == 1 {
562            self.inner.recv_task.wake();
563        }
564    }
565}
566
567impl<A: Actor> PartialEq for AddressSender<A> {
568    fn eq(&self, other: &Self) -> bool {
569        Arc::ptr_eq(&self.inner, &other.inner)
570    }
571}
572
573impl<A: Actor> Eq for AddressSender<A> {}
574
575impl<A: Actor> Hash for AddressSender<A> {
576    fn hash<H: Hasher>(&self, state: &mut H) {
577        let hash: *const Inner<A> = self.inner.as_ref();
578        hash.hash(state);
579    }
580}
581
582//
583//
584// ===== impl WeakSender =====
585//
586//
587impl<A: Actor> WeakAddressSender<A> {
588    /// Attempts to upgrade the `WeakAddressSender<A>` pointer to an [`AddressSender<A>`]
589    ///
590    /// Returns [`None`] if the actor has since been dropped.
591    pub fn upgrade(&self) -> Option<AddressSender<A>> {
592        Weak::upgrade(&self.inner).map(|inner| AddressSenderProducer { inner }.sender())
593    }
594}
595
596impl<A, M> WeakSender<M> for WeakAddressSender<A>
597where
598    A: Handler<M>,
599    A::Context: ToEnvelope<A, M>,
600    M::Result: Send,
601    M: Message + Send + 'static,
602{
603    fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>> {
604        if let Some(inner) = WeakAddressSender::upgrade(self) {
605            Some(Box::new(inner))
606        } else {
607            None
608        }
609    }
610
611    fn boxed(&self) -> Box<dyn WeakSender<M> + Sync> {
612        Box::new(self.clone())
613    }
614}
615
616//
617//
618// ===== impl SenderProducer =====
619//
620//
621impl<A: Actor> AddressSenderProducer<A> {
622    /// Are any senders connected
623    pub fn connected(&self) -> bool {
624        self.inner.num_senders.load(SeqCst) != 0
625    }
626
627    /// Get channel capacity
628    pub fn capacity(&self) -> usize {
629        self.inner.buffer.load(Relaxed)
630    }
631
632    /// Set channel capacity
633    ///
634    /// This method wakes up all waiting senders if new capacity is greater
635    /// than current
636    pub fn set_capacity(&mut self, cap: usize) {
637        let buffer = self.inner.buffer.load(Relaxed);
638        self.inner.buffer.store(cap, Relaxed);
639
640        // wake up all
641        if cap > buffer {
642            while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
643                task.lock().notify();
644            }
645        }
646    }
647
648    /// Get sender side of the channel
649    pub fn sender(&self) -> AddressSender<A> {
650        // this code same as Sender::clone
651        let mut curr = self.inner.num_senders.load(SeqCst);
652
653        loop {
654            // If the maximum number of senders has been reached, then fail
655            if curr == self.inner.max_senders() {
656                panic!("cannot clone `Sender` -- too many outstanding senders");
657            }
658
659            let next = curr + 1;
660            #[allow(deprecated)]
661            let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
662
663            // The ABA problem doesn't matter here. We only care that the
664            // number of senders never exceeds the maximum.
665            if actual == curr {
666                return AddressSender {
667                    inner: Arc::clone(&self.inner),
668                    sender_task: Arc::new(Mutex::new(SenderTask::new())),
669                    maybe_parked: Arc::new(AtomicBool::new(false)),
670                };
671            }
672
673            curr = actual;
674        }
675    }
676}
677
678//
679//
680// ===== impl Receiver =====
681//
682//
683impl<A: Actor> AddressReceiver<A> {
684    /// Returns whether any senders are still connected.
685    pub fn connected(&self) -> bool {
686        self.inner.num_senders.load(SeqCst) != 0
687    }
688
689    /// Returns the channel capacity.
690    pub fn capacity(&self) -> usize {
691        self.inner.buffer.load(Relaxed)
692    }
693
694    /// Sets the channel capacity.
695    ///
696    /// This method wakes up all waiting senders if the new capacity
697    /// is greater than the current one.
698    pub fn set_capacity(&mut self, cap: usize) {
699        let buffer = self.inner.buffer.load(Relaxed);
700        self.inner.buffer.store(cap, Relaxed);
701
702        // wake up all
703        if cap > buffer {
704            while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
705                task.lock().notify();
706            }
707        }
708    }
709
710    /// Returns the sender side of the channel.
711    pub fn sender(&self) -> AddressSender<A> {
712        // this code same as Sender::clone
713        let mut curr = self.inner.num_senders.load(SeqCst);
714
715        loop {
716            // If the maximum number of senders has been reached, then fail
717            if curr == self.inner.max_senders() {
718                panic!("cannot clone `Sender` -- too many outstanding senders");
719            }
720
721            let next = curr + 1;
722            #[allow(deprecated)]
723            let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
724
725            // The ABA problem doesn't matter here. We only care that the
726            // number of senders never exceeds the maximum.
727            if actual == curr {
728                return AddressSender {
729                    inner: Arc::clone(&self.inner),
730                    sender_task: Arc::new(Mutex::new(SenderTask::new())),
731                    maybe_parked: Arc::new(AtomicBool::new(false)),
732                };
733            }
734
735            curr = actual;
736        }
737    }
738
739    /// Creates the sender producer.
740    pub fn sender_producer(&self) -> AddressSenderProducer<A> {
741        AddressSenderProducer {
742            inner: self.inner.clone(),
743        }
744    }
745
746    fn next_message(&mut self) -> Poll<Option<Envelope<A>>> {
747        // Pop off a message
748        match unsafe { self.inner.message_queue.pop_spin() } {
749            Some(msg) => {
750                // If there are any parked task handles in the parked queue,
751                // pop one and unpark it.
752                self.unpark_one();
753
754                // Decrement number of messages
755                self.dec_num_messages();
756
757                Poll::Ready(Some(msg))
758            }
759            None => {
760                let state = decode_state(self.inner.state.load(SeqCst));
761                if state.is_closed() {
762                    // If closed flag is set AND there are no pending messages
763                    // it means end of stream
764                    Poll::Ready(None)
765                } else {
766                    // If queue is open, we need to return Pending
767                    // to be woken up when new messages arrive.
768                    // If queue is closed but num_messages is non-zero,
769                    // it means that senders updated the state,
770                    // but didn't put message to queue yet,
771                    // so we need to park until sender unparks the task
772                    // after queueing the message.
773                    Poll::Pending
774                }
775            }
776        }
777    }
778
779    // Unpark a single task handle if there is one pending in the parked queue
780    fn unpark_one(&mut self) {
781        if let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
782            task.lock().notify();
783        }
784    }
785
786    fn dec_num_messages(&self) {
787        // OPEN_MASK is highest bit, so it's unaffected by subtraction
788        // unless there's underflow, and we know there's no underflow
789        // because number of messages at this point is always > 0.
790        self.inner.state.fetch_sub(1, SeqCst);
791    }
792}
793
794impl<A: Actor> Stream for AddressReceiver<A> {
795    type Item = Envelope<A>;
796
797    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
798        let this = self.get_mut();
799        match this.next_message() {
800            Poll::Ready(msg) => Poll::Ready(msg),
801            Poll::Pending => {
802                // There are no messages to read, in this case, park.
803                this.inner.recv_task.register(cx.waker());
804                // Check queue again after parking to prevent race condition:
805                // a message could be added to the queue after previous `next_message`
806                // before `register` call.
807                this.next_message()
808            }
809        }
810    }
811}
812
813impl<A: Actor> Drop for AddressReceiver<A> {
814    fn drop(&mut self) {
815        // close
816        self.inner.set_closed();
817
818        // Wake up any threads waiting as they'll see that we've closed the
819        // channel and will continue on their merry way.
820        while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
821            task.lock().notify();
822        }
823
824        // Drain the channel of all pending messages
825        loop {
826            match self.next_message() {
827                Poll::Ready(Some(_)) => {}
828                Poll::Ready(None) => break,
829                Poll::Pending => {
830                    let state = decode_state(self.inner.state.load(SeqCst));
831
832                    // If the channel is closed, then there is no need to park.
833                    if state.is_closed() {
834                        break;
835                    }
836
837                    // TODO: Spinning isn't ideal, it might be worth
838                    // investigating using a condvar or some other strategy
839                    // here. That said, if this case is hit, then another thread
840                    // is about to push the value into the queue and this isn't
841                    // the only spinlock in the impl right now.
842                    thread::yield_now();
843                }
844            }
845        }
846    }
847}
848
849//
850//
851// ===== impl Inner =====
852//
853//
854impl<A: Actor> Inner<A> {
855    // The return value is such that the total number of messages that can be
856    // enqueued into the channel will never exceed MAX_CAPACITY
857    fn max_senders(&self) -> usize {
858        MAX_CAPACITY - self.buffer.load(Relaxed)
859    }
860
861    // Clear `open` flag in the state, keep `num_messages` intact.
862    fn set_closed(&self) {
863        let curr = self.state.load(SeqCst);
864        if !decode_state(curr).is_open {
865            return;
866        }
867
868        self.state.fetch_and(!OPEN_MASK, SeqCst);
869    }
870}
871
872unsafe impl<A: Actor> Send for Inner<A> {}
873unsafe impl<A: Actor> Sync for Inner<A> {}
874
875//
876//
877// ===== Helpers =====
878//
879//
880fn decode_state(num: usize) -> State {
881    State {
882        is_open: num & OPEN_MASK == OPEN_MASK,
883        num_messages: num & MAX_CAPACITY,
884    }
885}
886
887fn encode_state(state: &State) -> usize {
888    let mut num = state.num_messages;
889
890    if state.is_open {
891        num |= OPEN_MASK;
892    }
893
894    num
895}
896
897#[cfg(test)]
898mod tests {
899    use std::time;
900
901    use super::*;
902    use crate::{address::queue::PopResult, prelude::*};
903
904    struct Act;
905    impl Actor for Act {
906        type Context = Context<Act>;
907    }
908
909    struct Ping;
910    impl Message for Ping {
911        type Result = ();
912    }
913
914    impl Handler<Ping> for Act {
915        type Result = ();
916        fn handle(&mut self, _: Ping, _: &mut Context<Act>) {}
917    }
918
919    #[test]
920    fn test_cap() {
921        System::new().block_on(async {
922            let (s1, mut recv) = channel::<Act>(1);
923            let s2 = recv.sender();
924
925            let arb = Arbiter::new();
926            arb.spawn_fn(move || {
927                let _ = s1.send(Ping);
928            });
929            thread::sleep(time::Duration::from_millis(100));
930            let arb2 = Arbiter::new();
931            arb2.spawn_fn(move || {
932                let _ = s2.send(Ping);
933                let _ = s2.send(Ping);
934            });
935
936            thread::sleep(time::Duration::from_millis(100));
937            let state = decode_state(recv.inner.state.load(SeqCst));
938            assert_eq!(state.num_messages, 2);
939
940            let p = loop {
941                match unsafe { recv.inner.parked_queue.pop() } {
942                    PopResult::Data(task) => break Some(task),
943                    PopResult::Empty => break None,
944                    PopResult::Inconsistent => thread::yield_now(),
945                }
946            };
947
948            assert!(p.is_some());
949            recv.inner.parked_queue.push(p.unwrap());
950
951            recv.set_capacity(10);
952
953            thread::sleep(time::Duration::from_millis(100));
954            let state = decode_state(recv.inner.state.load(SeqCst));
955            assert_eq!(state.num_messages, 2);
956
957            let p = loop {
958                match unsafe { recv.inner.parked_queue.pop() } {
959                    PopResult::Data(task) => break Some(task),
960                    PopResult::Empty => break None,
961                    PopResult::Inconsistent => thread::yield_now(),
962                }
963            };
964            assert!(p.is_none());
965
966            System::current().stop();
967        });
968    }
969}