axum/extract/request_parts.rs
1use super::{Extension, FromRequestParts};
2use async_trait::async_trait;
3use http::{request::Parts, Uri};
4use std::convert::Infallible;
5
6/// Extractor that gets the original request URI regardless of nesting.
7///
8/// This is necessary since [`Uri`](http::Uri), when used as an extractor, will
9/// have the prefix stripped if used in a nested service.
10///
11/// # Example
12///
13/// ```
14/// use axum::{
15/// routing::get,
16/// Router,
17/// extract::OriginalUri,
18/// http::Uri
19/// };
20///
21/// let api_routes = Router::new()
22/// .route(
23/// "/users",
24/// get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
25/// // `uri` is `/users`
26/// // `original_uri` is `/api/users`
27/// }),
28/// );
29///
30/// let app = Router::new().nest("/api", api_routes);
31/// # let _: Router = app;
32/// ```
33///
34/// # Extracting via request extensions
35///
36/// `OriginalUri` can also be accessed from middleware via request extensions.
37/// This is useful for example with [`Trace`](tower_http::trace::Trace) to
38/// create a span that contains the full path, if your service might be nested:
39///
40/// ```
41/// use axum::{
42/// Router,
43/// extract::OriginalUri,
44/// http::Request,
45/// routing::get,
46/// };
47/// use tower_http::trace::TraceLayer;
48///
49/// let api_routes = Router::new()
50/// .route("/users/:id", get(|| async { /* ... */ }))
51/// .layer(
52/// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
53/// let path = if let Some(path) = req.extensions().get::<OriginalUri>() {
54/// // This will include `/api`
55/// path.0.path().to_owned()
56/// } else {
57/// // The `OriginalUri` extension will always be present if using
58/// // `Router` unless another extractor or middleware has removed it
59/// req.uri().path().to_owned()
60/// };
61/// tracing::info_span!("http-request", %path)
62/// }),
63/// );
64///
65/// let app = Router::new().nest("/api", api_routes);
66/// # let _: Router = app;
67/// ```
68#[cfg(feature = "original-uri")]
69#[derive(Debug, Clone)]
70pub struct OriginalUri(pub Uri);
71
72#[cfg(feature = "original-uri")]
73#[async_trait]
74impl<S> FromRequestParts<S> for OriginalUri
75where
76 S: Send + Sync,
77{
78 type Rejection = Infallible;
79
80 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
81 let uri = Extension::<Self>::from_request_parts(parts, state)
82 .await
83 .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
84 .0;
85 Ok(uri)
86 }
87}
88
89#[cfg(feature = "original-uri")]
90axum_core::__impl_deref!(OriginalUri: Uri);
91
92#[cfg(test)]
93mod tests {
94 use crate::{extract::Extension, routing::get, test_helpers::*, Router};
95 use http::{Method, StatusCode};
96
97 #[crate::test]
98 async fn extract_request_parts() {
99 #[derive(Clone)]
100 struct Ext;
101
102 async fn handler(parts: http::request::Parts) {
103 assert_eq!(parts.method, Method::GET);
104 assert_eq!(parts.uri, "/");
105 assert_eq!(parts.version, http::Version::HTTP_11);
106 assert_eq!(parts.headers["x-foo"], "123");
107 parts.extensions.get::<Ext>().unwrap();
108 }
109
110 let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
111
112 let res = client.get("/").header("x-foo", "123").await;
113 assert_eq!(res.status(), StatusCode::OK);
114 }
115}