tower_http/classify/
status_in_range_is_error.rs1use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier};
2use http::StatusCode;
3use std::{fmt, ops::RangeInclusive};
4
5#[derive(Debug, Clone)]
39pub struct StatusInRangeAsFailures {
40    range: RangeInclusive<u16>,
41}
42
43impl StatusInRangeAsFailures {
44    pub fn new(range: RangeInclusive<u16>) -> Self {
53        assert!(
54            StatusCode::from_u16(*range.start()).is_ok(),
55            "range start isn't a valid status code"
56        );
57        assert!(
58            StatusCode::from_u16(*range.end()).is_ok(),
59            "range end isn't a valid status code"
60        );
61
62        Self { range }
63    }
64
65    pub fn new_for_client_and_server_errors() -> Self {
70        Self::new(400..=599)
71    }
72
73    pub fn into_make_classifier(self) -> SharedClassifier<Self> {
77        SharedClassifier::new(self)
78    }
79}
80
81impl ClassifyResponse for StatusInRangeAsFailures {
82    type FailureClass = StatusInRangeFailureClass;
83    type ClassifyEos = NeverClassifyEos<Self::FailureClass>;
84
85    fn classify_response<B>(
86        self,
87        res: &http::Response<B>,
88    ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
89        if self.range.contains(&res.status().as_u16()) {
90            let class = StatusInRangeFailureClass::StatusCode(res.status());
91            ClassifiedResponse::Ready(Err(class))
92        } else {
93            ClassifiedResponse::Ready(Ok(()))
94        }
95    }
96
97    fn classify_error<E>(self, error: &E) -> Self::FailureClass
98    where
99        E: std::fmt::Display + 'static,
100    {
101        StatusInRangeFailureClass::Error(error.to_string())
102    }
103}
104
105#[derive(Debug)]
107pub enum StatusInRangeFailureClass {
108    StatusCode(StatusCode),
110    Error(String),
112}
113
114impl fmt::Display for StatusInRangeFailureClass {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        match self {
117            Self::StatusCode(code) => write!(f, "Status code: {}", code),
118            Self::Error(error) => write!(f, "Error: {}", error),
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    #[allow(unused_imports)]
126    use super::*;
127    use http::Response;
128
129    #[test]
130    fn basic() {
131        let classifier = StatusInRangeAsFailures::new(400..=599);
132
133        assert!(matches!(
134            classifier
135                .clone()
136                .classify_response(&response_with_status(200)),
137            ClassifiedResponse::Ready(Ok(())),
138        ));
139
140        assert!(matches!(
141            classifier
142                .clone()
143                .classify_response(&response_with_status(400)),
144            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
145                StatusCode::BAD_REQUEST
146            ))),
147        ));
148
149        assert!(matches!(
150            classifier.classify_response(&response_with_status(500)),
151            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
152                StatusCode::INTERNAL_SERVER_ERROR
153            ))),
154        ));
155    }
156
157    fn response_with_status(status: u16) -> Response<()> {
158        Response::builder().status(status).body(()).unwrap()
159    }
160}