actix_web_actors/
ws.rs

1//! Websocket integration.
2//!
3//! # Examples
4//!
5//! ```no_run
6//! use actix::{Actor, StreamHandler};
7//! use actix_web::{get, web, App, Error, HttpRequest, HttpResponse, HttpServer};
8//! use actix_web_actors::ws;
9//!
10//! /// Define Websocket actor
11//! struct MyWs;
12//!
13//! impl Actor for MyWs {
14//!     type Context = ws::WebsocketContext<Self>;
15//! }
16//!
17//! /// Handler for ws::Message message
18//! impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for MyWs {
19//!     fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
20//!         match msg {
21//!             Ok(ws::Message::Ping(msg)) => ctx.pong(&msg),
22//!             Ok(ws::Message::Text(text)) => ctx.text(text),
23//!             Ok(ws::Message::Binary(bin)) => ctx.binary(bin),
24//!             _ => (),
25//!         }
26//!     }
27//! }
28//!
29//! #[get("/ws")]
30//! async fn websocket(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
31//!     ws::start(MyWs, &req, stream)
32//! }
33//!
34//! const MAX_FRAME_SIZE: usize = 16_384; // 16KiB
35//!
36//! #[get("/custom-ws")]
37//! async fn custom_websocket(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
38//!     // Create a Websocket session with a specific max frame size, and protocols.
39//!     ws::WsResponseBuilder::new(MyWs, &req, stream)
40//!         .frame_size(MAX_FRAME_SIZE)
41//!         .protocols(&["A", "B"])
42//!         .start()
43//! }
44//!
45//! #[actix_web::main]
46//! async fn main() -> std::io::Result<()> {
47//!     HttpServer::new(|| {
48//!             App::new()
49//!                 .service(websocket)
50//!                 .service(custom_websocket)
51//!         })
52//!         .bind(("127.0.0.1", 8080))?
53//!         .run()
54//!         .await
55//! }
56//! ```
57//!
58
59use std::{
60    collections::VecDeque,
61    future::Future,
62    io, mem,
63    pin::Pin,
64    task::{Context, Poll},
65};
66
67use actix::{
68    dev::{
69        AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, ToEnvelope,
70    },
71    fut::ActorFuture,
72    Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message as ActixMessage,
73    SpawnHandle,
74};
75use actix_http::ws::{hash_key, Codec};
76pub use actix_http::ws::{CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError};
77use actix_web::{
78    error::{Error, PayloadError},
79    http::{
80        header::{self, HeaderValue},
81        Method, StatusCode,
82    },
83    HttpRequest, HttpResponse, HttpResponseBuilder,
84};
85use bytes::{Bytes, BytesMut};
86use bytestring::ByteString;
87use futures_core::Stream;
88use pin_project_lite::pin_project;
89use tokio::sync::oneshot;
90use tokio_util::codec::{Decoder as _, Encoder as _};
91
92/// Builder for Websocket session response.
93///
94/// # Examples
95///
96/// ```no_run
97/// # use actix::{Actor, StreamHandler};
98/// # use actix_web::{get, web, App, Error, HttpRequest, HttpResponse, HttpServer};
99/// # use actix_web_actors::ws;
100/// #
101/// # struct MyWs;
102/// #
103/// # impl Actor for MyWs {
104/// #     type Context = ws::WebsocketContext<Self>;
105/// # }
106/// #
107/// # /// Handler for ws::Message message
108/// # impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for MyWs {
109/// #     fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {}
110/// # }
111/// #
112/// #[get("/ws")]
113/// async fn websocket(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
114///     ws::WsResponseBuilder::new(MyWs, &req, stream).start()
115/// }
116///
117/// const MAX_FRAME_SIZE: usize = 16_384; // 16KiB
118///
119/// #[get("/custom-ws")]
120/// async fn custom_websocket(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
121///     // Create a Websocket session with a specific max frame size, codec, and protocols.
122///     ws::WsResponseBuilder::new(MyWs, &req, stream)
123///         .codec(actix_http::ws::Codec::new())
124///         // This will overwrite the codec's max frame-size
125///         .frame_size(MAX_FRAME_SIZE)
126///         .protocols(&["A", "B"])
127///         .start()
128/// }
129/// #
130/// # #[actix_web::main]
131/// # async fn main() -> std::io::Result<()> {
132/// #     HttpServer::new(|| {
133/// #             App::new()
134/// #                 .service(websocket)
135/// #                 .service(custom_websocket)
136/// #         })
137/// #         .bind(("127.0.0.1", 8080))?
138/// #         .run()
139/// #         .await
140/// # }
141/// ```
142pub struct WsResponseBuilder<'a, A, T>
143where
144    A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
145    T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
146{
147    actor: A,
148    req: &'a HttpRequest,
149    stream: T,
150    codec: Option<Codec>,
151    protocols: Option<&'a [&'a str]>,
152    frame_size: Option<usize>,
153}
154
155impl<'a, A, T> WsResponseBuilder<'a, A, T>
156where
157    A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
158    T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
159{
160    /// Construct a new `WsResponseBuilder` with actor, request, and payload stream.
161    ///
162    /// For usage example, see docs on [`WsResponseBuilder`] struct.
163    pub fn new(actor: A, req: &'a HttpRequest, stream: T) -> Self {
164        WsResponseBuilder {
165            actor,
166            req,
167            stream,
168            codec: None,
169            protocols: None,
170            frame_size: None,
171        }
172    }
173
174    /// Set the protocols for the session.
175    pub fn protocols(mut self, protocols: &'a [&'a str]) -> Self {
176        self.protocols = Some(protocols);
177        self
178    }
179
180    /// Set the max frame size for each message (in bytes).
181    ///
182    /// **Note**: This will override any given [`Codec`]'s max frame size.
183    pub fn frame_size(mut self, frame_size: usize) -> Self {
184        self.frame_size = Some(frame_size);
185        self
186    }
187
188    /// Set the [`Codec`] for the session. If [`Self::frame_size`] is also set, the given
189    /// [`Codec`]'s max frame size will be overridden.
190    pub fn codec(mut self, codec: Codec) -> Self {
191        self.codec = Some(codec);
192        self
193    }
194
195    fn handshake_resp(&self) -> Result<HttpResponseBuilder, HandshakeError> {
196        match self.protocols {
197            Some(protocols) => handshake_with_protocols(self.req, protocols),
198            None => handshake(self.req),
199        }
200    }
201
202    fn set_frame_size(&mut self) {
203        if let Some(frame_size) = self.frame_size {
204            match &mut self.codec {
205                Some(codec) => {
206                    // modify existing codec's max frame size
207                    let orig_codec = mem::take(codec);
208                    *codec = orig_codec.max_size(frame_size);
209                }
210
211                None => {
212                    // create a new codec with the given size
213                    self.codec = Some(Codec::new().max_size(frame_size));
214                }
215            }
216        }
217    }
218
219    /// Create a new Websocket context from an actor, request stream, and codec.
220    ///
221    /// Returns a pair, where the first item is an addr for the created actor, and the second item
222    /// is a stream intended to be set as part of the response
223    /// via [`HttpResponseBuilder::streaming()`].
224    fn create_with_codec_addr<S>(
225        actor: A,
226        stream: S,
227        codec: Codec,
228    ) -> (Addr<A>, impl Stream<Item = Result<Bytes, Error>>)
229    where
230        A: StreamHandler<Result<Message, ProtocolError>>,
231        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
232    {
233        let mb = Mailbox::default();
234        let mut ctx = WebsocketContext {
235            inner: ContextParts::new(mb.sender_producer()),
236            messages: VecDeque::new(),
237        };
238        ctx.add_stream(WsStream::new(stream, codec.clone()));
239
240        let addr = ctx.address();
241
242        (addr, WebsocketContextFut::new(ctx, actor, mb, codec))
243    }
244
245    /// Perform WebSocket handshake and start actor.
246    ///
247    /// `req` is an [`HttpRequest`] that should be requesting a websocket protocol change.
248    /// `stream` should be a [`Bytes`] stream (such as `actix_web::web::Payload`) that contains a
249    /// stream of the body request.
250    ///
251    /// If there is a problem with the handshake, an error is returned.
252    ///
253    /// If successful, consume the [`WsResponseBuilder`] and return a [`HttpResponse`] wrapped in
254    /// a [`Result`].
255    pub fn start(mut self) -> Result<HttpResponse, Error> {
256        let mut res = self.handshake_resp()?;
257        self.set_frame_size();
258
259        match self.codec {
260            Some(codec) => {
261                let out_stream = WebsocketContext::with_codec(self.actor, self.stream, codec);
262                Ok(res.streaming(out_stream))
263            }
264            None => {
265                let out_stream = WebsocketContext::create(self.actor, self.stream);
266                Ok(res.streaming(out_stream))
267            }
268        }
269    }
270
271    /// Perform WebSocket handshake and start actor.
272    ///
273    /// `req` is an [`HttpRequest`] that should be requesting a websocket protocol change.
274    /// `stream` should be a [`Bytes`] stream (such as `actix_web::web::Payload`) that contains a
275    /// stream of the body request.
276    ///
277    /// If there is a problem with the handshake, an error is returned.
278    ///
279    /// If successful, returns a pair where the first item is an address for the created actor and
280    /// the second item is the [`HttpResponse`] that should be returned from the websocket request.
281    pub fn start_with_addr(mut self) -> Result<(Addr<A>, HttpResponse), Error> {
282        let mut res = self.handshake_resp()?;
283        self.set_frame_size();
284
285        match self.codec {
286            Some(codec) => {
287                let (addr, out_stream) =
288                    Self::create_with_codec_addr(self.actor, self.stream, codec);
289                Ok((addr, res.streaming(out_stream)))
290            }
291            None => {
292                let (addr, out_stream) =
293                    WebsocketContext::create_with_addr(self.actor, self.stream);
294                Ok((addr, res.streaming(out_stream)))
295            }
296        }
297    }
298}
299
300/// Perform WebSocket handshake and start actor.
301///
302/// To customize options, see [`WsResponseBuilder`].
303pub fn start<A, T>(actor: A, req: &HttpRequest, stream: T) -> Result<HttpResponse, Error>
304where
305    A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
306    T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
307{
308    let mut res = handshake(req)?;
309    Ok(res.streaming(WebsocketContext::create(actor, stream)))
310}
311
312/// Perform WebSocket handshake and start actor.
313///
314/// `req` is an HTTP Request that should be requesting a websocket protocol change. `stream` should
315/// be a `Bytes` stream (such as `actix_web::web::Payload`) that contains a stream of the
316/// body request.
317///
318/// If there is a problem with the handshake, an error is returned.
319///
320/// If successful, returns a pair where the first item is an address for the created actor and the
321/// second item is the response that should be returned from the WebSocket request.
322#[deprecated(since = "4.0.0", note = "Prefer `WsResponseBuilder::start_with_addr`.")]
323pub fn start_with_addr<A, T>(
324    actor: A,
325    req: &HttpRequest,
326    stream: T,
327) -> Result<(Addr<A>, HttpResponse), Error>
328where
329    A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
330    T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
331{
332    let mut res = handshake(req)?;
333    let (addr, out_stream) = WebsocketContext::create_with_addr(actor, stream);
334    Ok((addr, res.streaming(out_stream)))
335}
336
337/// Do WebSocket handshake and start ws actor.
338///
339/// `protocols` is a sequence of known protocols.
340#[deprecated(
341    since = "4.0.0",
342    note = "Prefer `WsResponseBuilder` for setting protocols."
343)]
344pub fn start_with_protocols<A, T>(
345    actor: A,
346    protocols: &[&str],
347    req: &HttpRequest,
348    stream: T,
349) -> Result<HttpResponse, Error>
350where
351    A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
352    T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
353{
354    let mut res = handshake_with_protocols(req, protocols)?;
355    Ok(res.streaming(WebsocketContext::create(actor, stream)))
356}
357
358/// Prepare WebSocket handshake response.
359///
360/// This function returns handshake `HttpResponse`, ready to send to peer. It does not perform
361/// any IO.
362pub fn handshake(req: &HttpRequest) -> Result<HttpResponseBuilder, HandshakeError> {
363    handshake_with_protocols(req, &[])
364}
365
366/// Prepare WebSocket handshake response.
367///
368/// This function returns handshake `HttpResponse`, ready to send to peer. It does not perform
369/// any IO.
370///
371/// `protocols` is a sequence of known protocols. On successful handshake, the returned response
372/// headers contain the first protocol in this list which the server also knows.
373pub fn handshake_with_protocols(
374    req: &HttpRequest,
375    protocols: &[&str],
376) -> Result<HttpResponseBuilder, HandshakeError> {
377    // WebSocket accepts only GET
378    if *req.method() != Method::GET {
379        return Err(HandshakeError::GetMethodRequired);
380    }
381
382    // check for "UPGRADE" to WebSocket header
383    let has_hdr = if let Some(hdr) = req.headers().get(&header::UPGRADE) {
384        if let Ok(s) = hdr.to_str() {
385            s.to_ascii_lowercase().contains("websocket")
386        } else {
387            false
388        }
389    } else {
390        false
391    };
392    if !has_hdr {
393        return Err(HandshakeError::NoWebsocketUpgrade);
394    }
395
396    // Upgrade connection
397    if !req.head().upgrade() {
398        return Err(HandshakeError::NoConnectionUpgrade);
399    }
400
401    // check supported version
402    if !req.headers().contains_key(&header::SEC_WEBSOCKET_VERSION) {
403        return Err(HandshakeError::NoVersionHeader);
404    }
405    let supported_ver = {
406        if let Some(hdr) = req.headers().get(&header::SEC_WEBSOCKET_VERSION) {
407            hdr == "13" || hdr == "8" || hdr == "7"
408        } else {
409            false
410        }
411    };
412    if !supported_ver {
413        return Err(HandshakeError::UnsupportedVersion);
414    }
415
416    // check client handshake for validity
417    if !req.headers().contains_key(&header::SEC_WEBSOCKET_KEY) {
418        return Err(HandshakeError::BadWebsocketKey);
419    }
420    let key = {
421        let key = req.headers().get(&header::SEC_WEBSOCKET_KEY).unwrap();
422        hash_key(key.as_ref())
423    };
424
425    // check requested protocols
426    let protocol = req
427        .headers()
428        .get(&header::SEC_WEBSOCKET_PROTOCOL)
429        .and_then(|req_protocols| {
430            let req_protocols = req_protocols.to_str().ok()?;
431            req_protocols
432                .split(',')
433                .map(|req_p| req_p.trim())
434                .find(|req_p| protocols.iter().any(|p| p == req_p))
435        });
436
437    let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS)
438        .upgrade("websocket")
439        .insert_header((
440            header::SEC_WEBSOCKET_ACCEPT,
441            // key is known to be header value safe ascii
442            HeaderValue::from_bytes(&key).unwrap(),
443        ))
444        .take();
445
446    if let Some(protocol) = protocol {
447        response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol));
448    }
449
450    Ok(response)
451}
452
453/// Execution context for `WebSockets` actors
454pub struct WebsocketContext<A>
455where
456    A: Actor<Context = WebsocketContext<A>>,
457{
458    inner: ContextParts<A>,
459    messages: VecDeque<Option<Message>>,
460}
461
462impl<A> ActorContext for WebsocketContext<A>
463where
464    A: Actor<Context = Self>,
465{
466    fn stop(&mut self) {
467        self.inner.stop();
468    }
469
470    fn terminate(&mut self) {
471        self.inner.terminate()
472    }
473
474    fn state(&self) -> ActorState {
475        self.inner.state()
476    }
477}
478
479impl<A> AsyncContext<A> for WebsocketContext<A>
480where
481    A: Actor<Context = Self>,
482{
483    fn spawn<F>(&mut self, fut: F) -> SpawnHandle
484    where
485        F: ActorFuture<A, Output = ()> + 'static,
486    {
487        self.inner.spawn(fut)
488    }
489
490    fn wait<F>(&mut self, fut: F)
491    where
492        F: ActorFuture<A, Output = ()> + 'static,
493    {
494        self.inner.wait(fut)
495    }
496
497    #[doc(hidden)]
498    #[inline]
499    fn waiting(&self) -> bool {
500        self.inner.waiting()
501            || self.inner.state() == ActorState::Stopping
502            || self.inner.state() == ActorState::Stopped
503    }
504
505    fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
506        self.inner.cancel_future(handle)
507    }
508
509    #[inline]
510    fn address(&self) -> Addr<A> {
511        self.inner.address()
512    }
513}
514
515impl<A> WebsocketContext<A>
516where
517    A: Actor<Context = Self>,
518{
519    /// Create a new Websocket context from a request and an actor.
520    #[inline]
521    pub fn create<S>(actor: A, stream: S) -> impl Stream<Item = Result<Bytes, Error>>
522    where
523        A: StreamHandler<Result<Message, ProtocolError>>,
524        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
525    {
526        let (_, stream) = WebsocketContext::create_with_addr(actor, stream);
527        stream
528    }
529
530    /// Create a new Websocket context from a request and an actor.
531    ///
532    /// Returns a pair, where the first item is an addr for the created actor, and the second item
533    /// is a stream intended to be set as part of the response
534    /// via [`HttpResponseBuilder::streaming()`].
535    pub fn create_with_addr<S>(
536        actor: A,
537        stream: S,
538    ) -> (Addr<A>, impl Stream<Item = Result<Bytes, Error>>)
539    where
540        A: StreamHandler<Result<Message, ProtocolError>>,
541        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
542    {
543        let mb = Mailbox::default();
544        let mut ctx = WebsocketContext {
545            inner: ContextParts::new(mb.sender_producer()),
546            messages: VecDeque::new(),
547        };
548        ctx.add_stream(WsStream::new(stream, Codec::new()));
549
550        let addr = ctx.address();
551
552        (addr, WebsocketContextFut::new(ctx, actor, mb, Codec::new()))
553    }
554
555    /// Create a new Websocket context from a request, an actor, and a codec
556    pub fn with_codec<S>(
557        actor: A,
558        stream: S,
559        codec: Codec,
560    ) -> impl Stream<Item = Result<Bytes, Error>>
561    where
562        A: StreamHandler<Result<Message, ProtocolError>>,
563        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
564    {
565        let mb = Mailbox::default();
566        let mut ctx = WebsocketContext {
567            inner: ContextParts::new(mb.sender_producer()),
568            messages: VecDeque::new(),
569        };
570        ctx.add_stream(WsStream::new(stream, codec.clone()));
571
572        WebsocketContextFut::new(ctx, actor, mb, codec)
573    }
574
575    /// Create a new Websocket context
576    pub fn with_factory<S, F>(stream: S, f: F) -> impl Stream<Item = Result<Bytes, Error>>
577    where
578        F: FnOnce(&mut Self) -> A + 'static,
579        A: StreamHandler<Result<Message, ProtocolError>>,
580        S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
581    {
582        let mb = Mailbox::default();
583        let mut ctx = WebsocketContext {
584            inner: ContextParts::new(mb.sender_producer()),
585            messages: VecDeque::new(),
586        };
587        ctx.add_stream(WsStream::new(stream, Codec::new()));
588
589        let act = f(&mut ctx);
590
591        WebsocketContextFut::new(ctx, act, mb, Codec::new())
592    }
593}
594
595impl<A> WebsocketContext<A>
596where
597    A: Actor<Context = Self>,
598{
599    /// Write payload
600    ///
601    /// This is a low-level function that accepts framed messages that should
602    /// be created using `Frame::message()`. If you want to send text or binary
603    /// data you should prefer the `text()` or `binary()` convenience functions
604    /// that handle the framing for you.
605    #[inline]
606    pub fn write_raw(&mut self, msg: Message) {
607        self.messages.push_back(Some(msg));
608    }
609
610    /// Send text frame
611    #[inline]
612    pub fn text(&mut self, text: impl Into<ByteString>) {
613        self.write_raw(Message::Text(text.into()));
614    }
615
616    /// Send binary frame
617    #[inline]
618    pub fn binary(&mut self, data: impl Into<Bytes>) {
619        self.write_raw(Message::Binary(data.into()));
620    }
621
622    /// Send ping frame
623    #[inline]
624    pub fn ping(&mut self, message: &[u8]) {
625        self.write_raw(Message::Ping(Bytes::copy_from_slice(message)));
626    }
627
628    /// Send pong frame
629    #[inline]
630    pub fn pong(&mut self, message: &[u8]) {
631        self.write_raw(Message::Pong(Bytes::copy_from_slice(message)));
632    }
633
634    /// Send close frame
635    #[inline]
636    pub fn close(&mut self, reason: Option<CloseReason>) {
637        self.write_raw(Message::Close(reason));
638    }
639
640    /// Handle of the running future
641    ///
642    /// SpawnHandle is the handle returned by `AsyncContext::spawn()` method.
643    pub fn handle(&self) -> SpawnHandle {
644        self.inner.curr_handle()
645    }
646
647    /// Set mailbox capacity
648    ///
649    /// By default mailbox capacity is 16 messages.
650    pub fn set_mailbox_capacity(&mut self, cap: usize) {
651        self.inner.set_mailbox_capacity(cap)
652    }
653}
654
655impl<A> AsyncContextParts<A> for WebsocketContext<A>
656where
657    A: Actor<Context = Self>,
658{
659    fn parts(&mut self) -> &mut ContextParts<A> {
660        &mut self.inner
661    }
662}
663
664struct WebsocketContextFut<A>
665where
666    A: Actor<Context = WebsocketContext<A>>,
667{
668    fut: ContextFut<A, WebsocketContext<A>>,
669    encoder: Codec,
670    buf: BytesMut,
671    closed: bool,
672}
673
674impl<A> WebsocketContextFut<A>
675where
676    A: Actor<Context = WebsocketContext<A>>,
677{
678    fn new(ctx: WebsocketContext<A>, act: A, mailbox: Mailbox<A>, codec: Codec) -> Self {
679        let fut = ContextFut::new(ctx, act, mailbox);
680        WebsocketContextFut {
681            fut,
682            encoder: codec,
683            buf: BytesMut::new(),
684            closed: false,
685        }
686    }
687}
688
689impl<A> Stream for WebsocketContextFut<A>
690where
691    A: Actor<Context = WebsocketContext<A>>,
692{
693    type Item = Result<Bytes, Error>;
694
695    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
696        let this = self.get_mut();
697
698        if this.fut.alive() {
699            let _ = Pin::new(&mut this.fut).poll(cx);
700        }
701
702        // encode messages
703        while let Some(item) = this.fut.ctx().messages.pop_front() {
704            if let Some(msg) = item {
705                this.encoder.encode(msg, &mut this.buf)?;
706            } else {
707                this.closed = true;
708                break;
709            }
710        }
711
712        if !this.buf.is_empty() {
713            Poll::Ready(Some(Ok(std::mem::take(&mut this.buf).freeze())))
714        } else if this.fut.alive() && !this.closed {
715            Poll::Pending
716        } else {
717            Poll::Ready(None)
718        }
719    }
720}
721
722impl<A, M> ToEnvelope<A, M> for WebsocketContext<A>
723where
724    A: Actor<Context = WebsocketContext<A>> + Handler<M>,
725    M: ActixMessage + Send + 'static,
726    M::Result: Send,
727{
728    fn pack(msg: M, tx: Option<oneshot::Sender<M::Result>>) -> Envelope<A> {
729        Envelope::new(msg, tx)
730    }
731}
732
733pin_project! {
734    #[derive(Debug)]
735    struct WsStream<S> {
736        #[pin]
737        stream: S,
738        decoder: Codec,
739        buf: BytesMut,
740        closed: bool,
741    }
742}
743
744impl<S> WsStream<S>
745where
746    S: Stream<Item = Result<Bytes, PayloadError>>,
747{
748    fn new(stream: S, codec: Codec) -> Self {
749        Self {
750            stream,
751            decoder: codec,
752            buf: BytesMut::new(),
753            closed: false,
754        }
755    }
756}
757
758impl<S> Stream for WsStream<S>
759where
760    S: Stream<Item = Result<Bytes, PayloadError>>,
761{
762    type Item = Result<Message, ProtocolError>;
763
764    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
765        let mut this = self.as_mut().project();
766
767        if !*this.closed {
768            loop {
769                match Pin::new(&mut this.stream).poll_next(cx) {
770                    Poll::Ready(Some(Ok(chunk))) => {
771                        this.buf.extend_from_slice(&chunk[..]);
772                    }
773                    Poll::Ready(None) => {
774                        *this.closed = true;
775                        break;
776                    }
777                    Poll::Pending => break,
778                    Poll::Ready(Some(Err(err))) => {
779                        return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new(
780                            io::ErrorKind::Other,
781                            format!("{err}"),
782                        )))));
783                    }
784                }
785            }
786        }
787
788        match this.decoder.decode(this.buf)? {
789            None => {
790                if *this.closed {
791                    Poll::Ready(None)
792                } else {
793                    Poll::Pending
794                }
795            }
796            Some(frm) => {
797                let msg = match frm {
798                    Frame::Text(data) => {
799                        Message::Text(ByteString::try_from(data).map_err(|e| {
800                            ProtocolError::Io(io::Error::new(
801                                io::ErrorKind::Other,
802                                format!("{}", e),
803                            ))
804                        })?)
805                    }
806                    Frame::Binary(data) => Message::Binary(data),
807                    Frame::Ping(s) => Message::Ping(s),
808                    Frame::Pong(s) => Message::Pong(s),
809                    Frame::Close(reason) => Message::Close(reason),
810                    Frame::Continuation(item) => Message::Continuation(item),
811                };
812                Poll::Ready(Some(Ok(msg)))
813            }
814        }
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use actix_web::test::TestRequest;
821
822    use super::*;
823
824    #[test]
825    fn test_handshake() {
826        let req = TestRequest::default()
827            .method(Method::POST)
828            .to_http_request();
829        assert_eq!(
830            HandshakeError::GetMethodRequired,
831            handshake(&req).err().unwrap()
832        );
833
834        let req = TestRequest::default().to_http_request();
835        assert_eq!(
836            HandshakeError::NoWebsocketUpgrade,
837            handshake(&req).err().unwrap()
838        );
839
840        let req = TestRequest::default()
841            .insert_header((header::UPGRADE, header::HeaderValue::from_static("test")))
842            .to_http_request();
843        assert_eq!(
844            HandshakeError::NoWebsocketUpgrade,
845            handshake(&req).err().unwrap()
846        );
847
848        let req = TestRequest::default()
849            .insert_header((
850                header::UPGRADE,
851                header::HeaderValue::from_static("websocket"),
852            ))
853            .to_http_request();
854        assert_eq!(
855            HandshakeError::NoConnectionUpgrade,
856            handshake(&req).err().unwrap()
857        );
858
859        let req = TestRequest::default()
860            .insert_header((
861                header::UPGRADE,
862                header::HeaderValue::from_static("websocket"),
863            ))
864            .insert_header((
865                header::CONNECTION,
866                header::HeaderValue::from_static("upgrade"),
867            ))
868            .to_http_request();
869        assert_eq!(
870            HandshakeError::NoVersionHeader,
871            handshake(&req).err().unwrap()
872        );
873
874        let req = TestRequest::default()
875            .insert_header((
876                header::UPGRADE,
877                header::HeaderValue::from_static("websocket"),
878            ))
879            .insert_header((
880                header::CONNECTION,
881                header::HeaderValue::from_static("upgrade"),
882            ))
883            .insert_header((
884                header::SEC_WEBSOCKET_VERSION,
885                header::HeaderValue::from_static("5"),
886            ))
887            .to_http_request();
888        assert_eq!(
889            HandshakeError::UnsupportedVersion,
890            handshake(&req).err().unwrap()
891        );
892
893        let req = TestRequest::default()
894            .insert_header((
895                header::UPGRADE,
896                header::HeaderValue::from_static("websocket"),
897            ))
898            .insert_header((
899                header::CONNECTION,
900                header::HeaderValue::from_static("upgrade"),
901            ))
902            .insert_header((
903                header::SEC_WEBSOCKET_VERSION,
904                header::HeaderValue::from_static("13"),
905            ))
906            .to_http_request();
907        assert_eq!(
908            HandshakeError::BadWebsocketKey,
909            handshake(&req).err().unwrap()
910        );
911
912        let req = TestRequest::default()
913            .insert_header((
914                header::UPGRADE,
915                header::HeaderValue::from_static("websocket"),
916            ))
917            .insert_header((
918                header::CONNECTION,
919                header::HeaderValue::from_static("upgrade"),
920            ))
921            .insert_header((
922                header::SEC_WEBSOCKET_VERSION,
923                header::HeaderValue::from_static("13"),
924            ))
925            .insert_header((
926                header::SEC_WEBSOCKET_KEY,
927                header::HeaderValue::from_static("13"),
928            ))
929            .to_http_request();
930
931        let resp = handshake(&req).unwrap().finish();
932        assert_eq!(StatusCode::SWITCHING_PROTOCOLS, resp.status());
933        assert_eq!(None, resp.headers().get(&header::CONTENT_LENGTH));
934        assert_eq!(None, resp.headers().get(&header::TRANSFER_ENCODING));
935
936        let req = TestRequest::default()
937            .insert_header((
938                header::UPGRADE,
939                header::HeaderValue::from_static("websocket"),
940            ))
941            .insert_header((
942                header::CONNECTION,
943                header::HeaderValue::from_static("upgrade"),
944            ))
945            .insert_header((
946                header::SEC_WEBSOCKET_VERSION,
947                header::HeaderValue::from_static("13"),
948            ))
949            .insert_header((
950                header::SEC_WEBSOCKET_KEY,
951                header::HeaderValue::from_static("13"),
952            ))
953            .insert_header((
954                header::SEC_WEBSOCKET_PROTOCOL,
955                header::HeaderValue::from_static("graphql"),
956            ))
957            .to_http_request();
958
959        let protocols = ["graphql"];
960
961        assert_eq!(
962            StatusCode::SWITCHING_PROTOCOLS,
963            handshake_with_protocols(&req, &protocols)
964                .unwrap()
965                .finish()
966                .status()
967        );
968        assert_eq!(
969            Some(&header::HeaderValue::from_static("graphql")),
970            handshake_with_protocols(&req, &protocols)
971                .unwrap()
972                .finish()
973                .headers()
974                .get(&header::SEC_WEBSOCKET_PROTOCOL)
975        );
976
977        let req = TestRequest::default()
978            .insert_header((
979                header::UPGRADE,
980                header::HeaderValue::from_static("websocket"),
981            ))
982            .insert_header((
983                header::CONNECTION,
984                header::HeaderValue::from_static("upgrade"),
985            ))
986            .insert_header((
987                header::SEC_WEBSOCKET_VERSION,
988                header::HeaderValue::from_static("13"),
989            ))
990            .insert_header((
991                header::SEC_WEBSOCKET_KEY,
992                header::HeaderValue::from_static("13"),
993            ))
994            .insert_header((
995                header::SEC_WEBSOCKET_PROTOCOL,
996                header::HeaderValue::from_static("p1, p2, p3"),
997            ))
998            .to_http_request();
999
1000        let protocols = vec!["p3", "p2"];
1001
1002        assert_eq!(
1003            StatusCode::SWITCHING_PROTOCOLS,
1004            handshake_with_protocols(&req, &protocols)
1005                .unwrap()
1006                .finish()
1007                .status()
1008        );
1009        assert_eq!(
1010            Some(&header::HeaderValue::from_static("p2")),
1011            handshake_with_protocols(&req, &protocols)
1012                .unwrap()
1013                .finish()
1014                .headers()
1015                .get(&header::SEC_WEBSOCKET_PROTOCOL)
1016        );
1017
1018        let req = TestRequest::default()
1019            .insert_header((
1020                header::UPGRADE,
1021                header::HeaderValue::from_static("websocket"),
1022            ))
1023            .insert_header((
1024                header::CONNECTION,
1025                header::HeaderValue::from_static("upgrade"),
1026            ))
1027            .insert_header((
1028                header::SEC_WEBSOCKET_VERSION,
1029                header::HeaderValue::from_static("13"),
1030            ))
1031            .insert_header((
1032                header::SEC_WEBSOCKET_KEY,
1033                header::HeaderValue::from_static("13"),
1034            ))
1035            .insert_header((
1036                header::SEC_WEBSOCKET_PROTOCOL,
1037                header::HeaderValue::from_static("p1,p2,p3"),
1038            ))
1039            .to_http_request();
1040
1041        let protocols = vec!["p3", "p2"];
1042
1043        assert_eq!(
1044            StatusCode::SWITCHING_PROTOCOLS,
1045            handshake_with_protocols(&req, &protocols)
1046                .unwrap()
1047                .finish()
1048                .status()
1049        );
1050        assert_eq!(
1051            Some(&header::HeaderValue::from_static("p2")),
1052            handshake_with_protocols(&req, &protocols)
1053                .unwrap()
1054                .finish()
1055                .headers()
1056                .get(&header::SEC_WEBSOCKET_PROTOCOL)
1057        );
1058    }
1059}