axum/extract/
nested_path.rs

1use std::{
2    sync::Arc,
3    task::{Context, Poll},
4};
5
6use crate::extract::Request;
7use async_trait::async_trait;
8use axum_core::extract::FromRequestParts;
9use http::request::Parts;
10use tower_layer::{layer_fn, Layer};
11use tower_service::Service;
12
13use super::rejection::NestedPathRejection;
14
15/// Access the path the matched the route is nested at.
16///
17/// This can for example be used when doing redirects.
18///
19/// # Example
20///
21/// ```
22/// use axum::{
23///     Router,
24///     extract::NestedPath,
25///     routing::get,
26/// };
27///
28/// let api = Router::new().route(
29///     "/users",
30///     get(|path: NestedPath| async move {
31///         // `path` will be "/api" because thats what this
32///         // router is nested at when we build `app`
33///         let path = path.as_str();
34///     })
35/// );
36///
37/// let app = Router::new().nest("/api", api);
38/// # let _: Router = app;
39/// ```
40#[derive(Debug, Clone)]
41pub struct NestedPath(Arc<str>);
42
43impl NestedPath {
44    /// Returns a `str` representation of the path.
45    pub fn as_str(&self) -> &str {
46        &self.0
47    }
48}
49
50#[async_trait]
51impl<S> FromRequestParts<S> for NestedPath
52where
53    S: Send + Sync,
54{
55    type Rejection = NestedPathRejection;
56
57    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
58        match parts.extensions.get::<Self>() {
59            Some(nested_path) => Ok(nested_path.clone()),
60            None => Err(NestedPathRejection),
61        }
62    }
63}
64
65#[derive(Clone)]
66pub(crate) struct SetNestedPath<S> {
67    inner: S,
68    path: Arc<str>,
69}
70
71impl<S> SetNestedPath<S> {
72    pub(crate) fn layer(path: &str) -> impl Layer<S, Service = Self> + Clone {
73        let path = Arc::from(path);
74        layer_fn(move |inner| Self {
75            inner,
76            path: Arc::clone(&path),
77        })
78    }
79}
80
81impl<S, B> Service<Request<B>> for SetNestedPath<S>
82where
83    S: Service<Request<B>>,
84{
85    type Response = S::Response;
86    type Error = S::Error;
87    type Future = S::Future;
88
89    #[inline]
90    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91        self.inner.poll_ready(cx)
92    }
93
94    fn call(&mut self, mut req: Request<B>) -> Self::Future {
95        if let Some(prev) = req.extensions_mut().get_mut::<NestedPath>() {
96            let new_path = if prev.as_str() == "/" {
97                Arc::clone(&self.path)
98            } else {
99                format!("{}{}", prev.as_str().trim_end_matches('/'), self.path).into()
100            };
101            prev.0 = new_path;
102        } else {
103            req.extensions_mut()
104                .insert(NestedPath(Arc::clone(&self.path)));
105        };
106
107        self.inner.call(req)
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use axum_core::response::Response;
114    use http::StatusCode;
115
116    use crate::{
117        extract::{NestedPath, Request},
118        middleware::{from_fn, Next},
119        routing::get,
120        test_helpers::*,
121        Router,
122    };
123
124    #[crate::test]
125    async fn one_level_of_nesting() {
126        let api = Router::new().route(
127            "/users",
128            get(|nested_path: NestedPath| {
129                assert_eq!(nested_path.as_str(), "/api");
130                async {}
131            }),
132        );
133
134        let app = Router::new().nest("/api", api);
135
136        let client = TestClient::new(app);
137
138        let res = client.get("/api/users").await;
139        assert_eq!(res.status(), StatusCode::OK);
140    }
141
142    #[crate::test]
143    async fn one_level_of_nesting_with_trailing_slash() {
144        let api = Router::new().route(
145            "/users",
146            get(|nested_path: NestedPath| {
147                assert_eq!(nested_path.as_str(), "/api/");
148                async {}
149            }),
150        );
151
152        let app = Router::new().nest("/api/", api);
153
154        let client = TestClient::new(app);
155
156        let res = client.get("/api/users").await;
157        assert_eq!(res.status(), StatusCode::OK);
158    }
159
160    #[crate::test]
161    async fn two_levels_of_nesting() {
162        let api = Router::new().route(
163            "/users",
164            get(|nested_path: NestedPath| {
165                assert_eq!(nested_path.as_str(), "/api/v2");
166                async {}
167            }),
168        );
169
170        let app = Router::new().nest("/api", Router::new().nest("/v2", api));
171
172        let client = TestClient::new(app);
173
174        let res = client.get("/api/v2/users").await;
175        assert_eq!(res.status(), StatusCode::OK);
176    }
177
178    #[crate::test]
179    async fn two_levels_of_nesting_with_trailing_slash() {
180        let api = Router::new().route(
181            "/users",
182            get(|nested_path: NestedPath| {
183                assert_eq!(nested_path.as_str(), "/api/v2");
184                async {}
185            }),
186        );
187
188        let app = Router::new().nest("/api/", Router::new().nest("/v2", api));
189
190        let client = TestClient::new(app);
191
192        let res = client.get("/api/v2/users").await;
193        assert_eq!(res.status(), StatusCode::OK);
194    }
195
196    #[crate::test]
197    async fn nested_at_root() {
198        let api = Router::new().route(
199            "/users",
200            get(|nested_path: NestedPath| {
201                assert_eq!(nested_path.as_str(), "/");
202                async {}
203            }),
204        );
205
206        let app = Router::new().nest("/", api);
207
208        let client = TestClient::new(app);
209
210        let res = client.get("/users").await;
211        assert_eq!(res.status(), StatusCode::OK);
212    }
213
214    #[crate::test]
215    async fn deeply_nested_from_root() {
216        let api = Router::new().route(
217            "/users",
218            get(|nested_path: NestedPath| {
219                assert_eq!(nested_path.as_str(), "/api");
220                async {}
221            }),
222        );
223
224        let app = Router::new().nest("/", Router::new().nest("/api", api));
225
226        let client = TestClient::new(app);
227
228        let res = client.get("/api/users").await;
229        assert_eq!(res.status(), StatusCode::OK);
230    }
231
232    #[crate::test]
233    async fn in_fallbacks() {
234        let api = Router::new().fallback(get(|nested_path: NestedPath| {
235            assert_eq!(nested_path.as_str(), "/api");
236            async {}
237        }));
238
239        let app = Router::new().nest("/api", api);
240
241        let client = TestClient::new(app);
242
243        let res = client.get("/api/doesnt-exist").await;
244        assert_eq!(res.status(), StatusCode::OK);
245    }
246
247    #[crate::test]
248    async fn in_middleware() {
249        async fn middleware(nested_path: NestedPath, req: Request, next: Next) -> Response {
250            assert_eq!(nested_path.as_str(), "/api");
251            next.run(req).await
252        }
253
254        let api = Router::new()
255            .route("/users", get(|| async {}))
256            .layer(from_fn(middleware));
257
258        let app = Router::new().nest("/api", api);
259
260        let client = TestClient::new(app);
261
262        let res = client.get("/api/users").await;
263        assert_eq!(res.status(), StatusCode::OK);
264    }
265}