tower_http/classify/
status_in_range_is_error.rs

1use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier};
2use http::StatusCode;
3use std::{fmt, ops::RangeInclusive};
4
5/// Response classifier that considers responses with a status code within some range to be
6/// failures.
7///
8/// # Example
9///
10/// A client with tracing where server errors _and_ client errors are considered failures.
11///
12/// ```no_run
13/// use tower_http::{trace::TraceLayer, classify::StatusInRangeAsFailures};
14/// use tower::{ServiceBuilder, Service, ServiceExt};
15/// use http::{Request, Method};
16/// use http_body_util::Full;
17/// use bytes::Bytes;
18/// use hyper_util::{rt::TokioExecutor, client::legacy::Client};
19///
20/// # async fn foo() -> Result<(), tower::BoxError> {
21/// let classifier = StatusInRangeAsFailures::new(400..=599);
22///
23/// let client = Client::builder(TokioExecutor::new()).build_http();
24/// let mut client = ServiceBuilder::new()
25///     .layer(TraceLayer::new(classifier.into_make_classifier()))
26///     .service(client);
27///
28/// let request = Request::builder()
29///     .method(Method::GET)
30///     .uri("https://example.com")
31///     .body(Full::<Bytes>::default())
32///     .unwrap();
33///
34/// let response = client.ready().await?.call(request).await?;
35/// # Ok(())
36/// # }
37/// ```
38#[derive(Debug, Clone)]
39pub struct StatusInRangeAsFailures {
40    range: RangeInclusive<u16>,
41}
42
43impl StatusInRangeAsFailures {
44    /// Creates a new `StatusInRangeAsFailures`.
45    ///
46    /// # Panics
47    ///
48    /// Panics if the start or end of `range` aren't valid status codes as determined by
49    /// [`StatusCode::from_u16`].
50    ///
51    /// [`StatusCode::from_u16`]: https://docs.rs/http/latest/http/status/struct.StatusCode.html#method.from_u16
52    pub fn new(range: RangeInclusive<u16>) -> Self {
53        assert!(
54            StatusCode::from_u16(*range.start()).is_ok(),
55            "range start isn't a valid status code"
56        );
57        assert!(
58            StatusCode::from_u16(*range.end()).is_ok(),
59            "range end isn't a valid status code"
60        );
61
62        Self { range }
63    }
64
65    /// Creates a new `StatusInRangeAsFailures` that classifies client and server responses as
66    /// failures.
67    ///
68    /// This is a convenience for `StatusInRangeAsFailures::new(400..=599)`.
69    pub fn new_for_client_and_server_errors() -> Self {
70        Self::new(400..=599)
71    }
72
73    /// Convert this `StatusInRangeAsFailures` into a [`MakeClassifier`].
74    ///
75    /// [`MakeClassifier`]: super::MakeClassifier
76    pub fn into_make_classifier(self) -> SharedClassifier<Self> {
77        SharedClassifier::new(self)
78    }
79}
80
81impl ClassifyResponse for StatusInRangeAsFailures {
82    type FailureClass = StatusInRangeFailureClass;
83    type ClassifyEos = NeverClassifyEos<Self::FailureClass>;
84
85    fn classify_response<B>(
86        self,
87        res: &http::Response<B>,
88    ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
89        if self.range.contains(&res.status().as_u16()) {
90            let class = StatusInRangeFailureClass::StatusCode(res.status());
91            ClassifiedResponse::Ready(Err(class))
92        } else {
93            ClassifiedResponse::Ready(Ok(()))
94        }
95    }
96
97    fn classify_error<E>(self, error: &E) -> Self::FailureClass
98    where
99        E: std::fmt::Display + 'static,
100    {
101        StatusInRangeFailureClass::Error(error.to_string())
102    }
103}
104
105/// The failure class for [`StatusInRangeAsFailures`].
106#[derive(Debug)]
107pub enum StatusInRangeFailureClass {
108    /// A response was classified as a failure with the corresponding status.
109    StatusCode(StatusCode),
110    /// A response was classified as an error with the corresponding error description.
111    Error(String),
112}
113
114impl fmt::Display for StatusInRangeFailureClass {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        match self {
117            Self::StatusCode(code) => write!(f, "Status code: {}", code),
118            Self::Error(error) => write!(f, "Error: {}", error),
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    #[allow(unused_imports)]
126    use super::*;
127    use http::Response;
128
129    #[test]
130    fn basic() {
131        let classifier = StatusInRangeAsFailures::new(400..=599);
132
133        assert!(matches!(
134            classifier
135                .clone()
136                .classify_response(&response_with_status(200)),
137            ClassifiedResponse::Ready(Ok(())),
138        ));
139
140        assert!(matches!(
141            classifier
142                .clone()
143                .classify_response(&response_with_status(400)),
144            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
145                StatusCode::BAD_REQUEST
146            ))),
147        ));
148
149        assert!(matches!(
150            classifier.classify_response(&response_with_status(500)),
151            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
152                StatusCode::INTERNAL_SERVER_ERROR
153            ))),
154        ));
155    }
156
157    fn response_with_status(status: u16) -> Response<()> {
158        Response::builder().status(status).body(()).unwrap()
159    }
160}