axum/error_handling/
mod.rs

1#![doc = include_str!("../docs/error_handling.md")]
2
3use crate::{
4    extract::FromRequestParts,
5    http::Request,
6    response::{IntoResponse, Response},
7};
8use std::{
9    convert::Infallible,
10    fmt,
11    future::Future,
12    marker::PhantomData,
13    task::{Context, Poll},
14};
15use tower::ServiceExt;
16use tower_layer::Layer;
17use tower_service::Service;
18
19/// [`Layer`] that applies [`HandleError`] which is a [`Service`] adapter
20/// that handles errors by converting them into responses.
21///
22/// See [module docs](self) for more details on axum's error handling model.
23pub struct HandleErrorLayer<F, T> {
24    f: F,
25    _extractor: PhantomData<fn() -> T>,
26}
27
28impl<F, T> HandleErrorLayer<F, T> {
29    /// Create a new `HandleErrorLayer`.
30    pub fn new(f: F) -> Self {
31        Self {
32            f,
33            _extractor: PhantomData,
34        }
35    }
36}
37
38impl<F, T> Clone for HandleErrorLayer<F, T>
39where
40    F: Clone,
41{
42    fn clone(&self) -> Self {
43        Self {
44            f: self.f.clone(),
45            _extractor: PhantomData,
46        }
47    }
48}
49
50impl<F, E> fmt::Debug for HandleErrorLayer<F, E> {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("HandleErrorLayer")
53            .field("f", &format_args!("{}", std::any::type_name::<F>()))
54            .finish()
55    }
56}
57
58impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
59where
60    F: Clone,
61{
62    type Service = HandleError<S, F, T>;
63
64    fn layer(&self, inner: S) -> Self::Service {
65        HandleError::new(inner, self.f.clone())
66    }
67}
68
69/// A [`Service`] adapter that handles errors by converting them into responses.
70///
71/// See [module docs](self) for more details on axum's error handling model.
72pub struct HandleError<S, F, T> {
73    inner: S,
74    f: F,
75    _extractor: PhantomData<fn() -> T>,
76}
77
78impl<S, F, T> HandleError<S, F, T> {
79    /// Create a new `HandleError`.
80    pub fn new(inner: S, f: F) -> Self {
81        Self {
82            inner,
83            f,
84            _extractor: PhantomData,
85        }
86    }
87}
88
89impl<S, F, T> Clone for HandleError<S, F, T>
90where
91    S: Clone,
92    F: Clone,
93{
94    fn clone(&self) -> Self {
95        Self {
96            inner: self.inner.clone(),
97            f: self.f.clone(),
98            _extractor: PhantomData,
99        }
100    }
101}
102
103impl<S, F, E> fmt::Debug for HandleError<S, F, E>
104where
105    S: fmt::Debug,
106{
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        f.debug_struct("HandleError")
109            .field("inner", &self.inner)
110            .field("f", &format_args!("{}", std::any::type_name::<F>()))
111            .finish()
112    }
113}
114
115impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
116where
117    S: Service<Request<B>> + Clone + Send + 'static,
118    S::Response: IntoResponse + Send,
119    S::Error: Send,
120    S::Future: Send,
121    F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
122    Fut: Future<Output = Res> + Send,
123    Res: IntoResponse,
124    B: Send + 'static,
125{
126    type Response = Response;
127    type Error = Infallible;
128    type Future = future::HandleErrorFuture;
129
130    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131        Poll::Ready(Ok(()))
132    }
133
134    fn call(&mut self, req: Request<B>) -> Self::Future {
135        let f = self.f.clone();
136
137        let clone = self.inner.clone();
138        let inner = std::mem::replace(&mut self.inner, clone);
139
140        let future = Box::pin(async move {
141            match inner.oneshot(req).await {
142                Ok(res) => Ok(res.into_response()),
143                Err(err) => Ok(f(err).await.into_response()),
144            }
145        });
146
147        future::HandleErrorFuture { future }
148    }
149}
150
151#[allow(unused_macros)]
152macro_rules! impl_service {
153    ( $($ty:ident),* $(,)? ) => {
154        impl<S, F, B, Res, Fut, $($ty,)*> Service<Request<B>>
155            for HandleError<S, F, ($($ty,)*)>
156        where
157            S: Service<Request<B>> + Clone + Send + 'static,
158            S::Response: IntoResponse + Send,
159            S::Error: Send,
160            S::Future: Send,
161            F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
162            Fut: Future<Output = Res> + Send,
163            Res: IntoResponse,
164            $( $ty: FromRequestParts<()> + Send,)*
165            B: Send + 'static,
166        {
167            type Response = Response;
168            type Error = Infallible;
169
170            type Future = future::HandleErrorFuture;
171
172            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173                Poll::Ready(Ok(()))
174            }
175
176            #[allow(non_snake_case)]
177            fn call(&mut self, req: Request<B>) -> Self::Future {
178                let f = self.f.clone();
179
180                let clone = self.inner.clone();
181                let inner = std::mem::replace(&mut self.inner, clone);
182
183                let future = Box::pin(async move {
184                    let (mut parts, body) = req.into_parts();
185
186                    $(
187                        let $ty = match $ty::from_request_parts(&mut parts, &()).await {
188                            Ok(value) => value,
189                            Err(rejection) => return Ok(rejection.into_response()),
190                        };
191                    )*
192
193                    let req = Request::from_parts(parts, body);
194
195                    match inner.oneshot(req).await {
196                        Ok(res) => Ok(res.into_response()),
197                        Err(err) => Ok(f($($ty),*, err).await.into_response()),
198                    }
199                });
200
201                future::HandleErrorFuture { future }
202            }
203        }
204    }
205}
206
207impl_service!(T1);
208impl_service!(T1, T2);
209impl_service!(T1, T2, T3);
210impl_service!(T1, T2, T3, T4);
211impl_service!(T1, T2, T3, T4, T5);
212impl_service!(T1, T2, T3, T4, T5, T6);
213impl_service!(T1, T2, T3, T4, T5, T6, T7);
214impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
215impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
216impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
217impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
218impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
219impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
220impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
221impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
222impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
223
224pub mod future {
225    //! Future types.
226
227    use crate::response::Response;
228    use pin_project_lite::pin_project;
229    use std::{
230        convert::Infallible,
231        future::Future,
232        pin::Pin,
233        task::{Context, Poll},
234    };
235
236    pin_project! {
237        /// Response future for [`HandleError`].
238        pub struct HandleErrorFuture {
239            #[pin]
240            pub(super) future: Pin<Box<dyn Future<Output = Result<Response, Infallible>>
241                + Send
242                + 'static
243            >>,
244        }
245    }
246
247    impl Future for HandleErrorFuture {
248        type Output = Result<Response, Infallible>;
249
250        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
251            self.project().future.poll(cx)
252        }
253    }
254}
255
256#[test]
257fn traits() {
258    use crate::test_helpers::*;
259
260    assert_send::<HandleError<(), (), NotSendSync>>();
261    assert_sync::<HandleError<(), (), NotSendSync>>();
262}