1use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7 body::{Body, Bytes, HttpBody},
8 boxed::BoxedIntoRoute,
9 error_handling::{HandleError, HandleErrorLayer},
10 handler::Handler,
11 http::{Method, StatusCode},
12 response::Response,
13 routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18 convert::Infallible,
19 fmt,
20 task::{Context, Poll},
21};
22use tower::{service_fn, util::MapResponseLayer};
23use tower_layer::Layer;
24use tower_service::Service;
25
26macro_rules! top_level_service_fn {
27 (
28 $name:ident, GET
29 ) => {
30 top_level_service_fn!(
31 $name,
58 GET
59 );
60 };
61
62 (
63 $name:ident, CONNECT
64 ) => {
65 top_level_service_fn!(
66 $name,
71 CONNECT
72 );
73 };
74
75 (
76 $name:ident, $method:ident
77 ) => {
78 top_level_service_fn!(
79 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
80 $name,
83 $method
84 );
85 };
86
87 (
88 $(#[$m:meta])+
89 $name:ident, $method:ident
90 ) => {
91 $(#[$m])+
92 pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
93 where
94 T: Service<Request> + Clone + Send + 'static,
95 T::Response: IntoResponse + 'static,
96 T::Future: Send + 'static,
97 S: Clone,
98 {
99 on_service(MethodFilter::$method, svc)
100 }
101 };
102}
103
104macro_rules! top_level_handler_fn {
105 (
106 $name:ident, GET
107 ) => {
108 top_level_handler_fn!(
109 $name,
130 GET
131 );
132 };
133
134 (
135 $name:ident, CONNECT
136 ) => {
137 top_level_handler_fn!(
138 $name,
143 CONNECT
144 );
145 };
146
147 (
148 $name:ident, $method:ident
149 ) => {
150 top_level_handler_fn!(
151 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
152 $name,
155 $method
156 );
157 };
158
159 (
160 $(#[$m:meta])+
161 $name:ident, $method:ident
162 ) => {
163 $(#[$m])+
164 pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
165 where
166 H: Handler<T, S>,
167 T: 'static,
168 S: Clone + Send + Sync + 'static,
169 {
170 on(MethodFilter::$method, handler)
171 }
172 };
173}
174
175macro_rules! chained_service_fn {
176 (
177 $name:ident, GET
178 ) => {
179 chained_service_fn!(
180 $name,
212 GET
213 );
214 };
215
216 (
217 $name:ident, CONNECT
218 ) => {
219 chained_service_fn!(
220 $name,
225 CONNECT
226 );
227 };
228
229 (
230 $name:ident, $method:ident
231 ) => {
232 chained_service_fn!(
233 #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
234 $name,
237 $method
238 );
239 };
240
241 (
242 $(#[$m:meta])+
243 $name:ident, $method:ident
244 ) => {
245 $(#[$m])+
246 #[track_caller]
247 pub fn $name<T>(self, svc: T) -> Self
248 where
249 T: Service<Request, Error = E>
250 + Clone
251 + Send
252 + 'static,
253 T::Response: IntoResponse + 'static,
254 T::Future: Send + 'static,
255 {
256 self.on_service(MethodFilter::$method, svc)
257 }
258 };
259}
260
261macro_rules! chained_handler_fn {
262 (
263 $name:ident, GET
264 ) => {
265 chained_handler_fn!(
266 $name,
287 GET
288 );
289 };
290
291 (
292 $name:ident, CONNECT
293 ) => {
294 chained_handler_fn!(
295 $name,
300 CONNECT
301 );
302 };
303
304 (
305 $name:ident, $method:ident
306 ) => {
307 chained_handler_fn!(
308 #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
309 $name,
312 $method
313 );
314 };
315
316 (
317 $(#[$m:meta])+
318 $name:ident, $method:ident
319 ) => {
320 $(#[$m])+
321 #[track_caller]
322 pub fn $name<H, T>(self, handler: H) -> Self
323 where
324 H: Handler<T, S>,
325 T: 'static,
326 S: Send + Sync + 'static,
327 {
328 self.on(MethodFilter::$method, handler)
329 }
330 };
331}
332
333top_level_service_fn!(connect_service, CONNECT);
334top_level_service_fn!(delete_service, DELETE);
335top_level_service_fn!(get_service, GET);
336top_level_service_fn!(head_service, HEAD);
337top_level_service_fn!(options_service, OPTIONS);
338top_level_service_fn!(patch_service, PATCH);
339top_level_service_fn!(post_service, POST);
340top_level_service_fn!(put_service, PUT);
341top_level_service_fn!(trace_service, TRACE);
342
343pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
367where
368 T: Service<Request> + Clone + Send + 'static,
369 T::Response: IntoResponse + 'static,
370 T::Future: Send + 'static,
371 S: Clone,
372{
373 MethodRouter::new().on_service(filter, svc)
374}
375
376pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
426where
427 T: Service<Request> + Clone + Send + 'static,
428 T::Response: IntoResponse + 'static,
429 T::Future: Send + 'static,
430 S: Clone,
431{
432 MethodRouter::new()
433 .fallback_service(svc)
434 .skip_allow_header()
435}
436
437top_level_handler_fn!(connect, CONNECT);
438top_level_handler_fn!(delete, DELETE);
439top_level_handler_fn!(get, GET);
440top_level_handler_fn!(head, HEAD);
441top_level_handler_fn!(options, OPTIONS);
442top_level_handler_fn!(patch, PATCH);
443top_level_handler_fn!(post, POST);
444top_level_handler_fn!(put, PUT);
445top_level_handler_fn!(trace, TRACE);
446
447pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
465where
466 H: Handler<T, S>,
467 T: 'static,
468 S: Clone + Send + Sync + 'static,
469{
470 MethodRouter::new().on(filter, handler)
471}
472
473pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
507where
508 H: Handler<T, S>,
509 T: 'static,
510 S: Clone + Send + Sync + 'static,
511{
512 MethodRouter::new().fallback(handler).skip_allow_header()
513}
514
515#[must_use]
545pub struct MethodRouter<S = (), E = Infallible> {
546 get: MethodEndpoint<S, E>,
547 head: MethodEndpoint<S, E>,
548 delete: MethodEndpoint<S, E>,
549 options: MethodEndpoint<S, E>,
550 patch: MethodEndpoint<S, E>,
551 post: MethodEndpoint<S, E>,
552 put: MethodEndpoint<S, E>,
553 trace: MethodEndpoint<S, E>,
554 connect: MethodEndpoint<S, E>,
555 fallback: Fallback<S, E>,
556 allow_header: AllowHeader,
557}
558
559#[derive(Clone, Debug)]
560enum AllowHeader {
561 None,
563 Skip,
565 Bytes(BytesMut),
567}
568
569impl AllowHeader {
570 fn merge(self, other: Self) -> Self {
571 match (self, other) {
572 (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
573 (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
574 (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
575 (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
576 (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
577 a.extend_from_slice(b",");
578 a.extend_from_slice(&b);
579 AllowHeader::Bytes(a)
580 }
581 }
582 }
583}
584
585impl<S, E> fmt::Debug for MethodRouter<S, E> {
586 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
587 f.debug_struct("MethodRouter")
588 .field("get", &self.get)
589 .field("head", &self.head)
590 .field("delete", &self.delete)
591 .field("options", &self.options)
592 .field("patch", &self.patch)
593 .field("post", &self.post)
594 .field("put", &self.put)
595 .field("trace", &self.trace)
596 .field("connect", &self.connect)
597 .field("fallback", &self.fallback)
598 .field("allow_header", &self.allow_header)
599 .finish()
600 }
601}
602
603impl<S> MethodRouter<S, Infallible>
604where
605 S: Clone,
606{
607 #[track_caller]
629 pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
630 where
631 H: Handler<T, S>,
632 T: 'static,
633 S: Send + Sync + 'static,
634 {
635 self.on_endpoint(
636 filter,
637 MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
638 )
639 }
640
641 chained_handler_fn!(connect, CONNECT);
642 chained_handler_fn!(delete, DELETE);
643 chained_handler_fn!(get, GET);
644 chained_handler_fn!(head, HEAD);
645 chained_handler_fn!(options, OPTIONS);
646 chained_handler_fn!(patch, PATCH);
647 chained_handler_fn!(post, POST);
648 chained_handler_fn!(put, PUT);
649 chained_handler_fn!(trace, TRACE);
650
651 pub fn fallback<H, T>(mut self, handler: H) -> Self
653 where
654 H: Handler<T, S>,
655 T: 'static,
656 S: Send + Sync + 'static,
657 {
658 self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
659 self
660 }
661
662 pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
664 where
665 H: Handler<T, S>,
666 T: 'static,
667 S: Send + Sync + 'static,
668 {
669 match self.fallback {
670 Fallback::Default(_) => self.fallback(handler),
671 _ => self,
672 }
673 }
674}
675
676impl MethodRouter<(), Infallible> {
677 pub fn into_make_service(self) -> IntoMakeService<Self> {
705 IntoMakeService::new(self.with_state(()))
706 }
707
708 #[cfg(feature = "tokio")]
737 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
738 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
739 }
740}
741
742impl<S, E> MethodRouter<S, E>
743where
744 S: Clone,
745{
746 pub fn new() -> Self {
749 let fallback = Route::new(service_fn(|_: Request| async {
750 Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())
751 }));
752
753 Self {
754 get: MethodEndpoint::None,
755 head: MethodEndpoint::None,
756 delete: MethodEndpoint::None,
757 options: MethodEndpoint::None,
758 patch: MethodEndpoint::None,
759 post: MethodEndpoint::None,
760 put: MethodEndpoint::None,
761 trace: MethodEndpoint::None,
762 connect: MethodEndpoint::None,
763 allow_header: AllowHeader::None,
764 fallback: Fallback::Default(fallback),
765 }
766 }
767
768 pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
770 MethodRouter {
771 get: self.get.with_state(&state),
772 head: self.head.with_state(&state),
773 delete: self.delete.with_state(&state),
774 options: self.options.with_state(&state),
775 patch: self.patch.with_state(&state),
776 post: self.post.with_state(&state),
777 put: self.put.with_state(&state),
778 trace: self.trace.with_state(&state),
779 connect: self.connect.with_state(&state),
780 allow_header: self.allow_header,
781 fallback: self.fallback.with_state(state),
782 }
783 }
784
785 #[track_caller]
809 pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
810 where
811 T: Service<Request, Error = E> + Clone + Send + 'static,
812 T::Response: IntoResponse + 'static,
813 T::Future: Send + 'static,
814 {
815 self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
816 }
817
818 #[track_caller]
819 fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
820 #[track_caller]
822 fn set_endpoint<S, E>(
823 method_name: &str,
824 out: &mut MethodEndpoint<S, E>,
825 endpoint: &MethodEndpoint<S, E>,
826 endpoint_filter: MethodFilter,
827 filter: MethodFilter,
828 allow_header: &mut AllowHeader,
829 methods: &[&'static str],
830 ) where
831 MethodEndpoint<S, E>: Clone,
832 S: Clone,
833 {
834 if endpoint_filter.contains(filter) {
835 if out.is_some() {
836 panic!(
837 "Overlapping method route. Cannot add two method routes that both handle \
838 `{method_name}`",
839 )
840 }
841 *out = endpoint.clone();
842 for method in methods {
843 append_allow_header(allow_header, method);
844 }
845 }
846 }
847
848 set_endpoint(
849 "GET",
850 &mut self.get,
851 &endpoint,
852 filter,
853 MethodFilter::GET,
854 &mut self.allow_header,
855 &["GET", "HEAD"],
856 );
857
858 set_endpoint(
859 "HEAD",
860 &mut self.head,
861 &endpoint,
862 filter,
863 MethodFilter::HEAD,
864 &mut self.allow_header,
865 &["HEAD"],
866 );
867
868 set_endpoint(
869 "TRACE",
870 &mut self.trace,
871 &endpoint,
872 filter,
873 MethodFilter::TRACE,
874 &mut self.allow_header,
875 &["TRACE"],
876 );
877
878 set_endpoint(
879 "PUT",
880 &mut self.put,
881 &endpoint,
882 filter,
883 MethodFilter::PUT,
884 &mut self.allow_header,
885 &["PUT"],
886 );
887
888 set_endpoint(
889 "POST",
890 &mut self.post,
891 &endpoint,
892 filter,
893 MethodFilter::POST,
894 &mut self.allow_header,
895 &["POST"],
896 );
897
898 set_endpoint(
899 "PATCH",
900 &mut self.patch,
901 &endpoint,
902 filter,
903 MethodFilter::PATCH,
904 &mut self.allow_header,
905 &["PATCH"],
906 );
907
908 set_endpoint(
909 "OPTIONS",
910 &mut self.options,
911 &endpoint,
912 filter,
913 MethodFilter::OPTIONS,
914 &mut self.allow_header,
915 &["OPTIONS"],
916 );
917
918 set_endpoint(
919 "DELETE",
920 &mut self.delete,
921 &endpoint,
922 filter,
923 MethodFilter::DELETE,
924 &mut self.allow_header,
925 &["DELETE"],
926 );
927
928 set_endpoint(
929 "CONNECT",
930 &mut self.options,
931 &endpoint,
932 filter,
933 MethodFilter::CONNECT,
934 &mut self.allow_header,
935 &["CONNECT"],
936 );
937
938 self
939 }
940
941 chained_service_fn!(connect_service, CONNECT);
942 chained_service_fn!(delete_service, DELETE);
943 chained_service_fn!(get_service, GET);
944 chained_service_fn!(head_service, HEAD);
945 chained_service_fn!(options_service, OPTIONS);
946 chained_service_fn!(patch_service, PATCH);
947 chained_service_fn!(post_service, POST);
948 chained_service_fn!(put_service, PUT);
949 chained_service_fn!(trace_service, TRACE);
950
951 #[doc = include_str!("../docs/method_routing/fallback.md")]
952 pub fn fallback_service<T>(mut self, svc: T) -> Self
953 where
954 T: Service<Request, Error = E> + Clone + Send + 'static,
955 T::Response: IntoResponse + 'static,
956 T::Future: Send + 'static,
957 {
958 self.fallback = Fallback::Service(Route::new(svc));
959 self
960 }
961
962 #[doc = include_str!("../docs/method_routing/layer.md")]
963 pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
964 where
965 L: Layer<Route<E>> + Clone + Send + 'static,
966 L::Service: Service<Request> + Clone + Send + 'static,
967 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
968 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
969 <L::Service as Service<Request>>::Future: Send + 'static,
970 E: 'static,
971 S: 'static,
972 NewError: 'static,
973 {
974 let layer_fn = move |route: Route<E>| route.layer(layer.clone());
975
976 MethodRouter {
977 get: self.get.map(layer_fn.clone()),
978 head: self.head.map(layer_fn.clone()),
979 delete: self.delete.map(layer_fn.clone()),
980 options: self.options.map(layer_fn.clone()),
981 patch: self.patch.map(layer_fn.clone()),
982 post: self.post.map(layer_fn.clone()),
983 put: self.put.map(layer_fn.clone()),
984 trace: self.trace.map(layer_fn.clone()),
985 connect: self.connect.map(layer_fn.clone()),
986 fallback: self.fallback.map(layer_fn),
987 allow_header: self.allow_header,
988 }
989 }
990
991 #[doc = include_str!("../docs/method_routing/route_layer.md")]
992 #[track_caller]
993 pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
994 where
995 L: Layer<Route<E>> + Clone + Send + 'static,
996 L::Service: Service<Request, Error = E> + Clone + Send + 'static,
997 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
998 <L::Service as Service<Request>>::Future: Send + 'static,
999 E: 'static,
1000 S: 'static,
1001 {
1002 if self.get.is_none()
1003 && self.head.is_none()
1004 && self.delete.is_none()
1005 && self.options.is_none()
1006 && self.patch.is_none()
1007 && self.post.is_none()
1008 && self.put.is_none()
1009 && self.trace.is_none()
1010 && self.connect.is_none()
1011 {
1012 panic!(
1013 "Adding a route_layer before any routes is a no-op. \
1014 Add the routes you want the layer to apply to first."
1015 );
1016 }
1017
1018 let layer_fn = move |svc| {
1019 let svc = layer.layer(svc);
1020 let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
1021 Route::new(svc)
1022 };
1023
1024 self.get = self.get.map(layer_fn.clone());
1025 self.head = self.head.map(layer_fn.clone());
1026 self.delete = self.delete.map(layer_fn.clone());
1027 self.options = self.options.map(layer_fn.clone());
1028 self.patch = self.patch.map(layer_fn.clone());
1029 self.post = self.post.map(layer_fn.clone());
1030 self.put = self.put.map(layer_fn.clone());
1031 self.trace = self.trace.map(layer_fn.clone());
1032 self.connect = self.connect.map(layer_fn);
1033
1034 self
1035 }
1036
1037 #[track_caller]
1038 pub(crate) fn merge_for_path(mut self, path: Option<&str>, other: MethodRouter<S, E>) -> Self {
1039 #[track_caller]
1041 fn merge_inner<S, E>(
1042 path: Option<&str>,
1043 name: &str,
1044 first: MethodEndpoint<S, E>,
1045 second: MethodEndpoint<S, E>,
1046 ) -> MethodEndpoint<S, E> {
1047 match (first, second) {
1048 (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
1049 (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
1050 _ => {
1051 if let Some(path) = path {
1052 panic!(
1053 "Overlapping method route. Handler for `{name} {path}` already exists"
1054 );
1055 } else {
1056 panic!(
1057 "Overlapping method route. Cannot merge two method routes that both \
1058 define `{name}`"
1059 );
1060 }
1061 }
1062 }
1063 }
1064
1065 self.get = merge_inner(path, "GET", self.get, other.get);
1066 self.head = merge_inner(path, "HEAD", self.head, other.head);
1067 self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
1068 self.options = merge_inner(path, "OPTIONS", self.options, other.options);
1069 self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
1070 self.post = merge_inner(path, "POST", self.post, other.post);
1071 self.put = merge_inner(path, "PUT", self.put, other.put);
1072 self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
1073 self.connect = merge_inner(path, "CONNECT", self.connect, other.connect);
1074
1075 self.fallback = self
1076 .fallback
1077 .merge(other.fallback)
1078 .expect("Cannot merge two `MethodRouter`s that both have a fallback");
1079
1080 self.allow_header = self.allow_header.merge(other.allow_header);
1081
1082 self
1083 }
1084
1085 #[doc = include_str!("../docs/method_routing/merge.md")]
1086 #[track_caller]
1087 pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1088 self.merge_for_path(None, other)
1089 }
1090
1091 pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1095 where
1096 F: Clone + Send + Sync + 'static,
1097 HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1098 <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1099 <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1100 T: 'static,
1101 E: 'static,
1102 S: 'static,
1103 {
1104 self.layer(HandleErrorLayer::new(f))
1105 }
1106
1107 fn skip_allow_header(mut self) -> Self {
1108 self.allow_header = AllowHeader::Skip;
1109 self
1110 }
1111
1112 pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1113 let method = req.method();
1114 let is_head = *method == Method::HEAD;
1115
1116 macro_rules! call {
1117 (
1118 $method_variant:ident,
1119 $svc:expr
1120 ) => {
1121 if *method == Method::$method_variant {
1122 match $svc {
1123 MethodEndpoint::None => {}
1124 MethodEndpoint::Route(route) => {
1125 return RouteFuture::from_future(
1126 route.clone().oneshot_inner_owned(req),
1127 )
1128 .strip_body(is_head);
1129 }
1130 MethodEndpoint::BoxedHandler(handler) => {
1131 let route = handler.clone().into_route(state);
1132 return RouteFuture::from_future(
1133 route.clone().oneshot_inner_owned(req),
1134 )
1135 .strip_body(is_head);
1136 }
1137 }
1138 }
1139 };
1140 }
1141
1142 let Self {
1144 get,
1145 head,
1146 delete,
1147 options,
1148 patch,
1149 post,
1150 put,
1151 trace,
1152 connect,
1153 fallback,
1154 allow_header,
1155 } = self;
1156
1157 call!(HEAD, head);
1158 call!(HEAD, get);
1159 call!(GET, get);
1160 call!(POST, post);
1161 call!(OPTIONS, options);
1162 call!(PATCH, patch);
1163 call!(PUT, put);
1164 call!(DELETE, delete);
1165 call!(TRACE, trace);
1166 call!(CONNECT, connect);
1167
1168 let future = fallback.clone().call_with_state(req, state);
1169
1170 match allow_header {
1171 AllowHeader::None => future.allow_header(Bytes::new()),
1172 AllowHeader::Skip => future,
1173 AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1174 }
1175 }
1176}
1177
1178fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1179 match allow_header {
1180 AllowHeader::None => {
1181 *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1182 }
1183 AllowHeader::Skip => {}
1184 AllowHeader::Bytes(allow_header) => {
1185 if let Ok(s) = std::str::from_utf8(allow_header) {
1186 if !s.contains(method) {
1187 allow_header.extend_from_slice(b",");
1188 allow_header.extend_from_slice(method.as_bytes());
1189 }
1190 } else {
1191 #[cfg(debug_assertions)]
1192 panic!("`allow_header` contained invalid uft-8. This should never happen")
1193 }
1194 }
1195 }
1196}
1197
1198impl<S, E> Clone for MethodRouter<S, E> {
1199 fn clone(&self) -> Self {
1200 Self {
1201 get: self.get.clone(),
1202 head: self.head.clone(),
1203 delete: self.delete.clone(),
1204 options: self.options.clone(),
1205 patch: self.patch.clone(),
1206 post: self.post.clone(),
1207 put: self.put.clone(),
1208 trace: self.trace.clone(),
1209 connect: self.connect.clone(),
1210 fallback: self.fallback.clone(),
1211 allow_header: self.allow_header.clone(),
1212 }
1213 }
1214}
1215
1216impl<S, E> Default for MethodRouter<S, E>
1217where
1218 S: Clone,
1219{
1220 fn default() -> Self {
1221 Self::new()
1222 }
1223}
1224
1225enum MethodEndpoint<S, E> {
1226 None,
1227 Route(Route<E>),
1228 BoxedHandler(BoxedIntoRoute<S, E>),
1229}
1230
1231impl<S, E> MethodEndpoint<S, E>
1232where
1233 S: Clone,
1234{
1235 fn is_some(&self) -> bool {
1236 matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1237 }
1238
1239 fn is_none(&self) -> bool {
1240 matches!(self, Self::None)
1241 }
1242
1243 fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1244 where
1245 S: 'static,
1246 E: 'static,
1247 F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
1248 E2: 'static,
1249 {
1250 match self {
1251 Self::None => MethodEndpoint::None,
1252 Self::Route(route) => MethodEndpoint::Route(f(route)),
1253 Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1254 }
1255 }
1256
1257 fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1258 match self {
1259 MethodEndpoint::None => MethodEndpoint::None,
1260 MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1261 MethodEndpoint::BoxedHandler(handler) => {
1262 MethodEndpoint::Route(handler.into_route(state.clone()))
1263 }
1264 }
1265 }
1266}
1267
1268impl<S, E> Clone for MethodEndpoint<S, E> {
1269 fn clone(&self) -> Self {
1270 match self {
1271 Self::None => Self::None,
1272 Self::Route(inner) => Self::Route(inner.clone()),
1273 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1274 }
1275 }
1276}
1277
1278impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1280 match self {
1281 Self::None => f.debug_tuple("None").finish(),
1282 Self::Route(inner) => inner.fmt(f),
1283 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1284 }
1285 }
1286}
1287
1288impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1289where
1290 B: HttpBody<Data = Bytes> + Send + 'static,
1291 B::Error: Into<BoxError>,
1292{
1293 type Response = Response;
1294 type Error = E;
1295 type Future = RouteFuture<E>;
1296
1297 #[inline]
1298 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1299 Poll::Ready(Ok(()))
1300 }
1301
1302 #[inline]
1303 fn call(&mut self, req: Request<B>) -> Self::Future {
1304 let req = req.map(Body::new);
1305 self.call_with_state(req, ())
1306 }
1307}
1308
1309impl<S> Handler<(), S> for MethodRouter<S>
1310where
1311 S: Clone + 'static,
1312{
1313 type Future = InfallibleRouteFuture;
1314
1315 fn call(self, req: Request, state: S) -> Self::Future {
1316 InfallibleRouteFuture::new(self.call_with_state(req, state))
1317 }
1318}
1319
1320#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1322const _: () = {
1323 use crate::serve::IncomingStream;
1324
1325 impl Service<IncomingStream<'_>> for MethodRouter<()> {
1326 type Response = Self;
1327 type Error = Infallible;
1328 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1329
1330 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1331 Poll::Ready(Ok(()))
1332 }
1333
1334 fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
1335 std::future::ready(Ok(self.clone().with_state(())))
1336 }
1337 }
1338};
1339
1340#[cfg(test)]
1341mod tests {
1342 use super::*;
1343 use crate::{extract::State, handler::HandlerWithoutStateExt};
1344 use http::{header::ALLOW, HeaderMap};
1345 use http_body_util::BodyExt;
1346 use std::time::Duration;
1347 use tower::ServiceExt;
1348 use tower_http::{
1349 services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1350 };
1351
1352 #[crate::test]
1353 async fn method_not_allowed_by_default() {
1354 let mut svc = MethodRouter::new();
1355 let (status, _, body) = call(Method::GET, &mut svc).await;
1356 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1357 assert!(body.is_empty());
1358 }
1359
1360 #[crate::test]
1361 async fn get_service_fn() {
1362 async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1363 Ok(Response::new(Body::from("ok")))
1364 }
1365
1366 let mut svc = get_service(service_fn(handle));
1367
1368 let (status, _, body) = call(Method::GET, &mut svc).await;
1369 assert_eq!(status, StatusCode::OK);
1370 assert_eq!(body, "ok");
1371 }
1372
1373 #[crate::test]
1374 async fn get_handler() {
1375 let mut svc = MethodRouter::new().get(ok);
1376 let (status, _, body) = call(Method::GET, &mut svc).await;
1377 assert_eq!(status, StatusCode::OK);
1378 assert_eq!(body, "ok");
1379 }
1380
1381 #[crate::test]
1382 async fn get_accepts_head() {
1383 let mut svc = MethodRouter::new().get(ok);
1384 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1385 assert_eq!(status, StatusCode::OK);
1386 assert!(body.is_empty());
1387 }
1388
1389 #[crate::test]
1390 async fn head_takes_precedence_over_get() {
1391 let mut svc = MethodRouter::new().head(created).get(ok);
1392 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1393 assert_eq!(status, StatusCode::CREATED);
1394 assert!(body.is_empty());
1395 }
1396
1397 #[crate::test]
1398 async fn merge() {
1399 let mut svc = get(ok).merge(post(ok));
1400
1401 let (status, _, _) = call(Method::GET, &mut svc).await;
1402 assert_eq!(status, StatusCode::OK);
1403
1404 let (status, _, _) = call(Method::POST, &mut svc).await;
1405 assert_eq!(status, StatusCode::OK);
1406 }
1407
1408 #[crate::test]
1409 async fn layer() {
1410 let mut svc = MethodRouter::new()
1411 .get(|| async { std::future::pending::<()>().await })
1412 .layer(ValidateRequestHeaderLayer::bearer("password"));
1413
1414 let (status, _, _) = call(Method::GET, &mut svc).await;
1416 assert_eq!(status, StatusCode::UNAUTHORIZED);
1417
1418 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1420 assert_eq!(status, StatusCode::UNAUTHORIZED);
1421 }
1422
1423 #[crate::test]
1424 async fn route_layer() {
1425 let mut svc = MethodRouter::new()
1426 .get(|| async { std::future::pending::<()>().await })
1427 .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1428
1429 let (status, _, _) = call(Method::GET, &mut svc).await;
1431 assert_eq!(status, StatusCode::UNAUTHORIZED);
1432
1433 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1435 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1436 }
1437
1438 #[allow(dead_code)]
1439 async fn building_complex_router() {
1440 let app = crate::Router::new().route(
1441 "/",
1442 get(ok)
1444 .post(ok)
1445 .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1446 .merge(delete_service(ServeDir::new(".")))
1447 .fallback(|| async { StatusCode::NOT_FOUND })
1448 .put(ok)
1449 .layer(TimeoutLayer::new(Duration::from_secs(10))),
1450 );
1451
1452 let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1453 crate::serve(listener, app).await.unwrap();
1454 }
1455
1456 #[crate::test]
1457 async fn sets_allow_header() {
1458 let mut svc = MethodRouter::new().put(ok).patch(ok);
1459 let (status, headers, _) = call(Method::GET, &mut svc).await;
1460 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1461 assert_eq!(headers[ALLOW], "PUT,PATCH");
1462 }
1463
1464 #[crate::test]
1465 async fn sets_allow_header_get_head() {
1466 let mut svc = MethodRouter::new().get(ok).head(ok);
1467 let (status, headers, _) = call(Method::PUT, &mut svc).await;
1468 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1469 assert_eq!(headers[ALLOW], "GET,HEAD");
1470 }
1471
1472 #[crate::test]
1473 async fn empty_allow_header_by_default() {
1474 let mut svc = MethodRouter::new();
1475 let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1476 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1477 assert_eq!(headers[ALLOW], "");
1478 }
1479
1480 #[crate::test]
1481 async fn allow_header_when_merging() {
1482 let a = put(ok).patch(ok);
1483 let b = get(ok).head(ok);
1484 let mut svc = a.merge(b);
1485
1486 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1487 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1488 assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1489 }
1490
1491 #[crate::test]
1492 async fn allow_header_any() {
1493 let mut svc = any(ok);
1494
1495 let (status, headers, _) = call(Method::GET, &mut svc).await;
1496 assert_eq!(status, StatusCode::OK);
1497 assert!(!headers.contains_key(ALLOW));
1498 }
1499
1500 #[crate::test]
1501 async fn allow_header_with_fallback() {
1502 let mut svc = MethodRouter::new()
1503 .get(ok)
1504 .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1505
1506 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1507 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1508 assert_eq!(headers[ALLOW], "GET,HEAD");
1509 }
1510
1511 #[crate::test]
1512 async fn allow_header_with_fallback_that_sets_allow() {
1513 async fn fallback(method: Method) -> Response {
1514 if method == Method::POST {
1515 "OK".into_response()
1516 } else {
1517 (
1518 StatusCode::METHOD_NOT_ALLOWED,
1519 [(ALLOW, "GET,POST")],
1520 "Method not allowed",
1521 )
1522 .into_response()
1523 }
1524 }
1525
1526 let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1527
1528 let (status, _, _) = call(Method::GET, &mut svc).await;
1529 assert_eq!(status, StatusCode::OK);
1530
1531 let (status, _, _) = call(Method::POST, &mut svc).await;
1532 assert_eq!(status, StatusCode::OK);
1533
1534 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1535 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1536 assert_eq!(headers[ALLOW], "GET,POST");
1537 }
1538
1539 #[crate::test]
1540 async fn allow_header_noop_middleware() {
1541 let mut svc = MethodRouter::new()
1542 .get(ok)
1543 .layer(tower::layer::util::Identity::new());
1544
1545 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1546 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1547 assert_eq!(headers[ALLOW], "GET,HEAD");
1548 }
1549
1550 #[crate::test]
1551 #[should_panic(
1552 expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1553 )]
1554 async fn handler_overlaps() {
1555 let _: MethodRouter<()> = get(ok).get(ok);
1556 }
1557
1558 #[crate::test]
1559 #[should_panic(
1560 expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1561 )]
1562 async fn service_overlaps() {
1563 let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1564 }
1565
1566 #[crate::test]
1567 async fn get_head_does_not_overlap() {
1568 let _: MethodRouter<()> = get(ok).head(ok);
1569 }
1570
1571 #[crate::test]
1572 async fn head_get_does_not_overlap() {
1573 let _: MethodRouter<()> = head(ok).get(ok);
1574 }
1575
1576 #[crate::test]
1577 async fn accessing_state() {
1578 let mut svc = MethodRouter::new()
1579 .get(|State(state): State<&'static str>| async move { state })
1580 .with_state("state");
1581
1582 let (status, _, text) = call(Method::GET, &mut svc).await;
1583
1584 assert_eq!(status, StatusCode::OK);
1585 assert_eq!(text, "state");
1586 }
1587
1588 #[crate::test]
1589 async fn fallback_accessing_state() {
1590 let mut svc = MethodRouter::new()
1591 .fallback(|State(state): State<&'static str>| async move { state })
1592 .with_state("state");
1593
1594 let (status, _, text) = call(Method::GET, &mut svc).await;
1595
1596 assert_eq!(status, StatusCode::OK);
1597 assert_eq!(text, "state");
1598 }
1599
1600 #[crate::test]
1601 async fn merge_accessing_state() {
1602 let one = get(|State(state): State<&'static str>| async move { state });
1603 let two = post(|State(state): State<&'static str>| async move { state });
1604
1605 let mut svc = one.merge(two).with_state("state");
1606
1607 let (status, _, text) = call(Method::GET, &mut svc).await;
1608 assert_eq!(status, StatusCode::OK);
1609 assert_eq!(text, "state");
1610
1611 let (status, _, _) = call(Method::POST, &mut svc).await;
1612 assert_eq!(status, StatusCode::OK);
1613 assert_eq!(text, "state");
1614 }
1615
1616 async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1617 where
1618 S: Service<Request, Error = Infallible>,
1619 S::Response: IntoResponse,
1620 {
1621 let request = Request::builder()
1622 .uri("/")
1623 .method(method)
1624 .body(Body::empty())
1625 .unwrap();
1626 let response = svc
1627 .ready()
1628 .await
1629 .unwrap()
1630 .call(request)
1631 .await
1632 .unwrap()
1633 .into_response();
1634 let (parts, body) = response.into_parts();
1635 let body =
1636 String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1637 (parts.status, parts.headers, body)
1638 }
1639
1640 async fn ok() -> (StatusCode, &'static str) {
1641 (StatusCode::OK, "ok")
1642 }
1643
1644 async fn created() -> (StatusCode, &'static str) {
1645 (StatusCode::CREATED, "created")
1646 }
1647}