1use std::{
60 collections::VecDeque,
61 future::Future,
62 io, mem,
63 pin::Pin,
64 task::{Context, Poll},
65};
66
67use actix::{
68 dev::{
69 AsyncContextParts, ContextFut, ContextParts, Envelope, Mailbox, StreamHandler, ToEnvelope,
70 },
71 fut::ActorFuture,
72 Actor, ActorContext, ActorState, Addr, AsyncContext, Handler, Message as ActixMessage,
73 SpawnHandle,
74};
75use actix_http::ws::{hash_key, Codec};
76pub use actix_http::ws::{CloseCode, CloseReason, Frame, HandshakeError, Message, ProtocolError};
77use actix_web::{
78 error::{Error, PayloadError},
79 http::{
80 header::{self, HeaderValue},
81 Method, StatusCode,
82 },
83 HttpRequest, HttpResponse, HttpResponseBuilder,
84};
85use bytes::{Bytes, BytesMut};
86use bytestring::ByteString;
87use futures_core::Stream;
88use pin_project_lite::pin_project;
89use tokio::sync::oneshot;
90use tokio_util::codec::{Decoder as _, Encoder as _};
91
92pub struct WsResponseBuilder<'a, A, T>
143where
144 A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
145 T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
146{
147 actor: A,
148 req: &'a HttpRequest,
149 stream: T,
150 codec: Option<Codec>,
151 protocols: Option<&'a [&'a str]>,
152 frame_size: Option<usize>,
153}
154
155impl<'a, A, T> WsResponseBuilder<'a, A, T>
156where
157 A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
158 T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
159{
160 pub fn new(actor: A, req: &'a HttpRequest, stream: T) -> Self {
164 WsResponseBuilder {
165 actor,
166 req,
167 stream,
168 codec: None,
169 protocols: None,
170 frame_size: None,
171 }
172 }
173
174 pub fn protocols(mut self, protocols: &'a [&'a str]) -> Self {
176 self.protocols = Some(protocols);
177 self
178 }
179
180 pub fn frame_size(mut self, frame_size: usize) -> Self {
184 self.frame_size = Some(frame_size);
185 self
186 }
187
188 pub fn codec(mut self, codec: Codec) -> Self {
191 self.codec = Some(codec);
192 self
193 }
194
195 fn handshake_resp(&self) -> Result<HttpResponseBuilder, HandshakeError> {
196 match self.protocols {
197 Some(protocols) => handshake_with_protocols(self.req, protocols),
198 None => handshake(self.req),
199 }
200 }
201
202 fn set_frame_size(&mut self) {
203 if let Some(frame_size) = self.frame_size {
204 match &mut self.codec {
205 Some(codec) => {
206 let orig_codec = mem::take(codec);
208 *codec = orig_codec.max_size(frame_size);
209 }
210
211 None => {
212 self.codec = Some(Codec::new().max_size(frame_size));
214 }
215 }
216 }
217 }
218
219 fn create_with_codec_addr<S>(
225 actor: A,
226 stream: S,
227 codec: Codec,
228 ) -> (Addr<A>, impl Stream<Item = Result<Bytes, Error>>)
229 where
230 A: StreamHandler<Result<Message, ProtocolError>>,
231 S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
232 {
233 let mb = Mailbox::default();
234 let mut ctx = WebsocketContext {
235 inner: ContextParts::new(mb.sender_producer()),
236 messages: VecDeque::new(),
237 };
238 ctx.add_stream(WsStream::new(stream, codec.clone()));
239
240 let addr = ctx.address();
241
242 (addr, WebsocketContextFut::new(ctx, actor, mb, codec))
243 }
244
245 pub fn start(mut self) -> Result<HttpResponse, Error> {
256 let mut res = self.handshake_resp()?;
257 self.set_frame_size();
258
259 match self.codec {
260 Some(codec) => {
261 let out_stream = WebsocketContext::with_codec(self.actor, self.stream, codec);
262 Ok(res.streaming(out_stream))
263 }
264 None => {
265 let out_stream = WebsocketContext::create(self.actor, self.stream);
266 Ok(res.streaming(out_stream))
267 }
268 }
269 }
270
271 pub fn start_with_addr(mut self) -> Result<(Addr<A>, HttpResponse), Error> {
282 let mut res = self.handshake_resp()?;
283 self.set_frame_size();
284
285 match self.codec {
286 Some(codec) => {
287 let (addr, out_stream) =
288 Self::create_with_codec_addr(self.actor, self.stream, codec);
289 Ok((addr, res.streaming(out_stream)))
290 }
291 None => {
292 let (addr, out_stream) =
293 WebsocketContext::create_with_addr(self.actor, self.stream);
294 Ok((addr, res.streaming(out_stream)))
295 }
296 }
297 }
298}
299
300pub fn start<A, T>(actor: A, req: &HttpRequest, stream: T) -> Result<HttpResponse, Error>
304where
305 A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
306 T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
307{
308 let mut res = handshake(req)?;
309 Ok(res.streaming(WebsocketContext::create(actor, stream)))
310}
311
312#[deprecated(since = "4.0.0", note = "Prefer `WsResponseBuilder::start_with_addr`.")]
323pub fn start_with_addr<A, T>(
324 actor: A,
325 req: &HttpRequest,
326 stream: T,
327) -> Result<(Addr<A>, HttpResponse), Error>
328where
329 A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
330 T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
331{
332 let mut res = handshake(req)?;
333 let (addr, out_stream) = WebsocketContext::create_with_addr(actor, stream);
334 Ok((addr, res.streaming(out_stream)))
335}
336
337#[deprecated(
341 since = "4.0.0",
342 note = "Prefer `WsResponseBuilder` for setting protocols."
343)]
344pub fn start_with_protocols<A, T>(
345 actor: A,
346 protocols: &[&str],
347 req: &HttpRequest,
348 stream: T,
349) -> Result<HttpResponse, Error>
350where
351 A: Actor<Context = WebsocketContext<A>> + StreamHandler<Result<Message, ProtocolError>>,
352 T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
353{
354 let mut res = handshake_with_protocols(req, protocols)?;
355 Ok(res.streaming(WebsocketContext::create(actor, stream)))
356}
357
358pub fn handshake(req: &HttpRequest) -> Result<HttpResponseBuilder, HandshakeError> {
363 handshake_with_protocols(req, &[])
364}
365
366pub fn handshake_with_protocols(
374 req: &HttpRequest,
375 protocols: &[&str],
376) -> Result<HttpResponseBuilder, HandshakeError> {
377 if *req.method() != Method::GET {
379 return Err(HandshakeError::GetMethodRequired);
380 }
381
382 let has_hdr = if let Some(hdr) = req.headers().get(&header::UPGRADE) {
384 if let Ok(s) = hdr.to_str() {
385 s.to_ascii_lowercase().contains("websocket")
386 } else {
387 false
388 }
389 } else {
390 false
391 };
392 if !has_hdr {
393 return Err(HandshakeError::NoWebsocketUpgrade);
394 }
395
396 if !req.head().upgrade() {
398 return Err(HandshakeError::NoConnectionUpgrade);
399 }
400
401 if !req.headers().contains_key(&header::SEC_WEBSOCKET_VERSION) {
403 return Err(HandshakeError::NoVersionHeader);
404 }
405 let supported_ver = {
406 if let Some(hdr) = req.headers().get(&header::SEC_WEBSOCKET_VERSION) {
407 hdr == "13" || hdr == "8" || hdr == "7"
408 } else {
409 false
410 }
411 };
412 if !supported_ver {
413 return Err(HandshakeError::UnsupportedVersion);
414 }
415
416 if !req.headers().contains_key(&header::SEC_WEBSOCKET_KEY) {
418 return Err(HandshakeError::BadWebsocketKey);
419 }
420 let key = {
421 let key = req.headers().get(&header::SEC_WEBSOCKET_KEY).unwrap();
422 hash_key(key.as_ref())
423 };
424
425 let protocol = req
427 .headers()
428 .get(&header::SEC_WEBSOCKET_PROTOCOL)
429 .and_then(|req_protocols| {
430 let req_protocols = req_protocols.to_str().ok()?;
431 req_protocols
432 .split(',')
433 .map(|req_p| req_p.trim())
434 .find(|req_p| protocols.iter().any(|p| p == req_p))
435 });
436
437 let mut response = HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS)
438 .upgrade("websocket")
439 .insert_header((
440 header::SEC_WEBSOCKET_ACCEPT,
441 HeaderValue::from_bytes(&key).unwrap(),
443 ))
444 .take();
445
446 if let Some(protocol) = protocol {
447 response.insert_header((header::SEC_WEBSOCKET_PROTOCOL, protocol));
448 }
449
450 Ok(response)
451}
452
453pub struct WebsocketContext<A>
455where
456 A: Actor<Context = WebsocketContext<A>>,
457{
458 inner: ContextParts<A>,
459 messages: VecDeque<Option<Message>>,
460}
461
462impl<A> ActorContext for WebsocketContext<A>
463where
464 A: Actor<Context = Self>,
465{
466 fn stop(&mut self) {
467 self.inner.stop();
468 }
469
470 fn terminate(&mut self) {
471 self.inner.terminate()
472 }
473
474 fn state(&self) -> ActorState {
475 self.inner.state()
476 }
477}
478
479impl<A> AsyncContext<A> for WebsocketContext<A>
480where
481 A: Actor<Context = Self>,
482{
483 fn spawn<F>(&mut self, fut: F) -> SpawnHandle
484 where
485 F: ActorFuture<A, Output = ()> + 'static,
486 {
487 self.inner.spawn(fut)
488 }
489
490 fn wait<F>(&mut self, fut: F)
491 where
492 F: ActorFuture<A, Output = ()> + 'static,
493 {
494 self.inner.wait(fut)
495 }
496
497 #[doc(hidden)]
498 #[inline]
499 fn waiting(&self) -> bool {
500 self.inner.waiting()
501 || self.inner.state() == ActorState::Stopping
502 || self.inner.state() == ActorState::Stopped
503 }
504
505 fn cancel_future(&mut self, handle: SpawnHandle) -> bool {
506 self.inner.cancel_future(handle)
507 }
508
509 #[inline]
510 fn address(&self) -> Addr<A> {
511 self.inner.address()
512 }
513}
514
515impl<A> WebsocketContext<A>
516where
517 A: Actor<Context = Self>,
518{
519 #[inline]
521 pub fn create<S>(actor: A, stream: S) -> impl Stream<Item = Result<Bytes, Error>>
522 where
523 A: StreamHandler<Result<Message, ProtocolError>>,
524 S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
525 {
526 let (_, stream) = WebsocketContext::create_with_addr(actor, stream);
527 stream
528 }
529
530 pub fn create_with_addr<S>(
536 actor: A,
537 stream: S,
538 ) -> (Addr<A>, impl Stream<Item = Result<Bytes, Error>>)
539 where
540 A: StreamHandler<Result<Message, ProtocolError>>,
541 S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
542 {
543 let mb = Mailbox::default();
544 let mut ctx = WebsocketContext {
545 inner: ContextParts::new(mb.sender_producer()),
546 messages: VecDeque::new(),
547 };
548 ctx.add_stream(WsStream::new(stream, Codec::new()));
549
550 let addr = ctx.address();
551
552 (addr, WebsocketContextFut::new(ctx, actor, mb, Codec::new()))
553 }
554
555 pub fn with_codec<S>(
557 actor: A,
558 stream: S,
559 codec: Codec,
560 ) -> impl Stream<Item = Result<Bytes, Error>>
561 where
562 A: StreamHandler<Result<Message, ProtocolError>>,
563 S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
564 {
565 let mb = Mailbox::default();
566 let mut ctx = WebsocketContext {
567 inner: ContextParts::new(mb.sender_producer()),
568 messages: VecDeque::new(),
569 };
570 ctx.add_stream(WsStream::new(stream, codec.clone()));
571
572 WebsocketContextFut::new(ctx, actor, mb, codec)
573 }
574
575 pub fn with_factory<S, F>(stream: S, f: F) -> impl Stream<Item = Result<Bytes, Error>>
577 where
578 F: FnOnce(&mut Self) -> A + 'static,
579 A: StreamHandler<Result<Message, ProtocolError>>,
580 S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
581 {
582 let mb = Mailbox::default();
583 let mut ctx = WebsocketContext {
584 inner: ContextParts::new(mb.sender_producer()),
585 messages: VecDeque::new(),
586 };
587 ctx.add_stream(WsStream::new(stream, Codec::new()));
588
589 let act = f(&mut ctx);
590
591 WebsocketContextFut::new(ctx, act, mb, Codec::new())
592 }
593}
594
595impl<A> WebsocketContext<A>
596where
597 A: Actor<Context = Self>,
598{
599 #[inline]
606 pub fn write_raw(&mut self, msg: Message) {
607 self.messages.push_back(Some(msg));
608 }
609
610 #[inline]
612 pub fn text(&mut self, text: impl Into<ByteString>) {
613 self.write_raw(Message::Text(text.into()));
614 }
615
616 #[inline]
618 pub fn binary(&mut self, data: impl Into<Bytes>) {
619 self.write_raw(Message::Binary(data.into()));
620 }
621
622 #[inline]
624 pub fn ping(&mut self, message: &[u8]) {
625 self.write_raw(Message::Ping(Bytes::copy_from_slice(message)));
626 }
627
628 #[inline]
630 pub fn pong(&mut self, message: &[u8]) {
631 self.write_raw(Message::Pong(Bytes::copy_from_slice(message)));
632 }
633
634 #[inline]
636 pub fn close(&mut self, reason: Option<CloseReason>) {
637 self.write_raw(Message::Close(reason));
638 }
639
640 pub fn handle(&self) -> SpawnHandle {
644 self.inner.curr_handle()
645 }
646
647 pub fn set_mailbox_capacity(&mut self, cap: usize) {
651 self.inner.set_mailbox_capacity(cap)
652 }
653}
654
655impl<A> AsyncContextParts<A> for WebsocketContext<A>
656where
657 A: Actor<Context = Self>,
658{
659 fn parts(&mut self) -> &mut ContextParts<A> {
660 &mut self.inner
661 }
662}
663
664struct WebsocketContextFut<A>
665where
666 A: Actor<Context = WebsocketContext<A>>,
667{
668 fut: ContextFut<A, WebsocketContext<A>>,
669 encoder: Codec,
670 buf: BytesMut,
671 closed: bool,
672}
673
674impl<A> WebsocketContextFut<A>
675where
676 A: Actor<Context = WebsocketContext<A>>,
677{
678 fn new(ctx: WebsocketContext<A>, act: A, mailbox: Mailbox<A>, codec: Codec) -> Self {
679 let fut = ContextFut::new(ctx, act, mailbox);
680 WebsocketContextFut {
681 fut,
682 encoder: codec,
683 buf: BytesMut::new(),
684 closed: false,
685 }
686 }
687}
688
689impl<A> Stream for WebsocketContextFut<A>
690where
691 A: Actor<Context = WebsocketContext<A>>,
692{
693 type Item = Result<Bytes, Error>;
694
695 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
696 let this = self.get_mut();
697
698 if this.fut.alive() {
699 let _ = Pin::new(&mut this.fut).poll(cx);
700 }
701
702 while let Some(item) = this.fut.ctx().messages.pop_front() {
704 if let Some(msg) = item {
705 this.encoder.encode(msg, &mut this.buf)?;
706 } else {
707 this.closed = true;
708 break;
709 }
710 }
711
712 if !this.buf.is_empty() {
713 Poll::Ready(Some(Ok(std::mem::take(&mut this.buf).freeze())))
714 } else if this.fut.alive() && !this.closed {
715 Poll::Pending
716 } else {
717 Poll::Ready(None)
718 }
719 }
720}
721
722impl<A, M> ToEnvelope<A, M> for WebsocketContext<A>
723where
724 A: Actor<Context = WebsocketContext<A>> + Handler<M>,
725 M: ActixMessage + Send + 'static,
726 M::Result: Send,
727{
728 fn pack(msg: M, tx: Option<oneshot::Sender<M::Result>>) -> Envelope<A> {
729 Envelope::new(msg, tx)
730 }
731}
732
733pin_project! {
734 #[derive(Debug)]
735 struct WsStream<S> {
736 #[pin]
737 stream: S,
738 decoder: Codec,
739 buf: BytesMut,
740 closed: bool,
741 }
742}
743
744impl<S> WsStream<S>
745where
746 S: Stream<Item = Result<Bytes, PayloadError>>,
747{
748 fn new(stream: S, codec: Codec) -> Self {
749 Self {
750 stream,
751 decoder: codec,
752 buf: BytesMut::new(),
753 closed: false,
754 }
755 }
756}
757
758impl<S> Stream for WsStream<S>
759where
760 S: Stream<Item = Result<Bytes, PayloadError>>,
761{
762 type Item = Result<Message, ProtocolError>;
763
764 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
765 let mut this = self.as_mut().project();
766
767 if !*this.closed {
768 loop {
769 match Pin::new(&mut this.stream).poll_next(cx) {
770 Poll::Ready(Some(Ok(chunk))) => {
771 this.buf.extend_from_slice(&chunk[..]);
772 }
773 Poll::Ready(None) => {
774 *this.closed = true;
775 break;
776 }
777 Poll::Pending => break,
778 Poll::Ready(Some(Err(err))) => {
779 return Poll::Ready(Some(Err(ProtocolError::Io(io::Error::new(
780 io::ErrorKind::Other,
781 format!("{err}"),
782 )))));
783 }
784 }
785 }
786 }
787
788 match this.decoder.decode(this.buf)? {
789 None => {
790 if *this.closed {
791 Poll::Ready(None)
792 } else {
793 Poll::Pending
794 }
795 }
796 Some(frm) => {
797 let msg = match frm {
798 Frame::Text(data) => {
799 Message::Text(ByteString::try_from(data).map_err(|e| {
800 ProtocolError::Io(io::Error::new(
801 io::ErrorKind::Other,
802 format!("{}", e),
803 ))
804 })?)
805 }
806 Frame::Binary(data) => Message::Binary(data),
807 Frame::Ping(s) => Message::Ping(s),
808 Frame::Pong(s) => Message::Pong(s),
809 Frame::Close(reason) => Message::Close(reason),
810 Frame::Continuation(item) => Message::Continuation(item),
811 };
812 Poll::Ready(Some(Ok(msg)))
813 }
814 }
815 }
816}
817
818#[cfg(test)]
819mod tests {
820 use actix_web::test::TestRequest;
821
822 use super::*;
823
824 #[test]
825 fn test_handshake() {
826 let req = TestRequest::default()
827 .method(Method::POST)
828 .to_http_request();
829 assert_eq!(
830 HandshakeError::GetMethodRequired,
831 handshake(&req).err().unwrap()
832 );
833
834 let req = TestRequest::default().to_http_request();
835 assert_eq!(
836 HandshakeError::NoWebsocketUpgrade,
837 handshake(&req).err().unwrap()
838 );
839
840 let req = TestRequest::default()
841 .insert_header((header::UPGRADE, header::HeaderValue::from_static("test")))
842 .to_http_request();
843 assert_eq!(
844 HandshakeError::NoWebsocketUpgrade,
845 handshake(&req).err().unwrap()
846 );
847
848 let req = TestRequest::default()
849 .insert_header((
850 header::UPGRADE,
851 header::HeaderValue::from_static("websocket"),
852 ))
853 .to_http_request();
854 assert_eq!(
855 HandshakeError::NoConnectionUpgrade,
856 handshake(&req).err().unwrap()
857 );
858
859 let req = TestRequest::default()
860 .insert_header((
861 header::UPGRADE,
862 header::HeaderValue::from_static("websocket"),
863 ))
864 .insert_header((
865 header::CONNECTION,
866 header::HeaderValue::from_static("upgrade"),
867 ))
868 .to_http_request();
869 assert_eq!(
870 HandshakeError::NoVersionHeader,
871 handshake(&req).err().unwrap()
872 );
873
874 let req = TestRequest::default()
875 .insert_header((
876 header::UPGRADE,
877 header::HeaderValue::from_static("websocket"),
878 ))
879 .insert_header((
880 header::CONNECTION,
881 header::HeaderValue::from_static("upgrade"),
882 ))
883 .insert_header((
884 header::SEC_WEBSOCKET_VERSION,
885 header::HeaderValue::from_static("5"),
886 ))
887 .to_http_request();
888 assert_eq!(
889 HandshakeError::UnsupportedVersion,
890 handshake(&req).err().unwrap()
891 );
892
893 let req = TestRequest::default()
894 .insert_header((
895 header::UPGRADE,
896 header::HeaderValue::from_static("websocket"),
897 ))
898 .insert_header((
899 header::CONNECTION,
900 header::HeaderValue::from_static("upgrade"),
901 ))
902 .insert_header((
903 header::SEC_WEBSOCKET_VERSION,
904 header::HeaderValue::from_static("13"),
905 ))
906 .to_http_request();
907 assert_eq!(
908 HandshakeError::BadWebsocketKey,
909 handshake(&req).err().unwrap()
910 );
911
912 let req = TestRequest::default()
913 .insert_header((
914 header::UPGRADE,
915 header::HeaderValue::from_static("websocket"),
916 ))
917 .insert_header((
918 header::CONNECTION,
919 header::HeaderValue::from_static("upgrade"),
920 ))
921 .insert_header((
922 header::SEC_WEBSOCKET_VERSION,
923 header::HeaderValue::from_static("13"),
924 ))
925 .insert_header((
926 header::SEC_WEBSOCKET_KEY,
927 header::HeaderValue::from_static("13"),
928 ))
929 .to_http_request();
930
931 let resp = handshake(&req).unwrap().finish();
932 assert_eq!(StatusCode::SWITCHING_PROTOCOLS, resp.status());
933 assert_eq!(None, resp.headers().get(&header::CONTENT_LENGTH));
934 assert_eq!(None, resp.headers().get(&header::TRANSFER_ENCODING));
935
936 let req = TestRequest::default()
937 .insert_header((
938 header::UPGRADE,
939 header::HeaderValue::from_static("websocket"),
940 ))
941 .insert_header((
942 header::CONNECTION,
943 header::HeaderValue::from_static("upgrade"),
944 ))
945 .insert_header((
946 header::SEC_WEBSOCKET_VERSION,
947 header::HeaderValue::from_static("13"),
948 ))
949 .insert_header((
950 header::SEC_WEBSOCKET_KEY,
951 header::HeaderValue::from_static("13"),
952 ))
953 .insert_header((
954 header::SEC_WEBSOCKET_PROTOCOL,
955 header::HeaderValue::from_static("graphql"),
956 ))
957 .to_http_request();
958
959 let protocols = ["graphql"];
960
961 assert_eq!(
962 StatusCode::SWITCHING_PROTOCOLS,
963 handshake_with_protocols(&req, &protocols)
964 .unwrap()
965 .finish()
966 .status()
967 );
968 assert_eq!(
969 Some(&header::HeaderValue::from_static("graphql")),
970 handshake_with_protocols(&req, &protocols)
971 .unwrap()
972 .finish()
973 .headers()
974 .get(&header::SEC_WEBSOCKET_PROTOCOL)
975 );
976
977 let req = TestRequest::default()
978 .insert_header((
979 header::UPGRADE,
980 header::HeaderValue::from_static("websocket"),
981 ))
982 .insert_header((
983 header::CONNECTION,
984 header::HeaderValue::from_static("upgrade"),
985 ))
986 .insert_header((
987 header::SEC_WEBSOCKET_VERSION,
988 header::HeaderValue::from_static("13"),
989 ))
990 .insert_header((
991 header::SEC_WEBSOCKET_KEY,
992 header::HeaderValue::from_static("13"),
993 ))
994 .insert_header((
995 header::SEC_WEBSOCKET_PROTOCOL,
996 header::HeaderValue::from_static("p1, p2, p3"),
997 ))
998 .to_http_request();
999
1000 let protocols = vec!["p3", "p2"];
1001
1002 assert_eq!(
1003 StatusCode::SWITCHING_PROTOCOLS,
1004 handshake_with_protocols(&req, &protocols)
1005 .unwrap()
1006 .finish()
1007 .status()
1008 );
1009 assert_eq!(
1010 Some(&header::HeaderValue::from_static("p2")),
1011 handshake_with_protocols(&req, &protocols)
1012 .unwrap()
1013 .finish()
1014 .headers()
1015 .get(&header::SEC_WEBSOCKET_PROTOCOL)
1016 );
1017
1018 let req = TestRequest::default()
1019 .insert_header((
1020 header::UPGRADE,
1021 header::HeaderValue::from_static("websocket"),
1022 ))
1023 .insert_header((
1024 header::CONNECTION,
1025 header::HeaderValue::from_static("upgrade"),
1026 ))
1027 .insert_header((
1028 header::SEC_WEBSOCKET_VERSION,
1029 header::HeaderValue::from_static("13"),
1030 ))
1031 .insert_header((
1032 header::SEC_WEBSOCKET_KEY,
1033 header::HeaderValue::from_static("13"),
1034 ))
1035 .insert_header((
1036 header::SEC_WEBSOCKET_PROTOCOL,
1037 header::HeaderValue::from_static("p1,p2,p3"),
1038 ))
1039 .to_http_request();
1040
1041 let protocols = vec!["p3", "p2"];
1042
1043 assert_eq!(
1044 StatusCode::SWITCHING_PROTOCOLS,
1045 handshake_with_protocols(&req, &protocols)
1046 .unwrap()
1047 .finish()
1048 .status()
1049 );
1050 assert_eq!(
1051 Some(&header::HeaderValue::from_static("p2")),
1052 handshake_with_protocols(&req, &protocols)
1053 .unwrap()
1054 .finish()
1055 .headers()
1056 .get(&header::SEC_WEBSOCKET_PROTOCOL)
1057 );
1058 }
1059}