axum/middleware/
from_fn.rs

1use crate::response::{IntoResponse, Response};
2use axum_core::extract::{FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4use std::{
5    any::type_name,
6    convert::Infallible,
7    fmt,
8    future::Future,
9    marker::PhantomData,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use tower::{util::BoxCloneService, ServiceBuilder};
14use tower_layer::Layer;
15use tower_service::Service;
16
17/// Create a middleware from an async function.
18///
19/// `from_fn` requires the function given to
20///
21/// 1. Be an `async fn`.
22/// 2. Take zero or more [`FromRequestParts`] extractors.
23/// 3. Take exactly one [`FromRequest`] extractor as the second to last argument.
24/// 4. Take [`Next`](Next) as the last argument.
25/// 5. Return something that implements [`IntoResponse`].
26///
27/// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`].
28///
29/// # Example
30///
31/// ```rust
32/// use axum::{
33///     Router,
34///     http,
35///     routing::get,
36///     response::Response,
37///     middleware::{self, Next},
38///     extract::Request,
39/// };
40///
41/// async fn my_middleware(
42///     request: Request,
43///     next: Next,
44/// ) -> Response {
45///     // do something with `request`...
46///
47///     let response = next.run(request).await;
48///
49///     // do something with `response`...
50///
51///     response
52/// }
53///
54/// let app = Router::new()
55///     .route("/", get(|| async { /* ... */ }))
56///     .layer(middleware::from_fn(my_middleware));
57/// # let app: Router = app;
58/// ```
59///
60/// # Running extractors
61///
62/// ```rust
63/// use axum::{
64///     Router,
65///     extract::Request,
66///     http::{StatusCode, HeaderMap},
67///     middleware::{self, Next},
68///     response::Response,
69///     routing::get,
70/// };
71///
72/// async fn auth(
73///     // run the `HeaderMap` extractor
74///     headers: HeaderMap,
75///     // you can also add more extractors here but the last
76///     // extractor must implement `FromRequest` which
77///     // `Request` does
78///     request: Request,
79///     next: Next,
80/// ) -> Result<Response, StatusCode> {
81///     match get_token(&headers) {
82///         Some(token) if token_is_valid(token) => {
83///             let response = next.run(request).await;
84///             Ok(response)
85///         }
86///         _ => {
87///             Err(StatusCode::UNAUTHORIZED)
88///         }
89///     }
90/// }
91///
92/// fn get_token(headers: &HeaderMap) -> Option<&str> {
93///     // ...
94///     # None
95/// }
96///
97/// fn token_is_valid(token: &str) -> bool {
98///     // ...
99///     # false
100/// }
101///
102/// let app = Router::new()
103///     .route("/", get(|| async { /* ... */ }))
104///     .route_layer(middleware::from_fn(auth));
105/// # let app: Router = app;
106/// ```
107///
108/// [extractors]: crate::extract::FromRequest
109/// [`State`]: crate::extract::State
110pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
111    from_fn_with_state((), f)
112}
113
114/// Create a middleware from an async function with the given state.
115///
116/// See [`State`](crate::extract::State) for more details about accessing state.
117///
118/// # Example
119///
120/// ```rust
121/// use axum::{
122///     Router,
123///     http::StatusCode,
124///     routing::get,
125///     response::{IntoResponse, Response},
126///     middleware::{self, Next},
127///     extract::{Request, State},
128/// };
129///
130/// #[derive(Clone)]
131/// struct AppState { /* ... */ }
132///
133/// async fn my_middleware(
134///     State(state): State<AppState>,
135///     // you can add more extractors here but the last
136///     // extractor must implement `FromRequest` which
137///     // `Request` does
138///     request: Request,
139///     next: Next,
140/// ) -> Response {
141///     // do something with `request`...
142///
143///     let response = next.run(request).await;
144///
145///     // do something with `response`...
146///
147///     response
148/// }
149///
150/// let state = AppState { /* ... */ };
151///
152/// let app = Router::new()
153///     .route("/", get(|| async { /* ... */ }))
154///     .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
155///     .with_state(state);
156/// # let _: axum::Router = app;
157/// ```
158pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
159    FromFnLayer {
160        f,
161        state,
162        _extractor: PhantomData,
163    }
164}
165
166/// A [`tower::Layer`] from an async function.
167///
168/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
169///
170/// Created with [`from_fn`]. See that function for more details.
171#[must_use]
172pub struct FromFnLayer<F, S, T> {
173    f: F,
174    state: S,
175    _extractor: PhantomData<fn() -> T>,
176}
177
178impl<F, S, T> Clone for FromFnLayer<F, S, T>
179where
180    F: Clone,
181    S: Clone,
182{
183    fn clone(&self) -> Self {
184        Self {
185            f: self.f.clone(),
186            state: self.state.clone(),
187            _extractor: self._extractor,
188        }
189    }
190}
191
192impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
193where
194    F: Clone,
195    S: Clone,
196{
197    type Service = FromFn<F, S, I, T>;
198
199    fn layer(&self, inner: I) -> Self::Service {
200        FromFn {
201            f: self.f.clone(),
202            state: self.state.clone(),
203            inner,
204            _extractor: PhantomData,
205        }
206    }
207}
208
209impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
210where
211    S: fmt::Debug,
212{
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        f.debug_struct("FromFnLayer")
215            // Write out the type name, without quoting it as `&type_name::<F>()` would
216            .field("f", &format_args!("{}", type_name::<F>()))
217            .field("state", &self.state)
218            .finish()
219    }
220}
221
222/// A middleware created from an async function.
223///
224/// Created with [`from_fn`]. See that function for more details.
225pub struct FromFn<F, S, I, T> {
226    f: F,
227    inner: I,
228    state: S,
229    _extractor: PhantomData<fn() -> T>,
230}
231
232impl<F, S, I, T> Clone for FromFn<F, S, I, T>
233where
234    F: Clone,
235    I: Clone,
236    S: Clone,
237{
238    fn clone(&self) -> Self {
239        Self {
240            f: self.f.clone(),
241            inner: self.inner.clone(),
242            state: self.state.clone(),
243            _extractor: self._extractor,
244        }
245    }
246}
247
248macro_rules! impl_service {
249    (
250        [$($ty:ident),*], $last:ident
251    ) => {
252        #[allow(non_snake_case, unused_mut)]
253        impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
254        where
255            F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
256            $( $ty: FromRequestParts<S> + Send, )*
257            $last: FromRequest<S> + Send,
258            Fut: Future<Output = Out> + Send + 'static,
259            Out: IntoResponse + 'static,
260            I: Service<Request, Error = Infallible>
261                + Clone
262                + Send
263                + 'static,
264            I::Response: IntoResponse,
265            I::Future: Send + 'static,
266            S: Clone + Send + Sync + 'static,
267        {
268            type Response = Response;
269            type Error = Infallible;
270            type Future = ResponseFuture;
271
272            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273                self.inner.poll_ready(cx)
274            }
275
276            fn call(&mut self, req: Request) -> Self::Future {
277                let not_ready_inner = self.inner.clone();
278                let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
279
280                let mut f = self.f.clone();
281                let state = self.state.clone();
282
283                let future = Box::pin(async move {
284                    let (mut parts, body) = req.into_parts();
285
286                    $(
287                        let $ty = match $ty::from_request_parts(&mut parts, &state).await {
288                            Ok(value) => value,
289                            Err(rejection) => return rejection.into_response(),
290                        };
291                    )*
292
293                    let req = Request::from_parts(parts, body);
294
295                    let $last = match $last::from_request(req, &state).await {
296                        Ok(value) => value,
297                        Err(rejection) => return rejection.into_response(),
298                    };
299
300                    let inner = ServiceBuilder::new()
301                        .boxed_clone()
302                        .map_response(IntoResponse::into_response)
303                        .service(ready_inner);
304                    let next = Next { inner };
305
306                    f($($ty,)* $last, next).await.into_response()
307                });
308
309                ResponseFuture {
310                    inner: future
311                }
312            }
313        }
314    };
315}
316
317all_the_tuples!(impl_service);
318
319impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
320where
321    S: fmt::Debug,
322    I: fmt::Debug,
323{
324    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325        f.debug_struct("FromFnLayer")
326            .field("f", &format_args!("{}", type_name::<F>()))
327            .field("inner", &self.inner)
328            .field("state", &self.state)
329            .finish()
330    }
331}
332
333/// The remainder of a middleware stack, including the handler.
334#[derive(Debug, Clone)]
335pub struct Next {
336    inner: BoxCloneService<Request, Response, Infallible>,
337}
338
339impl Next {
340    /// Execute the remaining middleware stack.
341    pub async fn run(mut self, req: Request) -> Response {
342        match self.inner.call(req).await {
343            Ok(res) => res,
344            Err(err) => match err {},
345        }
346    }
347}
348
349impl Service<Request> for Next {
350    type Response = Response;
351    type Error = Infallible;
352    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
353
354    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355        self.inner.poll_ready(cx)
356    }
357
358    fn call(&mut self, req: Request) -> Self::Future {
359        self.inner.call(req)
360    }
361}
362
363/// Response future for [`FromFn`].
364pub struct ResponseFuture {
365    inner: BoxFuture<'static, Response>,
366}
367
368impl Future for ResponseFuture {
369    type Output = Result<Response, Infallible>;
370
371    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
372        self.inner.as_mut().poll(cx).map(Ok)
373    }
374}
375
376impl fmt::Debug for ResponseFuture {
377    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378        f.debug_struct("ResponseFuture").finish()
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::{body::Body, routing::get, Router};
386    use http::{HeaderMap, StatusCode};
387    use http_body_util::BodyExt;
388    use tower::ServiceExt;
389
390    #[crate::test]
391    async fn basic() {
392        async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
393            req.headers_mut()
394                .insert("x-axum-test", "ok".parse().unwrap());
395
396            next.run(req).await
397        }
398
399        async fn handle(headers: HeaderMap) -> String {
400            headers["x-axum-test"].to_str().unwrap().to_owned()
401        }
402
403        let app = Router::new()
404            .route("/", get(handle))
405            .layer(from_fn(insert_header));
406
407        let res = app
408            .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
409            .await
410            .unwrap();
411        assert_eq!(res.status(), StatusCode::OK);
412        let body = res.collect().await.unwrap().to_bytes();
413        assert_eq!(&body[..], b"ok");
414    }
415}