axum_core/ext_traits/
request_parts.rs

1use crate::extract::FromRequestParts;
2use futures_util::future::BoxFuture;
3use http::request::Parts;
4
5mod sealed {
6    pub trait Sealed {}
7    impl Sealed for http::request::Parts {}
8}
9
10/// Extension trait that adds additional methods to [`Parts`].
11pub trait RequestPartsExt: sealed::Sealed + Sized {
12    /// Apply an extractor to this `Parts`.
13    ///
14    /// This is just a convenience for `E::from_request_parts(parts, &())`.
15    ///
16    /// # Example
17    ///
18    /// ```
19    /// use axum::{
20    ///     extract::{Query, Path, FromRequestParts},
21    ///     response::{Response, IntoResponse},
22    ///     http::request::Parts,
23    ///     RequestPartsExt,
24    ///     async_trait,
25    /// };
26    /// use std::collections::HashMap;
27    ///
28    /// struct MyExtractor {
29    ///     path_params: HashMap<String, String>,
30    ///     query_params: HashMap<String, String>,
31    /// }
32    ///
33    /// #[async_trait]
34    /// impl<S> FromRequestParts<S> for MyExtractor
35    /// where
36    ///     S: Send + Sync,
37    /// {
38    ///     type Rejection = Response;
39    ///
40    ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
41    ///         let path_params = parts
42    ///             .extract::<Path<HashMap<String, String>>>()
43    ///             .await
44    ///             .map(|Path(path_params)| path_params)
45    ///             .map_err(|err| err.into_response())?;
46    ///
47    ///         let query_params = parts
48    ///             .extract::<Query<HashMap<String, String>>>()
49    ///             .await
50    ///             .map(|Query(params)| params)
51    ///             .map_err(|err| err.into_response())?;
52    ///
53    ///         Ok(MyExtractor { path_params, query_params })
54    ///     }
55    /// }
56    /// ```
57    fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
58    where
59        E: FromRequestParts<()> + 'static;
60
61    /// Apply an extractor that requires some state to this `Parts`.
62    ///
63    /// This is just a convenience for `E::from_request_parts(parts, state)`.
64    ///
65    /// # Example
66    ///
67    /// ```
68    /// use axum::{
69    ///     extract::{FromRef, FromRequestParts},
70    ///     response::{Response, IntoResponse},
71    ///     http::request::Parts,
72    ///     RequestPartsExt,
73    ///     async_trait,
74    /// };
75    ///
76    /// struct MyExtractor {
77    ///     requires_state: RequiresState,
78    /// }
79    ///
80    /// #[async_trait]
81    /// impl<S> FromRequestParts<S> for MyExtractor
82    /// where
83    ///     String: FromRef<S>,
84    ///     S: Send + Sync,
85    /// {
86    ///     type Rejection = std::convert::Infallible;
87    ///
88    ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
89    ///         let requires_state = parts
90    ///             .extract_with_state::<RequiresState, _>(state)
91    ///             .await?;
92    ///
93    ///         Ok(MyExtractor { requires_state })
94    ///     }
95    /// }
96    ///
97    /// struct RequiresState { /* ... */ }
98    ///
99    /// // some extractor that requires a `String` in the state
100    /// #[async_trait]
101    /// impl<S> FromRequestParts<S> for RequiresState
102    /// where
103    ///     String: FromRef<S>,
104    ///     S: Send + Sync,
105    /// {
106    ///     // ...
107    ///     # type Rejection = std::convert::Infallible;
108    ///     # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
109    ///     #     unimplemented!()
110    ///     # }
111    /// }
112    /// ```
113    fn extract_with_state<'a, E, S>(
114        &'a mut self,
115        state: &'a S,
116    ) -> BoxFuture<'a, Result<E, E::Rejection>>
117    where
118        E: FromRequestParts<S> + 'static,
119        S: Send + Sync;
120}
121
122impl RequestPartsExt for Parts {
123    fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
124    where
125        E: FromRequestParts<()> + 'static,
126    {
127        self.extract_with_state(&())
128    }
129
130    fn extract_with_state<'a, E, S>(
131        &'a mut self,
132        state: &'a S,
133    ) -> BoxFuture<'a, Result<E, E::Rejection>>
134    where
135        E: FromRequestParts<S> + 'static,
136        S: Send + Sync,
137    {
138        E::from_request_parts(self, state)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use std::convert::Infallible;
145
146    use super::*;
147    use crate::{
148        ext_traits::tests::{RequiresState, State},
149        extract::FromRef,
150    };
151    use async_trait::async_trait;
152    use http::{Method, Request};
153
154    #[tokio::test]
155    async fn extract_without_state() {
156        let (mut parts, _) = Request::new(()).into_parts();
157
158        let method: Method = parts.extract().await.unwrap();
159
160        assert_eq!(method, Method::GET);
161    }
162
163    #[tokio::test]
164    async fn extract_with_state() {
165        let (mut parts, _) = Request::new(()).into_parts();
166
167        let state = "state".to_owned();
168
169        let State(extracted_state): State<String> = parts
170            .extract_with_state::<State<String>, String>(&state)
171            .await
172            .unwrap();
173
174        assert_eq!(extracted_state, state);
175    }
176
177    // this stuff just needs to compile
178    #[allow(dead_code)]
179    struct WorksForCustomExtractor {
180        method: Method,
181        from_state: String,
182    }
183
184    #[async_trait]
185    impl<S> FromRequestParts<S> for WorksForCustomExtractor
186    where
187        S: Send + Sync,
188        String: FromRef<S>,
189    {
190        type Rejection = Infallible;
191
192        async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
193            let RequiresState(from_state) = parts.extract_with_state(state).await?;
194            let method = parts.extract().await?;
195
196            Ok(Self { method, from_state })
197        }
198    }
199}