tower_http/follow_redirect/policy/
mod.rs

1//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware.
2
3mod and;
4mod clone_body_fn;
5mod filter_credentials;
6mod limited;
7mod or;
8mod redirect_fn;
9mod same_origin;
10
11pub use self::{
12    and::And,
13    clone_body_fn::{clone_body_fn, CloneBodyFn},
14    filter_credentials::FilterCredentials,
15    limited::Limited,
16    or::Or,
17    redirect_fn::{redirect_fn, RedirectFn},
18    same_origin::SameOrigin,
19};
20
21use http::{uri::Scheme, Request, StatusCode, Uri};
22
23/// Trait for the policy on handling redirection responses.
24///
25/// # Example
26///
27/// Detecting a cyclic redirection:
28///
29/// ```
30/// use http::{Request, Uri};
31/// use std::collections::HashSet;
32/// use tower_http::follow_redirect::policy::{Action, Attempt, Policy};
33///
34/// #[derive(Clone)]
35/// pub struct DetectCycle {
36///     uris: HashSet<Uri>,
37/// }
38///
39/// impl<B, E> Policy<B, E> for DetectCycle {
40///     fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
41///         if self.uris.contains(attempt.location()) {
42///             Ok(Action::Stop)
43///         } else {
44///             self.uris.insert(attempt.previous().clone());
45///             Ok(Action::Follow)
46///         }
47///     }
48/// }
49/// ```
50pub trait Policy<B, E> {
51    /// Invoked when the service received a response with a redirection status code (`3xx`).
52    ///
53    /// This method returns an [`Action`] which indicates whether the service should follow
54    /// the redirection.
55    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E>;
56
57    /// Invoked right before the service makes a request, regardless of whether it is redirected
58    /// or not.
59    ///
60    /// This can for example be used to remove sensitive headers from the request
61    /// or prepare the request in other ways.
62    ///
63    /// The default implementation does nothing.
64    fn on_request(&mut self, _request: &mut Request<B>) {}
65
66    /// Try to clone a request body before the service makes a redirected request.
67    ///
68    /// If the request body cannot be cloned, return `None`.
69    ///
70    /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero,
71    /// in which case `B::default()` will be used to create a new request body.
72    ///
73    /// The default implementation returns `None`.
74    fn clone_body(&self, _body: &B) -> Option<B> {
75        None
76    }
77}
78
79impl<B, E, P> Policy<B, E> for &mut P
80where
81    P: Policy<B, E> + ?Sized,
82{
83    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
84        (**self).redirect(attempt)
85    }
86
87    fn on_request(&mut self, request: &mut Request<B>) {
88        (**self).on_request(request)
89    }
90
91    fn clone_body(&self, body: &B) -> Option<B> {
92        (**self).clone_body(body)
93    }
94}
95
96impl<B, E, P> Policy<B, E> for Box<P>
97where
98    P: Policy<B, E> + ?Sized,
99{
100    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
101        (**self).redirect(attempt)
102    }
103
104    fn on_request(&mut self, request: &mut Request<B>) {
105        (**self).on_request(request)
106    }
107
108    fn clone_body(&self, body: &B) -> Option<B> {
109        (**self).clone_body(body)
110    }
111}
112
113/// An extension trait for `Policy` that provides additional adapters.
114pub trait PolicyExt {
115    /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return
116    /// `Action::Follow`.
117    ///
118    /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
119    /// with both policies.
120    ///
121    /// # Example
122    ///
123    /// ```
124    /// use bytes::Bytes;
125    /// use http_body_util::Full;
126    /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt};
127    ///
128    /// enum MyBody {
129    ///     Bytes(Bytes),
130    ///     Full(Full<Bytes>),
131    /// }
132    ///
133    /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| {
134    ///     if let MyBody::Bytes(buf) = body {
135    ///         Some(MyBody::Bytes(buf.clone()))
136    ///     } else {
137    ///         None
138    ///     }
139    /// }));
140    /// ```
141    fn and<P, B, E>(self, other: P) -> And<Self, P>
142    where
143        Self: Policy<B, E> + Sized,
144        P: Policy<B, E>;
145
146    /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns
147    /// `Action::Follow`.
148    ///
149    /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
150    /// with both policies.
151    ///
152    /// # Example
153    ///
154    /// ```
155    /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt};
156    ///
157    /// #[derive(Clone)]
158    /// enum MyError {
159    ///     TooManyRedirects,
160    ///     // ...
161    /// }
162    ///
163    /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects));
164    /// ```
165    fn or<P, B, E>(self, other: P) -> Or<Self, P>
166    where
167        Self: Policy<B, E> + Sized,
168        P: Policy<B, E>;
169}
170
171impl<T> PolicyExt for T
172where
173    T: ?Sized,
174{
175    fn and<P, B, E>(self, other: P) -> And<Self, P>
176    where
177        Self: Policy<B, E> + Sized,
178        P: Policy<B, E>,
179    {
180        And::new(self, other)
181    }
182
183    fn or<P, B, E>(self, other: P) -> Or<Self, P>
184    where
185        Self: Policy<B, E> + Sized,
186        P: Policy<B, E>,
187    {
188        Or::new(self, other)
189    }
190}
191
192/// A redirection [`Policy`] with a reasonable set of standard behavior.
193///
194/// This policy limits the number of successive redirections ([`Limited`])
195/// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]).
196pub type Standard = And<Limited, FilterCredentials>;
197
198/// A type that holds information on a redirection attempt.
199pub struct Attempt<'a> {
200    pub(crate) status: StatusCode,
201    pub(crate) location: &'a Uri,
202    pub(crate) previous: &'a Uri,
203}
204
205impl<'a> Attempt<'a> {
206    /// Returns the redirection response.
207    pub fn status(&self) -> StatusCode {
208        self.status
209    }
210
211    /// Returns the destination URI of the redirection.
212    pub fn location(&self) -> &'a Uri {
213        self.location
214    }
215
216    /// Returns the URI of the original request.
217    pub fn previous(&self) -> &'a Uri {
218        self.previous
219    }
220}
221
222/// A value returned by [`Policy::redirect`] which indicates the action
223/// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response.
224#[derive(Clone, Copy, Debug)]
225pub enum Action {
226    /// Follow the redirection.
227    Follow,
228    /// Do not follow the redirection, and return the redirection response as-is.
229    Stop,
230}
231
232impl Action {
233    /// Returns `true` if the `Action` is a `Follow` value.
234    pub fn is_follow(&self) -> bool {
235        if let Action::Follow = self {
236            true
237        } else {
238            false
239        }
240    }
241
242    /// Returns `true` if the `Action` is a `Stop` value.
243    pub fn is_stop(&self) -> bool {
244        if let Action::Stop = self {
245            true
246        } else {
247            false
248        }
249    }
250}
251
252impl<B, E> Policy<B, E> for Action {
253    fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
254        Ok(*self)
255    }
256}
257
258impl<B, E> Policy<B, E> for Result<Action, E>
259where
260    E: Clone,
261{
262    fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
263        self.clone()
264    }
265}
266
267/// Compares the origins of two URIs as per RFC 6454 sections 4. through 5.
268fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool {
269    let default_port = match (lhs.scheme(), rhs.scheme()) {
270        (Some(l), Some(r)) if l == r => {
271            if l == &Scheme::HTTP {
272                80
273            } else if l == &Scheme::HTTPS {
274                443
275            } else {
276                return false;
277            }
278        }
279        _ => return false,
280    };
281    match (lhs.host(), rhs.host()) {
282        (Some(l), Some(r)) if l == r => {}
283        _ => return false,
284    }
285    lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port)
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn eq_origin_works() {
294        assert!(eq_origin(
295            &Uri::from_static("https://example.com/1"),
296            &Uri::from_static("https://example.com/2")
297        ));
298        assert!(eq_origin(
299            &Uri::from_static("https://example.com:443/"),
300            &Uri::from_static("https://example.com/")
301        ));
302        assert!(eq_origin(
303            &Uri::from_static("https://example.com/"),
304            &Uri::from_static("https://user@example.com/")
305        ));
306
307        assert!(!eq_origin(
308            &Uri::from_static("https://example.com/"),
309            &Uri::from_static("https://www.example.com/")
310        ));
311        assert!(!eq_origin(
312            &Uri::from_static("https://example.com/"),
313            &Uri::from_static("http://example.com/")
314        ));
315    }
316}