1#![doc = include_str!("../docs/error_handling.md")]
2
3use crate::{
4 extract::FromRequestParts,
5 http::Request,
6 response::{IntoResponse, Response},
7};
8use std::{
9 convert::Infallible,
10 fmt,
11 future::Future,
12 marker::PhantomData,
13 task::{Context, Poll},
14};
15use tower::ServiceExt;
16use tower_layer::Layer;
17use tower_service::Service;
18
19pub struct HandleErrorLayer<F, T> {
24 f: F,
25 _extractor: PhantomData<fn() -> T>,
26}
27
28impl<F, T> HandleErrorLayer<F, T> {
29 pub fn new(f: F) -> Self {
31 Self {
32 f,
33 _extractor: PhantomData,
34 }
35 }
36}
37
38impl<F, T> Clone for HandleErrorLayer<F, T>
39where
40 F: Clone,
41{
42 fn clone(&self) -> Self {
43 Self {
44 f: self.f.clone(),
45 _extractor: PhantomData,
46 }
47 }
48}
49
50impl<F, E> fmt::Debug for HandleErrorLayer<F, E> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("HandleErrorLayer")
53 .field("f", &format_args!("{}", std::any::type_name::<F>()))
54 .finish()
55 }
56}
57
58impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
59where
60 F: Clone,
61{
62 type Service = HandleError<S, F, T>;
63
64 fn layer(&self, inner: S) -> Self::Service {
65 HandleError::new(inner, self.f.clone())
66 }
67}
68
69pub struct HandleError<S, F, T> {
73 inner: S,
74 f: F,
75 _extractor: PhantomData<fn() -> T>,
76}
77
78impl<S, F, T> HandleError<S, F, T> {
79 pub fn new(inner: S, f: F) -> Self {
81 Self {
82 inner,
83 f,
84 _extractor: PhantomData,
85 }
86 }
87}
88
89impl<S, F, T> Clone for HandleError<S, F, T>
90where
91 S: Clone,
92 F: Clone,
93{
94 fn clone(&self) -> Self {
95 Self {
96 inner: self.inner.clone(),
97 f: self.f.clone(),
98 _extractor: PhantomData,
99 }
100 }
101}
102
103impl<S, F, E> fmt::Debug for HandleError<S, F, E>
104where
105 S: fmt::Debug,
106{
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 f.debug_struct("HandleError")
109 .field("inner", &self.inner)
110 .field("f", &format_args!("{}", std::any::type_name::<F>()))
111 .finish()
112 }
113}
114
115impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
116where
117 S: Service<Request<B>> + Clone + Send + 'static,
118 S::Response: IntoResponse + Send,
119 S::Error: Send,
120 S::Future: Send,
121 F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
122 Fut: Future<Output = Res> + Send,
123 Res: IntoResponse,
124 B: Send + 'static,
125{
126 type Response = Response;
127 type Error = Infallible;
128 type Future = future::HandleErrorFuture;
129
130 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131 Poll::Ready(Ok(()))
132 }
133
134 fn call(&mut self, req: Request<B>) -> Self::Future {
135 let f = self.f.clone();
136
137 let clone = self.inner.clone();
138 let inner = std::mem::replace(&mut self.inner, clone);
139
140 let future = Box::pin(async move {
141 match inner.oneshot(req).await {
142 Ok(res) => Ok(res.into_response()),
143 Err(err) => Ok(f(err).await.into_response()),
144 }
145 });
146
147 future::HandleErrorFuture { future }
148 }
149}
150
151#[allow(unused_macros)]
152macro_rules! impl_service {
153 ( $($ty:ident),* $(,)? ) => {
154 impl<S, F, B, Res, Fut, $($ty,)*> Service<Request<B>>
155 for HandleError<S, F, ($($ty,)*)>
156 where
157 S: Service<Request<B>> + Clone + Send + 'static,
158 S::Response: IntoResponse + Send,
159 S::Error: Send,
160 S::Future: Send,
161 F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
162 Fut: Future<Output = Res> + Send,
163 Res: IntoResponse,
164 $( $ty: FromRequestParts<()> + Send,)*
165 B: Send + 'static,
166 {
167 type Response = Response;
168 type Error = Infallible;
169
170 type Future = future::HandleErrorFuture;
171
172 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173 Poll::Ready(Ok(()))
174 }
175
176 #[allow(non_snake_case)]
177 fn call(&mut self, req: Request<B>) -> Self::Future {
178 let f = self.f.clone();
179
180 let clone = self.inner.clone();
181 let inner = std::mem::replace(&mut self.inner, clone);
182
183 let future = Box::pin(async move {
184 let (mut parts, body) = req.into_parts();
185
186 $(
187 let $ty = match $ty::from_request_parts(&mut parts, &()).await {
188 Ok(value) => value,
189 Err(rejection) => return Ok(rejection.into_response()),
190 };
191 )*
192
193 let req = Request::from_parts(parts, body);
194
195 match inner.oneshot(req).await {
196 Ok(res) => Ok(res.into_response()),
197 Err(err) => Ok(f($($ty),*, err).await.into_response()),
198 }
199 });
200
201 future::HandleErrorFuture { future }
202 }
203 }
204 }
205}
206
207impl_service!(T1);
208impl_service!(T1, T2);
209impl_service!(T1, T2, T3);
210impl_service!(T1, T2, T3, T4);
211impl_service!(T1, T2, T3, T4, T5);
212impl_service!(T1, T2, T3, T4, T5, T6);
213impl_service!(T1, T2, T3, T4, T5, T6, T7);
214impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
215impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
216impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
217impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
218impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
219impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
220impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
221impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
222impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
223
224pub mod future {
225 use crate::response::Response;
228 use pin_project_lite::pin_project;
229 use std::{
230 convert::Infallible,
231 future::Future,
232 pin::Pin,
233 task::{Context, Poll},
234 };
235
236 pin_project! {
237 pub struct HandleErrorFuture {
239 #[pin]
240 pub(super) future: Pin<Box<dyn Future<Output = Result<Response, Infallible>>
241 + Send
242 + 'static
243 >>,
244 }
245 }
246
247 impl Future for HandleErrorFuture {
248 type Output = Result<Response, Infallible>;
249
250 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
251 self.project().future.poll(cx)
252 }
253 }
254}
255
256#[test]
257fn traits() {
258 use crate::test_helpers::*;
259
260 assert_send::<HandleError<(), (), NotSendSync>>();
261 assert_sync::<HandleError<(), (), NotSendSync>>();
262}