axum/routing/
route.rs

1use crate::{
2    body::{Body, HttpBody},
3    response::Response,
4    util::AxumMutex,
5};
6use axum_core::{extract::Request, response::IntoResponse};
7use bytes::Bytes;
8use http::{
9    header::{self, CONTENT_LENGTH},
10    HeaderMap, HeaderValue,
11};
12use pin_project_lite::pin_project;
13use std::{
14    convert::Infallible,
15    fmt,
16    future::Future,
17    pin::Pin,
18    task::{Context, Poll},
19};
20use tower::{
21    util::{BoxCloneService, MapErrLayer, MapResponseLayer, Oneshot},
22    ServiceExt,
23};
24use tower_layer::Layer;
25use tower_service::Service;
26
27/// How routes are stored inside a [`Router`](super::Router).
28///
29/// You normally shouldn't need to care about this type. It's used in
30/// [`Router::layer`](super::Router::layer).
31pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
32
33impl<E> Route<E> {
34    pub(crate) fn new<T>(svc: T) -> Self
35    where
36        T: Service<Request, Error = E> + Clone + Send + 'static,
37        T::Response: IntoResponse + 'static,
38        T::Future: Send + 'static,
39    {
40        Self(AxumMutex::new(BoxCloneService::new(
41            svc.map_response(IntoResponse::into_response),
42        )))
43    }
44
45    /// Variant of [`Route::call`] that takes ownership of the route to avoid cloning.
46    pub(crate) fn call_owned(self, req: Request<Body>) -> RouteFuture<E> {
47        let req = req.map(Body::new);
48        RouteFuture::from_future(self.oneshot_inner_owned(req))
49    }
50
51    pub(crate) fn oneshot_inner(
52        &mut self,
53        req: Request,
54    ) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
55        self.0.get_mut().unwrap().clone().oneshot(req)
56    }
57
58    /// Variant of [`Route::oneshot_inner`] that takes ownership of the route to avoid cloning.
59    pub(crate) fn oneshot_inner_owned(
60        self,
61        req: Request,
62    ) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
63        self.0.into_inner().unwrap().oneshot(req)
64    }
65
66    pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
67    where
68        L: Layer<Route<E>> + Clone + Send + 'static,
69        L::Service: Service<Request> + Clone + Send + 'static,
70        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
71        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
72        <L::Service as Service<Request>>::Future: Send + 'static,
73        NewError: 'static,
74    {
75        let layer = (
76            MapErrLayer::new(Into::into),
77            MapResponseLayer::new(IntoResponse::into_response),
78            layer,
79        );
80
81        Route::new(layer.layer(self))
82    }
83}
84
85impl<E> Clone for Route<E> {
86    #[track_caller]
87    fn clone(&self) -> Self {
88        Self(AxumMutex::new(self.0.lock().unwrap().clone()))
89    }
90}
91
92impl<E> fmt::Debug for Route<E> {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        f.debug_struct("Route").finish()
95    }
96}
97
98impl<B, E> Service<Request<B>> for Route<E>
99where
100    B: HttpBody<Data = bytes::Bytes> + Send + 'static,
101    B::Error: Into<axum_core::BoxError>,
102{
103    type Response = Response;
104    type Error = E;
105    type Future = RouteFuture<E>;
106
107    #[inline]
108    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        Poll::Ready(Ok(()))
110    }
111
112    #[inline]
113    fn call(&mut self, req: Request<B>) -> Self::Future {
114        let req = req.map(Body::new);
115        RouteFuture::from_future(self.oneshot_inner(req)).not_top_level()
116    }
117}
118
119pin_project! {
120    /// Response future for [`Route`].
121    pub struct RouteFuture<E> {
122        #[pin]
123        kind: RouteFutureKind<E>,
124        strip_body: bool,
125        allow_header: Option<Bytes>,
126        top_level: bool,
127    }
128}
129
130pin_project! {
131    #[project = RouteFutureKindProj]
132    enum RouteFutureKind<E> {
133        Future {
134            #[pin]
135            future: Oneshot<
136                BoxCloneService<Request, Response, E>,
137                Request,
138            >,
139        },
140        Response {
141            response: Option<Response>,
142        }
143    }
144}
145
146impl<E> RouteFuture<E> {
147    pub(crate) fn from_future(
148        future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
149    ) -> Self {
150        Self {
151            kind: RouteFutureKind::Future { future },
152            strip_body: false,
153            allow_header: None,
154            top_level: true,
155        }
156    }
157
158    pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
159        self.strip_body = strip_body;
160        self
161    }
162
163    pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
164        self.allow_header = Some(allow_header);
165        self
166    }
167
168    pub(crate) fn not_top_level(mut self) -> Self {
169        self.top_level = false;
170        self
171    }
172}
173
174impl<E> Future for RouteFuture<E> {
175    type Output = Result<Response, E>;
176
177    #[inline]
178    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        let this = self.project();
180
181        let mut res = match this.kind.project() {
182            RouteFutureKindProj::Future { future } => match future.poll(cx) {
183                Poll::Ready(Ok(res)) => res,
184                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
185                Poll::Pending => return Poll::Pending,
186            },
187            RouteFutureKindProj::Response { response } => {
188                response.take().expect("future polled after completion")
189            }
190        };
191
192        if *this.top_level {
193            set_allow_header(res.headers_mut(), this.allow_header);
194
195            // make sure to set content-length before removing the body
196            set_content_length(res.size_hint(), res.headers_mut());
197
198            if *this.strip_body {
199                *res.body_mut() = Body::empty();
200            }
201        }
202
203        Poll::Ready(Ok(res))
204    }
205}
206
207fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
208    match allow_header.take() {
209        Some(allow_header) if !headers.contains_key(header::ALLOW) => {
210            headers.insert(
211                header::ALLOW,
212                HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
213            );
214        }
215        _ => {}
216    }
217}
218
219fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
220    if headers.contains_key(CONTENT_LENGTH) {
221        return;
222    }
223
224    if let Some(size) = size_hint.exact() {
225        let header_value = if size == 0 {
226            #[allow(clippy::declare_interior_mutable_const)]
227            const ZERO: HeaderValue = HeaderValue::from_static("0");
228
229            ZERO
230        } else {
231            let mut buffer = itoa::Buffer::new();
232            HeaderValue::from_str(buffer.format(size)).unwrap()
233        };
234
235        headers.insert(CONTENT_LENGTH, header_value);
236    }
237}
238
239pin_project! {
240    /// A [`RouteFuture`] that always yields a [`Response`].
241    pub struct InfallibleRouteFuture {
242        #[pin]
243        future: RouteFuture<Infallible>,
244    }
245}
246
247impl InfallibleRouteFuture {
248    pub(crate) fn new(future: RouteFuture<Infallible>) -> Self {
249        Self { future }
250    }
251}
252
253impl Future for InfallibleRouteFuture {
254    type Output = Response;
255
256    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
257        match futures_util::ready!(self.project().future.poll(cx)) {
258            Ok(response) => Poll::Ready(response),
259            Err(err) => match err {},
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn traits() {
270        use crate::test_helpers::*;
271        assert_send::<Route<()>>();
272    }
273}