axum/routing/
method_routing.rs

1//! Route to services and handlers based on HTTP methods.
2
3use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7    body::{Body, Bytes, HttpBody},
8    boxed::BoxedIntoRoute,
9    error_handling::{HandleError, HandleErrorLayer},
10    handler::Handler,
11    http::{Method, StatusCode},
12    response::Response,
13    routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18    convert::Infallible,
19    fmt,
20    task::{Context, Poll},
21};
22use tower::{service_fn, util::MapResponseLayer};
23use tower_layer::Layer;
24use tower_service::Service;
25
26macro_rules! top_level_service_fn {
27    (
28        $name:ident, GET
29    ) => {
30        top_level_service_fn!(
31            /// Route `GET` requests to the given service.
32            ///
33            /// # Example
34            ///
35            /// ```rust
36            /// use axum::{
37            ///     extract::Request,
38            ///     Router,
39            ///     routing::get_service,
40            ///     body::Body,
41            /// };
42            /// use http::Response;
43            /// use std::convert::Infallible;
44            ///
45            /// let service = tower::service_fn(|request: Request| async {
46            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
47            /// });
48            ///
49            /// // Requests to `GET /` will go to `service`.
50            /// let app = Router::new().route("/", get_service(service));
51            /// # let _: Router = app;
52            /// ```
53            ///
54            /// Note that `get` routes will also be called for `HEAD` requests but will have
55            /// the response body removed. Make sure to add explicit `HEAD` routes
56            /// afterwards.
57            $name,
58            GET
59        );
60    };
61
62    (
63        $name:ident, CONNECT
64    ) => {
65        top_level_service_fn!(
66            /// Route `CONNECT` requests to the given service.
67            ///
68            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
69            /// and [`get_service`] for an example.
70            $name,
71            CONNECT
72        );
73    };
74
75    (
76        $name:ident, $method:ident
77    ) => {
78        top_level_service_fn!(
79            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
80            ///
81            /// See [`get_service`] for an example.
82            $name,
83            $method
84        );
85    };
86
87    (
88        $(#[$m:meta])+
89        $name:ident, $method:ident
90    ) => {
91        $(#[$m])+
92        pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
93        where
94            T: Service<Request> + Clone + Send + 'static,
95            T::Response: IntoResponse + 'static,
96            T::Future: Send + 'static,
97            S: Clone,
98        {
99            on_service(MethodFilter::$method, svc)
100        }
101    };
102}
103
104macro_rules! top_level_handler_fn {
105    (
106        $name:ident, GET
107    ) => {
108        top_level_handler_fn!(
109            /// Route `GET` requests to the given handler.
110            ///
111            /// # Example
112            ///
113            /// ```rust
114            /// use axum::{
115            ///     routing::get,
116            ///     Router,
117            /// };
118            ///
119            /// async fn handler() {}
120            ///
121            /// // Requests to `GET /` will go to `handler`.
122            /// let app = Router::new().route("/", get(handler));
123            /// # let _: Router = app;
124            /// ```
125            ///
126            /// Note that `get` routes will also be called for `HEAD` requests but will have
127            /// the response body removed. Make sure to add explicit `HEAD` routes
128            /// afterwards.
129            $name,
130            GET
131        );
132    };
133
134    (
135        $name:ident, CONNECT
136    ) => {
137        top_level_handler_fn!(
138            /// Route `CONNECT` requests to the given handler.
139            ///
140            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
141            /// and [`get`] for an example.
142            $name,
143            CONNECT
144        );
145    };
146
147    (
148        $name:ident, $method:ident
149    ) => {
150        top_level_handler_fn!(
151            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
152            ///
153            /// See [`get`] for an example.
154            $name,
155            $method
156        );
157    };
158
159    (
160        $(#[$m:meta])+
161        $name:ident, $method:ident
162    ) => {
163        $(#[$m])+
164        pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
165        where
166            H: Handler<T, S>,
167            T: 'static,
168            S: Clone + Send + Sync + 'static,
169        {
170            on(MethodFilter::$method, handler)
171        }
172    };
173}
174
175macro_rules! chained_service_fn {
176    (
177        $name:ident, GET
178    ) => {
179        chained_service_fn!(
180            /// Chain an additional service that will only accept `GET` requests.
181            ///
182            /// # Example
183            ///
184            /// ```rust
185            /// use axum::{
186            ///     extract::Request,
187            ///     Router,
188            ///     routing::post_service,
189            ///     body::Body,
190            /// };
191            /// use http::Response;
192            /// use std::convert::Infallible;
193            ///
194            /// let service = tower::service_fn(|request: Request| async {
195            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
196            /// });
197            ///
198            /// let other_service = tower::service_fn(|request: Request| async {
199            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
200            /// });
201            ///
202            /// // Requests to `POST /` will go to `service` and `GET /` will go to
203            /// // `other_service`.
204            /// let app = Router::new().route("/", post_service(service).get_service(other_service));
205            /// # let _: Router = app;
206            /// ```
207            ///
208            /// Note that `get` routes will also be called for `HEAD` requests but will have
209            /// the response body removed. Make sure to add explicit `HEAD` routes
210            /// afterwards.
211            $name,
212            GET
213        );
214    };
215
216    (
217        $name:ident, CONNECT
218    ) => {
219        chained_service_fn!(
220            /// Chain an additional service that will only accept `CONNECT` requests.
221            ///
222            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
223            /// and [`MethodRouter::get_service`] for an example.
224            $name,
225            CONNECT
226        );
227    };
228
229    (
230        $name:ident, $method:ident
231    ) => {
232        chained_service_fn!(
233            #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
234            ///
235            /// See [`MethodRouter::get_service`] for an example.
236            $name,
237            $method
238        );
239    };
240
241    (
242        $(#[$m:meta])+
243        $name:ident, $method:ident
244    ) => {
245        $(#[$m])+
246        #[track_caller]
247        pub fn $name<T>(self, svc: T) -> Self
248        where
249            T: Service<Request, Error = E>
250                + Clone
251                + Send
252                + 'static,
253            T::Response: IntoResponse + 'static,
254            T::Future: Send + 'static,
255        {
256            self.on_service(MethodFilter::$method, svc)
257        }
258    };
259}
260
261macro_rules! chained_handler_fn {
262    (
263        $name:ident, GET
264    ) => {
265        chained_handler_fn!(
266            /// Chain an additional handler that will only accept `GET` requests.
267            ///
268            /// # Example
269            ///
270            /// ```rust
271            /// use axum::{routing::post, Router};
272            ///
273            /// async fn handler() {}
274            ///
275            /// async fn other_handler() {}
276            ///
277            /// // Requests to `POST /` will go to `handler` and `GET /` will go to
278            /// // `other_handler`.
279            /// let app = Router::new().route("/", post(handler).get(other_handler));
280            /// # let _: Router = app;
281            /// ```
282            ///
283            /// Note that `get` routes will also be called for `HEAD` requests but will have
284            /// the response body removed. Make sure to add explicit `HEAD` routes
285            /// afterwards.
286            $name,
287            GET
288        );
289    };
290
291    (
292        $name:ident, CONNECT
293    ) => {
294        chained_handler_fn!(
295            /// Chain an additional handler that will only accept `CONNECT` requests.
296            ///
297            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
298            /// and [`MethodRouter::get`] for an example.
299            $name,
300            CONNECT
301        );
302    };
303
304    (
305        $name:ident, $method:ident
306    ) => {
307        chained_handler_fn!(
308            #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
309            ///
310            /// See [`MethodRouter::get`] for an example.
311            $name,
312            $method
313        );
314    };
315
316    (
317        $(#[$m:meta])+
318        $name:ident, $method:ident
319    ) => {
320        $(#[$m])+
321        #[track_caller]
322        pub fn $name<H, T>(self, handler: H) -> Self
323        where
324            H: Handler<T, S>,
325            T: 'static,
326            S: Send + Sync + 'static,
327        {
328            self.on(MethodFilter::$method, handler)
329        }
330    };
331}
332
333top_level_service_fn!(connect_service, CONNECT);
334top_level_service_fn!(delete_service, DELETE);
335top_level_service_fn!(get_service, GET);
336top_level_service_fn!(head_service, HEAD);
337top_level_service_fn!(options_service, OPTIONS);
338top_level_service_fn!(patch_service, PATCH);
339top_level_service_fn!(post_service, POST);
340top_level_service_fn!(put_service, PUT);
341top_level_service_fn!(trace_service, TRACE);
342
343/// Route requests with the given method to the service.
344///
345/// # Example
346///
347/// ```rust
348/// use axum::{
349///     extract::Request,
350///     routing::on,
351///     Router,
352///     body::Body,
353///     routing::{MethodFilter, on_service},
354/// };
355/// use http::Response;
356/// use std::convert::Infallible;
357///
358/// let service = tower::service_fn(|request: Request| async {
359///     Ok::<_, Infallible>(Response::new(Body::empty()))
360/// });
361///
362/// // Requests to `POST /` will go to `service`.
363/// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
364/// # let _: Router = app;
365/// ```
366pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
367where
368    T: Service<Request> + Clone + Send + 'static,
369    T::Response: IntoResponse + 'static,
370    T::Future: Send + 'static,
371    S: Clone,
372{
373    MethodRouter::new().on_service(filter, svc)
374}
375
376/// Route requests to the given service regardless of its method.
377///
378/// # Example
379///
380/// ```rust
381/// use axum::{
382///     extract::Request,
383///     Router,
384///     routing::any_service,
385///     body::Body,
386/// };
387/// use http::Response;
388/// use std::convert::Infallible;
389///
390/// let service = tower::service_fn(|request: Request| async {
391///     Ok::<_, Infallible>(Response::new(Body::empty()))
392/// });
393///
394/// // All requests to `/` will go to `service`.
395/// let app = Router::new().route("/", any_service(service));
396/// # let _: Router = app;
397/// ```
398///
399/// Additional methods can still be chained:
400///
401/// ```rust
402/// use axum::{
403///     extract::Request,
404///     Router,
405///     routing::any_service,
406///     body::Body,
407/// };
408/// use http::Response;
409/// use std::convert::Infallible;
410///
411/// let service = tower::service_fn(|request: Request| async {
412///     # Ok::<_, Infallible>(Response::new(Body::empty()))
413///     // ...
414/// });
415///
416/// let other_service = tower::service_fn(|request: Request| async {
417///     # Ok::<_, Infallible>(Response::new(Body::empty()))
418///     // ...
419/// });
420///
421/// // `POST /` goes to `other_service`. All other requests go to `service`
422/// let app = Router::new().route("/", any_service(service).post_service(other_service));
423/// # let _: Router = app;
424/// ```
425pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
426where
427    T: Service<Request> + Clone + Send + 'static,
428    T::Response: IntoResponse + 'static,
429    T::Future: Send + 'static,
430    S: Clone,
431{
432    MethodRouter::new()
433        .fallback_service(svc)
434        .skip_allow_header()
435}
436
437top_level_handler_fn!(connect, CONNECT);
438top_level_handler_fn!(delete, DELETE);
439top_level_handler_fn!(get, GET);
440top_level_handler_fn!(head, HEAD);
441top_level_handler_fn!(options, OPTIONS);
442top_level_handler_fn!(patch, PATCH);
443top_level_handler_fn!(post, POST);
444top_level_handler_fn!(put, PUT);
445top_level_handler_fn!(trace, TRACE);
446
447/// Route requests with the given method to the handler.
448///
449/// # Example
450///
451/// ```rust
452/// use axum::{
453///     routing::on,
454///     Router,
455///     routing::MethodFilter,
456/// };
457///
458/// async fn handler() {}
459///
460/// // Requests to `POST /` will go to `handler`.
461/// let app = Router::new().route("/", on(MethodFilter::POST, handler));
462/// # let _: Router = app;
463/// ```
464pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
465where
466    H: Handler<T, S>,
467    T: 'static,
468    S: Clone + Send + Sync + 'static,
469{
470    MethodRouter::new().on(filter, handler)
471}
472
473/// Route requests with the given handler regardless of the method.
474///
475/// # Example
476///
477/// ```rust
478/// use axum::{
479///     routing::any,
480///     Router,
481/// };
482///
483/// async fn handler() {}
484///
485/// // All requests to `/` will go to `handler`.
486/// let app = Router::new().route("/", any(handler));
487/// # let _: Router = app;
488/// ```
489///
490/// Additional methods can still be chained:
491///
492/// ```rust
493/// use axum::{
494///     routing::any,
495///     Router,
496/// };
497///
498/// async fn handler() {}
499///
500/// async fn other_handler() {}
501///
502/// // `POST /` goes to `other_handler`. All other requests go to `handler`
503/// let app = Router::new().route("/", any(handler).post(other_handler));
504/// # let _: Router = app;
505/// ```
506pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
507where
508    H: Handler<T, S>,
509    T: 'static,
510    S: Clone + Send + Sync + 'static,
511{
512    MethodRouter::new().fallback(handler).skip_allow_header()
513}
514
515/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
516/// allows chaining additional handlers and services.
517///
518/// # When does `MethodRouter` implement [`Service`]?
519///
520/// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires.
521///
522/// ```
523/// use tower::Service;
524/// use axum::{routing::get, extract::{State, Request}, body::Body};
525///
526/// // this `MethodRouter` doesn't require any state, i.e. the state is `()`,
527/// let method_router = get(|| async {});
528/// // and thus it implements `Service`
529/// assert_service(method_router);
530///
531/// // this requires a `String` and doesn't implement `Service`
532/// let method_router = get(|_: State<String>| async {});
533/// // until you provide the `String` with `.with_state(...)`
534/// let method_router_with_state = method_router.with_state(String::new());
535/// // and then it implements `Service`
536/// assert_service(method_router_with_state);
537///
538/// // helper to check that a value implements `Service`
539/// fn assert_service<S>(service: S)
540/// where
541///     S: Service<Request>,
542/// {}
543/// ```
544#[must_use]
545pub struct MethodRouter<S = (), E = Infallible> {
546    get: MethodEndpoint<S, E>,
547    head: MethodEndpoint<S, E>,
548    delete: MethodEndpoint<S, E>,
549    options: MethodEndpoint<S, E>,
550    patch: MethodEndpoint<S, E>,
551    post: MethodEndpoint<S, E>,
552    put: MethodEndpoint<S, E>,
553    trace: MethodEndpoint<S, E>,
554    connect: MethodEndpoint<S, E>,
555    fallback: Fallback<S, E>,
556    allow_header: AllowHeader,
557}
558
559#[derive(Clone, Debug)]
560enum AllowHeader {
561    /// No `Allow` header value has been built-up yet. This is the default state
562    None,
563    /// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
564    Skip,
565    /// The current value of the `Allow` header.
566    Bytes(BytesMut),
567}
568
569impl AllowHeader {
570    fn merge(self, other: Self) -> Self {
571        match (self, other) {
572            (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
573            (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
574            (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
575            (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
576            (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
577                a.extend_from_slice(b",");
578                a.extend_from_slice(&b);
579                AllowHeader::Bytes(a)
580            }
581        }
582    }
583}
584
585impl<S, E> fmt::Debug for MethodRouter<S, E> {
586    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
587        f.debug_struct("MethodRouter")
588            .field("get", &self.get)
589            .field("head", &self.head)
590            .field("delete", &self.delete)
591            .field("options", &self.options)
592            .field("patch", &self.patch)
593            .field("post", &self.post)
594            .field("put", &self.put)
595            .field("trace", &self.trace)
596            .field("connect", &self.connect)
597            .field("fallback", &self.fallback)
598            .field("allow_header", &self.allow_header)
599            .finish()
600    }
601}
602
603impl<S> MethodRouter<S, Infallible>
604where
605    S: Clone,
606{
607    /// Chain an additional handler that will accept requests matching the given
608    /// `MethodFilter`.
609    ///
610    /// # Example
611    ///
612    /// ```rust
613    /// use axum::{
614    ///     routing::get,
615    ///     Router,
616    ///     routing::MethodFilter
617    /// };
618    ///
619    /// async fn handler() {}
620    ///
621    /// async fn other_handler() {}
622    ///
623    /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
624    /// // `other_handler`
625    /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
626    /// # let _: Router = app;
627    /// ```
628    #[track_caller]
629    pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
630    where
631        H: Handler<T, S>,
632        T: 'static,
633        S: Send + Sync + 'static,
634    {
635        self.on_endpoint(
636            filter,
637            MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
638        )
639    }
640
641    chained_handler_fn!(connect, CONNECT);
642    chained_handler_fn!(delete, DELETE);
643    chained_handler_fn!(get, GET);
644    chained_handler_fn!(head, HEAD);
645    chained_handler_fn!(options, OPTIONS);
646    chained_handler_fn!(patch, PATCH);
647    chained_handler_fn!(post, POST);
648    chained_handler_fn!(put, PUT);
649    chained_handler_fn!(trace, TRACE);
650
651    /// Add a fallback [`Handler`] to the router.
652    pub fn fallback<H, T>(mut self, handler: H) -> Self
653    where
654        H: Handler<T, S>,
655        T: 'static,
656        S: Send + Sync + 'static,
657    {
658        self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
659        self
660    }
661
662    /// Add a fallback [`Handler`] if no custom one has been provided.
663    pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
664    where
665        H: Handler<T, S>,
666        T: 'static,
667        S: Send + Sync + 'static,
668    {
669        match self.fallback {
670            Fallback::Default(_) => self.fallback(handler),
671            _ => self,
672        }
673    }
674}
675
676impl MethodRouter<(), Infallible> {
677    /// Convert the router into a [`MakeService`].
678    ///
679    /// This allows you to serve a single `MethodRouter` if you don't need any
680    /// routing based on the path:
681    ///
682    /// ```rust
683    /// use axum::{
684    ///     handler::Handler,
685    ///     http::{Uri, Method},
686    ///     response::IntoResponse,
687    ///     routing::get,
688    /// };
689    /// use std::net::SocketAddr;
690    ///
691    /// async fn handler(method: Method, uri: Uri, body: String) -> String {
692    ///     format!("received `{method} {uri}` with body `{body:?}`")
693    /// }
694    ///
695    /// let router = get(handler).post(handler);
696    ///
697    /// # async {
698    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
699    /// axum::serve(listener, router.into_make_service()).await.unwrap();
700    /// # };
701    /// ```
702    ///
703    /// [`MakeService`]: tower::make::MakeService
704    pub fn into_make_service(self) -> IntoMakeService<Self> {
705        IntoMakeService::new(self.with_state(()))
706    }
707
708    /// Convert the router into a [`MakeService`] which stores information
709    /// about the incoming connection.
710    ///
711    /// See [`Router::into_make_service_with_connect_info`] for more details.
712    ///
713    /// ```rust
714    /// use axum::{
715    ///     handler::Handler,
716    ///     response::IntoResponse,
717    ///     extract::ConnectInfo,
718    ///     routing::get,
719    /// };
720    /// use std::net::SocketAddr;
721    ///
722    /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
723    ///     format!("Hello {addr}")
724    /// }
725    ///
726    /// let router = get(handler).post(handler);
727    ///
728    /// # async {
729    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
730    /// axum::serve(listener, router.into_make_service()).await.unwrap();
731    /// # };
732    /// ```
733    ///
734    /// [`MakeService`]: tower::make::MakeService
735    /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
736    #[cfg(feature = "tokio")]
737    pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
738        IntoMakeServiceWithConnectInfo::new(self.with_state(()))
739    }
740}
741
742impl<S, E> MethodRouter<S, E>
743where
744    S: Clone,
745{
746    /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
747    /// requests.
748    pub fn new() -> Self {
749        let fallback = Route::new(service_fn(|_: Request| async {
750            Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())
751        }));
752
753        Self {
754            get: MethodEndpoint::None,
755            head: MethodEndpoint::None,
756            delete: MethodEndpoint::None,
757            options: MethodEndpoint::None,
758            patch: MethodEndpoint::None,
759            post: MethodEndpoint::None,
760            put: MethodEndpoint::None,
761            trace: MethodEndpoint::None,
762            connect: MethodEndpoint::None,
763            allow_header: AllowHeader::None,
764            fallback: Fallback::Default(fallback),
765        }
766    }
767
768    /// Provide the state for the router.
769    pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
770        MethodRouter {
771            get: self.get.with_state(&state),
772            head: self.head.with_state(&state),
773            delete: self.delete.with_state(&state),
774            options: self.options.with_state(&state),
775            patch: self.patch.with_state(&state),
776            post: self.post.with_state(&state),
777            put: self.put.with_state(&state),
778            trace: self.trace.with_state(&state),
779            connect: self.connect.with_state(&state),
780            allow_header: self.allow_header,
781            fallback: self.fallback.with_state(state),
782        }
783    }
784
785    /// Chain an additional service that will accept requests matching the given
786    /// `MethodFilter`.
787    ///
788    /// # Example
789    ///
790    /// ```rust
791    /// use axum::{
792    ///     extract::Request,
793    ///     Router,
794    ///     routing::{MethodFilter, on_service},
795    ///     body::Body,
796    /// };
797    /// use http::Response;
798    /// use std::convert::Infallible;
799    ///
800    /// let service = tower::service_fn(|request: Request| async {
801    ///     Ok::<_, Infallible>(Response::new(Body::empty()))
802    /// });
803    ///
804    /// // Requests to `DELETE /` will go to `service`
805    /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
806    /// # let _: Router = app;
807    /// ```
808    #[track_caller]
809    pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
810    where
811        T: Service<Request, Error = E> + Clone + Send + 'static,
812        T::Response: IntoResponse + 'static,
813        T::Future: Send + 'static,
814    {
815        self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
816    }
817
818    #[track_caller]
819    fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
820        // written as a separate function to generate less IR
821        #[track_caller]
822        fn set_endpoint<S, E>(
823            method_name: &str,
824            out: &mut MethodEndpoint<S, E>,
825            endpoint: &MethodEndpoint<S, E>,
826            endpoint_filter: MethodFilter,
827            filter: MethodFilter,
828            allow_header: &mut AllowHeader,
829            methods: &[&'static str],
830        ) where
831            MethodEndpoint<S, E>: Clone,
832            S: Clone,
833        {
834            if endpoint_filter.contains(filter) {
835                if out.is_some() {
836                    panic!(
837                        "Overlapping method route. Cannot add two method routes that both handle \
838                         `{method_name}`",
839                    )
840                }
841                *out = endpoint.clone();
842                for method in methods {
843                    append_allow_header(allow_header, method);
844                }
845            }
846        }
847
848        set_endpoint(
849            "GET",
850            &mut self.get,
851            &endpoint,
852            filter,
853            MethodFilter::GET,
854            &mut self.allow_header,
855            &["GET", "HEAD"],
856        );
857
858        set_endpoint(
859            "HEAD",
860            &mut self.head,
861            &endpoint,
862            filter,
863            MethodFilter::HEAD,
864            &mut self.allow_header,
865            &["HEAD"],
866        );
867
868        set_endpoint(
869            "TRACE",
870            &mut self.trace,
871            &endpoint,
872            filter,
873            MethodFilter::TRACE,
874            &mut self.allow_header,
875            &["TRACE"],
876        );
877
878        set_endpoint(
879            "PUT",
880            &mut self.put,
881            &endpoint,
882            filter,
883            MethodFilter::PUT,
884            &mut self.allow_header,
885            &["PUT"],
886        );
887
888        set_endpoint(
889            "POST",
890            &mut self.post,
891            &endpoint,
892            filter,
893            MethodFilter::POST,
894            &mut self.allow_header,
895            &["POST"],
896        );
897
898        set_endpoint(
899            "PATCH",
900            &mut self.patch,
901            &endpoint,
902            filter,
903            MethodFilter::PATCH,
904            &mut self.allow_header,
905            &["PATCH"],
906        );
907
908        set_endpoint(
909            "OPTIONS",
910            &mut self.options,
911            &endpoint,
912            filter,
913            MethodFilter::OPTIONS,
914            &mut self.allow_header,
915            &["OPTIONS"],
916        );
917
918        set_endpoint(
919            "DELETE",
920            &mut self.delete,
921            &endpoint,
922            filter,
923            MethodFilter::DELETE,
924            &mut self.allow_header,
925            &["DELETE"],
926        );
927
928        set_endpoint(
929            "CONNECT",
930            &mut self.options,
931            &endpoint,
932            filter,
933            MethodFilter::CONNECT,
934            &mut self.allow_header,
935            &["CONNECT"],
936        );
937
938        self
939    }
940
941    chained_service_fn!(connect_service, CONNECT);
942    chained_service_fn!(delete_service, DELETE);
943    chained_service_fn!(get_service, GET);
944    chained_service_fn!(head_service, HEAD);
945    chained_service_fn!(options_service, OPTIONS);
946    chained_service_fn!(patch_service, PATCH);
947    chained_service_fn!(post_service, POST);
948    chained_service_fn!(put_service, PUT);
949    chained_service_fn!(trace_service, TRACE);
950
951    #[doc = include_str!("../docs/method_routing/fallback.md")]
952    pub fn fallback_service<T>(mut self, svc: T) -> Self
953    where
954        T: Service<Request, Error = E> + Clone + Send + 'static,
955        T::Response: IntoResponse + 'static,
956        T::Future: Send + 'static,
957    {
958        self.fallback = Fallback::Service(Route::new(svc));
959        self
960    }
961
962    #[doc = include_str!("../docs/method_routing/layer.md")]
963    pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
964    where
965        L: Layer<Route<E>> + Clone + Send + 'static,
966        L::Service: Service<Request> + Clone + Send + 'static,
967        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
968        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
969        <L::Service as Service<Request>>::Future: Send + 'static,
970        E: 'static,
971        S: 'static,
972        NewError: 'static,
973    {
974        let layer_fn = move |route: Route<E>| route.layer(layer.clone());
975
976        MethodRouter {
977            get: self.get.map(layer_fn.clone()),
978            head: self.head.map(layer_fn.clone()),
979            delete: self.delete.map(layer_fn.clone()),
980            options: self.options.map(layer_fn.clone()),
981            patch: self.patch.map(layer_fn.clone()),
982            post: self.post.map(layer_fn.clone()),
983            put: self.put.map(layer_fn.clone()),
984            trace: self.trace.map(layer_fn.clone()),
985            connect: self.connect.map(layer_fn.clone()),
986            fallback: self.fallback.map(layer_fn),
987            allow_header: self.allow_header,
988        }
989    }
990
991    #[doc = include_str!("../docs/method_routing/route_layer.md")]
992    #[track_caller]
993    pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
994    where
995        L: Layer<Route<E>> + Clone + Send + 'static,
996        L::Service: Service<Request, Error = E> + Clone + Send + 'static,
997        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
998        <L::Service as Service<Request>>::Future: Send + 'static,
999        E: 'static,
1000        S: 'static,
1001    {
1002        if self.get.is_none()
1003            && self.head.is_none()
1004            && self.delete.is_none()
1005            && self.options.is_none()
1006            && self.patch.is_none()
1007            && self.post.is_none()
1008            && self.put.is_none()
1009            && self.trace.is_none()
1010            && self.connect.is_none()
1011        {
1012            panic!(
1013                "Adding a route_layer before any routes is a no-op. \
1014                 Add the routes you want the layer to apply to first."
1015            );
1016        }
1017
1018        let layer_fn = move |svc| {
1019            let svc = layer.layer(svc);
1020            let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
1021            Route::new(svc)
1022        };
1023
1024        self.get = self.get.map(layer_fn.clone());
1025        self.head = self.head.map(layer_fn.clone());
1026        self.delete = self.delete.map(layer_fn.clone());
1027        self.options = self.options.map(layer_fn.clone());
1028        self.patch = self.patch.map(layer_fn.clone());
1029        self.post = self.post.map(layer_fn.clone());
1030        self.put = self.put.map(layer_fn.clone());
1031        self.trace = self.trace.map(layer_fn.clone());
1032        self.connect = self.connect.map(layer_fn);
1033
1034        self
1035    }
1036
1037    #[track_caller]
1038    pub(crate) fn merge_for_path(mut self, path: Option<&str>, other: MethodRouter<S, E>) -> Self {
1039        // written using inner functions to generate less IR
1040        #[track_caller]
1041        fn merge_inner<S, E>(
1042            path: Option<&str>,
1043            name: &str,
1044            first: MethodEndpoint<S, E>,
1045            second: MethodEndpoint<S, E>,
1046        ) -> MethodEndpoint<S, E> {
1047            match (first, second) {
1048                (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
1049                (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
1050                _ => {
1051                    if let Some(path) = path {
1052                        panic!(
1053                            "Overlapping method route. Handler for `{name} {path}` already exists"
1054                        );
1055                    } else {
1056                        panic!(
1057                            "Overlapping method route. Cannot merge two method routes that both \
1058                             define `{name}`"
1059                        );
1060                    }
1061                }
1062            }
1063        }
1064
1065        self.get = merge_inner(path, "GET", self.get, other.get);
1066        self.head = merge_inner(path, "HEAD", self.head, other.head);
1067        self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
1068        self.options = merge_inner(path, "OPTIONS", self.options, other.options);
1069        self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
1070        self.post = merge_inner(path, "POST", self.post, other.post);
1071        self.put = merge_inner(path, "PUT", self.put, other.put);
1072        self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
1073        self.connect = merge_inner(path, "CONNECT", self.connect, other.connect);
1074
1075        self.fallback = self
1076            .fallback
1077            .merge(other.fallback)
1078            .expect("Cannot merge two `MethodRouter`s that both have a fallback");
1079
1080        self.allow_header = self.allow_header.merge(other.allow_header);
1081
1082        self
1083    }
1084
1085    #[doc = include_str!("../docs/method_routing/merge.md")]
1086    #[track_caller]
1087    pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1088        self.merge_for_path(None, other)
1089    }
1090
1091    /// Apply a [`HandleErrorLayer`].
1092    ///
1093    /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
1094    pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1095    where
1096        F: Clone + Send + Sync + 'static,
1097        HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1098        <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1099        <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1100        T: 'static,
1101        E: 'static,
1102        S: 'static,
1103    {
1104        self.layer(HandleErrorLayer::new(f))
1105    }
1106
1107    fn skip_allow_header(mut self) -> Self {
1108        self.allow_header = AllowHeader::Skip;
1109        self
1110    }
1111
1112    pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1113        let method = req.method();
1114        let is_head = *method == Method::HEAD;
1115
1116        macro_rules! call {
1117            (
1118                $method_variant:ident,
1119                $svc:expr
1120            ) => {
1121                if *method == Method::$method_variant {
1122                    match $svc {
1123                        MethodEndpoint::None => {}
1124                        MethodEndpoint::Route(route) => {
1125                            return RouteFuture::from_future(
1126                                route.clone().oneshot_inner_owned(req),
1127                            )
1128                            .strip_body(is_head);
1129                        }
1130                        MethodEndpoint::BoxedHandler(handler) => {
1131                            let route = handler.clone().into_route(state);
1132                            return RouteFuture::from_future(
1133                                route.clone().oneshot_inner_owned(req),
1134                            )
1135                            .strip_body(is_head);
1136                        }
1137                    }
1138                }
1139            };
1140        }
1141
1142        // written with a pattern match like this to ensure we call all routes
1143        let Self {
1144            get,
1145            head,
1146            delete,
1147            options,
1148            patch,
1149            post,
1150            put,
1151            trace,
1152            connect,
1153            fallback,
1154            allow_header,
1155        } = self;
1156
1157        call!(HEAD, head);
1158        call!(HEAD, get);
1159        call!(GET, get);
1160        call!(POST, post);
1161        call!(OPTIONS, options);
1162        call!(PATCH, patch);
1163        call!(PUT, put);
1164        call!(DELETE, delete);
1165        call!(TRACE, trace);
1166        call!(CONNECT, connect);
1167
1168        let future = fallback.clone().call_with_state(req, state);
1169
1170        match allow_header {
1171            AllowHeader::None => future.allow_header(Bytes::new()),
1172            AllowHeader::Skip => future,
1173            AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1174        }
1175    }
1176}
1177
1178fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1179    match allow_header {
1180        AllowHeader::None => {
1181            *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1182        }
1183        AllowHeader::Skip => {}
1184        AllowHeader::Bytes(allow_header) => {
1185            if let Ok(s) = std::str::from_utf8(allow_header) {
1186                if !s.contains(method) {
1187                    allow_header.extend_from_slice(b",");
1188                    allow_header.extend_from_slice(method.as_bytes());
1189                }
1190            } else {
1191                #[cfg(debug_assertions)]
1192                panic!("`allow_header` contained invalid uft-8. This should never happen")
1193            }
1194        }
1195    }
1196}
1197
1198impl<S, E> Clone for MethodRouter<S, E> {
1199    fn clone(&self) -> Self {
1200        Self {
1201            get: self.get.clone(),
1202            head: self.head.clone(),
1203            delete: self.delete.clone(),
1204            options: self.options.clone(),
1205            patch: self.patch.clone(),
1206            post: self.post.clone(),
1207            put: self.put.clone(),
1208            trace: self.trace.clone(),
1209            connect: self.connect.clone(),
1210            fallback: self.fallback.clone(),
1211            allow_header: self.allow_header.clone(),
1212        }
1213    }
1214}
1215
1216impl<S, E> Default for MethodRouter<S, E>
1217where
1218    S: Clone,
1219{
1220    fn default() -> Self {
1221        Self::new()
1222    }
1223}
1224
1225enum MethodEndpoint<S, E> {
1226    None,
1227    Route(Route<E>),
1228    BoxedHandler(BoxedIntoRoute<S, E>),
1229}
1230
1231impl<S, E> MethodEndpoint<S, E>
1232where
1233    S: Clone,
1234{
1235    fn is_some(&self) -> bool {
1236        matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1237    }
1238
1239    fn is_none(&self) -> bool {
1240        matches!(self, Self::None)
1241    }
1242
1243    fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1244    where
1245        S: 'static,
1246        E: 'static,
1247        F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
1248        E2: 'static,
1249    {
1250        match self {
1251            Self::None => MethodEndpoint::None,
1252            Self::Route(route) => MethodEndpoint::Route(f(route)),
1253            Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1254        }
1255    }
1256
1257    fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1258        match self {
1259            MethodEndpoint::None => MethodEndpoint::None,
1260            MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1261            MethodEndpoint::BoxedHandler(handler) => {
1262                MethodEndpoint::Route(handler.into_route(state.clone()))
1263            }
1264        }
1265    }
1266}
1267
1268impl<S, E> Clone for MethodEndpoint<S, E> {
1269    fn clone(&self) -> Self {
1270        match self {
1271            Self::None => Self::None,
1272            Self::Route(inner) => Self::Route(inner.clone()),
1273            Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1274        }
1275    }
1276}
1277
1278impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1280        match self {
1281            Self::None => f.debug_tuple("None").finish(),
1282            Self::Route(inner) => inner.fmt(f),
1283            Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1284        }
1285    }
1286}
1287
1288impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1289where
1290    B: HttpBody<Data = Bytes> + Send + 'static,
1291    B::Error: Into<BoxError>,
1292{
1293    type Response = Response;
1294    type Error = E;
1295    type Future = RouteFuture<E>;
1296
1297    #[inline]
1298    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1299        Poll::Ready(Ok(()))
1300    }
1301
1302    #[inline]
1303    fn call(&mut self, req: Request<B>) -> Self::Future {
1304        let req = req.map(Body::new);
1305        self.call_with_state(req, ())
1306    }
1307}
1308
1309impl<S> Handler<(), S> for MethodRouter<S>
1310where
1311    S: Clone + 'static,
1312{
1313    type Future = InfallibleRouteFuture;
1314
1315    fn call(self, req: Request, state: S) -> Self::Future {
1316        InfallibleRouteFuture::new(self.call_with_state(req, state))
1317    }
1318}
1319
1320// for `axum::serve(listener, router)`
1321#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1322const _: () = {
1323    use crate::serve::IncomingStream;
1324
1325    impl Service<IncomingStream<'_>> for MethodRouter<()> {
1326        type Response = Self;
1327        type Error = Infallible;
1328        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1329
1330        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1331            Poll::Ready(Ok(()))
1332        }
1333
1334        fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
1335            std::future::ready(Ok(self.clone().with_state(())))
1336        }
1337    }
1338};
1339
1340#[cfg(test)]
1341mod tests {
1342    use super::*;
1343    use crate::{extract::State, handler::HandlerWithoutStateExt};
1344    use http::{header::ALLOW, HeaderMap};
1345    use http_body_util::BodyExt;
1346    use std::time::Duration;
1347    use tower::ServiceExt;
1348    use tower_http::{
1349        services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1350    };
1351
1352    #[crate::test]
1353    async fn method_not_allowed_by_default() {
1354        let mut svc = MethodRouter::new();
1355        let (status, _, body) = call(Method::GET, &mut svc).await;
1356        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1357        assert!(body.is_empty());
1358    }
1359
1360    #[crate::test]
1361    async fn get_service_fn() {
1362        async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1363            Ok(Response::new(Body::from("ok")))
1364        }
1365
1366        let mut svc = get_service(service_fn(handle));
1367
1368        let (status, _, body) = call(Method::GET, &mut svc).await;
1369        assert_eq!(status, StatusCode::OK);
1370        assert_eq!(body, "ok");
1371    }
1372
1373    #[crate::test]
1374    async fn get_handler() {
1375        let mut svc = MethodRouter::new().get(ok);
1376        let (status, _, body) = call(Method::GET, &mut svc).await;
1377        assert_eq!(status, StatusCode::OK);
1378        assert_eq!(body, "ok");
1379    }
1380
1381    #[crate::test]
1382    async fn get_accepts_head() {
1383        let mut svc = MethodRouter::new().get(ok);
1384        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1385        assert_eq!(status, StatusCode::OK);
1386        assert!(body.is_empty());
1387    }
1388
1389    #[crate::test]
1390    async fn head_takes_precedence_over_get() {
1391        let mut svc = MethodRouter::new().head(created).get(ok);
1392        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1393        assert_eq!(status, StatusCode::CREATED);
1394        assert!(body.is_empty());
1395    }
1396
1397    #[crate::test]
1398    async fn merge() {
1399        let mut svc = get(ok).merge(post(ok));
1400
1401        let (status, _, _) = call(Method::GET, &mut svc).await;
1402        assert_eq!(status, StatusCode::OK);
1403
1404        let (status, _, _) = call(Method::POST, &mut svc).await;
1405        assert_eq!(status, StatusCode::OK);
1406    }
1407
1408    #[crate::test]
1409    async fn layer() {
1410        let mut svc = MethodRouter::new()
1411            .get(|| async { std::future::pending::<()>().await })
1412            .layer(ValidateRequestHeaderLayer::bearer("password"));
1413
1414        // method with route
1415        let (status, _, _) = call(Method::GET, &mut svc).await;
1416        assert_eq!(status, StatusCode::UNAUTHORIZED);
1417
1418        // method without route
1419        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1420        assert_eq!(status, StatusCode::UNAUTHORIZED);
1421    }
1422
1423    #[crate::test]
1424    async fn route_layer() {
1425        let mut svc = MethodRouter::new()
1426            .get(|| async { std::future::pending::<()>().await })
1427            .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1428
1429        // method with route
1430        let (status, _, _) = call(Method::GET, &mut svc).await;
1431        assert_eq!(status, StatusCode::UNAUTHORIZED);
1432
1433        // method without route
1434        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1435        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1436    }
1437
1438    #[allow(dead_code)]
1439    async fn building_complex_router() {
1440        let app = crate::Router::new().route(
1441            "/",
1442            // use the all the things 💣️
1443            get(ok)
1444                .post(ok)
1445                .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1446                .merge(delete_service(ServeDir::new(".")))
1447                .fallback(|| async { StatusCode::NOT_FOUND })
1448                .put(ok)
1449                .layer(TimeoutLayer::new(Duration::from_secs(10))),
1450        );
1451
1452        let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1453        crate::serve(listener, app).await.unwrap();
1454    }
1455
1456    #[crate::test]
1457    async fn sets_allow_header() {
1458        let mut svc = MethodRouter::new().put(ok).patch(ok);
1459        let (status, headers, _) = call(Method::GET, &mut svc).await;
1460        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1461        assert_eq!(headers[ALLOW], "PUT,PATCH");
1462    }
1463
1464    #[crate::test]
1465    async fn sets_allow_header_get_head() {
1466        let mut svc = MethodRouter::new().get(ok).head(ok);
1467        let (status, headers, _) = call(Method::PUT, &mut svc).await;
1468        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1469        assert_eq!(headers[ALLOW], "GET,HEAD");
1470    }
1471
1472    #[crate::test]
1473    async fn empty_allow_header_by_default() {
1474        let mut svc = MethodRouter::new();
1475        let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1476        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1477        assert_eq!(headers[ALLOW], "");
1478    }
1479
1480    #[crate::test]
1481    async fn allow_header_when_merging() {
1482        let a = put(ok).patch(ok);
1483        let b = get(ok).head(ok);
1484        let mut svc = a.merge(b);
1485
1486        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1487        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1488        assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1489    }
1490
1491    #[crate::test]
1492    async fn allow_header_any() {
1493        let mut svc = any(ok);
1494
1495        let (status, headers, _) = call(Method::GET, &mut svc).await;
1496        assert_eq!(status, StatusCode::OK);
1497        assert!(!headers.contains_key(ALLOW));
1498    }
1499
1500    #[crate::test]
1501    async fn allow_header_with_fallback() {
1502        let mut svc = MethodRouter::new()
1503            .get(ok)
1504            .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1505
1506        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1507        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1508        assert_eq!(headers[ALLOW], "GET,HEAD");
1509    }
1510
1511    #[crate::test]
1512    async fn allow_header_with_fallback_that_sets_allow() {
1513        async fn fallback(method: Method) -> Response {
1514            if method == Method::POST {
1515                "OK".into_response()
1516            } else {
1517                (
1518                    StatusCode::METHOD_NOT_ALLOWED,
1519                    [(ALLOW, "GET,POST")],
1520                    "Method not allowed",
1521                )
1522                    .into_response()
1523            }
1524        }
1525
1526        let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1527
1528        let (status, _, _) = call(Method::GET, &mut svc).await;
1529        assert_eq!(status, StatusCode::OK);
1530
1531        let (status, _, _) = call(Method::POST, &mut svc).await;
1532        assert_eq!(status, StatusCode::OK);
1533
1534        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1535        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1536        assert_eq!(headers[ALLOW], "GET,POST");
1537    }
1538
1539    #[crate::test]
1540    async fn allow_header_noop_middleware() {
1541        let mut svc = MethodRouter::new()
1542            .get(ok)
1543            .layer(tower::layer::util::Identity::new());
1544
1545        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1546        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1547        assert_eq!(headers[ALLOW], "GET,HEAD");
1548    }
1549
1550    #[crate::test]
1551    #[should_panic(
1552        expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1553    )]
1554    async fn handler_overlaps() {
1555        let _: MethodRouter<()> = get(ok).get(ok);
1556    }
1557
1558    #[crate::test]
1559    #[should_panic(
1560        expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1561    )]
1562    async fn service_overlaps() {
1563        let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1564    }
1565
1566    #[crate::test]
1567    async fn get_head_does_not_overlap() {
1568        let _: MethodRouter<()> = get(ok).head(ok);
1569    }
1570
1571    #[crate::test]
1572    async fn head_get_does_not_overlap() {
1573        let _: MethodRouter<()> = head(ok).get(ok);
1574    }
1575
1576    #[crate::test]
1577    async fn accessing_state() {
1578        let mut svc = MethodRouter::new()
1579            .get(|State(state): State<&'static str>| async move { state })
1580            .with_state("state");
1581
1582        let (status, _, text) = call(Method::GET, &mut svc).await;
1583
1584        assert_eq!(status, StatusCode::OK);
1585        assert_eq!(text, "state");
1586    }
1587
1588    #[crate::test]
1589    async fn fallback_accessing_state() {
1590        let mut svc = MethodRouter::new()
1591            .fallback(|State(state): State<&'static str>| async move { state })
1592            .with_state("state");
1593
1594        let (status, _, text) = call(Method::GET, &mut svc).await;
1595
1596        assert_eq!(status, StatusCode::OK);
1597        assert_eq!(text, "state");
1598    }
1599
1600    #[crate::test]
1601    async fn merge_accessing_state() {
1602        let one = get(|State(state): State<&'static str>| async move { state });
1603        let two = post(|State(state): State<&'static str>| async move { state });
1604
1605        let mut svc = one.merge(two).with_state("state");
1606
1607        let (status, _, text) = call(Method::GET, &mut svc).await;
1608        assert_eq!(status, StatusCode::OK);
1609        assert_eq!(text, "state");
1610
1611        let (status, _, _) = call(Method::POST, &mut svc).await;
1612        assert_eq!(status, StatusCode::OK);
1613        assert_eq!(text, "state");
1614    }
1615
1616    async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1617    where
1618        S: Service<Request, Error = Infallible>,
1619        S::Response: IntoResponse,
1620    {
1621        let request = Request::builder()
1622            .uri("/")
1623            .method(method)
1624            .body(Body::empty())
1625            .unwrap();
1626        let response = svc
1627            .ready()
1628            .await
1629            .unwrap()
1630            .call(request)
1631            .await
1632            .unwrap()
1633            .into_response();
1634        let (parts, body) = response.into_parts();
1635        let body =
1636            String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1637        (parts.status, parts.headers, body)
1638    }
1639
1640    async fn ok() -> (StatusCode, &'static str) {
1641        (StatusCode::OK, "ok")
1642    }
1643
1644    async fn created() -> (StatusCode, &'static str) {
1645        (StatusCode::CREATED, "created")
1646    }
1647}