1use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7 body::{Body, HttpBody},
8 boxed::BoxedIntoRoute,
9 handler::Handler,
10 util::try_downcast,
11};
12use axum_core::{
13 extract::Request,
14 response::{IntoResponse, Response},
15};
16use std::{
17 convert::Infallible,
18 fmt,
19 marker::PhantomData,
20 sync::Arc,
21 task::{Context, Poll},
22};
23use tower_layer::Layer;
24use tower_service::Service;
25
26pub mod future;
27pub mod method_routing;
28
29mod into_make_service;
30mod method_filter;
31mod not_found;
32pub(crate) mod path_router;
33mod route;
34mod strip_prefix;
35pub(crate) mod url_params;
36
37#[cfg(test)]
38mod tests;
39
40pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
41
42pub use self::method_routing::{
43 any, any_service, connect, connect_service, delete, delete_service, get, get_service, head,
44 head_service, on, on_service, options, options_service, patch, patch_service, post,
45 post_service, put, put_service, trace, trace_service, MethodRouter,
46};
47
48macro_rules! panic_on_err {
49 ($expr:expr) => {
50 match $expr {
51 Ok(x) => x,
52 Err(err) => panic!("{err}"),
53 }
54 };
55}
56
57#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
58pub(crate) struct RouteId(u32);
59
60#[must_use]
62pub struct Router<S = ()> {
63 inner: Arc<RouterInner<S>>,
64}
65
66impl<S> Clone for Router<S> {
67 fn clone(&self) -> Self {
68 Self {
69 inner: Arc::clone(&self.inner),
70 }
71 }
72}
73
74struct RouterInner<S> {
75 path_router: PathRouter<S, false>,
76 fallback_router: PathRouter<S, true>,
77 default_fallback: bool,
78 catch_all_fallback: Fallback<S>,
79}
80
81impl<S> Default for Router<S>
82where
83 S: Clone + Send + Sync + 'static,
84{
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl<S> fmt::Debug for Router<S> {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_struct("Router")
93 .field("path_router", &self.inner.path_router)
94 .field("fallback_router", &self.inner.fallback_router)
95 .field("default_fallback", &self.inner.default_fallback)
96 .field("catch_all_fallback", &self.inner.catch_all_fallback)
97 .finish()
98 }
99}
100
101pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
102pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
103pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
104pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback";
105
106macro_rules! map_inner {
107 ( $self_:ident, $inner:pat_param => $expr:expr) => {
108 #[allow(redundant_semicolons)]
109 {
110 let $inner = $self_.into_inner();
111 Router {
112 inner: Arc::new($expr),
113 }
114 }
115 };
116}
117
118macro_rules! tap_inner {
119 ( $self_:ident, mut $inner:ident => { $($stmt:stmt)* } ) => {
120 #[allow(redundant_semicolons)]
121 {
122 let mut $inner = $self_.into_inner();
123 $($stmt)*
124 Router {
125 inner: Arc::new($inner),
126 }
127 }
128 };
129}
130
131impl<S> Router<S>
132where
133 S: Clone + Send + Sync + 'static,
134{
135 pub fn new() -> Self {
140 Self {
141 inner: Arc::new(RouterInner {
142 path_router: Default::default(),
143 fallback_router: PathRouter::new_fallback(),
144 default_fallback: true,
145 catch_all_fallback: Fallback::Default(Route::new(NotFound)),
146 }),
147 }
148 }
149
150 fn into_inner(self) -> RouterInner<S> {
151 match Arc::try_unwrap(self.inner) {
152 Ok(inner) => inner,
153 Err(arc) => RouterInner {
154 path_router: arc.path_router.clone(),
155 fallback_router: arc.fallback_router.clone(),
156 default_fallback: arc.default_fallback,
157 catch_all_fallback: arc.catch_all_fallback.clone(),
158 },
159 }
160 }
161
162 #[doc = include_str!("../docs/routing/route.md")]
163 #[track_caller]
164 pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self {
165 tap_inner!(self, mut this => {
166 panic_on_err!(this.path_router.route(path, method_router));
167 })
168 }
169
170 #[doc = include_str!("../docs/routing/route_service.md")]
171 pub fn route_service<T>(self, path: &str, service: T) -> Self
172 where
173 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
174 T::Response: IntoResponse,
175 T::Future: Send + 'static,
176 {
177 let service = match try_downcast::<Router<S>, _>(service) {
178 Ok(_) => {
179 panic!(
180 "Invalid route: `Router::route_service` cannot be used with `Router`s. \
181 Use `Router::nest` instead"
182 );
183 }
184 Err(service) => service,
185 };
186
187 tap_inner!(self, mut this => {
188 panic_on_err!(this.path_router.route_service(path, service));
189 })
190 }
191
192 #[doc = include_str!("../docs/routing/nest.md")]
193 #[doc(alias = "scope")] #[track_caller]
195 pub fn nest(self, path: &str, router: Router<S>) -> Self {
196 let RouterInner {
197 path_router,
198 fallback_router,
199 default_fallback,
200 catch_all_fallback: _,
204 } = router.into_inner();
205
206 tap_inner!(self, mut this => {
207 panic_on_err!(this.path_router.nest(path, path_router));
208
209 if !default_fallback {
210 panic_on_err!(this.fallback_router.nest(path, fallback_router));
211 }
212 })
213 }
214
215 #[track_caller]
217 pub fn nest_service<T>(self, path: &str, service: T) -> Self
218 where
219 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
220 T::Response: IntoResponse,
221 T::Future: Send + 'static,
222 {
223 tap_inner!(self, mut this => {
224 panic_on_err!(this.path_router.nest_service(path, service));
225 })
226 }
227
228 #[doc = include_str!("../docs/routing/merge.md")]
229 #[track_caller]
230 pub fn merge<R>(self, other: R) -> Self
231 where
232 R: Into<Router<S>>,
233 {
234 const PANIC_MSG: &str =
235 "Failed to merge fallbacks. This is a bug in axum. Please file an issue";
236
237 let other: Router<S> = other.into();
238 let RouterInner {
239 path_router,
240 fallback_router: mut other_fallback,
241 default_fallback,
242 catch_all_fallback,
243 } = other.into_inner();
244
245 map_inner!(self, mut this => {
246 panic_on_err!(this.path_router.merge(path_router));
247
248 match (this.default_fallback, default_fallback) {
249 (true, true) => {
252 this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
253 }
254 (true, false) => {
256 this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
257 this.default_fallback = false;
258 }
259 (false, true) => {
261 let fallback_router = std::mem::take(&mut this.fallback_router);
262 other_fallback.merge(fallback_router).expect(PANIC_MSG);
263 this.fallback_router = other_fallback;
264 }
265 (false, false) => {
267 panic!("Cannot merge two `Router`s that both have a fallback")
268 }
269 };
270
271 this.catch_all_fallback = this
272 .catch_all_fallback
273 .merge(catch_all_fallback)
274 .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
275
276 this
277 })
278 }
279
280 #[doc = include_str!("../docs/routing/layer.md")]
281 pub fn layer<L>(self, layer: L) -> Router<S>
282 where
283 L: Layer<Route> + Clone + Send + 'static,
284 L::Service: Service<Request> + Clone + Send + 'static,
285 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
286 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
287 <L::Service as Service<Request>>::Future: Send + 'static,
288 {
289 map_inner!(self, this => RouterInner {
290 path_router: this.path_router.layer(layer.clone()),
291 fallback_router: this.fallback_router.layer(layer.clone()),
292 default_fallback: this.default_fallback,
293 catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
294 })
295 }
296
297 #[doc = include_str!("../docs/routing/route_layer.md")]
298 #[track_caller]
299 pub fn route_layer<L>(self, layer: L) -> Self
300 where
301 L: Layer<Route> + Clone + Send + 'static,
302 L::Service: Service<Request> + Clone + Send + 'static,
303 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
304 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
305 <L::Service as Service<Request>>::Future: Send + 'static,
306 {
307 map_inner!(self, this => RouterInner {
308 path_router: this.path_router.route_layer(layer),
309 fallback_router: this.fallback_router,
310 default_fallback: this.default_fallback,
311 catch_all_fallback: this.catch_all_fallback,
312 })
313 }
314
315 pub fn has_routes(&self) -> bool {
317 self.inner.path_router.has_routes()
318 }
319
320 #[track_caller]
321 #[doc = include_str!("../docs/routing/fallback.md")]
322 pub fn fallback<H, T>(self, handler: H) -> Self
323 where
324 H: Handler<T, S>,
325 T: 'static,
326 {
327 tap_inner!(self, mut this => {
328 this.catch_all_fallback =
329 Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
330 })
331 .fallback_endpoint(Endpoint::MethodRouter(any(handler)))
332 }
333
334 pub fn fallback_service<T>(self, service: T) -> Self
338 where
339 T: Service<Request, Error = Infallible> + Clone + Send + 'static,
340 T::Response: IntoResponse,
341 T::Future: Send + 'static,
342 {
343 let route = Route::new(service);
344 tap_inner!(self, mut this => {
345 this.catch_all_fallback = Fallback::Service(route.clone());
346 })
347 .fallback_endpoint(Endpoint::Route(route))
348 }
349
350 #[doc = include_str!("../docs/routing/method_not_allowed_fallback.md")]
351 pub fn method_not_allowed_fallback<H, T>(self, handler: H) -> Self
352 where
353 H: Handler<T, S>,
354 T: 'static,
355 {
356 tap_inner!(self, mut this => {
357 this.path_router
358 .method_not_allowed_fallback(handler.clone())
359 })
360 }
361
362 fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
363 tap_inner!(self, mut this => {
364 this.fallback_router.set_fallback(endpoint);
365 this.default_fallback = false;
366 })
367 }
368
369 #[doc = include_str!("../docs/routing/with_state.md")]
370 pub fn with_state<S2>(self, state: S) -> Router<S2> {
371 map_inner!(self, this => RouterInner {
372 path_router: this.path_router.with_state(state.clone()),
373 fallback_router: this.fallback_router.with_state(state.clone()),
374 default_fallback: this.default_fallback,
375 catch_all_fallback: this.catch_all_fallback.with_state(state),
376 })
377 }
378
379 pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<Infallible> {
380 let (req, state) = match self.inner.path_router.call_with_state(req, state) {
381 Ok(future) => return future,
382 Err((req, state)) => (req, state),
383 };
384
385 let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
386 Ok(future) => return future,
387 Err((req, state)) => (req, state),
388 };
389
390 self.inner
391 .catch_all_fallback
392 .clone()
393 .call_with_state(req, state)
394 }
395
396 pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S> {
450 RouterAsService {
451 router: self,
452 _marker: PhantomData,
453 }
454 }
455
456 pub fn into_service<B>(self) -> RouterIntoService<B, S> {
462 RouterIntoService {
463 router: self,
464 _marker: PhantomData,
465 }
466 }
467}
468
469impl Router {
470 pub fn into_make_service(self) -> IntoMakeService<Self> {
489 IntoMakeService::new(self.with_state(()))
492 }
493
494 #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")]
495 #[cfg(feature = "tokio")]
496 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
497 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
500 }
501}
502
503#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
505const _: () = {
506 use crate::serve::IncomingStream;
507
508 impl Service<IncomingStream<'_>> for Router<()> {
509 type Response = Self;
510 type Error = Infallible;
511 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
512
513 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
514 Poll::Ready(Ok(()))
515 }
516
517 fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
518 std::future::ready(Ok(self.clone().with_state(())))
521 }
522 }
523};
524
525impl<B> Service<Request<B>> for Router<()>
526where
527 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
528 B::Error: Into<axum_core::BoxError>,
529{
530 type Response = Response;
531 type Error = Infallible;
532 type Future = RouteFuture<Infallible>;
533
534 #[inline]
535 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
536 Poll::Ready(Ok(()))
537 }
538
539 #[inline]
540 fn call(&mut self, req: Request<B>) -> Self::Future {
541 let req = req.map(Body::new);
542 self.call_with_state(req, ())
543 }
544}
545
546pub struct RouterAsService<'a, B, S = ()> {
550 router: &'a mut Router<S>,
551 _marker: PhantomData<B>,
552}
553
554impl<B> Service<Request<B>> for RouterAsService<'_, B, ()>
555where
556 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
557 B::Error: Into<axum_core::BoxError>,
558{
559 type Response = Response;
560 type Error = Infallible;
561 type Future = RouteFuture<Infallible>;
562
563 #[inline]
564 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
565 <Router as Service<Request<B>>>::poll_ready(self.router, cx)
566 }
567
568 #[inline]
569 fn call(&mut self, req: Request<B>) -> Self::Future {
570 self.router.call(req)
571 }
572}
573
574impl<B, S> fmt::Debug for RouterAsService<'_, B, S>
575where
576 S: fmt::Debug,
577{
578 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
579 f.debug_struct("RouterAsService")
580 .field("router", &self.router)
581 .finish()
582 }
583}
584
585pub struct RouterIntoService<B, S = ()> {
589 router: Router<S>,
590 _marker: PhantomData<B>,
591}
592
593impl<B, S> Clone for RouterIntoService<B, S>
594where
595 Router<S>: Clone,
596{
597 fn clone(&self) -> Self {
598 Self {
599 router: self.router.clone(),
600 _marker: PhantomData,
601 }
602 }
603}
604
605impl<B> Service<Request<B>> for RouterIntoService<B, ()>
606where
607 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
608 B::Error: Into<axum_core::BoxError>,
609{
610 type Response = Response;
611 type Error = Infallible;
612 type Future = RouteFuture<Infallible>;
613
614 #[inline]
615 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616 <Router as Service<Request<B>>>::poll_ready(&mut self.router, cx)
617 }
618
619 #[inline]
620 fn call(&mut self, req: Request<B>) -> Self::Future {
621 self.router.call(req)
622 }
623}
624
625impl<B, S> fmt::Debug for RouterIntoService<B, S>
626where
627 S: fmt::Debug,
628{
629 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
630 f.debug_struct("RouterIntoService")
631 .field("router", &self.router)
632 .finish()
633 }
634}
635
636enum Fallback<S, E = Infallible> {
637 Default(Route<E>),
638 Service(Route<E>),
639 BoxedHandler(BoxedIntoRoute<S, E>),
640}
641
642impl<S, E> Fallback<S, E>
643where
644 S: Clone,
645{
646 fn merge(self, other: Self) -> Option<Self> {
647 match (self, other) {
648 (Self::Default(_), pick @ Self::Default(_)) => Some(pick),
649 (Self::Default(_), pick) | (pick, Self::Default(_)) => Some(pick),
650 _ => None,
651 }
652 }
653
654 fn map<F, E2>(self, f: F) -> Fallback<S, E2>
655 where
656 S: 'static,
657 E: 'static,
658 F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
659 E2: 'static,
660 {
661 match self {
662 Self::Default(route) => Fallback::Default(f(route)),
663 Self::Service(route) => Fallback::Service(f(route)),
664 Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
665 }
666 }
667
668 fn with_state<S2>(self, state: S) -> Fallback<S2, E> {
669 match self {
670 Fallback::Default(route) => Fallback::Default(route),
671 Fallback::Service(route) => Fallback::Service(route),
672 Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
673 }
674 }
675
676 fn call_with_state(self, req: Request, state: S) -> RouteFuture<E> {
677 match self {
678 Fallback::Default(route) | Fallback::Service(route) => {
679 RouteFuture::from_future(route.oneshot_inner_owned(req))
680 }
681 Fallback::BoxedHandler(handler) => {
682 let route = handler.clone().into_route(state);
683 RouteFuture::from_future(route.oneshot_inner_owned(req))
684 }
685 }
686 }
687}
688
689impl<S, E> Clone for Fallback<S, E> {
690 fn clone(&self) -> Self {
691 match self {
692 Self::Default(inner) => Self::Default(inner.clone()),
693 Self::Service(inner) => Self::Service(inner.clone()),
694 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
695 }
696 }
697}
698
699impl<S, E> fmt::Debug for Fallback<S, E> {
700 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
701 match self {
702 Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
703 Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
704 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
705 }
706 }
707}
708
709#[allow(clippy::large_enum_variant)]
710enum Endpoint<S> {
711 MethodRouter(MethodRouter<S>),
712 Route(Route),
713}
714
715impl<S> Endpoint<S>
716where
717 S: Clone + Send + Sync + 'static,
718{
719 fn layer<L>(self, layer: L) -> Endpoint<S>
720 where
721 L: Layer<Route> + Clone + Send + 'static,
722 L::Service: Service<Request> + Clone + Send + 'static,
723 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
724 <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
725 <L::Service as Service<Request>>::Future: Send + 'static,
726 {
727 match self {
728 Endpoint::MethodRouter(method_router) => {
729 Endpoint::MethodRouter(method_router.layer(layer))
730 }
731 Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
732 }
733 }
734}
735
736impl<S> Clone for Endpoint<S> {
737 fn clone(&self) -> Self {
738 match self {
739 Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
740 Self::Route(inner) => Self::Route(inner.clone()),
741 }
742 }
743}
744
745impl<S> fmt::Debug for Endpoint<S> {
746 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
747 match self {
748 Self::MethodRouter(method_router) => {
749 f.debug_tuple("MethodRouter").field(method_router).finish()
750 }
751 Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
752 }
753 }
754}
755
756#[test]
757fn traits() {
758 use crate::test_helpers::*;
759 assert_send::<Router<()>>();
760 assert_sync::<Router<()>>();
761}