axum/middleware/
map_response.rs

1use crate::response::{IntoResponse, Response};
2use axum_core::extract::FromRequestParts;
3use futures_util::future::BoxFuture;
4use http::Request;
5use std::{
6    any::type_name,
7    convert::Infallible,
8    fmt,
9    future::Future,
10    marker::PhantomData,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tower_layer::Layer;
15use tower_service::Service;
16
17/// Create a middleware from an async function that transforms a response.
18///
19/// This differs from [`tower::util::MapResponse`] in that it allows you to easily run axum-specific
20/// extractors.
21///
22/// # Example
23///
24/// ```
25/// use axum::{
26///     Router,
27///     routing::get,
28///     middleware::map_response,
29///     response::Response,
30/// };
31///
32/// async fn set_header<B>(mut response: Response<B>) -> Response<B> {
33///     response.headers_mut().insert("x-foo", "foo".parse().unwrap());
34///     response
35/// }
36///
37/// let app = Router::new()
38///     .route("/", get(|| async { /* ... */ }))
39///     .layer(map_response(set_header));
40/// # let _: Router = app;
41/// ```
42///
43/// # Running extractors
44///
45/// It is also possible to run extractors that implement [`FromRequestParts`]. These will be run
46/// before calling the handler.
47///
48/// ```
49/// use axum::{
50///     Router,
51///     routing::get,
52///     middleware::map_response,
53///     extract::Path,
54///     response::Response,
55/// };
56/// use std::collections::HashMap;
57///
58/// async fn log_path_params<B>(
59///     Path(path_params): Path<HashMap<String, String>>,
60///     response: Response<B>,
61/// ) -> Response<B> {
62///     tracing::debug!(?path_params);
63///     response
64/// }
65///
66/// let app = Router::new()
67///     .route("/", get(|| async { /* ... */ }))
68///     .layer(map_response(log_path_params));
69/// # let _: Router = app;
70/// ```
71///
72/// Note that to access state you must use either [`map_response_with_state`].
73///
74/// # Returning any `impl IntoResponse`
75///
76/// It is also possible to return anything that implements [`IntoResponse`]
77///
78/// ```
79/// use axum::{
80///     Router,
81///     routing::get,
82///     middleware::map_response,
83///     response::{Response, IntoResponse},
84/// };
85/// use std::collections::HashMap;
86///
87/// async fn set_header(response: Response) -> impl IntoResponse {
88///     (
89///         [("x-foo", "foo")],
90///         response,
91///     )
92/// }
93///
94/// let app = Router::new()
95///     .route("/", get(|| async { /* ... */ }))
96///     .layer(map_response(set_header));
97/// # let _: Router = app;
98/// ```
99pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
100    map_response_with_state((), f)
101}
102
103/// Create a middleware from an async function that transforms a response, with the given state.
104///
105/// See [`State`](crate::extract::State) for more details about accessing state.
106///
107/// # Example
108///
109/// ```rust
110/// use axum::{
111///     Router,
112///     http::StatusCode,
113///     routing::get,
114///     response::Response,
115///     middleware::map_response_with_state,
116///     extract::State,
117/// };
118///
119/// #[derive(Clone)]
120/// struct AppState { /* ... */ }
121///
122/// async fn my_middleware<B>(
123///     State(state): State<AppState>,
124///     // you can add more extractors here but they must
125///     // all implement `FromRequestParts`
126///     // `FromRequest` is not allowed
127///     response: Response<B>,
128/// ) -> Response<B> {
129///     // do something with `state` and `response`...
130///     response
131/// }
132///
133/// let state = AppState { /* ... */ };
134///
135/// let app = Router::new()
136///     .route("/", get(|| async { /* ... */ }))
137///     .route_layer(map_response_with_state(state.clone(), my_middleware))
138///     .with_state(state);
139/// # let _: axum::Router = app;
140/// ```
141pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
142    MapResponseLayer {
143        f,
144        state,
145        _extractor: PhantomData,
146    }
147}
148
149/// A [`tower::Layer`] from an async function that transforms a response.
150///
151/// Created with [`map_response`]. See that function for more details.
152#[must_use]
153pub struct MapResponseLayer<F, S, T> {
154    f: F,
155    state: S,
156    _extractor: PhantomData<fn() -> T>,
157}
158
159impl<F, S, T> Clone for MapResponseLayer<F, S, T>
160where
161    F: Clone,
162    S: Clone,
163{
164    fn clone(&self) -> Self {
165        Self {
166            f: self.f.clone(),
167            state: self.state.clone(),
168            _extractor: self._extractor,
169        }
170    }
171}
172
173impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
174where
175    F: Clone,
176    S: Clone,
177{
178    type Service = MapResponse<F, S, I, T>;
179
180    fn layer(&self, inner: I) -> Self::Service {
181        MapResponse {
182            f: self.f.clone(),
183            state: self.state.clone(),
184            inner,
185            _extractor: PhantomData,
186        }
187    }
188}
189
190impl<F, S, T> fmt::Debug for MapResponseLayer<F, S, T>
191where
192    S: fmt::Debug,
193{
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        f.debug_struct("MapResponseLayer")
196            // Write out the type name, without quoting it as `&type_name::<F>()` would
197            .field("f", &format_args!("{}", type_name::<F>()))
198            .field("state", &self.state)
199            .finish()
200    }
201}
202
203/// A middleware created from an async function that transforms a response.
204///
205/// Created with [`map_response`]. See that function for more details.
206pub struct MapResponse<F, S, I, T> {
207    f: F,
208    inner: I,
209    state: S,
210    _extractor: PhantomData<fn() -> T>,
211}
212
213impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
214where
215    F: Clone,
216    I: Clone,
217    S: Clone,
218{
219    fn clone(&self) -> Self {
220        Self {
221            f: self.f.clone(),
222            inner: self.inner.clone(),
223            state: self.state.clone(),
224            _extractor: self._extractor,
225        }
226    }
227}
228
229macro_rules! impl_service {
230    (
231        $($ty:ident),*
232    ) => {
233        #[allow(non_snake_case, unused_mut)]
234        impl<F, Fut, S, I, B, ResBody, $($ty,)*> Service<Request<B>> for MapResponse<F, S, I, ($($ty,)*)>
235        where
236            F: FnMut($($ty,)* Response<ResBody>) -> Fut + Clone + Send + 'static,
237            $( $ty: FromRequestParts<S> + Send, )*
238            Fut: Future + Send + 'static,
239            Fut::Output: IntoResponse + Send + 'static,
240            I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
241                + Clone
242                + Send
243                + 'static,
244            I::Future: Send + 'static,
245            B: Send + 'static,
246            ResBody: Send + 'static,
247            S: Clone + Send + Sync + 'static,
248        {
249            type Response = Response;
250            type Error = Infallible;
251            type Future = ResponseFuture;
252
253            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254                self.inner.poll_ready(cx)
255            }
256
257
258            fn call(&mut self, req: Request<B>) -> Self::Future {
259                let not_ready_inner = self.inner.clone();
260                let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
261
262                let mut f = self.f.clone();
263                let _state = self.state.clone();
264
265                let future = Box::pin(async move {
266                    let (mut parts, body) = req.into_parts();
267
268                    $(
269                        let $ty = match $ty::from_request_parts(&mut parts, &_state).await {
270                            Ok(value) => value,
271                            Err(rejection) => return rejection.into_response(),
272                        };
273                    )*
274
275                    let req = Request::from_parts(parts, body);
276
277                    match ready_inner.call(req).await {
278                        Ok(res) => {
279                            f($($ty,)* res).await.into_response()
280                        }
281                        Err(err) => match err {}
282                    }
283                });
284
285                ResponseFuture {
286                    inner: future
287                }
288            }
289        }
290    };
291}
292
293impl_service!();
294impl_service!(T1);
295impl_service!(T1, T2);
296impl_service!(T1, T2, T3);
297impl_service!(T1, T2, T3, T4);
298impl_service!(T1, T2, T3, T4, T5);
299impl_service!(T1, T2, T3, T4, T5, T6);
300impl_service!(T1, T2, T3, T4, T5, T6, T7);
301impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
302impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
303impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
304impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
305impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
306impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
307impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
308impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
309impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
310
311impl<F, S, I, T> fmt::Debug for MapResponse<F, S, I, T>
312where
313    S: fmt::Debug,
314    I: fmt::Debug,
315{
316    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        f.debug_struct("MapResponse")
318            .field("f", &format_args!("{}", type_name::<F>()))
319            .field("inner", &self.inner)
320            .field("state", &self.state)
321            .finish()
322    }
323}
324
325/// Response future for [`MapResponse`].
326pub struct ResponseFuture {
327    inner: BoxFuture<'static, Response>,
328}
329
330impl Future for ResponseFuture {
331    type Output = Result<Response, Infallible>;
332
333    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
334        self.inner.as_mut().poll(cx).map(Ok)
335    }
336}
337
338impl fmt::Debug for ResponseFuture {
339    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340        f.debug_struct("ResponseFuture").finish()
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    #[allow(unused_imports)]
347    use super::*;
348    use crate::{test_helpers::TestClient, Router};
349
350    #[crate::test]
351    async fn works() {
352        async fn add_header<B>(mut res: Response<B>) -> Response<B> {
353            res.headers_mut().insert("x-foo", "foo".parse().unwrap());
354            res
355        }
356
357        let app = Router::new().layer(map_response(add_header));
358        let client = TestClient::new(app);
359
360        let res = client.get("/").await;
361
362        assert_eq!(res.headers()["x-foo"], "foo");
363    }
364}