1use std::{future::Future, pin::Pin, sync::Arc, task, task::Poll, thread};
10
11use actix_rt::System;
12use crossbeam_channel as cb_channel;
13use futures_core::stream::Stream;
14use log::warn;
15use tokio::sync::oneshot::Sender as SyncSender;
16
17use crate::{
18    actor::{Actor, ActorContext, ActorState, Running},
19    address::{
20        channel, Addr, AddressReceiver, AddressSenderProducer, Envelope, EnvelopeProxy, ToEnvelope,
21    },
22    context::Context,
23    handler::{Handler, Message, MessageResponse},
24};
25
26pub struct SyncArbiter<A>
94where
95    A: Actor<Context = SyncContext<A>>,
96{
97    queue: Option<cb_channel::Sender<Envelope<A>>>,
98    msgs: AddressReceiver<A>,
99}
100
101impl<A> SyncArbiter<A>
102where
103    A: Actor<Context = SyncContext<A>>,
104{
105    pub fn start<F>(threads: usize, factory: F) -> Addr<A>
110    where
111        F: Fn() -> A + Send + Sync + 'static,
112    {
113        Self::start_with_thread_builder(threads, thread::Builder::new, factory)
114    }
115
116    pub fn start_with_thread_builder<F, BF>(
123        threads: usize,
124        mut thread_builder_factory: BF,
125        factory: F,
126    ) -> Addr<A>
127    where
128        F: Fn() -> A + Send + Sync + 'static,
129        BF: FnMut() -> thread::Builder,
130    {
131        let factory = Arc::new(factory);
132        let (sender, receiver) = cb_channel::unbounded();
133        let (tx, rx) = channel::channel(0);
134
135        for _ in 0..threads {
136            let f = Arc::clone(&factory);
137            let sys = System::current();
138            let actor_queue = receiver.clone();
139            let inner_rx = rx.sender_producer();
140
141            thread_builder_factory()
142                .spawn(move || {
143                    System::set_current(sys);
144                    SyncContext::new(f, actor_queue, inner_rx).run();
145                })
146                .expect("failed to spawn thread");
147        }
148
149        System::current().arbiter().spawn(Self {
150            queue: Some(sender),
151            msgs: rx,
152        });
153
154        Addr::new(tx)
155    }
156}
157
158impl<A> Actor for SyncArbiter<A>
159where
160    A: Actor<Context = SyncContext<A>>,
161{
162    type Context = Context<Self>;
163}
164
165#[doc(hidden)]
166impl<A> Future for SyncArbiter<A>
167where
168    A: Actor<Context = SyncContext<A>>,
169{
170    type Output = ();
171
172    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
173        let this = self.get_mut();
174        loop {
175            match Pin::new(&mut this.msgs).poll_next(cx) {
176                Poll::Ready(Some(msg)) => {
177                    if let Some(ref queue) = this.queue {
178                        assert!(queue.send(msg).is_ok());
179                    }
180                }
181                Poll::Pending => break,
182                Poll::Ready(None) => unreachable!(),
183            }
184        }
185
186        if this.msgs.connected() {
188            Poll::Pending
189        } else {
190            this.queue = None;
192            Poll::Ready(())
193        }
194    }
195}
196
197impl<A, M> ToEnvelope<A, M> for SyncContext<A>
198where
199    A: Actor<Context = Self> + Handler<M>,
200    M: Message + Send + 'static,
201    M::Result: Send,
202{
203    fn pack(msg: M, tx: Option<SyncSender<M::Result>>) -> Envelope<A> {
204        Envelope::with_proxy(Box::new(SyncContextEnvelope::new(msg, tx)))
205    }
206}
207
208pub struct SyncContext<A>
238where
239    A: Actor<Context = SyncContext<A>>,
240{
241    act: Option<A>,
242    queue: cb_channel::Receiver<Envelope<A>>,
243    stopping: bool,
244    state: ActorState,
245    factory: Arc<dyn Fn() -> A>,
246    address: AddressSenderProducer<A>,
247}
248
249impl<A> SyncContext<A>
250where
251    A: Actor<Context = Self>,
252{
253    fn new(
254        factory: Arc<dyn Fn() -> A>,
255        queue: cb_channel::Receiver<Envelope<A>>,
256        address: AddressSenderProducer<A>,
257    ) -> Self {
258        let act = factory();
259        Self {
260            queue,
261            factory,
262            act: Some(act),
263            stopping: false,
264            state: ActorState::Started,
265            address,
266        }
267    }
268
269    fn run(&mut self) {
270        let mut act = self.act.take().unwrap();
271
272        A::started(&mut act, self);
274        self.state = ActorState::Running;
275
276        loop {
277            match self.queue.recv() {
278                Ok(mut env) => {
279                    env.handle(&mut act, self);
280                }
281                Err(_) => {
282                    self.state = ActorState::Stopping;
283                    if A::stopping(&mut act, self) != Running::Stop {
284                        warn!("stopping method is not supported for sync actors");
285                    }
286                    self.state = ActorState::Stopped;
287                    A::stopped(&mut act, self);
288                    return;
289                }
290            }
291
292            if self.stopping {
293                self.stopping = false;
294
295                A::stopping(&mut act, self);
297                self.state = ActorState::Stopped;
298                A::stopped(&mut act, self);
299
300                self.state = ActorState::Started;
302                act = (*self.factory)();
303                A::started(&mut act, self);
304                self.state = ActorState::Running;
305            }
306        }
307    }
308
309    pub fn address(&self) -> Addr<A> {
310        Addr::new(self.address.sender())
311    }
312}
313
314impl<A> ActorContext for SyncContext<A>
315where
316    A: Actor<Context = Self>,
317{
318    fn stop(&mut self) {
321        self.stopping = true;
322        self.state = ActorState::Stopping;
323    }
324
325    fn terminate(&mut self) {
328        self.stopping = true;
329        self.state = ActorState::Stopping;
330    }
331
332    fn state(&self) -> ActorState {
334        self.state
335    }
336}
337
338pub(crate) struct SyncContextEnvelope<M>
339where
340    M: Message + Send,
341{
342    msg: Option<M>,
343    tx: Option<SyncSender<M::Result>>,
344}
345
346impl<M> SyncContextEnvelope<M>
347where
348    M: Message + Send,
349    M::Result: Send,
350{
351    pub fn new(msg: M, tx: Option<SyncSender<M::Result>>) -> Self {
352        Self { tx, msg: Some(msg) }
353    }
354}
355
356impl<A, M> EnvelopeProxy<A> for SyncContextEnvelope<M>
357where
358    M: Message + Send + 'static,
359    M::Result: Send,
360    A: Actor<Context = SyncContext<A>> + Handler<M>,
361{
362    fn handle(&mut self, act: &mut A, ctx: &mut A::Context) {
363        let tx = self.tx.take();
364        if tx.is_some() && tx.as_ref().unwrap().is_closed() {
365            return;
366        }
367
368        if let Some(msg) = self.msg.take() {
369            <A as Handler<M>>::handle(act, msg, ctx).handle(ctx, tx)
370        }
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use tokio::sync::oneshot;
377
378    use crate::prelude::*;
379
380    struct SyncActor2;
381
382    impl Actor for SyncActor2 {
383        type Context = SyncContext<Self>;
384    }
385
386    struct SyncActor1(Addr<SyncActor2>);
387
388    impl Actor for SyncActor1 {
389        type Context = SyncContext<Self>;
390    }
391
392    impl SyncActor1 {
393        fn run() -> SyncActor1 {
394            SyncActor1(SyncArbiter::start(1, || SyncActor2))
395        }
396    }
397
398    struct Msg(oneshot::Sender<u8>);
399
400    impl Message for Msg {
401        type Result = ();
402    }
403
404    impl Handler<Msg> for SyncActor1 {
405        type Result = ();
406
407        fn handle(&mut self, msg: Msg, _: &mut Self::Context) -> Self::Result {
408            self.0.do_send(msg);
409        }
410    }
411
412    impl Handler<Msg> for SyncActor2 {
413        type Result = ();
414
415        fn handle(&mut self, msg: Msg, _: &mut Self::Context) -> Self::Result {
416            msg.0.send(233u8).unwrap();
417        }
418    }
419
420    #[test]
421    fn nested_sync_arbiters() {
422        System::new().block_on(async {
423            let addr = SyncArbiter::start(1, SyncActor1::run);
424            let (tx, rx) = oneshot::channel();
425            addr.send(Msg(tx)).await.unwrap();
426            assert_eq!(233u8, rx.await.unwrap());
427            System::current().stop();
428        })
429    }
430}