tower_http/follow_redirect/
mod.rs1pub mod policy;
96
97use self::policy::{Action, Attempt, Policy, Standard};
98use futures_util::future::Either;
99use http::{
100 header::CONTENT_ENCODING, header::CONTENT_LENGTH, header::CONTENT_TYPE, header::LOCATION,
101 header::TRANSFER_ENCODING, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Uri,
102 Version,
103};
104use http_body::Body;
105use iri_string::types::{UriAbsoluteString, UriReferenceStr};
106use pin_project_lite::pin_project;
107use std::{
108 convert::TryFrom,
109 future::Future,
110 mem,
111 pin::Pin,
112 str,
113 task::{ready, Context, Poll},
114};
115use tower::util::Oneshot;
116use tower_layer::Layer;
117use tower_service::Service;
118
119#[derive(Clone, Copy, Debug, Default)]
123pub struct FollowRedirectLayer<P = Standard> {
124 policy: P,
125}
126
127impl FollowRedirectLayer {
128 pub fn new() -> Self {
130 Self::default()
131 }
132}
133
134impl<P> FollowRedirectLayer<P> {
135 pub fn with_policy(policy: P) -> Self {
137 FollowRedirectLayer { policy }
138 }
139}
140
141impl<S, P> Layer<S> for FollowRedirectLayer<P>
142where
143 S: Clone,
144 P: Clone,
145{
146 type Service = FollowRedirect<S, P>;
147
148 fn layer(&self, inner: S) -> Self::Service {
149 FollowRedirect::with_policy(inner, self.policy.clone())
150 }
151}
152
153#[derive(Clone, Copy, Debug)]
157pub struct FollowRedirect<S, P = Standard> {
158 inner: S,
159 policy: P,
160}
161
162impl<S> FollowRedirect<S> {
163 pub fn new(inner: S) -> Self {
165 Self::with_policy(inner, Standard::default())
166 }
167
168 pub fn layer() -> FollowRedirectLayer {
172 FollowRedirectLayer::new()
173 }
174}
175
176impl<S, P> FollowRedirect<S, P>
177where
178 P: Clone,
179{
180 pub fn with_policy(inner: S, policy: P) -> Self {
182 FollowRedirect { inner, policy }
183 }
184
185 pub fn layer_with_policy(policy: P) -> FollowRedirectLayer<P> {
190 FollowRedirectLayer::with_policy(policy)
191 }
192
193 define_inner_service_accessors!();
194}
195
196impl<ReqBody, ResBody, S, P> Service<Request<ReqBody>> for FollowRedirect<S, P>
197where
198 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
199 ReqBody: Body + Default,
200 P: Policy<ReqBody, S::Error> + Clone,
201{
202 type Response = Response<ResBody>;
203 type Error = S::Error;
204 type Future = ResponseFuture<S, ReqBody, P>;
205
206 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
207 self.inner.poll_ready(cx)
208 }
209
210 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
211 let service = self.inner.clone();
212 let mut service = mem::replace(&mut self.inner, service);
213 let mut policy = self.policy.clone();
214 let mut body = BodyRepr::None;
215 body.try_clone_from(req.body(), &policy);
216 policy.on_request(&mut req);
217 ResponseFuture {
218 method: req.method().clone(),
219 uri: req.uri().clone(),
220 version: req.version(),
221 headers: req.headers().clone(),
222 body,
223 future: Either::Left(service.call(req)),
224 service,
225 policy,
226 }
227 }
228}
229
230pin_project! {
231 #[derive(Debug)]
233 pub struct ResponseFuture<S, B, P>
234 where
235 S: Service<Request<B>>,
236 {
237 #[pin]
238 future: Either<S::Future, Oneshot<S, Request<B>>>,
239 service: S,
240 policy: P,
241 method: Method,
242 uri: Uri,
243 version: Version,
244 headers: HeaderMap<HeaderValue>,
245 body: BodyRepr<B>,
246 }
247}
248
249impl<S, ReqBody, ResBody, P> Future for ResponseFuture<S, ReqBody, P>
250where
251 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
252 ReqBody: Body + Default,
253 P: Policy<ReqBody, S::Error>,
254{
255 type Output = Result<Response<ResBody>, S::Error>;
256
257 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
258 let mut this = self.project();
259 let mut res = ready!(this.future.as_mut().poll(cx)?);
260 res.extensions_mut().insert(RequestUri(this.uri.clone()));
261
262 let drop_payload_headers = |headers: &mut HeaderMap| {
263 for header in &[
264 CONTENT_TYPE,
265 CONTENT_LENGTH,
266 CONTENT_ENCODING,
267 TRANSFER_ENCODING,
268 ] {
269 headers.remove(header);
270 }
271 };
272 match res.status() {
273 StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND => {
274 if *this.method == Method::POST {
277 *this.method = Method::GET;
278 *this.body = BodyRepr::Empty;
279 drop_payload_headers(this.headers);
280 }
281 }
282 StatusCode::SEE_OTHER => {
283 if *this.method != Method::HEAD {
285 *this.method = Method::GET;
286 }
287 *this.body = BodyRepr::Empty;
288 drop_payload_headers(this.headers);
289 }
290 StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {}
291 _ => return Poll::Ready(Ok(res)),
292 };
293
294 let body = if let Some(body) = this.body.take() {
295 body
296 } else {
297 return Poll::Ready(Ok(res));
298 };
299
300 let location = res
301 .headers()
302 .get(&LOCATION)
303 .and_then(|loc| resolve_uri(str::from_utf8(loc.as_bytes()).ok()?, this.uri));
304 let location = if let Some(loc) = location {
305 loc
306 } else {
307 return Poll::Ready(Ok(res));
308 };
309
310 let attempt = Attempt {
311 status: res.status(),
312 location: &location,
313 previous: this.uri,
314 };
315 match this.policy.redirect(&attempt)? {
316 Action::Follow => {
317 *this.uri = location;
318 this.body.try_clone_from(&body, &this.policy);
319
320 let mut req = Request::new(body);
321 *req.uri_mut() = this.uri.clone();
322 *req.method_mut() = this.method.clone();
323 *req.version_mut() = *this.version;
324 *req.headers_mut() = this.headers.clone();
325 this.policy.on_request(&mut req);
326 this.future
327 .set(Either::Right(Oneshot::new(this.service.clone(), req)));
328
329 cx.waker().wake_by_ref();
330 Poll::Pending
331 }
332 Action::Stop => Poll::Ready(Ok(res)),
333 }
334 }
335}
336
337#[derive(Clone)]
343pub struct RequestUri(pub Uri);
344
345#[derive(Debug)]
346enum BodyRepr<B> {
347 Some(B),
348 Empty,
349 None,
350}
351
352impl<B> BodyRepr<B>
353where
354 B: Body + Default,
355{
356 fn take(&mut self) -> Option<B> {
357 match mem::replace(self, BodyRepr::None) {
358 BodyRepr::Some(body) => Some(body),
359 BodyRepr::Empty => {
360 *self = BodyRepr::Empty;
361 Some(B::default())
362 }
363 BodyRepr::None => None,
364 }
365 }
366
367 fn try_clone_from<P, E>(&mut self, body: &B, policy: &P)
368 where
369 P: Policy<B, E>,
370 {
371 match self {
372 BodyRepr::Some(_) | BodyRepr::Empty => {}
373 BodyRepr::None => {
374 if let Some(body) = clone_body(policy, body) {
375 *self = BodyRepr::Some(body);
376 }
377 }
378 }
379 }
380}
381
382fn clone_body<P, B, E>(policy: &P, body: &B) -> Option<B>
383where
384 P: Policy<B, E>,
385 B: Body + Default,
386{
387 if body.size_hint().exact() == Some(0) {
388 Some(B::default())
389 } else {
390 policy.clone_body(body)
391 }
392}
393
394fn resolve_uri(relative: &str, base: &Uri) -> Option<Uri> {
396 let relative = UriReferenceStr::new(relative).ok()?;
397 let base = UriAbsoluteString::try_from(base.to_string()).ok()?;
398 let uri = relative.resolve_against(&base).to_string();
399 Uri::try_from(uri).ok()
400}
401
402#[cfg(test)]
403mod tests {
404 use super::{policy::*, *};
405 use crate::test_helpers::Body;
406 use http::header::LOCATION;
407 use std::convert::Infallible;
408 use tower::{ServiceBuilder, ServiceExt};
409
410 #[tokio::test]
411 async fn follows() {
412 let svc = ServiceBuilder::new()
413 .layer(FollowRedirectLayer::with_policy(Action::Follow))
414 .buffer(1)
415 .service_fn(handle);
416 let req = Request::builder()
417 .uri("http://example.com/42")
418 .body(Body::empty())
419 .unwrap();
420 let res = svc.oneshot(req).await.unwrap();
421 assert_eq!(*res.body(), 0);
422 assert_eq!(
423 res.extensions().get::<RequestUri>().unwrap().0,
424 "http://example.com/0"
425 );
426 }
427
428 #[tokio::test]
429 async fn stops() {
430 let svc = ServiceBuilder::new()
431 .layer(FollowRedirectLayer::with_policy(Action::Stop))
432 .buffer(1)
433 .service_fn(handle);
434 let req = Request::builder()
435 .uri("http://example.com/42")
436 .body(Body::empty())
437 .unwrap();
438 let res = svc.oneshot(req).await.unwrap();
439 assert_eq!(*res.body(), 42);
440 assert_eq!(
441 res.extensions().get::<RequestUri>().unwrap().0,
442 "http://example.com/42"
443 );
444 }
445
446 #[tokio::test]
447 async fn limited() {
448 let svc = ServiceBuilder::new()
449 .layer(FollowRedirectLayer::with_policy(Limited::new(10)))
450 .buffer(1)
451 .service_fn(handle);
452 let req = Request::builder()
453 .uri("http://example.com/42")
454 .body(Body::empty())
455 .unwrap();
456 let res = svc.oneshot(req).await.unwrap();
457 assert_eq!(*res.body(), 42 - 10);
458 assert_eq!(
459 res.extensions().get::<RequestUri>().unwrap().0,
460 "http://example.com/32"
461 );
462 }
463
464 async fn handle<B>(req: Request<B>) -> Result<Response<u64>, Infallible> {
467 let n: u64 = req.uri().path()[1..].parse().unwrap();
468 let mut res = Response::builder();
469 if n > 0 {
470 res = res
471 .status(StatusCode::MOVED_PERMANENTLY)
472 .header(LOCATION, format!("/{}", n - 1));
473 }
474 Ok::<_, Infallible>(res.body(n).unwrap())
475 }
476}