axum/middleware/
map_request.rs

1use crate::body::{Body, Bytes, HttpBody};
2use crate::response::{IntoResponse, Response};
3use crate::BoxError;
4use axum_core::extract::{FromRequest, FromRequestParts};
5use futures_util::future::BoxFuture;
6use http::Request;
7use std::{
8    any::type_name,
9    convert::Infallible,
10    fmt,
11    future::Future,
12    marker::PhantomData,
13    pin::Pin,
14    task::{Context, Poll},
15};
16use tower_layer::Layer;
17use tower_service::Service;
18
19/// Create a middleware from an async function that transforms a request.
20///
21/// This differs from [`tower::util::MapRequest`] in that it allows you to easily run axum-specific
22/// extractors.
23///
24/// # Example
25///
26/// ```
27/// use axum::{
28///     Router,
29///     routing::get,
30///     middleware::map_request,
31///     http::Request,
32/// };
33///
34/// async fn set_header<B>(mut request: Request<B>) -> Request<B> {
35///     request.headers_mut().insert("x-foo", "foo".parse().unwrap());
36///     request
37/// }
38///
39/// async fn handler<B>(request: Request<B>) {
40///     // `request` will have an `x-foo` header
41/// }
42///
43/// let app = Router::new()
44///     .route("/", get(handler))
45///     .layer(map_request(set_header));
46/// # let _: Router = app;
47/// ```
48///
49/// # Rejecting the request
50///
51/// The function given to `map_request` is allowed to also return a `Result` which can be used to
52/// reject the request and return a response immediately, without calling the remaining
53/// middleware.
54///
55/// Specifically the valid return types are:
56///
57/// - `Request<B>`
58/// - `Result<Request<B>, E> where E:  IntoResponse`
59///
60/// ```
61/// use axum::{
62///     Router,
63///     http::{Request, StatusCode},
64///     routing::get,
65///     middleware::map_request,
66/// };
67///
68/// async fn auth<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
69///     let auth_header = request.headers()
70///         .get(http::header::AUTHORIZATION)
71///         .and_then(|header| header.to_str().ok());
72///
73///     match auth_header {
74///         Some(auth_header) if token_is_valid(auth_header) => Ok(request),
75///         _ => Err(StatusCode::UNAUTHORIZED),
76///     }
77/// }
78///
79/// fn token_is_valid(token: &str) -> bool {
80///     // ...
81///     # false
82/// }
83///
84/// let app = Router::new()
85///     .route("/", get(|| async { /* ... */ }))
86///     .route_layer(map_request(auth));
87/// # let app: Router = app;
88/// ```
89///
90/// # Running extractors
91///
92/// ```
93/// use axum::{
94///     Router,
95///     routing::get,
96///     middleware::map_request,
97///     extract::Path,
98///     http::Request,
99/// };
100/// use std::collections::HashMap;
101///
102/// async fn log_path_params<B>(
103///     Path(path_params): Path<HashMap<String, String>>,
104///     request: Request<B>,
105/// ) -> Request<B> {
106///     tracing::debug!(?path_params);
107///     request
108/// }
109///
110/// let app = Router::new()
111///     .route("/", get(|| async { /* ... */ }))
112///     .layer(map_request(log_path_params));
113/// # let _: Router = app;
114/// ```
115///
116/// Note that to access state you must use either [`map_request_with_state`].
117pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
118    map_request_with_state((), f)
119}
120
121/// Create a middleware from an async function that transforms a request, with the given state.
122///
123/// See [`State`](crate::extract::State) for more details about accessing state.
124///
125/// # Example
126///
127/// ```rust
128/// use axum::{
129///     Router,
130///     http::{Request, StatusCode},
131///     routing::get,
132///     response::IntoResponse,
133///     middleware::map_request_with_state,
134///     extract::State,
135/// };
136///
137/// #[derive(Clone)]
138/// struct AppState { /* ... */ }
139///
140/// async fn my_middleware<B>(
141///     State(state): State<AppState>,
142///     // you can add more extractors here but the last
143///     // extractor must implement `FromRequest` which
144///     // `Request` does
145///     request: Request<B>,
146/// ) -> Request<B> {
147///     // do something with `state` and `request`...
148///     request
149/// }
150///
151/// let state = AppState { /* ... */ };
152///
153/// let app = Router::new()
154///     .route("/", get(|| async { /* ... */ }))
155///     .route_layer(map_request_with_state(state.clone(), my_middleware))
156///     .with_state(state);
157/// # let _: axum::Router = app;
158/// ```
159pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
160    MapRequestLayer {
161        f,
162        state,
163        _extractor: PhantomData,
164    }
165}
166
167/// A [`tower::Layer`] from an async function that transforms a request.
168///
169/// Created with [`map_request`]. See that function for more details.
170#[must_use]
171pub struct MapRequestLayer<F, S, T> {
172    f: F,
173    state: S,
174    _extractor: PhantomData<fn() -> T>,
175}
176
177impl<F, S, T> Clone for MapRequestLayer<F, S, T>
178where
179    F: Clone,
180    S: Clone,
181{
182    fn clone(&self) -> Self {
183        Self {
184            f: self.f.clone(),
185            state: self.state.clone(),
186            _extractor: self._extractor,
187        }
188    }
189}
190
191impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
192where
193    F: Clone,
194    S: Clone,
195{
196    type Service = MapRequest<F, S, I, T>;
197
198    fn layer(&self, inner: I) -> Self::Service {
199        MapRequest {
200            f: self.f.clone(),
201            state: self.state.clone(),
202            inner,
203            _extractor: PhantomData,
204        }
205    }
206}
207
208impl<F, S, T> fmt::Debug for MapRequestLayer<F, S, T>
209where
210    S: fmt::Debug,
211{
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        f.debug_struct("MapRequestLayer")
214            // Write out the type name, without quoting it as `&type_name::<F>()` would
215            .field("f", &format_args!("{}", type_name::<F>()))
216            .field("state", &self.state)
217            .finish()
218    }
219}
220
221/// A middleware created from an async function that transforms a request.
222///
223/// Created with [`map_request`]. See that function for more details.
224pub struct MapRequest<F, S, I, T> {
225    f: F,
226    inner: I,
227    state: S,
228    _extractor: PhantomData<fn() -> T>,
229}
230
231impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
232where
233    F: Clone,
234    I: Clone,
235    S: Clone,
236{
237    fn clone(&self) -> Self {
238        Self {
239            f: self.f.clone(),
240            inner: self.inner.clone(),
241            state: self.state.clone(),
242            _extractor: self._extractor,
243        }
244    }
245}
246
247macro_rules! impl_service {
248    (
249        [$($ty:ident),*], $last:ident
250    ) => {
251        #[allow(non_snake_case, unused_mut)]
252        impl<F, Fut, S, I, B, $($ty,)* $last> Service<Request<B>> for MapRequest<F, S, I, ($($ty,)* $last,)>
253        where
254            F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static,
255            $( $ty: FromRequestParts<S> + Send, )*
256            $last: FromRequest<S> + Send,
257            Fut: Future + Send + 'static,
258            Fut::Output: IntoMapRequestResult<B> + Send + 'static,
259            I: Service<Request<B>, Error = Infallible>
260                + Clone
261                + Send
262                + 'static,
263            I::Response: IntoResponse,
264            I::Future: Send + 'static,
265            B: HttpBody<Data = Bytes> + Send + 'static,
266            B::Error: Into<BoxError>,
267            S: Clone + Send + Sync + 'static,
268        {
269            type Response = Response;
270            type Error = Infallible;
271            type Future = ResponseFuture;
272
273            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274                self.inner.poll_ready(cx)
275            }
276
277            fn call(&mut self, req: Request<B>) -> Self::Future {
278                let req = req.map(Body::new);
279
280                let not_ready_inner = self.inner.clone();
281                let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
282
283                let mut f = self.f.clone();
284                let state = self.state.clone();
285
286                let future = Box::pin(async move {
287                    let (mut parts, body) = req.into_parts();
288
289                    $(
290                        let $ty = match $ty::from_request_parts(&mut parts, &state).await {
291                            Ok(value) => value,
292                            Err(rejection) => return rejection.into_response(),
293                        };
294                    )*
295
296                    let req = Request::from_parts(parts, body);
297
298                    let $last = match $last::from_request(req, &state).await {
299                        Ok(value) => value,
300                        Err(rejection) => return rejection.into_response(),
301                    };
302
303                    match f($($ty,)* $last).await.into_map_request_result() {
304                        Ok(req) => {
305                            ready_inner.call(req).await.into_response()
306                        }
307                        Err(res) => {
308                            res
309                        }
310                    }
311                });
312
313                ResponseFuture {
314                    inner: future
315                }
316            }
317        }
318    };
319}
320
321all_the_tuples!(impl_service);
322
323impl<F, S, I, T> fmt::Debug for MapRequest<F, S, I, T>
324where
325    S: fmt::Debug,
326    I: fmt::Debug,
327{
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        f.debug_struct("MapRequest")
330            .field("f", &format_args!("{}", type_name::<F>()))
331            .field("inner", &self.inner)
332            .field("state", &self.state)
333            .finish()
334    }
335}
336
337/// Response future for [`MapRequest`].
338pub struct ResponseFuture {
339    inner: BoxFuture<'static, Response>,
340}
341
342impl Future for ResponseFuture {
343    type Output = Result<Response, Infallible>;
344
345    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
346        self.inner.as_mut().poll(cx).map(Ok)
347    }
348}
349
350impl fmt::Debug for ResponseFuture {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        f.debug_struct("ResponseFuture").finish()
353    }
354}
355
356mod private {
357    use crate::{http::Request, response::IntoResponse};
358
359    pub trait Sealed<B> {}
360    impl<B, E> Sealed<B> for Result<Request<B>, E> where E: IntoResponse {}
361    impl<B> Sealed<B> for Request<B> {}
362}
363
364/// Trait implemented by types that can be returned from [`map_request`],
365/// [`map_request_with_state`].
366///
367/// This trait is sealed such that it cannot be implemented outside this crate.
368pub trait IntoMapRequestResult<B>: private::Sealed<B> {
369    /// Perform the conversion.
370    fn into_map_request_result(self) -> Result<Request<B>, Response>;
371}
372
373impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E>
374where
375    E: IntoResponse,
376{
377    fn into_map_request_result(self) -> Result<Request<B>, Response> {
378        self.map_err(IntoResponse::into_response)
379    }
380}
381
382impl<B> IntoMapRequestResult<B> for Request<B> {
383    fn into_map_request_result(self) -> Result<Request<B>, Response> {
384        Ok(self)
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use crate::{routing::get, test_helpers::TestClient, Router};
392    use http::{HeaderMap, StatusCode};
393
394    #[crate::test]
395    async fn works() {
396        async fn add_header<B>(mut req: Request<B>) -> Request<B> {
397            req.headers_mut().insert("x-foo", "foo".parse().unwrap());
398            req
399        }
400
401        async fn handler(headers: HeaderMap) -> Response {
402            headers["x-foo"]
403                .to_str()
404                .unwrap()
405                .to_owned()
406                .into_response()
407        }
408
409        let app = Router::new()
410            .route("/", get(handler))
411            .layer(map_request(add_header));
412        let client = TestClient::new(app);
413
414        let res = client.get("/").await;
415
416        assert_eq!(res.text().await, "foo");
417    }
418
419    #[crate::test]
420    async fn works_for_short_circutting() {
421        async fn add_header<B>(_req: Request<B>) -> Result<Request<B>, (StatusCode, &'static str)> {
422            Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong"))
423        }
424
425        async fn handler(_headers: HeaderMap) -> Response {
426            unreachable!()
427        }
428
429        let app = Router::new()
430            .route("/", get(handler))
431            .layer(map_request(add_header));
432        let client = TestClient::new(app);
433
434        let res = client.get("/").await;
435
436        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
437        assert_eq!(res.text().await, "something went wrong");
438    }
439}