1use crate::response::{IntoResponse, Response};
2use axum_core::extract::FromRequestParts;
3use futures_util::future::BoxFuture;
4use http::Request;
5use std::{
6 any::type_name,
7 convert::Infallible,
8 fmt,
9 future::Future,
10 marker::PhantomData,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tower_layer::Layer;
15use tower_service::Service;
16
17pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
100 map_response_with_state((), f)
101}
102
103pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
142 MapResponseLayer {
143 f,
144 state,
145 _extractor: PhantomData,
146 }
147}
148
149#[must_use]
153pub struct MapResponseLayer<F, S, T> {
154 f: F,
155 state: S,
156 _extractor: PhantomData<fn() -> T>,
157}
158
159impl<F, S, T> Clone for MapResponseLayer<F, S, T>
160where
161 F: Clone,
162 S: Clone,
163{
164 fn clone(&self) -> Self {
165 Self {
166 f: self.f.clone(),
167 state: self.state.clone(),
168 _extractor: self._extractor,
169 }
170 }
171}
172
173impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
174where
175 F: Clone,
176 S: Clone,
177{
178 type Service = MapResponse<F, S, I, T>;
179
180 fn layer(&self, inner: I) -> Self::Service {
181 MapResponse {
182 f: self.f.clone(),
183 state: self.state.clone(),
184 inner,
185 _extractor: PhantomData,
186 }
187 }
188}
189
190impl<F, S, T> fmt::Debug for MapResponseLayer<F, S, T>
191where
192 S: fmt::Debug,
193{
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("MapResponseLayer")
196 .field("f", &format_args!("{}", type_name::<F>()))
198 .field("state", &self.state)
199 .finish()
200 }
201}
202
203pub struct MapResponse<F, S, I, T> {
207 f: F,
208 inner: I,
209 state: S,
210 _extractor: PhantomData<fn() -> T>,
211}
212
213impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
214where
215 F: Clone,
216 I: Clone,
217 S: Clone,
218{
219 fn clone(&self) -> Self {
220 Self {
221 f: self.f.clone(),
222 inner: self.inner.clone(),
223 state: self.state.clone(),
224 _extractor: self._extractor,
225 }
226 }
227}
228
229macro_rules! impl_service {
230 (
231 $($ty:ident),*
232 ) => {
233 #[allow(non_snake_case, unused_mut)]
234 impl<F, Fut, S, I, B, ResBody, $($ty,)*> Service<Request<B>> for MapResponse<F, S, I, ($($ty,)*)>
235 where
236 F: FnMut($($ty,)* Response<ResBody>) -> Fut + Clone + Send + 'static,
237 $( $ty: FromRequestParts<S> + Send, )*
238 Fut: Future + Send + 'static,
239 Fut::Output: IntoResponse + Send + 'static,
240 I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
241 + Clone
242 + Send
243 + 'static,
244 I::Future: Send + 'static,
245 B: Send + 'static,
246 ResBody: Send + 'static,
247 S: Clone + Send + Sync + 'static,
248 {
249 type Response = Response;
250 type Error = Infallible;
251 type Future = ResponseFuture;
252
253 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254 self.inner.poll_ready(cx)
255 }
256
257
258 fn call(&mut self, req: Request<B>) -> Self::Future {
259 let not_ready_inner = self.inner.clone();
260 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
261
262 let mut f = self.f.clone();
263 let _state = self.state.clone();
264
265 let future = Box::pin(async move {
266 let (mut parts, body) = req.into_parts();
267
268 $(
269 let $ty = match $ty::from_request_parts(&mut parts, &_state).await {
270 Ok(value) => value,
271 Err(rejection) => return rejection.into_response(),
272 };
273 )*
274
275 let req = Request::from_parts(parts, body);
276
277 match ready_inner.call(req).await {
278 Ok(res) => {
279 f($($ty,)* res).await.into_response()
280 }
281 Err(err) => match err {}
282 }
283 });
284
285 ResponseFuture {
286 inner: future
287 }
288 }
289 }
290 };
291}
292
293impl_service!();
294impl_service!(T1);
295impl_service!(T1, T2);
296impl_service!(T1, T2, T3);
297impl_service!(T1, T2, T3, T4);
298impl_service!(T1, T2, T3, T4, T5);
299impl_service!(T1, T2, T3, T4, T5, T6);
300impl_service!(T1, T2, T3, T4, T5, T6, T7);
301impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
302impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
303impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
304impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
305impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
306impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
307impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
308impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
309impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
310
311impl<F, S, I, T> fmt::Debug for MapResponse<F, S, I, T>
312where
313 S: fmt::Debug,
314 I: fmt::Debug,
315{
316 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317 f.debug_struct("MapResponse")
318 .field("f", &format_args!("{}", type_name::<F>()))
319 .field("inner", &self.inner)
320 .field("state", &self.state)
321 .finish()
322 }
323}
324
325pub struct ResponseFuture {
327 inner: BoxFuture<'static, Response>,
328}
329
330impl Future for ResponseFuture {
331 type Output = Result<Response, Infallible>;
332
333 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
334 self.inner.as_mut().poll(cx).map(Ok)
335 }
336}
337
338impl fmt::Debug for ResponseFuture {
339 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340 f.debug_struct("ResponseFuture").finish()
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 #[allow(unused_imports)]
347 use super::*;
348 use crate::{test_helpers::TestClient, Router};
349
350 #[crate::test]
351 async fn works() {
352 async fn add_header<B>(mut res: Response<B>) -> Response<B> {
353 res.headers_mut().insert("x-foo", "foo".parse().unwrap());
354 res
355 }
356
357 let app = Router::new().layer(map_response(add_header));
358 let client = TestClient::new(app);
359
360 let res = client.get("/").await;
361
362 assert_eq!(res.headers()["x-foo"], "foo");
363 }
364}