tower_http/follow_redirect/
mod.rs

1//! Middleware for following redirections.
2//!
3//! # Overview
4//!
5//! The [`FollowRedirect`] middleware retries requests with the inner [`Service`] to follow HTTP
6//! redirections.
7//!
8//! The middleware tries to clone the original [`Request`] when making a redirected request.
9//! However, since [`Extensions`][http::Extensions] are `!Clone`, any extensions set by outer
10//! middleware will be discarded. Also, the request body cannot always be cloned. When the
11//! original body is known to be empty by [`Body::size_hint`], the middleware uses `Default`
12//! implementation of the body type to create a new request body. If you know that the body can be
13//! cloned in some way, you can tell the middleware to clone it by configuring a [`policy`].
14//!
15//! # Examples
16//!
17//! ## Basic usage
18//!
19//! ```
20//! use http::{Request, Response};
21//! use bytes::Bytes;
22//! use http_body_util::Full;
23//! use tower::{Service, ServiceBuilder, ServiceExt};
24//! use tower_http::follow_redirect::{FollowRedirectLayer, RequestUri};
25//!
26//! # #[tokio::main]
27//! # async fn main() -> Result<(), std::convert::Infallible> {
28//! # let http_client = tower::service_fn(|req: Request<_>| async move {
29//! #     let dest = "https://www.rust-lang.org/";
30//! #     let mut res = http::Response::builder();
31//! #     if req.uri() != dest {
32//! #         res = res
33//! #             .status(http::StatusCode::MOVED_PERMANENTLY)
34//! #             .header(http::header::LOCATION, dest);
35//! #     }
36//! #     Ok::<_, std::convert::Infallible>(res.body(Full::<Bytes>::default()).unwrap())
37//! # });
38//! let mut client = ServiceBuilder::new()
39//!     .layer(FollowRedirectLayer::new())
40//!     .service(http_client);
41//!
42//! let request = Request::builder()
43//!     .uri("https://rust-lang.org/")
44//!     .body(Full::<Bytes>::default())
45//!     .unwrap();
46//!
47//! let response = client.ready().await?.call(request).await?;
48//! // Get the final request URI.
49//! assert_eq!(response.extensions().get::<RequestUri>().unwrap().0, "https://www.rust-lang.org/");
50//! # Ok(())
51//! # }
52//! ```
53//!
54//! ## Customizing the `Policy`
55//!
56//! You can use a [`Policy`] value to customize how the middleware handles redirections.
57//!
58//! ```
59//! use http::{Request, Response};
60//! use http_body_util::Full;
61//! use bytes::Bytes;
62//! use tower::{Service, ServiceBuilder, ServiceExt};
63//! use tower_http::follow_redirect::{
64//!     policy::{self, PolicyExt},
65//!     FollowRedirectLayer,
66//! };
67//!
68//! #[derive(Debug)]
69//! enum MyError {
70//!     TooManyRedirects,
71//!     Other(tower::BoxError),
72//! }
73//!
74//! # #[tokio::main]
75//! # async fn main() -> Result<(), MyError> {
76//! # let http_client =
77//! #     tower::service_fn(|_: Request<Full<Bytes>>| async { Ok(Response::new(Full::<Bytes>::default())) });
78//! let policy = policy::Limited::new(10) // Set the maximum number of redirections to 10.
79//!     // Return an error when the limit was reached.
80//!     .or::<_, (), _>(policy::redirect_fn(|_| Err(MyError::TooManyRedirects)))
81//!     // Do not follow cross-origin redirections, and return the redirection responses as-is.
82//!     .and::<_, (), _>(policy::SameOrigin::new());
83//!
84//! let mut client = ServiceBuilder::new()
85//!     .layer(FollowRedirectLayer::with_policy(policy))
86//!     .map_err(MyError::Other)
87//!     .service(http_client);
88//!
89//! // ...
90//! # let _ = client.ready().await?.call(Request::default()).await?;
91//! # Ok(())
92//! # }
93//! ```
94
95pub mod policy;
96
97use self::policy::{Action, Attempt, Policy, Standard};
98use futures_util::future::Either;
99use http::{
100    header::CONTENT_ENCODING, header::CONTENT_LENGTH, header::CONTENT_TYPE, header::LOCATION,
101    header::TRANSFER_ENCODING, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri,
102    Version,
103};
104use http_body::Body;
105use iri_string::types::{UriAbsoluteString, UriReferenceStr};
106use pin_project_lite::pin_project;
107use std::{
108    convert::TryFrom,
109    future::Future,
110    mem,
111    pin::Pin,
112    str,
113    task::{ready, Context, Poll},
114};
115use tower::util::Oneshot;
116use tower_layer::Layer;
117use tower_service::Service;
118
119/// [`Layer`] for retrying requests with a [`Service`] to follow redirection responses.
120///
121/// See the [module docs](self) for more details.
122#[derive(Clone, Copy, Debug, Default)]
123pub struct FollowRedirectLayer<P = Standard> {
124    policy: P,
125}
126
127impl FollowRedirectLayer {
128    /// Create a new [`FollowRedirectLayer`] with a [`Standard`] redirection policy.
129    pub fn new() -> Self {
130        Self::default()
131    }
132}
133
134impl<P> FollowRedirectLayer<P> {
135    /// Create a new [`FollowRedirectLayer`] with the given redirection [`Policy`].
136    pub fn with_policy(policy: P) -> Self {
137        FollowRedirectLayer { policy }
138    }
139}
140
141impl<S, P> Layer<S> for FollowRedirectLayer<P>
142where
143    S: Clone,
144    P: Clone,
145{
146    type Service = FollowRedirect<S, P>;
147
148    fn layer(&self, inner: S) -> Self::Service {
149        FollowRedirect::with_policy(inner, self.policy.clone())
150    }
151}
152
153/// Middleware that retries requests with a [`Service`] to follow redirection responses.
154///
155/// See the [module docs](self) for more details.
156#[derive(Clone, Copy, Debug)]
157pub struct FollowRedirect<S, P = Standard> {
158    inner: S,
159    policy: P,
160}
161
162impl<S> FollowRedirect<S> {
163    /// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy.
164    pub fn new(inner: S) -> Self {
165        Self::with_policy(inner, Standard::default())
166    }
167
168    /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware.
169    ///
170    /// [`Layer`]: tower_layer::Layer
171    pub fn layer() -> FollowRedirectLayer {
172        FollowRedirectLayer::new()
173    }
174}
175
176impl<S, P> FollowRedirect<S, P>
177where
178    P: Clone,
179{
180    /// Create a new [`FollowRedirect`] with the given redirection [`Policy`].
181    pub fn with_policy(inner: S, policy: P) -> Self {
182        FollowRedirect { inner, policy }
183    }
184
185    /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware
186    /// with the given redirection [`Policy`].
187    ///
188    /// [`Layer`]: tower_layer::Layer
189    pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
190        FollowRedirectLayer::with_policy(policy)
191    }
192
193    define_inner_service_accessors!();
194}
195
196impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
197where
198    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
199    ReqBody: Body + Default,
200    P: Policy<ReqBody, S::Error> + Clone,
201{
202    type Response = Response<ResBody>;
203    type Error = S::Error;
204    type Future = ResponseFuture<S, ReqBody, P>;
205
206    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207        self.inner.poll_ready(cx)
208    }
209
210    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
211        let service = self.inner.clone();
212        let mut service = mem::replace(&mut self.inner, service);
213        let mut policy = self.policy.clone();
214        let mut body = BodyRepr::None;
215        body.try_clone_from(req.body(), &policy);
216        policy.on_request(&mut req);
217        ResponseFuture {
218            method: req.method().clone(),
219            uri: req.uri().clone(),
220            version: req.version(),
221            headers: req.headers().clone(),
222            body,
223            future: Either::Left(service.call(req)),
224            service,
225            policy,
226        }
227    }
228}
229
230pin_project! {
231    /// Response future for [`FollowRedirect`].
232    #[derive(Debug)]
233    pub struct ResponseFuture<S, B, P>
234    where
235        S: Service<Request<B>>,
236    {
237        #[pin]
238        future: Either<S::Future, Oneshot<S, Request<B>>>,
239        service: S,
240        policy: P,
241        method: Method,
242        uri: Uri,
243        version: Version,
244        headers: HeaderMap<HeaderValue>,
245        body: BodyRepr<B>,
246    }
247}
248
249impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
250where
251    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
252    ReqBody: Body + Default,
253    P: Policy<ReqBody, S::Error>,
254{
255    type Output = Result<Response<ResBody>, S::Error>;
256
257    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
258        let mut this = self.project();
259        let mut res = ready!(this.future.as_mut().poll(cx)?);
260        res.extensions_mut().insert(RequestUri(this.uri.clone()));
261
262        let drop_payload_headers = |headers: &mut HeaderMap| {
263            for header in &[
264                CONTENT_TYPE,
265                CONTENT_LENGTH,
266                CONTENT_ENCODING,
267                TRANSFER_ENCODING,
268            ] {
269                headers.remove(header);
270            }
271        };
272        match res.status() {
273            StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
274                // User agents MAY change the request method from POST to GET
275                // (RFC 7231 section 6.4.2. and 6.4.3.).
276                if *this.method == Method::POST {
277                    *this.method = Method::GET;
278                    *this.body = BodyRepr::Empty;
279                    drop_payload_headers(this.headers);
280                }
281            }
282            StatusCode::SEE_OTHER => {
283                // A user agent can perform a GET or HEAD request (RFC 7231 section 6.4.4.).
284                if *this.method != Method::HEAD {
285                    *this.method = Method::GET;
286                }
287                *this.body = BodyRepr::Empty;
288                drop_payload_headers(this.headers);
289            }
290            StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
291            _ => return Poll::Ready(Ok(res)),
292        };
293
294        let body = if let Some(body) = this.body.take() {
295            body
296        } else {
297            return Poll::Ready(Ok(res));
298        };
299
300        let location = res
301            .headers()
302            .get(&LOCATION)
303            .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri));
304        let location = if let Some(loc) = location {
305            loc
306        } else {
307            return Poll::Ready(Ok(res));
308        };
309
310        let attempt = Attempt {
311            status: res.status(),
312            location: &location,
313            previous: this.uri,
314        };
315        match this.policy.redirect(&attempt)? {
316            Action::Follow => {
317                *this.uri = location;
318                this.body.try_clone_from(&body, &this.policy);
319
320                let mut req = Request::new(body);
321                *req.uri_mut() = this.uri.clone();
322                *req.method_mut() = this.method.clone();
323                *req.version_mut() = *this.version;
324                *req.headers_mut() = this.headers.clone();
325                this.policy.on_request(&mut req);
326                this.future
327                    .set(Either::Right(Oneshot::new(this.service.clone(), req)));
328
329                cx.waker().wake_by_ref();
330                Poll::Pending
331            }
332            Action::Stop => Poll::Ready(Ok(res)),
333        }
334    }
335}
336
337/// Response [`Extensions`][http::Extensions] value that represents the effective request URI of
338/// a response returned by a [`FollowRedirect`] middleware.
339///
340/// The value differs from the original request's effective URI if the middleware has followed
341/// redirections.
342#[derive(Clone)]
343pub struct RequestUri(pub Uri);
344
345#[derive(Debug)]
346enum BodyRepr<B> {
347    Some(B),
348    Empty,
349    None,
350}
351
352impl<B> BodyRepr<B>
353where
354    B: Body + Default,
355{
356    fn take(&mut self) -> Option<B> {
357        match mem::replace(self, BodyRepr::None) {
358            BodyRepr::Some(body) => Some(body),
359            BodyRepr::Empty => {
360                *self = BodyRepr::Empty;
361                Some(B::default())
362            }
363            BodyRepr::None => None,
364        }
365    }
366
367    fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
368    where
369        P: Policy<B, E>,
370    {
371        match self {
372            BodyRepr::Some(_) | BodyRepr::Empty => {}
373            BodyRepr::None => {
374                if let Some(body) = clone_body(policy, body) {
375                    *self = BodyRepr::Some(body);
376                }
377            }
378        }
379    }
380}
381
382fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
383where
384    P: Policy<B, E>,
385    B: Body + Default,
386{
387    if body.size_hint().exact() == Some(0) {
388        Some(B::default())
389    } else {
390        policy.clone_body(body)
391    }
392}
393
394/// Try to resolve a URI reference `relative` against a base URI `base`.
395fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
396    let relative = UriReferenceStr::new(relative).ok()?;
397    let base = UriAbsoluteString::try_from(base.to_string()).ok()?;
398    let uri = relative.resolve_against(&base).to_string();
399    Uri::try_from(uri).ok()
400}
401
402#[cfg(test)]
403mod tests {
404    use super::{policy::*, *};
405    use crate::test_helpers::Body;
406    use http::header::LOCATION;
407    use std::convert::Infallible;
408    use tower::{ServiceBuilder, ServiceExt};
409
410    #[tokio::test]
411    async fn follows() {
412        let svc = ServiceBuilder::new()
413            .layer(FollowRedirectLayer::with_policy(Action::Follow))
414            .buffer(1)
415            .service_fn(handle);
416        let req = Request::builder()
417            .uri("http://example.com/42")
418            .body(Body::empty())
419            .unwrap();
420        let res = svc.oneshot(req).await.unwrap();
421        assert_eq!(*res.body(), 0);
422        assert_eq!(
423            res.extensions().get::<RequestUri>().unwrap().0,
424            "http://example.com/0"
425        );
426    }
427
428    #[tokio::test]
429    async fn stops() {
430        let svc = ServiceBuilder::new()
431            .layer(FollowRedirectLayer::with_policy(Action::Stop))
432            .buffer(1)
433            .service_fn(handle);
434        let req = Request::builder()
435            .uri("http://example.com/42")
436            .body(Body::empty())
437            .unwrap();
438        let res = svc.oneshot(req).await.unwrap();
439        assert_eq!(*res.body(), 42);
440        assert_eq!(
441            res.extensions().get::<RequestUri>().unwrap().0,
442            "http://example.com/42"
443        );
444    }
445
446    #[tokio::test]
447    async fn limited() {
448        let svc = ServiceBuilder::new()
449            .layer(FollowRedirectLayer::with_policy(Limited::new(10)))
450            .buffer(1)
451            .service_fn(handle);
452        let req = Request::builder()
453            .uri("http://example.com/42")
454            .body(Body::empty())
455            .unwrap();
456        let res = svc.oneshot(req).await.unwrap();
457        assert_eq!(*res.body(), 42 - 10);
458        assert_eq!(
459            res.extensions().get::<RequestUri>().unwrap().0,
460            "http://example.com/32"
461        );
462    }
463
464    /// A server with an endpoint `GET /{n}` which redirects to `/{n-1}` unless `n` equals zero,
465    /// returning `n` as the response body.
466    async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
467        let n: u64 = req.uri().path()[1..].parse().unwrap();
468        let mut res = Response::builder();
469        if n > 0 {
470            res = res
471                .status(StatusCode::MOVED_PERMANENTLY)
472                .header(LOCATION, format!("/{}", n - 1));
473        }
474        Ok::<_, Infallible>(res.body(n).unwrap())
475    }
476}