axum_core/ext_traits/request.rs
1use crate::body::Body;
2use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4
5mod sealed {
6 pub trait Sealed {}
7 impl Sealed for http::Request<crate::body::Body> {}
8}
9
10/// Extension trait that adds additional methods to [`Request`].
11pub trait RequestExt: sealed::Sealed + Sized {
12 /// Apply an extractor to this `Request`.
13 ///
14 /// This is just a convenience for `E::from_request(req, &())`.
15 ///
16 /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting
17 /// the body and don't want to consume the request.
18 ///
19 /// # Example
20 ///
21 /// ```
22 /// use axum::{
23 /// async_trait,
24 /// extract::{Request, FromRequest},
25 /// body::Body,
26 /// http::{header::CONTENT_TYPE, StatusCode},
27 /// response::{IntoResponse, Response},
28 /// Form, Json, RequestExt,
29 /// };
30 ///
31 /// struct FormOrJson<T>(T);
32 ///
33 /// #[async_trait]
34 /// impl<S, T> FromRequest<S> for FormOrJson<T>
35 /// where
36 /// Json<T>: FromRequest<()>,
37 /// Form<T>: FromRequest<()>,
38 /// T: 'static,
39 /// S: Send + Sync,
40 /// {
41 /// type Rejection = Response;
42 ///
43 /// async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
44 /// let content_type = req
45 /// .headers()
46 /// .get(CONTENT_TYPE)
47 /// .and_then(|value| value.to_str().ok())
48 /// .ok_or_else(|| StatusCode::BAD_REQUEST.into_response())?;
49 ///
50 /// if content_type.starts_with("application/json") {
51 /// let Json(payload) = req
52 /// .extract::<Json<T>, _>()
53 /// .await
54 /// .map_err(|err| err.into_response())?;
55 ///
56 /// Ok(Self(payload))
57 /// } else if content_type.starts_with("application/x-www-form-urlencoded") {
58 /// let Form(payload) = req
59 /// .extract::<Form<T>, _>()
60 /// .await
61 /// .map_err(|err| err.into_response())?;
62 ///
63 /// Ok(Self(payload))
64 /// } else {
65 /// Err(StatusCode::BAD_REQUEST.into_response())
66 /// }
67 /// }
68 /// }
69 /// ```
70 fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
71 where
72 E: FromRequest<(), M> + 'static,
73 M: 'static;
74
75 /// Apply an extractor that requires some state to this `Request`.
76 ///
77 /// This is just a convenience for `E::from_request(req, state)`.
78 ///
79 /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not
80 /// extracting the body and don't want to consume the request.
81 ///
82 /// # Example
83 ///
84 /// ```
85 /// use axum::{
86 /// async_trait,
87 /// body::Body,
88 /// extract::{Request, FromRef, FromRequest},
89 /// RequestExt,
90 /// };
91 ///
92 /// struct MyExtractor {
93 /// requires_state: RequiresState,
94 /// }
95 ///
96 /// #[async_trait]
97 /// impl<S> FromRequest<S> for MyExtractor
98 /// where
99 /// String: FromRef<S>,
100 /// S: Send + Sync,
101 /// {
102 /// type Rejection = std::convert::Infallible;
103 ///
104 /// async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
105 /// let requires_state = req.extract_with_state::<RequiresState, _, _>(state).await?;
106 ///
107 /// Ok(Self { requires_state })
108 /// }
109 /// }
110 ///
111 /// // some extractor that consumes the request body and requires state
112 /// struct RequiresState { /* ... */ }
113 ///
114 /// #[async_trait]
115 /// impl<S> FromRequest<S> for RequiresState
116 /// where
117 /// String: FromRef<S>,
118 /// S: Send + Sync,
119 /// {
120 /// // ...
121 /// # type Rejection = std::convert::Infallible;
122 /// # async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
123 /// # todo!()
124 /// # }
125 /// }
126 /// ```
127 fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
128 where
129 E: FromRequest<S, M> + 'static,
130 S: Send + Sync;
131
132 /// Apply a parts extractor to this `Request`.
133 ///
134 /// This is just a convenience for `E::from_request_parts(parts, state)`.
135 ///
136 /// # Example
137 ///
138 /// ```
139 /// use axum::{
140 /// async_trait,
141 /// extract::{Path, Request, FromRequest},
142 /// response::{IntoResponse, Response},
143 /// body::Body,
144 /// Json, RequestExt,
145 /// };
146 /// use axum_extra::{
147 /// TypedHeader,
148 /// headers::{authorization::Bearer, Authorization},
149 /// };
150 /// use std::collections::HashMap;
151 ///
152 /// struct MyExtractor<T> {
153 /// path_params: HashMap<String, String>,
154 /// payload: T,
155 /// }
156 ///
157 /// #[async_trait]
158 /// impl<S, T> FromRequest<S> for MyExtractor<T>
159 /// where
160 /// S: Send + Sync,
161 /// Json<T>: FromRequest<()>,
162 /// T: 'static,
163 /// {
164 /// type Rejection = Response;
165 ///
166 /// async fn from_request(mut req: Request, _state: &S) -> Result<Self, Self::Rejection> {
167 /// let path_params = req
168 /// .extract_parts::<Path<_>>()
169 /// .await
170 /// .map(|Path(path_params)| path_params)
171 /// .map_err(|err| err.into_response())?;
172 ///
173 /// let Json(payload) = req
174 /// .extract::<Json<T>, _>()
175 /// .await
176 /// .map_err(|err| err.into_response())?;
177 ///
178 /// Ok(Self { path_params, payload })
179 /// }
180 /// }
181 /// ```
182 fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
183 where
184 E: FromRequestParts<()> + 'static;
185
186 /// Apply a parts extractor that requires some state to this `Request`.
187 ///
188 /// This is just a convenience for `E::from_request_parts(parts, state)`.
189 ///
190 /// # Example
191 ///
192 /// ```
193 /// use axum::{
194 /// async_trait,
195 /// extract::{Request, FromRef, FromRequest, FromRequestParts},
196 /// http::request::Parts,
197 /// response::{IntoResponse, Response},
198 /// body::Body,
199 /// Json, RequestExt,
200 /// };
201 ///
202 /// struct MyExtractor<T> {
203 /// requires_state: RequiresState,
204 /// payload: T,
205 /// }
206 ///
207 /// #[async_trait]
208 /// impl<S, T> FromRequest<S> for MyExtractor<T>
209 /// where
210 /// String: FromRef<S>,
211 /// Json<T>: FromRequest<()>,
212 /// T: 'static,
213 /// S: Send + Sync,
214 /// {
215 /// type Rejection = Response;
216 ///
217 /// async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
218 /// let requires_state = req
219 /// .extract_parts_with_state::<RequiresState, _>(state)
220 /// .await
221 /// .map_err(|err| err.into_response())?;
222 ///
223 /// let Json(payload) = req
224 /// .extract::<Json<T>, _>()
225 /// .await
226 /// .map_err(|err| err.into_response())?;
227 ///
228 /// Ok(Self {
229 /// requires_state,
230 /// payload,
231 /// })
232 /// }
233 /// }
234 ///
235 /// struct RequiresState {}
236 ///
237 /// #[async_trait]
238 /// impl<S> FromRequestParts<S> for RequiresState
239 /// where
240 /// String: FromRef<S>,
241 /// S: Send + Sync,
242 /// {
243 /// // ...
244 /// # type Rejection = std::convert::Infallible;
245 /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
246 /// # todo!()
247 /// # }
248 /// }
249 /// ```
250 fn extract_parts_with_state<'a, E, S>(
251 &'a mut self,
252 state: &'a S,
253 ) -> BoxFuture<'a, Result<E, E::Rejection>>
254 where
255 E: FromRequestParts<S> + 'static,
256 S: Send + Sync;
257
258 /// Apply the [default body limit](crate::extract::DefaultBodyLimit).
259 ///
260 /// If it is disabled, the request is returned as-is.
261 fn with_limited_body(self) -> Request;
262
263 /// Consumes the request, returning the body wrapped in [`http_body_util::Limited`] if a
264 /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
265 /// default limit is disabled.
266 fn into_limited_body(self) -> Body;
267}
268
269impl RequestExt for Request {
270 fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
271 where
272 E: FromRequest<(), M> + 'static,
273 M: 'static,
274 {
275 self.extract_with_state(&())
276 }
277
278 fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
279 where
280 E: FromRequest<S, M> + 'static,
281 S: Send + Sync,
282 {
283 E::from_request(self, state)
284 }
285
286 fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
287 where
288 E: FromRequestParts<()> + 'static,
289 {
290 self.extract_parts_with_state(&())
291 }
292
293 fn extract_parts_with_state<'a, E, S>(
294 &'a mut self,
295 state: &'a S,
296 ) -> BoxFuture<'a, Result<E, E::Rejection>>
297 where
298 E: FromRequestParts<S> + 'static,
299 S: Send + Sync,
300 {
301 let mut req = Request::new(());
302 *req.version_mut() = self.version();
303 *req.method_mut() = self.method().clone();
304 *req.uri_mut() = self.uri().clone();
305 *req.headers_mut() = std::mem::take(self.headers_mut());
306 *req.extensions_mut() = std::mem::take(self.extensions_mut());
307 let (mut parts, ()) = req.into_parts();
308
309 Box::pin(async move {
310 let result = E::from_request_parts(&mut parts, state).await;
311
312 *self.version_mut() = parts.version;
313 *self.method_mut() = parts.method.clone();
314 *self.uri_mut() = parts.uri.clone();
315 *self.headers_mut() = std::mem::take(&mut parts.headers);
316 *self.extensions_mut() = std::mem::take(&mut parts.extensions);
317
318 result
319 })
320 }
321
322 fn with_limited_body(self) -> Request {
323 // update docs in `axum-core/src/extract/default_body_limit.rs` and
324 // `axum/src/docs/extract.md` if this changes
325 const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb
326
327 match self.extensions().get::<DefaultBodyLimitKind>().copied() {
328 Some(DefaultBodyLimitKind::Disable) => self,
329 Some(DefaultBodyLimitKind::Limit(limit)) => {
330 self.map(|b| Body::new(http_body_util::Limited::new(b, limit)))
331 }
332 None => self.map(|b| Body::new(http_body_util::Limited::new(b, DEFAULT_LIMIT))),
333 }
334 }
335
336 fn into_limited_body(self) -> Body {
337 self.with_limited_body().into_body()
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::{
345 ext_traits::tests::{RequiresState, State},
346 extract::FromRef,
347 };
348 use async_trait::async_trait;
349 use http::Method;
350
351 #[tokio::test]
352 async fn extract_without_state() {
353 let req = Request::new(Body::empty());
354
355 let method: Method = req.extract().await.unwrap();
356
357 assert_eq!(method, Method::GET);
358 }
359
360 #[tokio::test]
361 async fn extract_body_without_state() {
362 let req = Request::new(Body::from("foobar"));
363
364 let body: String = req.extract().await.unwrap();
365
366 assert_eq!(body, "foobar");
367 }
368
369 #[tokio::test]
370 async fn extract_with_state() {
371 let req = Request::new(Body::empty());
372
373 let state = "state".to_owned();
374
375 let State(extracted_state): State<String> = req.extract_with_state(&state).await.unwrap();
376
377 assert_eq!(extracted_state, state);
378 }
379
380 #[tokio::test]
381 async fn extract_parts_without_state() {
382 let mut req = Request::builder()
383 .header("x-foo", "foo")
384 .body(Body::empty())
385 .unwrap();
386
387 let method: Method = req.extract_parts().await.unwrap();
388
389 assert_eq!(method, Method::GET);
390 assert_eq!(req.headers()["x-foo"], "foo");
391 }
392
393 #[tokio::test]
394 async fn extract_parts_with_state() {
395 let mut req = Request::builder()
396 .header("x-foo", "foo")
397 .body(Body::empty())
398 .unwrap();
399
400 let state = "state".to_owned();
401
402 let State(extracted_state): State<String> =
403 req.extract_parts_with_state(&state).await.unwrap();
404
405 assert_eq!(extracted_state, state);
406 assert_eq!(req.headers()["x-foo"], "foo");
407 }
408
409 // this stuff just needs to compile
410 #[allow(dead_code)]
411 struct WorksForCustomExtractor {
412 method: Method,
413 from_state: String,
414 body: String,
415 }
416
417 #[async_trait]
418 impl<S> FromRequest<S> for WorksForCustomExtractor
419 where
420 S: Send + Sync,
421 String: FromRef<S> + FromRequest<()>,
422 {
423 type Rejection = <String as FromRequest<()>>::Rejection;
424
425 async fn from_request(mut req: Request, state: &S) -> Result<Self, Self::Rejection> {
426 let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap();
427 let method = req.extract_parts().await.unwrap();
428 let body = req.extract().await?;
429
430 Ok(Self {
431 method,
432 from_state,
433 body,
434 })
435 }
436 }
437}