1use std::{
4 fmt,
5 hash::{Hash, Hasher},
6 pin::Pin,
7 sync::{
8 atomic::{
9 AtomicBool, AtomicUsize,
10 Ordering::{Relaxed, SeqCst},
11 },
12 Arc, Weak,
13 },
14 task::{self, Poll},
15 thread,
16};
17
18use futures_core::{stream::Stream, task::__internal::AtomicWaker};
19use parking_lot::Mutex;
20use tokio::sync::oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver};
21
22use super::{
23 envelope::{Envelope, ToEnvelope},
24 queue::Queue,
25 SendError,
26};
27use crate::{
28 actor::Actor,
29 handler::{Handler, Message},
30};
31
32pub trait Sender<M>: Send
33where
34 M::Result: Send,
35 M: Message + Send,
36{
37 fn do_send(&self, msg: M) -> Result<(), SendError<M>>;
38
39 fn try_send(&self, msg: M) -> Result<(), SendError<M>>;
40
41 fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>;
42
43 fn boxed(&self) -> Box<dyn Sender<M> + Sync>;
44
45 fn hash(&self) -> usize;
46
47 fn connected(&self) -> bool;
48
49 fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync + 'static>;
51}
52
53impl<S, M> Sender<M> for Box<S>
54where
55 S: Sender<M> + ?Sized,
56 M::Result: Send,
57 M: Message + Send,
58{
59 fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
60 (**self).do_send(msg)
61 }
62
63 fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
64 (**self).try_send(msg)
65 }
66
67 fn send(&self, msg: M) -> Result<OneshotReceiver<<M as Message>::Result>, SendError<M>> {
68 (**self).send(msg)
69 }
70
71 fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
72 (**self).boxed()
73 }
74
75 fn hash(&self) -> usize {
76 (**self).hash()
77 }
78
79 fn connected(&self) -> bool {
80 (**self).connected()
81 }
82
83 fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync> {
84 (**self).downgrade()
85 }
86}
87
88pub trait WeakSender<M>: Send
89where
90 M::Result: Send,
91 M: Message + Send,
92{
93 fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>>;
97
98 fn boxed(&self) -> Box<dyn WeakSender<M> + Sync>;
99}
100
101pub struct AddressSender<A: Actor> {
105 inner: Arc<Inner<A>>,
107
108 sender_task: Arc<Mutex<SenderTask>>,
112
113 maybe_parked: Arc<AtomicBool>,
116}
117
118impl<A: Actor> fmt::Debug for AddressSender<A> {
119 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
120 fmt.debug_struct("AddressSender")
121 .field("sender_task", &self.sender_task)
122 .field("maybe_parked", &self.maybe_parked)
123 .finish()
124 }
125}
126
127pub struct WeakAddressSender<A: Actor> {
131 inner: Weak<Inner<A>>,
132}
133
134impl<A: Actor> Clone for WeakAddressSender<A> {
135 fn clone(&self) -> WeakAddressSender<A> {
136 WeakAddressSender {
137 inner: self.inner.clone(),
138 }
139 }
140}
141
142impl<A: Actor> fmt::Debug for WeakAddressSender<A> {
143 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
144 fmt.debug_struct("WeakAddressSender").finish()
145 }
146}
147
148impl<A: Actor> PartialEq for WeakAddressSender<A> {
149 fn eq(&self, other: &Self) -> bool {
150 self.inner.ptr_eq(&other.inner)
151 }
152}
153
154impl<A: Actor> Eq for WeakAddressSender<A> {}
155
156#[allow(dead_code)]
157trait AssertKinds: Send + Sync + Clone {}
158
159pub struct AddressReceiver<A: Actor> {
165 inner: Arc<Inner<A>>,
166}
167
168pub struct AddressSenderProducer<A: Actor> {
170 inner: Arc<Inner<A>>,
171}
172
173struct Inner<A: Actor> {
174 buffer: AtomicUsize,
176
177 state: AtomicUsize,
180
181 message_queue: Queue<Envelope<A>>,
183
184 parked_queue: Queue<Arc<Mutex<SenderTask>>>,
186
187 num_senders: AtomicUsize,
189
190 recv_task: AtomicWaker,
192}
193
194#[derive(Debug, Clone, Copy)]
196struct State {
197 is_open: bool,
199
200 num_messages: usize,
202}
203
204impl State {
205 fn is_closed(&self) -> bool {
206 !self.is_open && self.num_messages == 0
207 }
208}
209
210const OPEN_MASK: usize = usize::MAX - (usize::MAX >> 1);
212
213const INIT_STATE: usize = OPEN_MASK;
216
217const MAX_CAPACITY: usize = !(OPEN_MASK);
219
220const MAX_BUFFER: usize = MAX_CAPACITY >> 1;
223
224#[derive(Debug)]
226struct SenderTask {
227 task: Option<task::Waker>,
228 is_parked: bool,
229}
230
231impl SenderTask {
232 fn new() -> Self {
233 SenderTask {
234 task: None,
235 is_parked: false,
236 }
237 }
238
239 fn notify(&mut self) -> bool {
240 self.is_parked = false;
241
242 if let Some(task) = self.task.take() {
243 task.wake();
244 true
245 } else {
246 false
247 }
248 }
249}
250
251pub fn channel<A: Actor>(buffer: usize) -> (AddressSender<A>, AddressReceiver<A>) {
265 assert!(buffer < MAX_BUFFER, "requested buffer size too large");
268
269 let inner = Arc::new(Inner {
270 buffer: AtomicUsize::new(buffer),
271 state: AtomicUsize::new(INIT_STATE),
272 message_queue: Queue::new(),
273 parked_queue: Queue::new(),
274 num_senders: AtomicUsize::new(1),
275 recv_task: AtomicWaker::new(),
276 });
277
278 let tx = AddressSender {
279 inner: Arc::clone(&inner),
280 sender_task: Arc::new(Mutex::new(SenderTask::new())),
281 maybe_parked: Arc::new(AtomicBool::new(false)),
282 };
283
284 let rx = AddressReceiver { inner };
285
286 (tx, rx)
287}
288
289impl<A: Actor> AddressSender<A> {
295 pub fn connected(&self) -> bool {
297 let curr = self.inner.state.load(SeqCst);
298 let state = decode_state(curr);
299
300 state.is_open
301 }
302
303 pub fn send<M>(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>>
307 where
308 A: Handler<M>,
309 A::Context: ToEnvelope<A, M>,
310 M::Result: Send,
311 M: Message + Send,
312 {
313 if !self.poll_unparked(false, None).is_ready() {
315 return Err(SendError::Full(msg));
316 }
317
318 let park_self = match self.inc_num_messages() {
326 Some(num_messages) => {
327 let buffer = self.inner.buffer.load(Relaxed);
329 buffer != 0 && num_messages >= buffer
330 }
331 None => return Err(SendError::Closed(msg)),
332 };
333
334 if park_self {
337 self.park();
338 }
339 let (tx, rx) = oneshot_channel();
340 let env = <A::Context as ToEnvelope<A, M>>::pack(msg, Some(tx));
341 self.queue_push_and_signal(env);
342 Ok(rx)
343 }
344
345 pub fn try_send<M>(&self, msg: M, park: bool) -> Result<(), SendError<M>>
347 where
348 A: Handler<M>,
349 <A as Actor>::Context: ToEnvelope<A, M>,
350 M::Result: Send,
351 M: Message + Send + 'static,
352 {
353 if !self.poll_unparked(false, None).is_ready() {
355 return Err(SendError::Full(msg));
356 }
357
358 let park_self = match self.inc_num_messages() {
359 Some(num_messages) => {
360 let buffer = self.inner.buffer.load(Relaxed);
362 buffer != 0 && num_messages >= buffer
363 }
364 None => return Err(SendError::Closed(msg)),
365 };
366
367 if park_self && park {
368 self.park();
369 }
370 let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
371 self.queue_push_and_signal(env);
372 Ok(())
373 }
374
375 pub fn do_send<M>(&self, msg: M) -> Result<(), SendError<M>>
379 where
380 A: Handler<M>,
381 <A as Actor>::Context: ToEnvelope<A, M>,
382 M::Result: Send,
383 M: Message + Send,
384 {
385 if self.inc_num_messages().is_none() {
386 Err(SendError::Closed(msg))
387 } else {
388 let env = <A::Context as ToEnvelope<A, M>>::pack(msg, None);
392 self.queue_push_and_signal(env);
393 Ok(())
394 }
395 }
396
397 pub fn downgrade(&self) -> WeakAddressSender<A> {
399 WeakAddressSender {
400 inner: Arc::downgrade(&self.inner),
401 }
402 }
403
404 fn queue_push_and_signal(&self, msg: Envelope<A>) {
406 self.inner.message_queue.push(msg);
408
409 self.inner.recv_task.wake();
412 }
413
414 fn inc_num_messages(&self) -> Option<usize> {
417 let mut curr = self.inner.state.load(SeqCst);
418 loop {
419 let mut state = decode_state(curr);
420 if !state.is_open {
421 return None;
422 }
423 state.num_messages += 1;
424
425 let next = encode_state(&state);
426 match self
427 .inner
428 .state
429 .compare_exchange(curr, next, SeqCst, SeqCst)
430 {
431 Ok(_) => {
432 return Some(state.num_messages);
433 }
434 Err(actual) => curr = actual,
435 }
436 }
437 }
438
439 fn park(&self) {
441 {
442 let mut sender = self.sender_task.lock();
443 sender.task = None;
444 sender.is_parked = true;
445 }
446
447 self.inner.parked_queue.push(Arc::clone(&self.sender_task));
449
450 let state = decode_state(self.inner.state.load(SeqCst));
452 self.maybe_parked.store(state.is_open, Relaxed);
453 }
454
455 fn poll_unparked(&self, do_park: bool, cx: Option<&mut task::Context<'_>>) -> Poll<()> {
456 if self.maybe_parked.load(Relaxed) {
459 let mut task = self.sender_task.lock();
461
462 if !task.is_parked {
463 self.maybe_parked.store(false, Relaxed);
464 return Poll::Ready(());
465 }
466
467 task.task = if do_park {
474 cx.map(|cx| cx.waker().clone())
475 } else {
476 None
477 };
478
479 Poll::Pending
480 } else {
481 Poll::Ready(())
482 }
483 }
484}
485
486impl<A, M> Sender<M> for AddressSender<A>
487where
488 A: Handler<M>,
489 A::Context: ToEnvelope<A, M>,
490 M::Result: Send,
491 M: Message + Send + 'static,
492{
493 fn do_send(&self, msg: M) -> Result<(), SendError<M>> {
494 self.do_send(msg)
495 }
496 fn try_send(&self, msg: M) -> Result<(), SendError<M>> {
497 self.try_send(msg, true)
498 }
499 fn send(&self, msg: M) -> Result<OneshotReceiver<M::Result>, SendError<M>> {
500 self.send(msg)
501 }
502 fn boxed(&self) -> Box<dyn Sender<M> + Sync> {
503 Box::new(self.clone())
504 }
505
506 fn hash(&self) -> usize {
507 let hash: *const _ = self.inner.as_ref();
508 hash as usize
509 }
510
511 fn connected(&self) -> bool {
512 self.connected()
513 }
514
515 fn downgrade(&self) -> Box<dyn WeakSender<M> + Sync + 'static> {
516 Box::new(WeakAddressSender {
517 inner: Arc::downgrade(&self.inner),
518 })
519 }
520}
521
522impl<A: Actor> Clone for AddressSender<A> {
523 fn clone(&self) -> AddressSender<A> {
524 let mut curr = self.inner.num_senders.load(SeqCst);
528
529 loop {
530 if curr == self.inner.max_senders() {
532 panic!("cannot clone `Sender` -- too many outstanding senders");
533 }
534
535 debug_assert!(curr < self.inner.max_senders());
536
537 let next = curr + 1;
538 #[allow(deprecated)]
539 let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
540
541 if actual == curr {
544 return AddressSender {
545 inner: Arc::clone(&self.inner),
546 sender_task: Arc::new(Mutex::new(SenderTask::new())),
547 maybe_parked: Arc::new(AtomicBool::new(false)),
548 };
549 }
550
551 curr = actual;
552 }
553 }
554}
555
556impl<A: Actor> Drop for AddressSender<A> {
557 fn drop(&mut self) {
558 let prev = self.inner.num_senders.fetch_sub(1, SeqCst);
560 if prev == 1 {
562 self.inner.recv_task.wake();
563 }
564 }
565}
566
567impl<A: Actor> PartialEq for AddressSender<A> {
568 fn eq(&self, other: &Self) -> bool {
569 Arc::ptr_eq(&self.inner, &other.inner)
570 }
571}
572
573impl<A: Actor> Eq for AddressSender<A> {}
574
575impl<A: Actor> Hash for AddressSender<A> {
576 fn hash<H: Hasher>(&self, state: &mut H) {
577 let hash: *const Inner<A> = self.inner.as_ref();
578 hash.hash(state);
579 }
580}
581
582impl<A: Actor> WeakAddressSender<A> {
588 pub fn upgrade(&self) -> Option<AddressSender<A>> {
592 Weak::upgrade(&self.inner).map(|inner| AddressSenderProducer { inner }.sender())
593 }
594}
595
596impl<A, M> WeakSender<M> for WeakAddressSender<A>
597where
598 A: Handler<M>,
599 A::Context: ToEnvelope<A, M>,
600 M::Result: Send,
601 M: Message + Send + 'static,
602{
603 fn upgrade(&self) -> Option<Box<dyn Sender<M> + Sync>> {
604 if let Some(inner) = WeakAddressSender::upgrade(self) {
605 Some(Box::new(inner))
606 } else {
607 None
608 }
609 }
610
611 fn boxed(&self) -> Box<dyn WeakSender<M> + Sync> {
612 Box::new(self.clone())
613 }
614}
615
616impl<A: Actor> AddressSenderProducer<A> {
622 pub fn connected(&self) -> bool {
624 self.inner.num_senders.load(SeqCst) != 0
625 }
626
627 pub fn capacity(&self) -> usize {
629 self.inner.buffer.load(Relaxed)
630 }
631
632 pub fn set_capacity(&mut self, cap: usize) {
637 let buffer = self.inner.buffer.load(Relaxed);
638 self.inner.buffer.store(cap, Relaxed);
639
640 if cap > buffer {
642 while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
643 task.lock().notify();
644 }
645 }
646 }
647
648 pub fn sender(&self) -> AddressSender<A> {
650 let mut curr = self.inner.num_senders.load(SeqCst);
652
653 loop {
654 if curr == self.inner.max_senders() {
656 panic!("cannot clone `Sender` -- too many outstanding senders");
657 }
658
659 let next = curr + 1;
660 #[allow(deprecated)]
661 let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
662
663 if actual == curr {
666 return AddressSender {
667 inner: Arc::clone(&self.inner),
668 sender_task: Arc::new(Mutex::new(SenderTask::new())),
669 maybe_parked: Arc::new(AtomicBool::new(false)),
670 };
671 }
672
673 curr = actual;
674 }
675 }
676}
677
678impl<A: Actor> AddressReceiver<A> {
684 pub fn connected(&self) -> bool {
686 self.inner.num_senders.load(SeqCst) != 0
687 }
688
689 pub fn capacity(&self) -> usize {
691 self.inner.buffer.load(Relaxed)
692 }
693
694 pub fn set_capacity(&mut self, cap: usize) {
699 let buffer = self.inner.buffer.load(Relaxed);
700 self.inner.buffer.store(cap, Relaxed);
701
702 if cap > buffer {
704 while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
705 task.lock().notify();
706 }
707 }
708 }
709
710 pub fn sender(&self) -> AddressSender<A> {
712 let mut curr = self.inner.num_senders.load(SeqCst);
714
715 loop {
716 if curr == self.inner.max_senders() {
718 panic!("cannot clone `Sender` -- too many outstanding senders");
719 }
720
721 let next = curr + 1;
722 #[allow(deprecated)]
723 let actual = self.inner.num_senders.compare_and_swap(curr, next, SeqCst);
724
725 if actual == curr {
728 return AddressSender {
729 inner: Arc::clone(&self.inner),
730 sender_task: Arc::new(Mutex::new(SenderTask::new())),
731 maybe_parked: Arc::new(AtomicBool::new(false)),
732 };
733 }
734
735 curr = actual;
736 }
737 }
738
739 pub fn sender_producer(&self) -> AddressSenderProducer<A> {
741 AddressSenderProducer {
742 inner: self.inner.clone(),
743 }
744 }
745
746 fn next_message(&mut self) -> Poll<Option<Envelope<A>>> {
747 match unsafe { self.inner.message_queue.pop_spin() } {
749 Some(msg) => {
750 self.unpark_one();
753
754 self.dec_num_messages();
756
757 Poll::Ready(Some(msg))
758 }
759 None => {
760 let state = decode_state(self.inner.state.load(SeqCst));
761 if state.is_closed() {
762 Poll::Ready(None)
765 } else {
766 Poll::Pending
774 }
775 }
776 }
777 }
778
779 fn unpark_one(&mut self) {
781 if let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
782 task.lock().notify();
783 }
784 }
785
786 fn dec_num_messages(&self) {
787 self.inner.state.fetch_sub(1, SeqCst);
791 }
792}
793
794impl<A: Actor> Stream for AddressReceiver<A> {
795 type Item = Envelope<A>;
796
797 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
798 let this = self.get_mut();
799 match this.next_message() {
800 Poll::Ready(msg) => Poll::Ready(msg),
801 Poll::Pending => {
802 this.inner.recv_task.register(cx.waker());
804 this.next_message()
808 }
809 }
810 }
811}
812
813impl<A: Actor> Drop for AddressReceiver<A> {
814 fn drop(&mut self) {
815 self.inner.set_closed();
817
818 while let Some(task) = unsafe { self.inner.parked_queue.pop_spin() } {
821 task.lock().notify();
822 }
823
824 loop {
826 match self.next_message() {
827 Poll::Ready(Some(_)) => {}
828 Poll::Ready(None) => break,
829 Poll::Pending => {
830 let state = decode_state(self.inner.state.load(SeqCst));
831
832 if state.is_closed() {
834 break;
835 }
836
837 thread::yield_now();
843 }
844 }
845 }
846 }
847}
848
849impl<A: Actor> Inner<A> {
855 fn max_senders(&self) -> usize {
858 MAX_CAPACITY - self.buffer.load(Relaxed)
859 }
860
861 fn set_closed(&self) {
863 let curr = self.state.load(SeqCst);
864 if !decode_state(curr).is_open {
865 return;
866 }
867
868 self.state.fetch_and(!OPEN_MASK, SeqCst);
869 }
870}
871
872unsafe impl<A: Actor> Send for Inner<A> {}
873unsafe impl<A: Actor> Sync for Inner<A> {}
874
875fn decode_state(num: usize) -> State {
881 State {
882 is_open: num & OPEN_MASK == OPEN_MASK,
883 num_messages: num & MAX_CAPACITY,
884 }
885}
886
887fn encode_state(state: &State) -> usize {
888 let mut num = state.num_messages;
889
890 if state.is_open {
891 num |= OPEN_MASK;
892 }
893
894 num
895}
896
897#[cfg(test)]
898mod tests {
899 use std::time;
900
901 use super::*;
902 use crate::{address::queue::PopResult, prelude::*};
903
904 struct Act;
905 impl Actor for Act {
906 type Context = Context<Act>;
907 }
908
909 struct Ping;
910 impl Message for Ping {
911 type Result = ();
912 }
913
914 impl Handler<Ping> for Act {
915 type Result = ();
916 fn handle(&mut self, _: Ping, _: &mut Context<Act>) {}
917 }
918
919 #[test]
920 fn test_cap() {
921 System::new().block_on(async {
922 let (s1, mut recv) = channel::<Act>(1);
923 let s2 = recv.sender();
924
925 let arb = Arbiter::new();
926 arb.spawn_fn(move || {
927 let _ = s1.send(Ping);
928 });
929 thread::sleep(time::Duration::from_millis(100));
930 let arb2 = Arbiter::new();
931 arb2.spawn_fn(move || {
932 let _ = s2.send(Ping);
933 let _ = s2.send(Ping);
934 });
935
936 thread::sleep(time::Duration::from_millis(100));
937 let state = decode_state(recv.inner.state.load(SeqCst));
938 assert_eq!(state.num_messages, 2);
939
940 let p = loop {
941 match unsafe { recv.inner.parked_queue.pop() } {
942 PopResult::Data(task) => break Some(task),
943 PopResult::Empty => break None,
944 PopResult::Inconsistent => thread::yield_now(),
945 }
946 };
947
948 assert!(p.is_some());
949 recv.inner.parked_queue.push(p.unwrap());
950
951 recv.set_capacity(10);
952
953 thread::sleep(time::Duration::from_millis(100));
954 let state = decode_state(recv.inner.state.load(SeqCst));
955 assert_eq!(state.num_messages, 2);
956
957 let p = loop {
958 match unsafe { recv.inner.parked_queue.pop() } {
959 PopResult::Data(task) => break Some(task),
960 PopResult::Empty => break None,
961 PopResult::Inconsistent => thread::yield_now(),
962 }
963 };
964 assert!(p.is_none());
965
966 System::current().stop();
967 });
968 }
969}