1use crate::body::{Body, Bytes, HttpBody};
2use crate::response::{IntoResponse, Response};
3use crate::BoxError;
4use axum_core::extract::{FromRequest, FromRequestParts};
5use futures_util::future::BoxFuture;
6use http::Request;
7use std::{
8 any::type_name,
9 convert::Infallible,
10 fmt,
11 future::Future,
12 marker::PhantomData,
13 pin::Pin,
14 task::{Context, Poll},
15};
16use tower_layer::Layer;
17use tower_service::Service;
18
19pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
118 map_request_with_state((), f)
119}
120
121pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
160 MapRequestLayer {
161 f,
162 state,
163 _extractor: PhantomData,
164 }
165}
166
167#[must_use]
171pub struct MapRequestLayer<F, S, T> {
172 f: F,
173 state: S,
174 _extractor: PhantomData<fn() -> T>,
175}
176
177impl<F, S, T> Clone for MapRequestLayer<F, S, T>
178where
179 F: Clone,
180 S: Clone,
181{
182 fn clone(&self) -> Self {
183 Self {
184 f: self.f.clone(),
185 state: self.state.clone(),
186 _extractor: self._extractor,
187 }
188 }
189}
190
191impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
192where
193 F: Clone,
194 S: Clone,
195{
196 type Service = MapRequest<F, S, I, T>;
197
198 fn layer(&self, inner: I) -> Self::Service {
199 MapRequest {
200 f: self.f.clone(),
201 state: self.state.clone(),
202 inner,
203 _extractor: PhantomData,
204 }
205 }
206}
207
208impl<F, S, T> fmt::Debug for MapRequestLayer<F, S, T>
209where
210 S: fmt::Debug,
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("MapRequestLayer")
214 .field("f", &format_args!("{}", type_name::<F>()))
216 .field("state", &self.state)
217 .finish()
218 }
219}
220
221pub struct MapRequest<F, S, I, T> {
225 f: F,
226 inner: I,
227 state: S,
228 _extractor: PhantomData<fn() -> T>,
229}
230
231impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
232where
233 F: Clone,
234 I: Clone,
235 S: Clone,
236{
237 fn clone(&self) -> Self {
238 Self {
239 f: self.f.clone(),
240 inner: self.inner.clone(),
241 state: self.state.clone(),
242 _extractor: self._extractor,
243 }
244 }
245}
246
247macro_rules! impl_service {
248 (
249 [$($ty:ident),*], $last:ident
250 ) => {
251 #[allow(non_snake_case, unused_mut)]
252 impl<F, Fut, S, I, B, $($ty,)* $last> Service<Request<B>> for MapRequest<F, S, I, ($($ty,)* $last,)>
253 where
254 F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static,
255 $( $ty: FromRequestParts<S> + Send, )*
256 $last: FromRequest<S> + Send,
257 Fut: Future + Send + 'static,
258 Fut::Output: IntoMapRequestResult<B> + Send + 'static,
259 I: Service<Request<B>, Error = Infallible>
260 + Clone
261 + Send
262 + 'static,
263 I::Response: IntoResponse,
264 I::Future: Send + 'static,
265 B: HttpBody<Data = Bytes> + Send + 'static,
266 B::Error: Into<BoxError>,
267 S: Clone + Send + Sync + 'static,
268 {
269 type Response = Response;
270 type Error = Infallible;
271 type Future = ResponseFuture;
272
273 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274 self.inner.poll_ready(cx)
275 }
276
277 fn call(&mut self, req: Request<B>) -> Self::Future {
278 let req = req.map(Body::new);
279
280 let not_ready_inner = self.inner.clone();
281 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
282
283 let mut f = self.f.clone();
284 let state = self.state.clone();
285
286 let future = Box::pin(async move {
287 let (mut parts, body) = req.into_parts();
288
289 $(
290 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
291 Ok(value) => value,
292 Err(rejection) => return rejection.into_response(),
293 };
294 )*
295
296 let req = Request::from_parts(parts, body);
297
298 let $last = match $last::from_request(req, &state).await {
299 Ok(value) => value,
300 Err(rejection) => return rejection.into_response(),
301 };
302
303 match f($($ty,)* $last).await.into_map_request_result() {
304 Ok(req) => {
305 ready_inner.call(req).await.into_response()
306 }
307 Err(res) => {
308 res
309 }
310 }
311 });
312
313 ResponseFuture {
314 inner: future
315 }
316 }
317 }
318 };
319}
320
321all_the_tuples!(impl_service);
322
323impl<F, S, I, T> fmt::Debug for MapRequest<F, S, I, T>
324where
325 S: fmt::Debug,
326 I: fmt::Debug,
327{
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 f.debug_struct("MapRequest")
330 .field("f", &format_args!("{}", type_name::<F>()))
331 .field("inner", &self.inner)
332 .field("state", &self.state)
333 .finish()
334 }
335}
336
337pub struct ResponseFuture {
339 inner: BoxFuture<'static, Response>,
340}
341
342impl Future for ResponseFuture {
343 type Output = Result<Response, Infallible>;
344
345 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
346 self.inner.as_mut().poll(cx).map(Ok)
347 }
348}
349
350impl fmt::Debug for ResponseFuture {
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 f.debug_struct("ResponseFuture").finish()
353 }
354}
355
356mod private {
357 use crate::{http::Request, response::IntoResponse};
358
359 pub trait Sealed<B> {}
360 impl<B, E> Sealed<B> for Result<Request<B>, E> where E: IntoResponse {}
361 impl<B> Sealed<B> for Request<B> {}
362}
363
364pub trait IntoMapRequestResult<B>: private::Sealed<B> {
369 fn into_map_request_result(self) -> Result<Request<B>, Response>;
371}
372
373impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E>
374where
375 E: IntoResponse,
376{
377 fn into_map_request_result(self) -> Result<Request<B>, Response> {
378 self.map_err(IntoResponse::into_response)
379 }
380}
381
382impl<B> IntoMapRequestResult<B> for Request<B> {
383 fn into_map_request_result(self) -> Result<Request<B>, Response> {
384 Ok(self)
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::{routing::get, test_helpers::TestClient, Router};
392 use http::{HeaderMap, StatusCode};
393
394 #[crate::test]
395 async fn works() {
396 async fn add_header<B>(mut req: Request<B>) -> Request<B> {
397 req.headers_mut().insert("x-foo", "foo".parse().unwrap());
398 req
399 }
400
401 async fn handler(headers: HeaderMap) -> Response {
402 headers["x-foo"]
403 .to_str()
404 .unwrap()
405 .to_owned()
406 .into_response()
407 }
408
409 let app = Router::new()
410 .route("/", get(handler))
411 .layer(map_request(add_header));
412 let client = TestClient::new(app);
413
414 let res = client.get("/").await;
415
416 assert_eq!(res.text().await, "foo");
417 }
418
419 #[crate::test]
420 async fn works_for_short_circutting() {
421 async fn add_header<B>(_req: Request<B>) -> Result<Request<B>, (StatusCode, &'static str)> {
422 Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong"))
423 }
424
425 async fn handler(_headers: HeaderMap) -> Response {
426 unreachable!()
427 }
428
429 let app = Router::new()
430 .route("/", get(handler))
431 .layer(map_request(add_header));
432 let client = TestClient::new(app);
433
434 let res = client.get("/").await;
435
436 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
437 assert_eq!(res.text().await, "something went wrong");
438 }
439}