axum_core/ext_traits/request_parts.rs
1use crate::extract::FromRequestParts;
2use futures_util::future::BoxFuture;
3use http::request::Parts;
4
5mod sealed {
6 pub trait Sealed {}
7 impl Sealed for http::request::Parts {}
8}
9
10/// Extension trait that adds additional methods to [`Parts`].
11pub trait RequestPartsExt: sealed::Sealed + Sized {
12 /// Apply an extractor to this `Parts`.
13 ///
14 /// This is just a convenience for `E::from_request_parts(parts, &())`.
15 ///
16 /// # Example
17 ///
18 /// ```
19 /// use axum::{
20 /// extract::{Query, Path, FromRequestParts},
21 /// response::{Response, IntoResponse},
22 /// http::request::Parts,
23 /// RequestPartsExt,
24 /// async_trait,
25 /// };
26 /// use std::collections::HashMap;
27 ///
28 /// struct MyExtractor {
29 /// path_params: HashMap<String, String>,
30 /// query_params: HashMap<String, String>,
31 /// }
32 ///
33 /// #[async_trait]
34 /// impl<S> FromRequestParts<S> for MyExtractor
35 /// where
36 /// S: Send + Sync,
37 /// {
38 /// type Rejection = Response;
39 ///
40 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
41 /// let path_params = parts
42 /// .extract::<Path<HashMap<String, String>>>()
43 /// .await
44 /// .map(|Path(path_params)| path_params)
45 /// .map_err(|err| err.into_response())?;
46 ///
47 /// let query_params = parts
48 /// .extract::<Query<HashMap<String, String>>>()
49 /// .await
50 /// .map(|Query(params)| params)
51 /// .map_err(|err| err.into_response())?;
52 ///
53 /// Ok(MyExtractor { path_params, query_params })
54 /// }
55 /// }
56 /// ```
57 fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
58 where
59 E: FromRequestParts<()> + 'static;
60
61 /// Apply an extractor that requires some state to this `Parts`.
62 ///
63 /// This is just a convenience for `E::from_request_parts(parts, state)`.
64 ///
65 /// # Example
66 ///
67 /// ```
68 /// use axum::{
69 /// extract::{FromRef, FromRequestParts},
70 /// response::{Response, IntoResponse},
71 /// http::request::Parts,
72 /// RequestPartsExt,
73 /// async_trait,
74 /// };
75 ///
76 /// struct MyExtractor {
77 /// requires_state: RequiresState,
78 /// }
79 ///
80 /// #[async_trait]
81 /// impl<S> FromRequestParts<S> for MyExtractor
82 /// where
83 /// String: FromRef<S>,
84 /// S: Send + Sync,
85 /// {
86 /// type Rejection = std::convert::Infallible;
87 ///
88 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
89 /// let requires_state = parts
90 /// .extract_with_state::<RequiresState, _>(state)
91 /// .await?;
92 ///
93 /// Ok(MyExtractor { requires_state })
94 /// }
95 /// }
96 ///
97 /// struct RequiresState { /* ... */ }
98 ///
99 /// // some extractor that requires a `String` in the state
100 /// #[async_trait]
101 /// impl<S> FromRequestParts<S> for RequiresState
102 /// where
103 /// String: FromRef<S>,
104 /// S: Send + Sync,
105 /// {
106 /// // ...
107 /// # type Rejection = std::convert::Infallible;
108 /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
109 /// # unimplemented!()
110 /// # }
111 /// }
112 /// ```
113 fn extract_with_state<'a, E, S>(
114 &'a mut self,
115 state: &'a S,
116 ) -> BoxFuture<'a, Result<E, E::Rejection>>
117 where
118 E: FromRequestParts<S> + 'static,
119 S: Send + Sync;
120}
121
122impl RequestPartsExt for Parts {
123 fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
124 where
125 E: FromRequestParts<()> + 'static,
126 {
127 self.extract_with_state(&())
128 }
129
130 fn extract_with_state<'a, E, S>(
131 &'a mut self,
132 state: &'a S,
133 ) -> BoxFuture<'a, Result<E, E::Rejection>>
134 where
135 E: FromRequestParts<S> + 'static,
136 S: Send + Sync,
137 {
138 E::from_request_parts(self, state)
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use std::convert::Infallible;
145
146 use super::*;
147 use crate::{
148 ext_traits::tests::{RequiresState, State},
149 extract::FromRef,
150 };
151 use async_trait::async_trait;
152 use http::{Method, Request};
153
154 #[tokio::test]
155 async fn extract_without_state() {
156 let (mut parts, _) = Request::new(()).into_parts();
157
158 let method: Method = parts.extract().await.unwrap();
159
160 assert_eq!(method, Method::GET);
161 }
162
163 #[tokio::test]
164 async fn extract_with_state() {
165 let (mut parts, _) = Request::new(()).into_parts();
166
167 let state = "state".to_owned();
168
169 let State(extracted_state): State<String> = parts
170 .extract_with_state::<State<String>, String>(&state)
171 .await
172 .unwrap();
173
174 assert_eq!(extracted_state, state);
175 }
176
177 // this stuff just needs to compile
178 #[allow(dead_code)]
179 struct WorksForCustomExtractor {
180 method: Method,
181 from_state: String,
182 }
183
184 #[async_trait]
185 impl<S> FromRequestParts<S> for WorksForCustomExtractor
186 where
187 S: Send + Sync,
188 String: FromRef<S>,
189 {
190 type Rejection = Infallible;
191
192 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
193 let RequiresState(from_state) = parts.extract_with_state(state).await?;
194 let method = parts.extract().await?;
195
196 Ok(Self { method, from_state })
197 }
198 }
199}