1use crate::response::{IntoResponse, Response};
2use axum_core::extract::{FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4use std::{
5 any::type_name,
6 convert::Infallible,
7 fmt,
8 future::Future,
9 marker::PhantomData,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tower::{util::BoxCloneService, ServiceBuilder};
14use tower_layer::Layer;
15use tower_service::Service;
16
17pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
111 from_fn_with_state((), f)
112}
113
114pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
159 FromFnLayer {
160 f,
161 state,
162 _extractor: PhantomData,
163 }
164}
165
166#[must_use]
172pub struct FromFnLayer<F, S, T> {
173 f: F,
174 state: S,
175 _extractor: PhantomData<fn() -> T>,
176}
177
178impl<F, S, T> Clone for FromFnLayer<F, S, T>
179where
180 F: Clone,
181 S: Clone,
182{
183 fn clone(&self) -> Self {
184 Self {
185 f: self.f.clone(),
186 state: self.state.clone(),
187 _extractor: self._extractor,
188 }
189 }
190}
191
192impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
193where
194 F: Clone,
195 S: Clone,
196{
197 type Service = FromFn<F, S, I, T>;
198
199 fn layer(&self, inner: I) -> Self::Service {
200 FromFn {
201 f: self.f.clone(),
202 state: self.state.clone(),
203 inner,
204 _extractor: PhantomData,
205 }
206 }
207}
208
209impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
210where
211 S: fmt::Debug,
212{
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 f.debug_struct("FromFnLayer")
215 .field("f", &format_args!("{}", type_name::<F>()))
217 .field("state", &self.state)
218 .finish()
219 }
220}
221
222pub struct FromFn<F, S, I, T> {
226 f: F,
227 inner: I,
228 state: S,
229 _extractor: PhantomData<fn() -> T>,
230}
231
232impl<F, S, I, T> Clone for FromFn<F, S, I, T>
233where
234 F: Clone,
235 I: Clone,
236 S: Clone,
237{
238 fn clone(&self) -> Self {
239 Self {
240 f: self.f.clone(),
241 inner: self.inner.clone(),
242 state: self.state.clone(),
243 _extractor: self._extractor,
244 }
245 }
246}
247
248macro_rules! impl_service {
249 (
250 [$($ty:ident),*], $last:ident
251 ) => {
252 #[allow(non_snake_case, unused_mut)]
253 impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
254 where
255 F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
256 $( $ty: FromRequestParts<S> + Send, )*
257 $last: FromRequest<S> + Send,
258 Fut: Future<Output = Out> + Send + 'static,
259 Out: IntoResponse + 'static,
260 I: Service<Request, Error = Infallible>
261 + Clone
262 + Send
263 + 'static,
264 I::Response: IntoResponse,
265 I::Future: Send + 'static,
266 S: Clone + Send + Sync + 'static,
267 {
268 type Response = Response;
269 type Error = Infallible;
270 type Future = ResponseFuture;
271
272 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273 self.inner.poll_ready(cx)
274 }
275
276 fn call(&mut self, req: Request) -> Self::Future {
277 let not_ready_inner = self.inner.clone();
278 let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
279
280 let mut f = self.f.clone();
281 let state = self.state.clone();
282
283 let future = Box::pin(async move {
284 let (mut parts, body) = req.into_parts();
285
286 $(
287 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
288 Ok(value) => value,
289 Err(rejection) => return rejection.into_response(),
290 };
291 )*
292
293 let req = Request::from_parts(parts, body);
294
295 let $last = match $last::from_request(req, &state).await {
296 Ok(value) => value,
297 Err(rejection) => return rejection.into_response(),
298 };
299
300 let inner = ServiceBuilder::new()
301 .boxed_clone()
302 .map_response(IntoResponse::into_response)
303 .service(ready_inner);
304 let next = Next { inner };
305
306 f($($ty,)* $last, next).await.into_response()
307 });
308
309 ResponseFuture {
310 inner: future
311 }
312 }
313 }
314 };
315}
316
317all_the_tuples!(impl_service);
318
319impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
320where
321 S: fmt::Debug,
322 I: fmt::Debug,
323{
324 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325 f.debug_struct("FromFnLayer")
326 .field("f", &format_args!("{}", type_name::<F>()))
327 .field("inner", &self.inner)
328 .field("state", &self.state)
329 .finish()
330 }
331}
332
333#[derive(Debug, Clone)]
335pub struct Next {
336 inner: BoxCloneService<Request, Response, Infallible>,
337}
338
339impl Next {
340 pub async fn run(mut self, req: Request) -> Response {
342 match self.inner.call(req).await {
343 Ok(res) => res,
344 Err(err) => match err {},
345 }
346 }
347}
348
349impl Service<Request> for Next {
350 type Response = Response;
351 type Error = Infallible;
352 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
353
354 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355 self.inner.poll_ready(cx)
356 }
357
358 fn call(&mut self, req: Request) -> Self::Future {
359 self.inner.call(req)
360 }
361}
362
363pub struct ResponseFuture {
365 inner: BoxFuture<'static, Response>,
366}
367
368impl Future for ResponseFuture {
369 type Output = Result<Response, Infallible>;
370
371 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
372 self.inner.as_mut().poll(cx).map(Ok)
373 }
374}
375
376impl fmt::Debug for ResponseFuture {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 f.debug_struct("ResponseFuture").finish()
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::{body::Body, routing::get, Router};
386 use http::{HeaderMap, StatusCode};
387 use http_body_util::BodyExt;
388 use tower::ServiceExt;
389
390 #[crate::test]
391 async fn basic() {
392 async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
393 req.headers_mut()
394 .insert("x-axum-test", "ok".parse().unwrap());
395
396 next.run(req).await
397 }
398
399 async fn handle(headers: HeaderMap) -> String {
400 headers["x-axum-test"].to_str().unwrap().to_owned()
401 }
402
403 let app = Router::new()
404 .route("/", get(handle))
405 .layer(from_fn(insert_header));
406
407 let res = app
408 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
409 .await
410 .unwrap();
411 assert_eq!(res.status(), StatusCode::OK);
412 let body = res.collect().await.unwrap().to_bytes();
413 assert_eq!(&body[..], b"ok");
414 }
415}