1use crate::{
2 body::{Body, HttpBody},
3 response::Response,
4 util::AxumMutex,
5};
6use axum_core::{extract::Request, response::IntoResponse};
7use bytes::Bytes;
8use http::{
9 header::{self, CONTENT_LENGTH},
10 HeaderMap, HeaderValue,
11};
12use pin_project_lite::pin_project;
13use std::{
14 convert::Infallible,
15 fmt,
16 future::Future,
17 pin::Pin,
18 task::{Context, Poll},
19};
20use tower::{
21 util::{BoxCloneService, MapErrLayer, MapResponseLayer, Oneshot},
22 ServiceExt,
23};
24use tower_layer::Layer;
25use tower_service::Service;
26
27pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
32
33impl<E> Route<E> {
34 pub(crate) fn new<T>(svc: T) -> Self
35 where
36 T: Service<Request, Error = E> + Clone + Send + 'static,
37 T::Response: IntoResponse + 'static,
38 T::Future: Send + 'static,
39 {
40 Self(AxumMutex::new(BoxCloneService::new(
41 svc.map_response(IntoResponse::into_response),
42 )))
43 }
44
45 pub(crate) fn call_owned(self, req: Request<Body>) -> RouteFuture<E> {
47 let req = req.map(Body::new);
48 RouteFuture::from_future(self.oneshot_inner_owned(req))
49 }
50
51 pub(crate) fn oneshot_inner(
52 &mut self,
53 req: Request,
54 ) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
55 self.0.get_mut().unwrap().clone().oneshot(req)
56 }
57
58 pub(crate) fn oneshot_inner_owned(
60 self,
61 req: Request,
62 ) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
63 self.0.into_inner().unwrap().oneshot(req)
64 }
65
66 pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
67 where
68 L: Layer<Route<E>> + Clone + Send + 'static,
69 L::Service: Service<Request> + Clone + Send + 'static,
70 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
71 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
72 <L::Service as Service<Request>>::Future: Send + 'static,
73 NewError: 'static,
74 {
75 let layer = (
76 MapErrLayer::new(Into::into),
77 MapResponseLayer::new(IntoResponse::into_response),
78 layer,
79 );
80
81 Route::new(layer.layer(self))
82 }
83}
84
85impl<E> Clone for Route<E> {
86 #[track_caller]
87 fn clone(&self) -> Self {
88 Self(AxumMutex::new(self.0.lock().unwrap().clone()))
89 }
90}
91
92impl<E> fmt::Debug for Route<E> {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.debug_struct("Route").finish()
95 }
96}
97
98impl<B, E> Service<Request<B>> for Route<E>
99where
100 B: HttpBody<Data = bytes::Bytes> + Send + 'static,
101 B::Error: Into<axum_core::BoxError>,
102{
103 type Response = Response;
104 type Error = E;
105 type Future = RouteFuture<E>;
106
107 #[inline]
108 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 Poll::Ready(Ok(()))
110 }
111
112 #[inline]
113 fn call(&mut self, req: Request<B>) -> Self::Future {
114 let req = req.map(Body::new);
115 RouteFuture::from_future(self.oneshot_inner(req)).not_top_level()
116 }
117}
118
119pin_project! {
120 pub struct RouteFuture<E> {
122 #[pin]
123 kind: RouteFutureKind<E>,
124 strip_body: bool,
125 allow_header: Option<Bytes>,
126 top_level: bool,
127 }
128}
129
130pin_project! {
131 #[project = RouteFutureKindProj]
132 enum RouteFutureKind<E> {
133 Future {
134 #[pin]
135 future: Oneshot<
136 BoxCloneService<Request, Response, E>,
137 Request,
138 >,
139 },
140 Response {
141 response: Option<Response>,
142 }
143 }
144}
145
146impl<E> RouteFuture<E> {
147 pub(crate) fn from_future(
148 future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
149 ) -> Self {
150 Self {
151 kind: RouteFutureKind::Future { future },
152 strip_body: false,
153 allow_header: None,
154 top_level: true,
155 }
156 }
157
158 pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
159 self.strip_body = strip_body;
160 self
161 }
162
163 pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
164 self.allow_header = Some(allow_header);
165 self
166 }
167
168 pub(crate) fn not_top_level(mut self) -> Self {
169 self.top_level = false;
170 self
171 }
172}
173
174impl<E> Future for RouteFuture<E> {
175 type Output = Result<Response, E>;
176
177 #[inline]
178 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 let this = self.project();
180
181 let mut res = match this.kind.project() {
182 RouteFutureKindProj::Future { future } => match future.poll(cx) {
183 Poll::Ready(Ok(res)) => res,
184 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
185 Poll::Pending => return Poll::Pending,
186 },
187 RouteFutureKindProj::Response { response } => {
188 response.take().expect("future polled after completion")
189 }
190 };
191
192 if *this.top_level {
193 set_allow_header(res.headers_mut(), this.allow_header);
194
195 set_content_length(res.size_hint(), res.headers_mut());
197
198 if *this.strip_body {
199 *res.body_mut() = Body::empty();
200 }
201 }
202
203 Poll::Ready(Ok(res))
204 }
205}
206
207fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
208 match allow_header.take() {
209 Some(allow_header) if !headers.contains_key(header::ALLOW) => {
210 headers.insert(
211 header::ALLOW,
212 HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
213 );
214 }
215 _ => {}
216 }
217}
218
219fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
220 if headers.contains_key(CONTENT_LENGTH) {
221 return;
222 }
223
224 if let Some(size) = size_hint.exact() {
225 let header_value = if size == 0 {
226 #[allow(clippy::declare_interior_mutable_const)]
227 const ZERO: HeaderValue = HeaderValue::from_static("0");
228
229 ZERO
230 } else {
231 let mut buffer = itoa::Buffer::new();
232 HeaderValue::from_str(buffer.format(size)).unwrap()
233 };
234
235 headers.insert(CONTENT_LENGTH, header_value);
236 }
237}
238
239pin_project! {
240 pub struct InfallibleRouteFuture {
242 #[pin]
243 future: RouteFuture<Infallible>,
244 }
245}
246
247impl InfallibleRouteFuture {
248 pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
249 Self { future }
250 }
251}
252
253impl Future for InfallibleRouteFuture {
254 type Output = Response;
255
256 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
257 match futures_util::ready!(self.project().future.poll(cx)) {
258 Ok(response) => Poll::Ready(response),
259 Err(err) => match err {},
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn traits() {
270 use crate::test_helpers::*;
271 assert_send::<Route<()>>();
272 }
273}