1use std::{
4 future::Future,
5 marker::PhantomData,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use actix_http::encoding::Encoder;
11use actix_service::{Service, Transform};
12use actix_utils::future::{ok, Either, Ready};
13use futures_core::ready;
14use mime::Mime;
15use once_cell::sync::Lazy;
16use pin_project_lite::pin_project;
17
18use crate::{
19 body::{EitherBody, MessageBody},
20 http::{
21 header::{self, AcceptEncoding, ContentEncoding, Encoding, HeaderValue},
22 StatusCode,
23 },
24 service::{ServiceRequest, ServiceResponse},
25 Error, HttpMessage, HttpResponse,
26};
27
28#[derive(Debug, Clone, Default)]
76#[non_exhaustive]
77pub struct Compress;
78
79impl<S, B> Transform<S, ServiceRequest> for Compress
80where
81 B: MessageBody,
82 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
83{
84 type Response = ServiceResponse<EitherBody<Encoder<B>>>;
85 type Error = Error;
86 type Transform = CompressMiddleware<S>;
87 type InitError = ();
88 type Future = Ready<Result<Self::Transform, Self::InitError>>;
89
90 fn new_transform(&self, service: S) -> Self::Future {
91 ok(CompressMiddleware { service })
92 }
93}
94
95pub struct CompressMiddleware<S> {
96 service: S,
97}
98
99impl<S, B> Service<ServiceRequest> for CompressMiddleware<S>
100where
101 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
102 B: MessageBody,
103{
104 type Response = ServiceResponse<EitherBody<Encoder<B>>>;
105 type Error = Error;
106 #[allow(clippy::type_complexity)]
107 type Future = Either<CompressResponse<S, B>, Ready<Result<Self::Response, Self::Error>>>;
108
109 actix_service::forward_ready!(service);
110
111 #[allow(clippy::borrow_interior_mutable_const)]
112 fn call(&self, req: ServiceRequest) -> Self::Future {
113 let accept_encoding = req.get_header::<AcceptEncoding>();
115
116 let accept_encoding = match accept_encoding {
117 None => {
119 return Either::left(CompressResponse {
120 encoding: Encoding::identity(),
121 fut: self.service.call(req),
122 _phantom: PhantomData,
123 })
124 }
125
126 Some(accept_encoding) => accept_encoding,
128 };
129
130 match accept_encoding.negotiate(SUPPORTED_ENCODINGS.iter()) {
131 None => {
132 let mut res = HttpResponse::with_body(
133 StatusCode::NOT_ACCEPTABLE,
134 SUPPORTED_ENCODINGS_STRING.as_str(),
135 );
136
137 res.headers_mut()
138 .insert(header::VARY, HeaderValue::from_static("Accept-Encoding"));
139
140 Either::right(ok(req
141 .into_response(res)
142 .map_into_boxed_body()
143 .map_into_right_body()))
144 }
145
146 Some(encoding) => Either::left(CompressResponse {
147 fut: self.service.call(req),
148 encoding,
149 _phantom: PhantomData,
150 }),
151 }
152 }
153}
154
155pin_project! {
156 pub struct CompressResponse<S, B>
157 where
158 S: Service<ServiceRequest>,
159 {
160 #[pin]
161 fut: S::Future,
162 encoding: Encoding,
163 _phantom: PhantomData<B>,
164 }
165}
166
167impl<S, B> Future for CompressResponse<S, B>
168where
169 B: MessageBody,
170 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
171{
172 type Output = Result<ServiceResponse<EitherBody<Encoder<B>>>, Error>;
173
174 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
175 let this = self.as_mut().project();
176
177 match ready!(this.fut.poll(cx)) {
178 Ok(resp) => {
179 let enc = match this.encoding {
180 Encoding::Known(enc) => *enc,
181 Encoding::Unknown(enc) => {
182 unimplemented!("encoding '{enc}' should not be here");
183 }
184 };
185
186 Poll::Ready(Ok(resp.map_body(move |head, body| {
187 let content_type = head.headers.get(header::CONTENT_TYPE);
188
189 fn default_compress_predicate(content_type: Option<&HeaderValue>) -> bool {
190 match content_type {
191 None => true,
192 Some(hdr) => {
193 match hdr.to_str().ok().and_then(|hdr| hdr.parse::<Mime>().ok()) {
194 Some(mime) if mime.type_() == mime::IMAGE => {
195 matches!(mime.subtype(), mime::SVG)
196 }
197 Some(mime) if mime.type_() == mime::VIDEO => false,
198 _ => true,
199 }
200 }
201 }
202 }
203
204 let enc = if default_compress_predicate(content_type) {
205 enc
206 } else {
207 ContentEncoding::Identity
208 };
209
210 EitherBody::left(Encoder::response(enc, head, body))
211 })))
212 }
213
214 Err(err) => Poll::Ready(Err(err)),
215 }
216 }
217}
218
219static SUPPORTED_ENCODINGS_STRING: Lazy<String> = Lazy::new(|| {
220 #[allow(unused_mut)] let mut encoding: Vec<&str> = vec![];
222
223 #[cfg(feature = "compress-brotli")]
224 {
225 encoding.push("br");
226 }
227
228 #[cfg(feature = "compress-gzip")]
229 {
230 encoding.push("gzip");
231 encoding.push("deflate");
232 }
233
234 #[cfg(feature = "compress-zstd")]
235 {
236 encoding.push("zstd");
237 }
238
239 assert!(
240 !encoding.is_empty(),
241 "encoding can not be empty unless __compress feature has been explicitly enabled by itself"
242 );
243
244 encoding.join(", ")
245});
246
247static SUPPORTED_ENCODINGS: &[Encoding] = &[
248 Encoding::identity(),
249 #[cfg(feature = "compress-brotli")]
250 {
251 Encoding::brotli()
252 },
253 #[cfg(feature = "compress-gzip")]
254 {
255 Encoding::gzip()
256 },
257 #[cfg(feature = "compress-gzip")]
258 {
259 Encoding::deflate()
260 },
261 #[cfg(feature = "compress-zstd")]
262 {
263 Encoding::zstd()
264 },
265];
266
267#[cfg(feature = "compress-gzip")]
269#[cfg(test)]
270mod tests {
271 use std::collections::HashSet;
272
273 use static_assertions::assert_impl_all;
274
275 use super::*;
276 use crate::{http::header::ContentType, middleware::DefaultHeaders, test, web, App};
277
278 const HTML_DATA_PART: &str = "<html><h1>hello world</h1></html";
279 const HTML_DATA: &str = const_str::repeat!(HTML_DATA_PART, 100);
280
281 const TEXT_DATA_PART: &str = "hello world ";
282 const TEXT_DATA: &str = const_str::repeat!(TEXT_DATA_PART, 100);
283
284 assert_impl_all!(Compress: Send, Sync);
285
286 pub fn gzip_decode(bytes: impl AsRef<[u8]>) -> Vec<u8> {
287 use std::io::Read as _;
288 let mut decoder = flate2::read::GzDecoder::new(bytes.as_ref());
289 let mut buf = Vec::new();
290 decoder.read_to_end(&mut buf).unwrap();
291 buf
292 }
293
294 #[track_caller]
295 fn assert_successful_res_with_content_type<B>(res: &ServiceResponse<B>, ct: &str) {
296 assert!(res.status().is_success());
297 assert!(
298 res.headers()
299 .get(header::CONTENT_TYPE)
300 .expect("content-type header should be present")
301 .to_str()
302 .expect("content-type header should be utf-8")
303 .contains(ct),
304 "response's content-type did not match {}",
305 ct
306 );
307 }
308
309 #[track_caller]
310 fn assert_successful_gzip_res_with_content_type<B>(res: &ServiceResponse<B>, ct: &str) {
311 assert_successful_res_with_content_type(res, ct);
312 assert_eq!(
313 res.headers()
314 .get(header::CONTENT_ENCODING)
315 .expect("response should be gzip compressed"),
316 "gzip",
317 );
318 }
319
320 #[track_caller]
321 fn assert_successful_identity_res_with_content_type<B>(res: &ServiceResponse<B>, ct: &str) {
322 assert_successful_res_with_content_type(res, ct);
323 assert!(
324 res.headers().get(header::CONTENT_ENCODING).is_none(),
325 "response should not be compressed",
326 );
327 }
328
329 #[actix_rt::test]
330 async fn prevents_double_compressing() {
331 let app = test::init_service({
332 App::new()
333 .wrap(Compress::default())
334 .route(
335 "/single",
336 web::get().to(move || HttpResponse::Ok().body(TEXT_DATA)),
337 )
338 .service(
339 web::resource("/double")
340 .wrap(Compress::default())
341 .wrap(DefaultHeaders::new().add(("x-double", "true")))
342 .route(web::get().to(move || HttpResponse::Ok().body(TEXT_DATA))),
343 )
344 })
345 .await;
346
347 let req = test::TestRequest::default()
348 .uri("/single")
349 .insert_header((header::ACCEPT_ENCODING, "gzip"))
350 .to_request();
351 let res = test::call_service(&app, req).await;
352 assert_eq!(res.status(), StatusCode::OK);
353 assert_eq!(res.headers().get("x-double"), None);
354 assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip");
355 let bytes = test::read_body(res).await;
356 assert_eq!(gzip_decode(bytes), TEXT_DATA.as_bytes());
357
358 let req = test::TestRequest::default()
359 .uri("/double")
360 .insert_header((header::ACCEPT_ENCODING, "gzip"))
361 .to_request();
362 let res = test::call_service(&app, req).await;
363 assert_eq!(res.status(), StatusCode::OK);
364 assert_eq!(res.headers().get("x-double").unwrap(), "true");
365 assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip");
366 let bytes = test::read_body(res).await;
367 assert_eq!(gzip_decode(bytes), TEXT_DATA.as_bytes());
368 }
369
370 #[actix_rt::test]
371 async fn retains_previously_set_vary_header() {
372 let app = test::init_service({
373 App::new()
374 .wrap(Compress::default())
375 .default_service(web::to(move || {
376 HttpResponse::Ok()
377 .insert_header((header::VARY, "x-test"))
378 .body(TEXT_DATA)
379 }))
380 })
381 .await;
382
383 let req = test::TestRequest::default()
384 .insert_header((header::ACCEPT_ENCODING, "gzip"))
385 .to_request();
386 let res = test::call_service(&app, req).await;
387 assert_eq!(res.status(), StatusCode::OK);
388 #[allow(clippy::mutable_key_type)]
389 let vary_headers = res.headers().get_all(header::VARY).collect::<HashSet<_>>();
390 assert!(vary_headers.contains(&HeaderValue::from_static("x-test")));
391 assert!(vary_headers.contains(&HeaderValue::from_static("accept-encoding")));
392 }
393
394 fn configure_predicate_test(cfg: &mut web::ServiceConfig) {
395 cfg.route(
396 "/html",
397 web::get().to(|| {
398 HttpResponse::Ok()
399 .content_type(ContentType::html())
400 .body(HTML_DATA)
401 }),
402 )
403 .route(
404 "/image",
405 web::get().to(|| {
406 HttpResponse::Ok()
407 .content_type(ContentType::jpeg())
408 .body(TEXT_DATA)
409 }),
410 );
411 }
412
413 #[actix_rt::test]
414 async fn prevents_compression_jpeg() {
415 let app = test::init_service(
416 App::new()
417 .wrap(Compress::default())
418 .configure(configure_predicate_test),
419 )
420 .await;
421
422 let req =
423 test::TestRequest::with_uri("/html").insert_header((header::ACCEPT_ENCODING, "gzip"));
424 let res = test::call_service(&app, req.to_request()).await;
425 assert_successful_gzip_res_with_content_type(&res, "text/html");
426 assert_ne!(test::read_body(res).await, HTML_DATA.as_bytes());
427
428 let req =
429 test::TestRequest::with_uri("/image").insert_header((header::ACCEPT_ENCODING, "gzip"));
430 let res = test::call_service(&app, req.to_request()).await;
431 assert_successful_identity_res_with_content_type(&res, "image/jpeg");
432 assert_eq!(test::read_body(res).await, TEXT_DATA.as_bytes());
433 }
434
435 #[actix_rt::test]
436 async fn prevents_compression_empty() {
437 let app = test::init_service({
438 App::new()
439 .wrap(Compress::default())
440 .default_service(web::to(move || HttpResponse::Ok().finish()))
441 })
442 .await;
443
444 let req = test::TestRequest::default()
445 .insert_header((header::ACCEPT_ENCODING, "gzip"))
446 .to_request();
447 let res = test::call_service(&app, req).await;
448 assert_eq!(res.status(), StatusCode::OK);
449 assert!(!res.headers().contains_key(header::CONTENT_ENCODING));
450 assert!(test::read_body(res).await.is_empty());
451 }
452}
453
454#[cfg(feature = "compress-brotli")]
455#[cfg(test)]
456mod tests_brotli {
457 use super::*;
458 use crate::{test, web, App};
459
460 #[actix_rt::test]
461 async fn prevents_compression_empty() {
462 let app = test::init_service({
463 App::new()
464 .wrap(Compress::default())
465 .default_service(web::to(move || HttpResponse::Ok().finish()))
466 })
467 .await;
468
469 let req = test::TestRequest::default()
470 .insert_header((header::ACCEPT_ENCODING, "br"))
471 .to_request();
472 let res = test::call_service(&app, req).await;
473 assert_eq!(res.status(), StatusCode::OK);
474 assert!(!res.headers().contains_key(header::CONTENT_ENCODING));
475 assert!(test::read_body(res).await.is_empty());
476 }
477}