axum/routing/
path_router.rs

1use crate::{
2    extract::{nested_path::SetNestedPath, Request},
3    handler::Handler,
4};
5use axum_core::response::IntoResponse;
6use matchit::MatchError;
7use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
8use tower_layer::Layer;
9use tower_service::Service;
10
11use super::{
12    future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
13    MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
14};
15
16pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
17    routes: HashMap<RouteId, Endpoint<S>>,
18    node: Arc<Node>,
19    prev_route_id: RouteId,
20}
21
22impl<S> PathRouter<S, true>
23where
24    S: Clone + Send + Sync + 'static,
25{
26    pub(super) fn new_fallback() -> Self {
27        let mut this = Self::default();
28        this.set_fallback(Endpoint::Route(Route::new(NotFound)));
29        this
30    }
31
32    pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
33        self.replace_endpoint("/", endpoint.clone());
34        self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
35    }
36}
37
38impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
39where
40    S: Clone + Send + Sync + 'static,
41{
42    pub(super) fn route(
43        &mut self,
44        path: &str,
45        method_router: MethodRouter<S>,
46    ) -> Result<(), Cow<'static, str>> {
47        fn validate_path(path: &str) -> Result<(), &'static str> {
48            if path.is_empty() {
49                return Err("Paths must start with a `/`. Use \"/\" for root routes");
50            } else if !path.starts_with('/') {
51                return Err("Paths must start with a `/`");
52            }
53
54            Ok(())
55        }
56
57        validate_path(path)?;
58
59        let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
60            .node
61            .path_to_route_id
62            .get(path)
63            .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
64        {
65            // if we're adding a new `MethodRouter` to a route that already has one just
66            // merge them. This makes `.route("/", get(_)).route("/", post(_))` work
67            let service = Endpoint::MethodRouter(
68                prev_method_router
69                    .clone()
70                    .merge_for_path(Some(path), method_router),
71            );
72            self.routes.insert(route_id, service);
73            return Ok(());
74        } else {
75            Endpoint::MethodRouter(method_router)
76        };
77
78        let id = self.next_route_id();
79        self.set_node(path, id)?;
80        self.routes.insert(id, endpoint);
81
82        Ok(())
83    }
84
85    pub(super) fn method_not_allowed_fallback<H, T>(&mut self, handler: H)
86    where
87        H: Handler<T, S>,
88        T: 'static,
89    {
90        for (_, endpoint) in self.routes.iter_mut() {
91            if let Endpoint::MethodRouter(rt) = endpoint {
92                *rt = rt.clone().default_fallback(handler.clone());
93            }
94        }
95    }
96
97    pub(super) fn route_service<T>(
98        &mut self,
99        path: &str,
100        service: T,
101    ) -> Result<(), Cow<'static, str>>
102    where
103        T: Service<Request, Error = Infallible> + Clone + Send + 'static,
104        T::Response: IntoResponse,
105        T::Future: Send + 'static,
106    {
107        self.route_endpoint(path, Endpoint::Route(Route::new(service)))
108    }
109
110    pub(super) fn route_endpoint(
111        &mut self,
112        path: &str,
113        endpoint: Endpoint<S>,
114    ) -> Result<(), Cow<'static, str>> {
115        if path.is_empty() {
116            return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
117        } else if !path.starts_with('/') {
118            return Err("Paths must start with a `/`".into());
119        }
120
121        let id = self.next_route_id();
122        self.set_node(path, id)?;
123        self.routes.insert(id, endpoint);
124
125        Ok(())
126    }
127
128    fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
129        let node = Arc::make_mut(&mut self.node);
130
131        node.insert(path, id)
132            .map_err(|err| format!("Invalid route {path:?}: {err}"))
133    }
134
135    pub(super) fn merge(
136        &mut self,
137        other: PathRouter<S, IS_FALLBACK>,
138    ) -> Result<(), Cow<'static, str>> {
139        let PathRouter {
140            routes,
141            node,
142            prev_route_id: _,
143        } = other;
144
145        for (id, route) in routes {
146            let path = node
147                .route_id_to_path
148                .get(&id)
149                .expect("no path for route id. This is a bug in axum. Please file an issue");
150
151            if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
152                // when merging two routers it doesn't matter if you do `a.merge(b)` or
153                // `b.merge(a)`. This must also be true for fallbacks.
154                //
155                // However all fallback routers will have routes for `/` and `/*` so when merging
156                // we have to ignore the top level fallbacks on one side otherwise we get
157                // conflicts.
158                //
159                // `Router::merge` makes sure that when merging fallbacks `other` always has the
160                // fallback we want to keep. It panics if both routers have a custom fallback. Thus
161                // it is always okay to ignore one fallback and `Router::merge` also makes sure the
162                // one we can ignore is that of `self`.
163                self.replace_endpoint(path, route);
164            } else {
165                match route {
166                    Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
167                    Endpoint::Route(route) => self.route_service(path, route)?,
168                }
169            }
170        }
171
172        Ok(())
173    }
174
175    pub(super) fn nest(
176        &mut self,
177        path_to_nest_at: &str,
178        router: PathRouter<S, IS_FALLBACK>,
179    ) -> Result<(), Cow<'static, str>> {
180        let prefix = validate_nest_path(path_to_nest_at);
181
182        let PathRouter {
183            routes,
184            node,
185            prev_route_id: _,
186        } = router;
187
188        for (id, endpoint) in routes {
189            let inner_path = node
190                .route_id_to_path
191                .get(&id)
192                .expect("no path for route id. This is a bug in axum. Please file an issue");
193
194            let path = path_for_nested_route(prefix, inner_path);
195
196            let layer = (
197                StripPrefix::layer(prefix),
198                SetNestedPath::layer(path_to_nest_at),
199            );
200            match endpoint.layer(layer) {
201                Endpoint::MethodRouter(method_router) => {
202                    self.route(&path, method_router)?;
203                }
204                Endpoint::Route(route) => {
205                    self.route_endpoint(&path, Endpoint::Route(route))?;
206                }
207            }
208        }
209
210        Ok(())
211    }
212
213    pub(super) fn nest_service<T>(
214        &mut self,
215        path_to_nest_at: &str,
216        svc: T,
217    ) -> Result<(), Cow<'static, str>>
218    where
219        T: Service<Request, Error = Infallible> + Clone + Send + 'static,
220        T::Response: IntoResponse,
221        T::Future: Send + 'static,
222    {
223        let path = validate_nest_path(path_to_nest_at);
224        let prefix = path;
225
226        let path = if path.ends_with('/') {
227            format!("{path}*{NEST_TAIL_PARAM}")
228        } else {
229            format!("{path}/*{NEST_TAIL_PARAM}")
230        };
231
232        let layer = (
233            StripPrefix::layer(prefix),
234            SetNestedPath::layer(path_to_nest_at),
235        );
236        let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));
237
238        self.route_endpoint(&path, endpoint.clone())?;
239
240        // `/*rest` is not matched by `/` so we need to also register a router at the
241        // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
242        // wouldn't match, which it should
243        self.route_endpoint(prefix, endpoint.clone())?;
244        if !prefix.ends_with('/') {
245            // same goes for `/foo/`, that should also match
246            self.route_endpoint(&format!("{prefix}/"), endpoint)?;
247        }
248
249        Ok(())
250    }
251
252    pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
253    where
254        L: Layer<Route> + Clone + Send + 'static,
255        L::Service: Service<Request> + Clone + Send + 'static,
256        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
257        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
258        <L::Service as Service<Request>>::Future: Send + 'static,
259    {
260        let routes = self
261            .routes
262            .into_iter()
263            .map(|(id, endpoint)| {
264                let route = endpoint.layer(layer.clone());
265                (id, route)
266            })
267            .collect();
268
269        PathRouter {
270            routes,
271            node: self.node,
272            prev_route_id: self.prev_route_id,
273        }
274    }
275
276    #[track_caller]
277    pub(super) fn route_layer<L>(self, layer: L) -> Self
278    where
279        L: Layer<Route> + Clone + Send + 'static,
280        L::Service: Service<Request> + Clone + Send + 'static,
281        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
282        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
283        <L::Service as Service<Request>>::Future: Send + 'static,
284    {
285        if self.routes.is_empty() {
286            panic!(
287                "Adding a route_layer before any routes is a no-op. \
288                 Add the routes you want the layer to apply to first."
289            );
290        }
291
292        let routes = self
293            .routes
294            .into_iter()
295            .map(|(id, endpoint)| {
296                let route = endpoint.layer(layer.clone());
297                (id, route)
298            })
299            .collect();
300
301        PathRouter {
302            routes,
303            node: self.node,
304            prev_route_id: self.prev_route_id,
305        }
306    }
307
308    pub(super) fn has_routes(&self) -> bool {
309        !self.routes.is_empty()
310    }
311
312    pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
313        let routes = self
314            .routes
315            .into_iter()
316            .map(|(id, endpoint)| {
317                let endpoint: Endpoint<S2> = match endpoint {
318                    Endpoint::MethodRouter(method_router) => {
319                        Endpoint::MethodRouter(method_router.with_state(state.clone()))
320                    }
321                    Endpoint::Route(route) => Endpoint::Route(route),
322                };
323                (id, endpoint)
324            })
325            .collect();
326
327        PathRouter {
328            routes,
329            node: self.node,
330            prev_route_id: self.prev_route_id,
331        }
332    }
333
334    pub(super) fn call_with_state(
335        &self,
336        mut req: Request,
337        state: S,
338    ) -> Result<RouteFuture<Infallible>, (Request, S)> {
339        #[cfg(feature = "original-uri")]
340        {
341            use crate::extract::OriginalUri;
342
343            if req.extensions().get::<OriginalUri>().is_none() {
344                let original_uri = OriginalUri(req.uri().clone());
345                req.extensions_mut().insert(original_uri);
346            }
347        }
348
349        let (mut parts, body) = req.into_parts();
350
351        match self.node.at(parts.uri.path()) {
352            Ok(match_) => {
353                let id = *match_.value;
354
355                if !IS_FALLBACK {
356                    #[cfg(feature = "matched-path")]
357                    crate::extract::matched_path::set_matched_path_for_request(
358                        id,
359                        &self.node.route_id_to_path,
360                        &mut parts.extensions,
361                    );
362                }
363
364                url_params::insert_url_params(&mut parts.extensions, match_.params);
365
366                let endpoint = self
367                    .routes
368                    .get(&id)
369                    .expect("no route for id. This is a bug in axum. Please file an issue");
370
371                let req = Request::from_parts(parts, body);
372                match endpoint {
373                    Endpoint::MethodRouter(method_router) => {
374                        Ok(method_router.call_with_state(req, state))
375                    }
376                    Endpoint::Route(route) => Ok(route.clone().call_owned(req)),
377                }
378            }
379            // explicitly handle all variants in case matchit adds
380            // new ones we need to handle differently
381            Err(
382                MatchError::NotFound
383                | MatchError::ExtraTrailingSlash
384                | MatchError::MissingTrailingSlash,
385            ) => Err((Request::from_parts(parts, body), state)),
386        }
387    }
388
389    pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
390        match self.node.at(path) {
391            Ok(match_) => {
392                let id = *match_.value;
393                self.routes.insert(id, endpoint);
394            }
395            Err(_) => self
396                .route_endpoint(path, endpoint)
397                .expect("path wasn't matched so endpoint shouldn't exist"),
398        }
399    }
400
401    fn next_route_id(&mut self) -> RouteId {
402        let next_id = self
403            .prev_route_id
404            .0
405            .checked_add(1)
406            .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
407        self.prev_route_id = RouteId(next_id);
408        self.prev_route_id
409    }
410}
411
412impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
413    fn default() -> Self {
414        Self {
415            routes: Default::default(),
416            node: Default::default(),
417            prev_route_id: RouteId(0),
418        }
419    }
420}
421
422impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
423    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424        f.debug_struct("PathRouter")
425            .field("routes", &self.routes)
426            .field("node", &self.node)
427            .finish()
428    }
429}
430
431impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
432    fn clone(&self) -> Self {
433        Self {
434            routes: self.routes.clone(),
435            node: self.node.clone(),
436            prev_route_id: self.prev_route_id,
437        }
438    }
439}
440
441/// Wrapper around `matchit::Router` that supports merging two `Router`s.
442#[derive(Clone, Default)]
443struct Node {
444    inner: matchit::Router<RouteId>,
445    route_id_to_path: HashMap<RouteId, Arc<str>>,
446    path_to_route_id: HashMap<Arc<str>, RouteId>,
447}
448
449impl Node {
450    fn insert(
451        &mut self,
452        path: impl Into<String>,
453        val: RouteId,
454    ) -> Result<(), matchit::InsertError> {
455        let path = path.into();
456
457        self.inner.insert(&path, val)?;
458
459        let shared_path: Arc<str> = path.into();
460        self.route_id_to_path.insert(val, shared_path.clone());
461        self.path_to_route_id.insert(shared_path, val);
462
463        Ok(())
464    }
465
466    fn at<'n, 'p>(
467        &'n self,
468        path: &'p str,
469    ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
470        self.inner.at(path)
471    }
472}
473
474impl fmt::Debug for Node {
475    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476        f.debug_struct("Node")
477            .field("paths", &self.route_id_to_path)
478            .finish()
479    }
480}
481
482#[track_caller]
483fn validate_nest_path(path: &str) -> &str {
484    if path.is_empty() {
485        // nesting at `""` and `"/"` should mean the same thing
486        return "/";
487    }
488
489    if path.contains('*') {
490        panic!("Invalid route: nested routes cannot contain wildcards (*)");
491    }
492
493    path
494}
495
496pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
497    debug_assert!(prefix.starts_with('/'));
498    debug_assert!(path.starts_with('/'));
499
500    if prefix.ends_with('/') {
501        format!("{prefix}{}", path.trim_start_matches('/')).into()
502    } else if path == "/" {
503        prefix.into()
504    } else {
505        format!("{prefix}{path}").into()
506    }
507}