tower_http/follow_redirect/policy/mod.rs
1//! Tools for customizing the behavior of a [`FollowRedirect`][super::FollowRedirect] middleware.
2
3mod and;
4mod clone_body_fn;
5mod filter_credentials;
6mod limited;
7mod or;
8mod redirect_fn;
9mod same_origin;
10
11pub use self::{
12 and::And,
13 clone_body_fn::{clone_body_fn, CloneBodyFn},
14 filter_credentials::FilterCredentials,
15 limited::Limited,
16 or::Or,
17 redirect_fn::{redirect_fn, RedirectFn},
18 same_origin::SameOrigin,
19};
20
21use http::{uri::Scheme, Request, StatusCode, Uri};
22
23/// Trait for the policy on handling redirection responses.
24///
25/// # Example
26///
27/// Detecting a cyclic redirection:
28///
29/// ```
30/// use http::{Request, Uri};
31/// use std::collections::HashSet;
32/// use tower_http::follow_redirect::policy::{Action, Attempt, Policy};
33///
34/// #[derive(Clone)]
35/// pub struct DetectCycle {
36/// uris: HashSet<Uri>,
37/// }
38///
39/// impl<B, E> Policy<B, E> for DetectCycle {
40/// fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
41/// if self.uris.contains(attempt.location()) {
42/// Ok(Action::Stop)
43/// } else {
44/// self.uris.insert(attempt.previous().clone());
45/// Ok(Action::Follow)
46/// }
47/// }
48/// }
49/// ```
50pub trait Policy<B, E> {
51 /// Invoked when the service received a response with a redirection status code (`3xx`).
52 ///
53 /// This method returns an [`Action`] which indicates whether the service should follow
54 /// the redirection.
55 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E>;
56
57 /// Invoked right before the service makes a request, regardless of whether it is redirected
58 /// or not.
59 ///
60 /// This can for example be used to remove sensitive headers from the request
61 /// or prepare the request in other ways.
62 ///
63 /// The default implementation does nothing.
64 fn on_request(&mut self, _request: &mut Request<B>) {}
65
66 /// Try to clone a request body before the service makes a redirected request.
67 ///
68 /// If the request body cannot be cloned, return `None`.
69 ///
70 /// This is not invoked when [`B::size_hint`][http_body::Body::size_hint] returns zero,
71 /// in which case `B::default()` will be used to create a new request body.
72 ///
73 /// The default implementation returns `None`.
74 fn clone_body(&self, _body: &B) -> Option<B> {
75 None
76 }
77}
78
79impl<B, E, P> Policy<B, E> for &mut P
80where
81 P: Policy<B, E> + ?Sized,
82{
83 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
84 (**self).redirect(attempt)
85 }
86
87 fn on_request(&mut self, request: &mut Request<B>) {
88 (**self).on_request(request)
89 }
90
91 fn clone_body(&self, body: &B) -> Option<B> {
92 (**self).clone_body(body)
93 }
94}
95
96impl<B, E, P> Policy<B, E> for Box<P>
97where
98 P: Policy<B, E> + ?Sized,
99{
100 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
101 (**self).redirect(attempt)
102 }
103
104 fn on_request(&mut self, request: &mut Request<B>) {
105 (**self).on_request(request)
106 }
107
108 fn clone_body(&self, body: &B) -> Option<B> {
109 (**self).clone_body(body)
110 }
111}
112
113/// An extension trait for `Policy` that provides additional adapters.
114pub trait PolicyExt {
115 /// Create a new `Policy` that returns [`Action::Follow`] only if `self` and `other` return
116 /// `Action::Follow`.
117 ///
118 /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
119 /// with both policies.
120 ///
121 /// # Example
122 ///
123 /// ```
124 /// use bytes::Bytes;
125 /// use http_body_util::Full;
126 /// use tower_http::follow_redirect::policy::{self, clone_body_fn, Limited, PolicyExt};
127 ///
128 /// enum MyBody {
129 /// Bytes(Bytes),
130 /// Full(Full<Bytes>),
131 /// }
132 ///
133 /// let policy = Limited::default().and::<_, _, ()>(clone_body_fn(|body| {
134 /// if let MyBody::Bytes(buf) = body {
135 /// Some(MyBody::Bytes(buf.clone()))
136 /// } else {
137 /// None
138 /// }
139 /// }));
140 /// ```
141 fn and<P, B, E>(self, other: P) -> And<Self, P>
142 where
143 Self: Policy<B, E> + Sized,
144 P: Policy<B, E>;
145
146 /// Create a new `Policy` that returns [`Action::Follow`] if either `self` or `other` returns
147 /// `Action::Follow`.
148 ///
149 /// [`clone_body`][Policy::clone_body] method of the returned `Policy` tries to clone the body
150 /// with both policies.
151 ///
152 /// # Example
153 ///
154 /// ```
155 /// use tower_http::follow_redirect::policy::{self, Action, Limited, PolicyExt};
156 ///
157 /// #[derive(Clone)]
158 /// enum MyError {
159 /// TooManyRedirects,
160 /// // ...
161 /// }
162 ///
163 /// let policy = Limited::default().or::<_, (), _>(Err(MyError::TooManyRedirects));
164 /// ```
165 fn or<P, B, E>(self, other: P) -> Or<Self, P>
166 where
167 Self: Policy<B, E> + Sized,
168 P: Policy<B, E>;
169}
170
171impl<T> PolicyExt for T
172where
173 T: ?Sized,
174{
175 fn and<P, B, E>(self, other: P) -> And<Self, P>
176 where
177 Self: Policy<B, E> + Sized,
178 P: Policy<B, E>,
179 {
180 And::new(self, other)
181 }
182
183 fn or<P, B, E>(self, other: P) -> Or<Self, P>
184 where
185 Self: Policy<B, E> + Sized,
186 P: Policy<B, E>,
187 {
188 Or::new(self, other)
189 }
190}
191
192/// A redirection [`Policy`] with a reasonable set of standard behavior.
193///
194/// This policy limits the number of successive redirections ([`Limited`])
195/// and removes credentials from requests in cross-origin redirections ([`FilterCredentials`]).
196pub type Standard = And<Limited, FilterCredentials>;
197
198/// A type that holds information on a redirection attempt.
199pub struct Attempt<'a> {
200 pub(crate) status: StatusCode,
201 pub(crate) location: &'a Uri,
202 pub(crate) previous: &'a Uri,
203}
204
205impl<'a> Attempt<'a> {
206 /// Returns the redirection response.
207 pub fn status(&self) -> StatusCode {
208 self.status
209 }
210
211 /// Returns the destination URI of the redirection.
212 pub fn location(&self) -> &'a Uri {
213 self.location
214 }
215
216 /// Returns the URI of the original request.
217 pub fn previous(&self) -> &'a Uri {
218 self.previous
219 }
220}
221
222/// A value returned by [`Policy::redirect`] which indicates the action
223/// [`FollowRedirect`][super::FollowRedirect] should take for a redirection response.
224#[derive(Clone, Copy, Debug)]
225pub enum Action {
226 /// Follow the redirection.
227 Follow,
228 /// Do not follow the redirection, and return the redirection response as-is.
229 Stop,
230}
231
232impl Action {
233 /// Returns `true` if the `Action` is a `Follow` value.
234 pub fn is_follow(&self) -> bool {
235 if let Action::Follow = self {
236 true
237 } else {
238 false
239 }
240 }
241
242 /// Returns `true` if the `Action` is a `Stop` value.
243 pub fn is_stop(&self) -> bool {
244 if let Action::Stop = self {
245 true
246 } else {
247 false
248 }
249 }
250}
251
252impl<B, E> Policy<B, E> for Action {
253 fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
254 Ok(*self)
255 }
256}
257
258impl<B, E> Policy<B, E> for Result<Action, E>
259where
260 E: Clone,
261{
262 fn redirect(&mut self, _: &Attempt<'_>) -> Result<Action, E> {
263 self.clone()
264 }
265}
266
267/// Compares the origins of two URIs as per RFC 6454 sections 4. through 5.
268fn eq_origin(lhs: &Uri, rhs: &Uri) -> bool {
269 let default_port = match (lhs.scheme(), rhs.scheme()) {
270 (Some(l), Some(r)) if l == r => {
271 if l == &Scheme::HTTP {
272 80
273 } else if l == &Scheme::HTTPS {
274 443
275 } else {
276 return false;
277 }
278 }
279 _ => return false,
280 };
281 match (lhs.host(), rhs.host()) {
282 (Some(l), Some(r)) if l == r => {}
283 _ => return false,
284 }
285 lhs.port_u16().unwrap_or(default_port) == rhs.port_u16().unwrap_or(default_port)
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn eq_origin_works() {
294 assert!(eq_origin(
295 &Uri::from_static("https://example.com/1"),
296 &Uri::from_static("https://example.com/2")
297 ));
298 assert!(eq_origin(
299 &Uri::from_static("https://example.com:443/"),
300 &Uri::from_static("https://example.com/")
301 ));
302 assert!(eq_origin(
303 &Uri::from_static("https://example.com/"),
304 &Uri::from_static("https://user@example.com/")
305 ));
306
307 assert!(!eq_origin(
308 &Uri::from_static("https://example.com/"),
309 &Uri::from_static("https://www.example.com/")
310 ));
311 assert!(!eq_origin(
312 &Uri::from_static("https://example.com/"),
313 &Uri::from_static("http://example.com/")
314 ));
315 }
316}