1use std::fmt;
8use std::{error::Error as StdError, sync::Arc};
9
10use crate::header::{AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, REFERER, WWW_AUTHENTICATE};
11use http::{HeaderMap, HeaderValue};
12use hyper::StatusCode;
13
14use crate::{async_impl, Url};
15use tower_http::follow_redirect::policy::{
16 Action as TowerAction, Attempt as TowerAttempt, Policy as TowerPolicy,
17};
18
19pub struct Policy {
29 inner: PolicyKind,
30}
31
32#[derive(Debug)]
35pub struct Attempt<'a> {
36 status: StatusCode,
37 next: &'a Url,
38 previous: &'a [Url],
39}
40
41#[derive(Debug)]
43pub struct Action {
44 inner: ActionKind,
45}
46
47impl Policy {
48 pub fn limited(max: usize) -> Self {
52 Self {
53 inner: PolicyKind::Limit(max),
54 }
55 }
56
57 pub fn none() -> Self {
59 Self {
60 inner: PolicyKind::None,
61 }
62 }
63
64 pub fn custom<T>(policy: T) -> Self
103 where
104 T: Fn(Attempt) -> Action + Send + Sync + 'static,
105 {
106 Self {
107 inner: PolicyKind::Custom(Box::new(policy)),
108 }
109 }
110
111 pub fn redirect(&self, attempt: Attempt) -> Action {
132 match self.inner {
133 PolicyKind::Custom(ref custom) => custom(attempt),
134 PolicyKind::Limit(max) => {
135 if attempt.previous.len() > max {
137 attempt.error(TooManyRedirects)
138 } else {
139 attempt.follow()
140 }
141 }
142 PolicyKind::None => attempt.stop(),
143 }
144 }
145
146 pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
147 self.redirect(Attempt {
148 status,
149 next,
150 previous,
151 })
152 .inner
153 }
154
155 pub(crate) fn is_default(&self) -> bool {
156 matches!(self.inner, PolicyKind::Limit(10))
157 }
158}
159
160impl Default for Policy {
161 fn default() -> Policy {
162 Policy::limited(10)
164 }
165}
166
167impl<'a> Attempt<'a> {
168 pub fn status(&self) -> StatusCode {
170 self.status
171 }
172
173 pub fn url(&self) -> &Url {
175 self.next
176 }
177
178 pub fn previous(&self) -> &[Url] {
180 self.previous
181 }
182 pub fn follow(self) -> Action {
184 Action {
185 inner: ActionKind::Follow,
186 }
187 }
188
189 pub fn stop(self) -> Action {
193 Action {
194 inner: ActionKind::Stop,
195 }
196 }
197
198 pub fn error<E: Into<Box<dyn StdError + Send + Sync>>>(self, error: E) -> Action {
202 Action {
203 inner: ActionKind::Error(error.into()),
204 }
205 }
206}
207
208enum PolicyKind {
209 Custom(Box<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
210 Limit(usize),
211 None,
212}
213
214impl fmt::Debug for Policy {
215 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216 f.debug_tuple("Policy").field(&self.inner).finish()
217 }
218}
219
220impl fmt::Debug for PolicyKind {
221 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
222 match *self {
223 PolicyKind::Custom(..) => f.pad("Custom"),
224 PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
225 PolicyKind::None => f.pad("None"),
226 }
227 }
228}
229
230#[derive(Debug)]
233pub(crate) enum ActionKind {
234 Follow,
235 Stop,
236 Error(Box<dyn StdError + Send + Sync>),
237}
238
239pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
240 if let Some(previous) = previous.last() {
241 let cross_host = next.host_str() != previous.host_str()
242 || next.port_or_known_default() != previous.port_or_known_default();
243 if cross_host {
244 headers.remove(AUTHORIZATION);
245 headers.remove(COOKIE);
246 headers.remove("cookie2");
247 headers.remove(PROXY_AUTHORIZATION);
248 headers.remove(WWW_AUTHENTICATE);
249 }
250 }
251}
252
253#[derive(Debug)]
254struct TooManyRedirects;
255
256impl fmt::Display for TooManyRedirects {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 f.write_str("too many redirects")
259 }
260}
261
262impl StdError for TooManyRedirects {}
263
264#[derive(Clone)]
265pub(crate) struct TowerRedirectPolicy {
266 policy: Arc<Policy>,
267 referer: bool,
268 urls: Vec<Url>,
269 https_only: bool,
270}
271
272impl TowerRedirectPolicy {
273 pub(crate) fn new(policy: Policy) -> Self {
274 Self {
275 policy: Arc::new(policy),
276 referer: false,
277 urls: Vec::new(),
278 https_only: false,
279 }
280 }
281
282 pub(crate) fn with_referer(&mut self, referer: bool) -> &mut Self {
283 self.referer = referer;
284 self
285 }
286
287 pub(crate) fn with_https_only(&mut self, https_only: bool) -> &mut Self {
288 self.https_only = https_only;
289 self
290 }
291}
292
293fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
294 if next.scheme() == "http" && previous.scheme() == "https" {
295 return None;
296 }
297
298 let mut referer = previous.clone();
299 let _ = referer.set_username("");
300 let _ = referer.set_password(None);
301 referer.set_fragment(None);
302 referer.as_str().parse().ok()
303}
304
305impl TowerPolicy<async_impl::body::Body, crate::Error> for TowerRedirectPolicy {
306 fn redirect(&mut self, attempt: &TowerAttempt<'_>) -> Result<TowerAction, crate::Error> {
307 let previous_url =
308 Url::parse(&attempt.previous().to_string()).expect("Previous URL must be valid");
309
310 let next_url = match Url::parse(&attempt.location().to_string()) {
311 Ok(url) => url,
312 Err(e) => return Err(crate::error::builder(e)),
313 };
314
315 self.urls.push(previous_url.clone());
316
317 match self.policy.check(attempt.status(), &next_url, &self.urls) {
318 ActionKind::Follow => {
319 if next_url.scheme() != "http" && next_url.scheme() != "https" {
320 return Err(crate::error::url_bad_scheme(next_url));
321 }
322
323 if self.https_only && next_url.scheme() != "https" {
324 return Err(crate::error::redirect(
325 crate::error::url_bad_scheme(next_url.clone()),
326 next_url,
327 ));
328 }
329 Ok(TowerAction::Follow)
330 }
331 ActionKind::Stop => Ok(TowerAction::Stop),
332 ActionKind::Error(e) => Err(crate::error::redirect(e, previous_url)),
333 }
334 }
335
336 fn on_request(&mut self, req: &mut http::Request<async_impl::body::Body>) {
337 if let Ok(next_url) = Url::parse(&req.uri().to_string()) {
338 remove_sensitive_headers(req.headers_mut(), &next_url, &self.urls);
339 if self.referer {
340 if let Some(previous_url) = self.urls.last() {
341 if let Some(v) = make_referer(&next_url, previous_url) {
342 req.headers_mut().insert(REFERER, v);
343 }
344 }
345 }
346 };
347 }
348
349 fn clone_body(&self, body: &async_impl::body::Body) -> Option<async_impl::body::Body> {
351 body.try_clone()
352 }
353}
354
355#[test]
356fn test_redirect_policy_limit() {
357 let policy = Policy::default();
358 let next = Url::parse("http://x.y/z").unwrap();
359 let mut previous = (0..=9)
360 .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
361 .collect::<Vec<_>>();
362
363 match policy.check(StatusCode::FOUND, &next, &previous) {
364 ActionKind::Follow => (),
365 other => panic!("unexpected {other:?}"),
366 }
367
368 previous.push(Url::parse("http://a.b.d/e/33").unwrap());
369
370 match policy.check(StatusCode::FOUND, &next, &previous) {
371 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
372 other => panic!("unexpected {other:?}"),
373 }
374}
375
376#[test]
377fn test_redirect_policy_limit_to_0() {
378 let policy = Policy::limited(0);
379 let next = Url::parse("http://x.y/z").unwrap();
380 let previous = vec![Url::parse("http://a.b/c").unwrap()];
381
382 match policy.check(StatusCode::FOUND, &next, &previous) {
383 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
384 other => panic!("unexpected {other:?}"),
385 }
386}
387
388#[test]
389fn test_redirect_policy_custom() {
390 let policy = Policy::custom(|attempt| {
391 if attempt.url().host_str() == Some("foo") {
392 attempt.stop()
393 } else {
394 attempt.follow()
395 }
396 });
397
398 let next = Url::parse("http://bar/baz").unwrap();
399 match policy.check(StatusCode::FOUND, &next, &[]) {
400 ActionKind::Follow => (),
401 other => panic!("unexpected {other:?}"),
402 }
403
404 let next = Url::parse("http://foo/baz").unwrap();
405 match policy.check(StatusCode::FOUND, &next, &[]) {
406 ActionKind::Stop => (),
407 other => panic!("unexpected {other:?}"),
408 }
409}
410
411#[test]
412fn test_remove_sensitive_headers() {
413 use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
414
415 let mut headers = HeaderMap::new();
416 headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
417 headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
418 headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
419
420 let next = Url::parse("http://initial-domain.com/path").unwrap();
421 let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
422 let mut filtered_headers = headers.clone();
423
424 remove_sensitive_headers(&mut headers, &next, &prev);
425 assert_eq!(headers, filtered_headers);
426
427 prev.push(Url::parse("http://new-domain.com/path").unwrap());
428 filtered_headers.remove(AUTHORIZATION);
429 filtered_headers.remove(COOKIE);
430
431 remove_sensitive_headers(&mut headers, &next, &prev);
432 assert_eq!(headers, filtered_headers);
433}