1use crate::{
2 extract::{nested_path::SetNestedPath, Request},
3 handler::Handler,
4};
5use axum_core::response::IntoResponse;
6use matchit::MatchError;
7use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
8use tower_layer::Layer;
9use tower_service::Service;
10
11use super::{
12 future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
13 MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
14};
15
16pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
17 routes: HashMap<RouteId, Endpoint<S>>,
18 node: Arc<Node>,
19 prev_route_id: RouteId,
20}
21
22impl<S> PathRouter<S, true>
23where
24 S: Clone + Send + Sync + 'static,
25{
26 pub(super) fn new_fallback() -> Self {
27 let mut this = Self::default();
28 this.set_fallback(Endpoint::Route(Route::new(NotFound)));
29 this
30 }
31
32 pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
33 self.replace_endpoint("/", endpoint.clone());
34 self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
35 }
36}
37
38impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
39where
40 S: Clone + Send + Sync + 'static,
41{
42 pub(super) fn route(
43 &mut self,
44 path: &str,
45 method_router: MethodRouter<S>,
46 ) -> Result<(), Cow<'static, str>> {
47 fn validate_path(path: &str) -> Result<(), &'static str> {
48 if path.is_empty() {
49 return Err("Paths must start with a `/`. Use \"/\" for root routes");
50 } else if !path.starts_with('/') {
51 return Err("Paths must start with a `/`");
52 }
53
54 Ok(())
55 }
56
57 validate_path(path)?;
58
59 let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
60 .node
61 .path_to_route_id
62 .get(path)
63 .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
64 {
65 let service = Endpoint::MethodRouter(
68 prev_method_router
69 .clone()
70 .merge_for_path(Some(path), method_router),
71 );
72 self.routes.insert(route_id, service);
73 return Ok(());
74 } else {
75 Endpoint::MethodRouter(method_router)
76 };
77
78 let id = self.next_route_id();
79 self.set_node(path, id)?;
80 self.routes.insert(id, endpoint);
81
82 Ok(())
83 }
84
85 pub(super) fn method_not_allowed_fallback<H, T>(&mut self, handler: H)
86 where
87 H: Handler<T, S>,
88 T: 'static,
89 {
90 for (_, endpoint) in self.routes.iter_mut() {
91 if let Endpoint::MethodRouter(rt) = endpoint {
92 *rt = rt.clone().default_fallback(handler.clone());
93 }
94 }
95 }
96
97 pub(super) fn route_service<T>(
98 &mut self,
99 path: &str,
100 service: T,
101 ) -> Result<(), Cow<'static, str>>
102 where
103 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
104 T::Response: IntoResponse,
105 T::Future: Send + 'static,
106 {
107 self.route_endpoint(path, Endpoint::Route(Route::new(service)))
108 }
109
110 pub(super) fn route_endpoint(
111 &mut self,
112 path: &str,
113 endpoint: Endpoint<S>,
114 ) -> Result<(), Cow<'static, str>> {
115 if path.is_empty() {
116 return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
117 } else if !path.starts_with('/') {
118 return Err("Paths must start with a `/`".into());
119 }
120
121 let id = self.next_route_id();
122 self.set_node(path, id)?;
123 self.routes.insert(id, endpoint);
124
125 Ok(())
126 }
127
128 fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
129 let node = Arc::make_mut(&mut self.node);
130
131 node.insert(path, id)
132 .map_err(|err| format!("Invalid route {path:?}: {err}"))
133 }
134
135 pub(super) fn merge(
136 &mut self,
137 other: PathRouter<S, IS_FALLBACK>,
138 ) -> Result<(), Cow<'static, str>> {
139 let PathRouter {
140 routes,
141 node,
142 prev_route_id: _,
143 } = other;
144
145 for (id, route) in routes {
146 let path = node
147 .route_id_to_path
148 .get(&id)
149 .expect("no path for route id. This is a bug in axum. Please file an issue");
150
151 if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
152 self.replace_endpoint(path, route);
164 } else {
165 match route {
166 Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
167 Endpoint::Route(route) => self.route_service(path, route)?,
168 }
169 }
170 }
171
172 Ok(())
173 }
174
175 pub(super) fn nest(
176 &mut self,
177 path_to_nest_at: &str,
178 router: PathRouter<S, IS_FALLBACK>,
179 ) -> Result<(), Cow<'static, str>> {
180 let prefix = validate_nest_path(path_to_nest_at);
181
182 let PathRouter {
183 routes,
184 node,
185 prev_route_id: _,
186 } = router;
187
188 for (id, endpoint) in routes {
189 let inner_path = node
190 .route_id_to_path
191 .get(&id)
192 .expect("no path for route id. This is a bug in axum. Please file an issue");
193
194 let path = path_for_nested_route(prefix, inner_path);
195
196 let layer = (
197 StripPrefix::layer(prefix),
198 SetNestedPath::layer(path_to_nest_at),
199 );
200 match endpoint.layer(layer) {
201 Endpoint::MethodRouter(method_router) => {
202 self.route(&path, method_router)?;
203 }
204 Endpoint::Route(route) => {
205 self.route_endpoint(&path, Endpoint::Route(route))?;
206 }
207 }
208 }
209
210 Ok(())
211 }
212
213 pub(super) fn nest_service<T>(
214 &mut self,
215 path_to_nest_at: &str,
216 svc: T,
217 ) -> Result<(), Cow<'static, str>>
218 where
219 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
220 T::Response: IntoResponse,
221 T::Future: Send + 'static,
222 {
223 let path = validate_nest_path(path_to_nest_at);
224 let prefix = path;
225
226 let path = if path.ends_with('/') {
227 format!("{path}*{NEST_TAIL_PARAM}")
228 } else {
229 format!("{path}/*{NEST_TAIL_PARAM}")
230 };
231
232 let layer = (
233 StripPrefix::layer(prefix),
234 SetNestedPath::layer(path_to_nest_at),
235 );
236 let endpoint = Endpoint::Route(Route::new(layer.layer(svc)));
237
238 self.route_endpoint(&path, endpoint.clone())?;
239
240 self.route_endpoint(prefix, endpoint.clone())?;
244 if !prefix.ends_with('/') {
245 self.route_endpoint(&format!("{prefix}/"), endpoint)?;
247 }
248
249 Ok(())
250 }
251
252 pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
253 where
254 L: Layer<Route> + Clone + Send + 'static,
255 L::Service: Service<Request> + Clone + Send + 'static,
256 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
257 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
258 <L::Service as Service<Request>>::Future: Send + 'static,
259 {
260 let routes = self
261 .routes
262 .into_iter()
263 .map(|(id, endpoint)| {
264 let route = endpoint.layer(layer.clone());
265 (id, route)
266 })
267 .collect();
268
269 PathRouter {
270 routes,
271 node: self.node,
272 prev_route_id: self.prev_route_id,
273 }
274 }
275
276 #[track_caller]
277 pub(super) fn route_layer<L>(self, layer: L) -> Self
278 where
279 L: Layer<Route> + Clone + Send + 'static,
280 L::Service: Service<Request> + Clone + Send + 'static,
281 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
282 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
283 <L::Service as Service<Request>>::Future: Send + 'static,
284 {
285 if self.routes.is_empty() {
286 panic!(
287 "Adding a route_layer before any routes is a no-op. \
288 Add the routes you want the layer to apply to first."
289 );
290 }
291
292 let routes = self
293 .routes
294 .into_iter()
295 .map(|(id, endpoint)| {
296 let route = endpoint.layer(layer.clone());
297 (id, route)
298 })
299 .collect();
300
301 PathRouter {
302 routes,
303 node: self.node,
304 prev_route_id: self.prev_route_id,
305 }
306 }
307
308 pub(super) fn has_routes(&self) -> bool {
309 !self.routes.is_empty()
310 }
311
312 pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
313 let routes = self
314 .routes
315 .into_iter()
316 .map(|(id, endpoint)| {
317 let endpoint: Endpoint<S2> = match endpoint {
318 Endpoint::MethodRouter(method_router) => {
319 Endpoint::MethodRouter(method_router.with_state(state.clone()))
320 }
321 Endpoint::Route(route) => Endpoint::Route(route),
322 };
323 (id, endpoint)
324 })
325 .collect();
326
327 PathRouter {
328 routes,
329 node: self.node,
330 prev_route_id: self.prev_route_id,
331 }
332 }
333
334 pub(super) fn call_with_state(
335 &self,
336 mut req: Request,
337 state: S,
338 ) -> Result<RouteFuture<Infallible>, (Request, S)> {
339 #[cfg(feature = "original-uri")]
340 {
341 use crate::extract::OriginalUri;
342
343 if req.extensions().get::<OriginalUri>().is_none() {
344 let original_uri = OriginalUri(req.uri().clone());
345 req.extensions_mut().insert(original_uri);
346 }
347 }
348
349 let (mut parts, body) = req.into_parts();
350
351 match self.node.at(parts.uri.path()) {
352 Ok(match_) => {
353 let id = *match_.value;
354
355 if !IS_FALLBACK {
356 #[cfg(feature = "matched-path")]
357 crate::extract::matched_path::set_matched_path_for_request(
358 id,
359 &self.node.route_id_to_path,
360 &mut parts.extensions,
361 );
362 }
363
364 url_params::insert_url_params(&mut parts.extensions, match_.params);
365
366 let endpoint = self
367 .routes
368 .get(&id)
369 .expect("no route for id. This is a bug in axum. Please file an issue");
370
371 let req = Request::from_parts(parts, body);
372 match endpoint {
373 Endpoint::MethodRouter(method_router) => {
374 Ok(method_router.call_with_state(req, state))
375 }
376 Endpoint::Route(route) => Ok(route.clone().call_owned(req)),
377 }
378 }
379 Err(
382 MatchError::NotFound
383 | MatchError::ExtraTrailingSlash
384 | MatchError::MissingTrailingSlash,
385 ) => Err((Request::from_parts(parts, body), state)),
386 }
387 }
388
389 pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
390 match self.node.at(path) {
391 Ok(match_) => {
392 let id = *match_.value;
393 self.routes.insert(id, endpoint);
394 }
395 Err(_) => self
396 .route_endpoint(path, endpoint)
397 .expect("path wasn't matched so endpoint shouldn't exist"),
398 }
399 }
400
401 fn next_route_id(&mut self) -> RouteId {
402 let next_id = self
403 .prev_route_id
404 .0
405 .checked_add(1)
406 .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
407 self.prev_route_id = RouteId(next_id);
408 self.prev_route_id
409 }
410}
411
412impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
413 fn default() -> Self {
414 Self {
415 routes: Default::default(),
416 node: Default::default(),
417 prev_route_id: RouteId(0),
418 }
419 }
420}
421
422impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
423 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
424 f.debug_struct("PathRouter")
425 .field("routes", &self.routes)
426 .field("node", &self.node)
427 .finish()
428 }
429}
430
431impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
432 fn clone(&self) -> Self {
433 Self {
434 routes: self.routes.clone(),
435 node: self.node.clone(),
436 prev_route_id: self.prev_route_id,
437 }
438 }
439}
440
441#[derive(Clone, Default)]
443struct Node {
444 inner: matchit::Router<RouteId>,
445 route_id_to_path: HashMap<RouteId, Arc<str>>,
446 path_to_route_id: HashMap<Arc<str>, RouteId>,
447}
448
449impl Node {
450 fn insert(
451 &mut self,
452 path: impl Into<String>,
453 val: RouteId,
454 ) -> Result<(), matchit::InsertError> {
455 let path = path.into();
456
457 self.inner.insert(&path, val)?;
458
459 let shared_path: Arc<str> = path.into();
460 self.route_id_to_path.insert(val, shared_path.clone());
461 self.path_to_route_id.insert(shared_path, val);
462
463 Ok(())
464 }
465
466 fn at<'n, 'p>(
467 &'n self,
468 path: &'p str,
469 ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
470 self.inner.at(path)
471 }
472}
473
474impl fmt::Debug for Node {
475 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476 f.debug_struct("Node")
477 .field("paths", &self.route_id_to_path)
478 .finish()
479 }
480}
481
482#[track_caller]
483fn validate_nest_path(path: &str) -> &str {
484 if path.is_empty() {
485 return "/";
487 }
488
489 if path.contains('*') {
490 panic!("Invalid route: nested routes cannot contain wildcards (*)");
491 }
492
493 path
494}
495
496pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
497 debug_assert!(prefix.starts_with('/'));
498 debug_assert!(path.starts_with('/'));
499
500 if prefix.ends_with('/') {
501 format!("{prefix}{}", path.trim_start_matches('/')).into()
502 } else if path == "/" {
503 prefix.into()
504 } else {
505 format!("{prefix}{path}").into()
506 }
507}