1use std::{
4    fmt,
5    hash::{Hash, Hasher},
6    pin::Pin,
7    sync::{
8        atomic::{
9            AtomicBool, AtomicUsize,
10            Ordering::{Relaxed, SeqCst},
11        },
12        Arc, Weak,
13    },
14    task::{self, Poll},
15    thread,
16};
17
18use futures_core::{stream::Stream, task::__internal::AtomicWaker};
19use parking_lot::Mutex;
20use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver};
21
22use super::{
23    envelope::{Envelope, ToEnvelope},
24    queue::Queue,
25    SendError,
26};
27use crate::{
28    actor::Actor,
29    handler::{Handler, Message},
30};
31
32pub trait Sender<M>: Send
33where
34    M::Result: Send,
35    M: Message + Send,
36{
37    fn do_send(&self, msg: M) -> Result<(), SendError<M>>;
38
39    fn try_send(&self, msg: M) -> Result<(), SendError<M>>;
40
41    fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>;
42
43    fn boxed(&self) -> Box<dyn Sender<M> + Sync>;
44
45    fn hash(&self) -> usize;
46
47    fn connected(&self) -> bool;
48
49    fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync + 'static>;
51}
52
53impl<S, M> Sender<M> for Box<S>
54where
55    S: Sender<M> + ?Sized,
56    M::Result: Send,
57    M: Message + Send,
58{
59    fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
60        (**self).do_send(msg)
61    }
62
63    fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
64        (**self).try_send(msg)
65    }
66
67    fn send(&self, msg: M) -> Result<OneshotReceiver<<M as Message>::Result>, SendError<M>> {
68        (**self).send(msg)
69    }
70
71    fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
72        (**self).boxed()
73    }
74
75    fn hash(&self) -> usize {
76        (**self).hash()
77    }
78
79    fn connected(&self) -> bool {
80        (**self).connected()
81    }
82
83    fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync> {
84        (**self).downgrade()
85    }
86}
87
88pub trait WeakSender<M>: Send
89where
90    M::Result: Send,
91    M: Message + Send,
92{
93    fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>>;
97
98    fn boxed(&self) -> Box<dyn WeakSender<M> + Sync>;
99}
100
101pub struct AddressSender<A: Actor> {
105    inner: Arc<Inner<A>>,
107
108    sender_task: Arc<Mutex<SenderTask>>,
112
113    maybe_parked: Arc<AtomicBool>,
116}
117
118impl<A: Actor> fmt::Debug for AddressSender<A> {
119    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
120        fmt.debug_struct("AddressSender")
121            .field("sender_task", &self.sender_task)
122            .field("maybe_parked", &self.maybe_parked)
123            .finish()
124    }
125}
126
127pub struct WeakAddressSender<A: Actor> {
131    inner: Weak<Inner<A>>,
132}
133
134impl<A: Actor> Clone for WeakAddressSender<A> {
135    fn clone(&self) -> WeakAddressSender<A> {
136        WeakAddressSender {
137            inner: self.inner.clone(),
138        }
139    }
140}
141
142impl<A: Actor> fmt::Debug for WeakAddressSender<A> {
143    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
144        fmt.debug_struct("WeakAddressSender").finish()
145    }
146}
147
148impl<A: Actor> PartialEq for WeakAddressSender<A> {
149    fn eq(&self, other: &Self) -> bool {
150        self.inner.ptr_eq(&other.inner)
151    }
152}
153
154impl<A: Actor> Eq for WeakAddressSender<A> {}
155
156#[allow(dead_code)]
157trait AssertKinds: Send + Sync + Clone {}
158
159pub struct AddressReceiver<A: Actor> {
165    inner: Arc<Inner<A>>,
166}
167
168pub struct AddressSenderProducer<A: Actor> {
170    inner: Arc<Inner<A>>,
171}
172
173struct Inner<A: Actor> {
174    buffer: AtomicUsize,
176
177    state: AtomicUsize,
180
181    message_queue: Queue<Envelope<A>>,
183
184    parked_queue: Queue<Arc<Mutex<SenderTask>>>,
186
187    num_senders: AtomicUsize,
189
190    recv_task: AtomicWaker,
192}
193
194#[derive(Debug, Clone, Copy)]
196struct State {
197    is_open: bool,
199
200    num_messages: usize,
202}
203
204impl State {
205    fn is_closed(&self) -> bool {
206        !self.is_open && self.num_messages == 0
207    }
208}
209
210const OPEN_MASK: usize = usize::MAX - (usize::MAX >> 1);
212
213const INIT_STATE: usize = OPEN_MASK;
216
217const MAX_CAPACITY: usize = !(OPEN_MASK);
219
220const MAX_BUFFER: usize = MAX_CAPACITY >> 1;
223
224#[derive(Debug)]
226struct SenderTask {
227    task: Option<task::Waker>,
228    is_parked: bool,
229}
230
231impl SenderTask {
232    fn new() -> Self {
233        SenderTask {
234            task: None,
235            is_parked: false,
236        }
237    }
238
239    fn notify(&mut self) -> bool {
240        self.is_parked = false;
241
242        if let Some(task) = self.task.take() {
243            task.wake();
244            true
245        } else {
246            false
247        }
248    }
249}
250
251pub fn channel<A: Actor>(buffer: usize) -> (AddressSender<A>, AddressReceiver<A>) {
265    assert!(buffer < MAX_BUFFER, "requested buffer size too large");
268
269    let inner = Arc::new(Inner {
270        buffer: AtomicUsize::new(buffer),
271        state: AtomicUsize::new(INIT_STATE),
272        message_queue: Queue::new(),
273        parked_queue: Queue::new(),
274        num_senders: AtomicUsize::new(1),
275        recv_task: AtomicWaker::new(),
276    });
277
278    let tx = AddressSender {
279        inner: Arc::clone(&inner),
280        sender_task: Arc::new(Mutex::new(SenderTask::new())),
281        maybe_parked: Arc::new(AtomicBool::new(false)),
282    };
283
284    let rx = AddressReceiver { inner };
285
286    (tx, rx)
287}
288
289impl<A: Actor> AddressSender<A> {
295    pub fn connected(&self) -> bool {
297        let curr = self.inner.state.load(SeqCst);
298        let state = decode_state(curr);
299
300        state.is_open
301    }
302
303    pub fn send<M>(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>
307    where
308        A: Handler<M>,
309        A::Context: ToEnvelope<A, M>,
310        M::Result: Send,
311        M: Message + Send,
312    {
313        if !self.poll_unparked(false, None).is_ready() {
315            return Err(SendError::Full(msg));
316        }
317
318        let park_self = match self.inc_num_messages() {
326            Some(num_messages) => {
327                let buffer = self.inner.buffer.load(Relaxed);
329                buffer != 0 && num_messages >= buffer
330            }
331            None => return Err(SendError::Closed(msg)),
332        };
333
334        if park_self {
337            self.park();
338        }
339        let (tx, rx) = oneshot_channel();
340        let env = <A::Context as ToEnvelope<A, M>>::pack(msg, Some(tx));
341        self.queue_push_and_signal(env);
342        Ok(rx)
343    }
344
345    pub fn try_send<M>(&self, msg: M, park: bool) -> Result<(), SendError<M>>
347    where
348        A: Handler<M>,
349        <A as Actor>::Context: ToEnvelope<A, M>,
350        M::Result: Send,
351        M: Message + Send + 'static,
352    {
353        if !self.poll_unparked(false, None).is_ready() {
355            return Err(SendError::Full(msg));
356        }
357
358        let park_self = match self.inc_num_messages() {
359            Some(num_messages) => {
360                let buffer = self.inner.buffer.load(Relaxed);
362                buffer != 0 && num_messages >= buffer
363            }
364            None => return Err(SendError::Closed(msg)),
365        };
366
367        if park_self && park {
368            self.park();
369        }
370        let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
371        self.queue_push_and_signal(env);
372        Ok(())
373    }
374
375    pub fn do_send<M>(&self, msg: M) -> Result<(), SendError<M>>
379    where
380        A: Handler<M>,
381        <A as Actor>::Context: ToEnvelope<A, M>,
382        M::Result: Send,
383        M: Message + Send,
384    {
385        if self.inc_num_messages().is_none() {
386            Err(SendError::Closed(msg))
387        } else {
388            let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
392            self.queue_push_and_signal(env);
393            Ok(())
394        }
395    }
396
397    pub fn downgrade(&self) -> WeakAddressSender<A> {
399        WeakAddressSender {
400            inner: Arc::downgrade(&self.inner),
401        }
402    }
403
404    fn queue_push_and_signal(&self, msg: Envelope<A>) {
406        self.inner.message_queue.push(msg);
408
409        self.inner.recv_task.wake();
412    }
413
414    fn inc_num_messages(&self) -> Option<usize> {
417        let mut curr = self.inner.state.load(SeqCst);
418        loop {
419            let mut state = decode_state(curr);
420            if !state.is_open {
421                return None;
422            }
423            state.num_messages += 1;
424
425            let next = encode_state(&state);
426            match self
427                .inner
428                .state
429                .compare_exchange(curr, next, SeqCst, SeqCst)
430            {
431                Ok(_) => {
432                    return Some(state.num_messages);
433                }
434                Err(actual) => curr = actual,
435            }
436        }
437    }
438
439    fn park(&self) {
441        {
442            let mut sender = self.sender_task.lock();
443            sender.task = None;
444            sender.is_parked = true;
445        }
446
447        self.inner.parked_queue.push(Arc::clone(&self.sender_task));
449
450        let state = decode_state(self.inner.state.load(SeqCst));
452        self.maybe_parked.store(state.is_open, Relaxed);
453    }
454
455    fn poll_unparked(&self, do_park: bool, cx: Option<&mut task::Context<'_>>) -> Poll<()> {
456        if self.maybe_parked.load(Relaxed) {
459            let mut task = self.sender_task.lock();
461
462            if !task.is_parked {
463                self.maybe_parked.store(false, Relaxed);
464                return Poll::Ready(());
465            }
466
467            task.task = if do_park {
474                cx.map(|cx| cx.waker().clone())
475            } else {
476                None
477            };
478
479            Poll::Pending
480        } else {
481            Poll::Ready(())
482        }
483    }
484}
485
486impl<A, M> Sender<M> for AddressSender<A>
487where
488    A: Handler<M>,
489    A::Context: ToEnvelope<A, M>,
490    M::Result: Send,
491    M: Message + Send + 'static,
492{
493    fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
494        self.do_send(msg)
495    }
496    fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
497        self.try_send(msg, true)
498    }
499    fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>> {
500        self.send(msg)
501    }
502    fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
503        Box::new(self.clone())
504    }
505
506    fn hash(&self) -> usize {
507        let hash: *const _ = self.inner.as_ref();
508        hash as usize
509    }
510
511    fn connected(&self) -> bool {
512        self.connected()
513    }
514
515    fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync + 'static> {
516        Box::new(WeakAddressSender {
517            inner: Arc::downgrade(&self.inner),
518        })
519    }
520}
521
522impl<A: Actor> Clone for AddressSender<A> {
523    fn clone(&self) -> AddressSender<A> {
524        let mut curr = self.inner.num_senders.load(SeqCst);
528
529        loop {
530            if curr == self.inner.max_senders() {
532                panic!("cannot clone `Sender` -- too many outstanding senders");
533            }
534
535            debug_assert!(curr < self.inner.max_senders());
536
537            let next = curr + 1;
538            #[allow(deprecated)]
539            let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
540
541            if actual == curr {
544                return AddressSender {
545                    inner: Arc::clone(&self.inner),
546                    sender_task: Arc::new(Mutex::new(SenderTask::new())),
547                    maybe_parked: Arc::new(AtomicBool::new(false)),
548                };
549            }
550
551            curr = actual;
552        }
553    }
554}
555
556impl<A: Actor> Drop for AddressSender<A> {
557    fn drop(&mut self) {
558        let prev = self.inner.num_senders.fetch_sub(1, SeqCst);
560        if prev == 1 {
562            self.inner.recv_task.wake();
563        }
564    }
565}
566
567impl<A: Actor> PartialEq for AddressSender<A> {
568    fn eq(&self, other: &Self) -> bool {
569        Arc::ptr_eq(&self.inner, &other.inner)
570    }
571}
572
573impl<A: Actor> Eq for AddressSender<A> {}
574
575impl<A: Actor> Hash for AddressSender<A> {
576    fn hash<H: Hasher>(&self, state: &mut H) {
577        let hash: *const Inner<A> = self.inner.as_ref();
578        hash.hash(state);
579    }
580}
581
582impl<A: Actor> WeakAddressSender<A> {
588    pub fn upgrade(&self) -> Option<AddressSender<A>> {
592        Weak::upgrade(&self.inner).map(|inner| AddressSenderProducer { inner }.sender())
593    }
594}
595
596impl<A, M> WeakSender<M> for WeakAddressSender<A>
597where
598    A: Handler<M>,
599    A::Context: ToEnvelope<A, M>,
600    M::Result: Send,
601    M: Message + Send + 'static,
602{
603    fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>> {
604        if let Some(inner) = WeakAddressSender::upgrade(self) {
605            Some(Box::new(inner))
606        } else {
607            None
608        }
609    }
610
611    fn boxed(&self) -> Box<dyn WeakSender<M> + Sync> {
612        Box::new(self.clone())
613    }
614}
615
616impl<A: Actor> AddressSenderProducer<A> {
622    pub fn connected(&self) -> bool {
624        self.inner.num_senders.load(SeqCst) != 0
625    }
626
627    pub fn capacity(&self) -> usize {
629        self.inner.buffer.load(Relaxed)
630    }
631
632    pub fn set_capacity(&mut self, cap: usize) {
637        let buffer = self.inner.buffer.load(Relaxed);
638        self.inner.buffer.store(cap, Relaxed);
639
640        if cap > buffer {
642            while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
643                task.lock().notify();
644            }
645        }
646    }
647
648    pub fn sender(&self) -> AddressSender<A> {
650        let mut curr = self.inner.num_senders.load(SeqCst);
652
653        loop {
654            if curr == self.inner.max_senders() {
656                panic!("cannot clone `Sender` -- too many outstanding senders");
657            }
658
659            let next = curr + 1;
660            #[allow(deprecated)]
661            let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
662
663            if actual == curr {
666                return AddressSender {
667                    inner: Arc::clone(&self.inner),
668                    sender_task: Arc::new(Mutex::new(SenderTask::new())),
669                    maybe_parked: Arc::new(AtomicBool::new(false)),
670                };
671            }
672
673            curr = actual;
674        }
675    }
676}
677
678impl<A: Actor> AddressReceiver<A> {
684    pub fn connected(&self) -> bool {
686        self.inner.num_senders.load(SeqCst) != 0
687    }
688
689    pub fn capacity(&self) -> usize {
691        self.inner.buffer.load(Relaxed)
692    }
693
694    pub fn set_capacity(&mut self, cap: usize) {
699        let buffer = self.inner.buffer.load(Relaxed);
700        self.inner.buffer.store(cap, Relaxed);
701
702        if cap > buffer {
704            while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
705                task.lock().notify();
706            }
707        }
708    }
709
710    pub fn sender(&self) -> AddressSender<A> {
712        let mut curr = self.inner.num_senders.load(SeqCst);
714
715        loop {
716            if curr == self.inner.max_senders() {
718                panic!("cannot clone `Sender` -- too many outstanding senders");
719            }
720
721            let next = curr + 1;
722            #[allow(deprecated)]
723            let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
724
725            if actual == curr {
728                return AddressSender {
729                    inner: Arc::clone(&self.inner),
730                    sender_task: Arc::new(Mutex::new(SenderTask::new())),
731                    maybe_parked: Arc::new(AtomicBool::new(false)),
732                };
733            }
734
735            curr = actual;
736        }
737    }
738
739    pub fn sender_producer(&self) -> AddressSenderProducer<A> {
741        AddressSenderProducer {
742            inner: self.inner.clone(),
743        }
744    }
745
746    fn next_message(&mut self) -> Poll<Option<Envelope<A>>> {
747        match unsafe { self.inner.message_queue.pop_spin() } {
749            Some(msg) => {
750                self.unpark_one();
753
754                self.dec_num_messages();
756
757                Poll::Ready(Some(msg))
758            }
759            None => {
760                let state = decode_state(self.inner.state.load(SeqCst));
761                if state.is_closed() {
762                    Poll::Ready(None)
765                } else {
766                    Poll::Pending
774                }
775            }
776        }
777    }
778
779    fn unpark_one(&mut self) {
781        if let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
782            task.lock().notify();
783        }
784    }
785
786    fn dec_num_messages(&self) {
787        self.inner.state.fetch_sub(1, SeqCst);
791    }
792}
793
794impl<A: Actor> Stream for AddressReceiver<A> {
795    type Item = Envelope<A>;
796
797    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
798        let this = self.get_mut();
799        match this.next_message() {
800            Poll::Ready(msg) => Poll::Ready(msg),
801            Poll::Pending => {
802                this.inner.recv_task.register(cx.waker());
804                this.next_message()
808            }
809        }
810    }
811}
812
813impl<A: Actor> Drop for AddressReceiver<A> {
814    fn drop(&mut self) {
815        self.inner.set_closed();
817
818        while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
821            task.lock().notify();
822        }
823
824        loop {
826            match self.next_message() {
827                Poll::Ready(Some(_)) => {}
828                Poll::Ready(None) => break,
829                Poll::Pending => {
830                    let state = decode_state(self.inner.state.load(SeqCst));
831
832                    if state.is_closed() {
834                        break;
835                    }
836
837                    thread::yield_now();
843                }
844            }
845        }
846    }
847}
848
849impl<A: Actor> Inner<A> {
855    fn max_senders(&self) -> usize {
858        MAX_CAPACITY - self.buffer.load(Relaxed)
859    }
860
861    fn set_closed(&self) {
863        let curr = self.state.load(SeqCst);
864        if !decode_state(curr).is_open {
865            return;
866        }
867
868        self.state.fetch_and(!OPEN_MASK, SeqCst);
869    }
870}
871
872unsafe impl<A: Actor> Send for Inner<A> {}
873unsafe impl<A: Actor> Sync for Inner<A> {}
874
875fn decode_state(num: usize) -> State {
881    State {
882        is_open: num & OPEN_MASK == OPEN_MASK,
883        num_messages: num & MAX_CAPACITY,
884    }
885}
886
887fn encode_state(state: &State) -> usize {
888    let mut num = state.num_messages;
889
890    if state.is_open {
891        num |= OPEN_MASK;
892    }
893
894    num
895}
896
897#[cfg(test)]
898mod tests {
899    use std::time;
900
901    use super::*;
902    use crate::{address::queue::PopResult, prelude::*};
903
904    struct Act;
905    impl Actor for Act {
906        type Context = Context<Act>;
907    }
908
909    struct Ping;
910    impl Message for Ping {
911        type Result = ();
912    }
913
914    impl Handler<Ping> for Act {
915        type Result = ();
916        fn handle(&mut self, _: Ping, _: &mut Context<Act>) {}
917    }
918
919    #[test]
920    fn test_cap() {
921        System::new().block_on(async {
922            let (s1, mut recv) = channel::<Act>(1);
923            let s2 = recv.sender();
924
925            let arb = Arbiter::new();
926            arb.spawn_fn(move || {
927                let _ = s1.send(Ping);
928            });
929            thread::sleep(time::Duration::from_millis(100));
930            let arb2 = Arbiter::new();
931            arb2.spawn_fn(move || {
932                let _ = s2.send(Ping);
933                let _ = s2.send(Ping);
934            });
935
936            thread::sleep(time::Duration::from_millis(100));
937            let state = decode_state(recv.inner.state.load(SeqCst));
938            assert_eq!(state.num_messages, 2);
939
940            let p = loop {
941                match unsafe { recv.inner.parked_queue.pop() } {
942                    PopResult::Data(task) => break Some(task),
943                    PopResult::Empty => break None,
944                    PopResult::Inconsistent => thread::yield_now(),
945                }
946            };
947
948            assert!(p.is_some());
949            recv.inner.parked_queue.push(p.unwrap());
950
951            recv.set_capacity(10);
952
953            thread::sleep(time::Duration::from_millis(100));
954            let state = decode_state(recv.inner.state.load(SeqCst));
955            assert_eq!(state.num_messages, 2);
956
957            let p = loop {
958                match unsafe { recv.inner.parked_queue.pop() } {
959                    PopResult::Data(task) => break Some(task),
960                    PopResult::Empty => break None,
961                    PopResult::Inconsistent => thread::yield_now(),
962                }
963            };
964            assert!(p.is_none());
965
966            System::current().stop();
967        });
968    }
969}