axum/middleware/
from_extractor.rs

1use crate::{
2    extract::FromRequestParts,
3    response::{IntoResponse, Response},
4};
5use futures_util::{future::BoxFuture, ready};
6use http::Request;
7use pin_project_lite::pin_project;
8use std::{
9    fmt,
10    future::Future,
11    marker::PhantomData,
12    pin::Pin,
13    task::{Context, Poll},
14};
15use tower_layer::Layer;
16use tower_service::Service;
17
18/// Create a middleware from an extractor.
19///
20/// If the extractor succeeds the value will be discarded and the inner service
21/// will be called. If the extractor fails the rejection will be returned and
22/// the inner service will _not_ be called.
23///
24/// This can be used to perform validation of requests if the validation doesn't
25/// produce any useful output, and run the extractor for several handlers
26/// without repeating it in the function signature.
27///
28/// Note that if the extractor consumes the request body, as `String` or
29/// [`Bytes`] does, an empty body will be left in its place. Thus won't be
30/// accessible to subsequent extractors or handlers.
31///
32/// # Example
33///
34/// ```rust
35/// use axum::{
36///     extract::FromRequestParts,
37///     middleware::from_extractor,
38///     routing::{get, post},
39///     Router,
40///     http::{header, StatusCode, request::Parts},
41/// };
42/// use async_trait::async_trait;
43///
44/// // An extractor that performs authorization.
45/// struct RequireAuth;
46///
47/// #[async_trait]
48/// impl<S> FromRequestParts<S> for RequireAuth
49/// where
50///     S: Send + Sync,
51/// {
52///     type Rejection = StatusCode;
53///
54///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
55///         let auth_header = parts
56///             .headers
57///             .get(header::AUTHORIZATION)
58///             .and_then(|value| value.to_str().ok());
59///
60///         match auth_header {
61///             Some(auth_header) if token_is_valid(auth_header) => {
62///                 Ok(Self)
63///             }
64///             _ => Err(StatusCode::UNAUTHORIZED),
65///         }
66///     }
67/// }
68///
69/// fn token_is_valid(token: &str) -> bool {
70///     // ...
71///     # false
72/// }
73///
74/// async fn handler() {
75///     // If we get here the request has been authorized
76/// }
77///
78/// async fn other_handler() {
79///     // If we get here the request has been authorized
80/// }
81///
82/// let app = Router::new()
83///     .route("/", get(handler))
84///     .route("/foo", post(other_handler))
85///     // The extractor will run before all routes
86///     .route_layer(from_extractor::<RequireAuth>());
87/// # let _: Router = app;
88/// ```
89///
90/// [`Bytes`]: bytes::Bytes
91pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
92    from_extractor_with_state(())
93}
94
95/// Create a middleware from an extractor with the given state.
96///
97/// See [`State`](crate::extract::State) for more details about accessing state.
98pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
99    FromExtractorLayer {
100        state,
101        _marker: PhantomData,
102    }
103}
104
105/// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
106/// discards the value.
107///
108/// See [`from_extractor`] for more details.
109///
110/// [`Layer`]: tower::Layer
111#[must_use]
112pub struct FromExtractorLayer<E, S> {
113    state: S,
114    _marker: PhantomData<fn() -> E>,
115}
116
117impl<E, S> Clone for FromExtractorLayer<E, S>
118where
119    S: Clone,
120{
121    fn clone(&self) -> Self {
122        Self {
123            state: self.state.clone(),
124            _marker: PhantomData,
125        }
126    }
127}
128
129impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
130where
131    S: fmt::Debug,
132{
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("FromExtractorLayer")
135            .field("state", &self.state)
136            .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
137            .finish()
138    }
139}
140
141impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
142where
143    S: Clone,
144{
145    type Service = FromExtractor<T, E, S>;
146
147    fn layer(&self, inner: T) -> Self::Service {
148        FromExtractor {
149            inner,
150            state: self.state.clone(),
151            _extractor: PhantomData,
152        }
153    }
154}
155
156/// Middleware that runs an extractor and discards the value.
157///
158/// See [`from_extractor`] for more details.
159pub struct FromExtractor<T, E, S> {
160    inner: T,
161    state: S,
162    _extractor: PhantomData<fn() -> E>,
163}
164
165#[test]
166fn traits() {
167    use crate::test_helpers::*;
168    assert_send::<FromExtractor<(), NotSendSync, ()>>();
169    assert_sync::<FromExtractor<(), NotSendSync, ()>>();
170}
171
172impl<T, E, S> Clone for FromExtractor<T, E, S>
173where
174    T: Clone,
175    S: Clone,
176{
177    fn clone(&self) -> Self {
178        Self {
179            inner: self.inner.clone(),
180            state: self.state.clone(),
181            _extractor: PhantomData,
182        }
183    }
184}
185
186impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
187where
188    T: fmt::Debug,
189    S: fmt::Debug,
190{
191    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192        f.debug_struct("FromExtractor")
193            .field("inner", &self.inner)
194            .field("state", &self.state)
195            .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
196            .finish()
197    }
198}
199
200impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
201where
202    E: FromRequestParts<S> + 'static,
203    B: Send + 'static,
204    T: Service<Request<B>> + Clone,
205    T::Response: IntoResponse,
206    S: Clone + Send + Sync + 'static,
207{
208    type Response = Response;
209    type Error = T::Error;
210    type Future = ResponseFuture<B, T, E, S>;
211
212    #[inline]
213    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214        self.inner.poll_ready(cx)
215    }
216
217    fn call(&mut self, req: Request<B>) -> Self::Future {
218        let state = self.state.clone();
219        let extract_future = Box::pin(async move {
220            let (mut parts, body) = req.into_parts();
221            let extracted = E::from_request_parts(&mut parts, &state).await;
222            let req = Request::from_parts(parts, body);
223            (req, extracted)
224        });
225
226        ResponseFuture {
227            state: State::Extracting {
228                future: extract_future,
229            },
230            svc: Some(self.inner.clone()),
231        }
232    }
233}
234
235pin_project! {
236    /// Response future for [`FromExtractor`].
237    #[allow(missing_debug_implementations)]
238    pub struct ResponseFuture<B, T, E, S>
239    where
240        E: FromRequestParts<S>,
241        T: Service<Request<B>>,
242    {
243        #[pin]
244        state: State<B, T, E, S>,
245        svc: Option<T>,
246    }
247}
248
249pin_project! {
250    #[project = StateProj]
251    enum State<B, T, E, S>
252    where
253        E: FromRequestParts<S>,
254        T: Service<Request<B>>,
255    {
256        Extracting {
257            future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
258        },
259        Call { #[pin] future: T::Future },
260    }
261}
262
263impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
264where
265    E: FromRequestParts<S>,
266    T: Service<Request<B>>,
267    T::Response: IntoResponse,
268{
269    type Output = Result<Response, T::Error>;
270
271    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
272        loop {
273            let mut this = self.as_mut().project();
274
275            let new_state = match this.state.as_mut().project() {
276                StateProj::Extracting { future } => {
277                    let (req, extracted) = ready!(future.as_mut().poll(cx));
278
279                    match extracted {
280                        Ok(_) => {
281                            let mut svc = this.svc.take().expect("future polled after completion");
282                            let future = svc.call(req);
283                            State::Call { future }
284                        }
285                        Err(err) => {
286                            let res = err.into_response();
287                            return Poll::Ready(Ok(res));
288                        }
289                    }
290                }
291                StateProj::Call { future } => {
292                    return future
293                        .poll(cx)
294                        .map(|result| result.map(IntoResponse::into_response));
295                }
296            };
297
298            this.state.set(new_state);
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::{async_trait, handler::Handler, routing::get, test_helpers::*, Router};
307    use axum_core::extract::FromRef;
308    use http::{header, request::Parts, StatusCode};
309    use tower_http::limit::RequestBodyLimitLayer;
310
311    #[crate::test]
312    async fn test_from_extractor() {
313        #[derive(Clone)]
314        struct Secret(&'static str);
315
316        struct RequireAuth;
317
318        #[async_trait::async_trait]
319        impl<S> FromRequestParts<S> for RequireAuth
320        where
321            S: Send + Sync,
322            Secret: FromRef<S>,
323        {
324            type Rejection = StatusCode;
325
326            async fn from_request_parts(
327                parts: &mut Parts,
328                state: &S,
329            ) -> Result<Self, Self::Rejection> {
330                let Secret(secret) = Secret::from_ref(state);
331                if let Some(auth) = parts
332                    .headers
333                    .get(header::AUTHORIZATION)
334                    .and_then(|v| v.to_str().ok())
335                {
336                    if auth == secret {
337                        return Ok(Self);
338                    }
339                }
340
341                Err(StatusCode::UNAUTHORIZED)
342            }
343        }
344
345        async fn handler() {}
346
347        let state = Secret("secret");
348        let app = Router::new().route(
349            "/",
350            get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
351        );
352
353        let client = TestClient::new(app);
354
355        let res = client.get("/").await;
356        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
357
358        let res = client
359            .get("/")
360            .header(http::header::AUTHORIZATION, "secret")
361            .await;
362        assert_eq!(res.status(), StatusCode::OK);
363    }
364
365    // just needs to compile
366    #[allow(dead_code)]
367    fn works_with_request_body_limit() {
368        struct MyExtractor;
369
370        #[async_trait]
371        impl<S> FromRequestParts<S> for MyExtractor
372        where
373            S: Send + Sync,
374        {
375            type Rejection = std::convert::Infallible;
376
377            async fn from_request_parts(
378                _parts: &mut Parts,
379                _state: &S,
380            ) -> Result<Self, Self::Rejection> {
381                unimplemented!()
382            }
383        }
384
385        let _: Router = Router::new()
386            .layer(from_extractor::<MyExtractor>())
387            .layer(RequestBodyLimitLayer::new(1));
388    }
389}