1use crate::{
2 extract::FromRequestParts,
3 response::{IntoResponse, Response},
4};
5use futures_util::{future::BoxFuture, ready};
6use http::Request;
7use pin_project_lite::pin_project;
8use std::{
9 fmt,
10 future::Future,
11 marker::PhantomData,
12 pin::Pin,
13 task::{Context, Poll},
14};
15use tower_layer::Layer;
16use tower_service::Service;
17
18pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
92 from_extractor_with_state(())
93}
94
95pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
99 FromExtractorLayer {
100 state,
101 _marker: PhantomData,
102 }
103}
104
105#[must_use]
112pub struct FromExtractorLayer<E, S> {
113 state: S,
114 _marker: PhantomData<fn() -> E>,
115}
116
117impl<E, S> Clone for FromExtractorLayer<E, S>
118where
119 S: Clone,
120{
121 fn clone(&self) -> Self {
122 Self {
123 state: self.state.clone(),
124 _marker: PhantomData,
125 }
126 }
127}
128
129impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
130where
131 S: fmt::Debug,
132{
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("FromExtractorLayer")
135 .field("state", &self.state)
136 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
137 .finish()
138 }
139}
140
141impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
142where
143 S: Clone,
144{
145 type Service = FromExtractor<T, E, S>;
146
147 fn layer(&self, inner: T) -> Self::Service {
148 FromExtractor {
149 inner,
150 state: self.state.clone(),
151 _extractor: PhantomData,
152 }
153 }
154}
155
156pub struct FromExtractor<T, E, S> {
160 inner: T,
161 state: S,
162 _extractor: PhantomData<fn() -> E>,
163}
164
165#[test]
166fn traits() {
167 use crate::test_helpers::*;
168 assert_send::<FromExtractor<(), NotSendSync, ()>>();
169 assert_sync::<FromExtractor<(), NotSendSync, ()>>();
170}
171
172impl<T, E, S> Clone for FromExtractor<T, E, S>
173where
174 T: Clone,
175 S: Clone,
176{
177 fn clone(&self) -> Self {
178 Self {
179 inner: self.inner.clone(),
180 state: self.state.clone(),
181 _extractor: PhantomData,
182 }
183 }
184}
185
186impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
187where
188 T: fmt::Debug,
189 S: fmt::Debug,
190{
191 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 f.debug_struct("FromExtractor")
193 .field("inner", &self.inner)
194 .field("state", &self.state)
195 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
196 .finish()
197 }
198}
199
200impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
201where
202 E: FromRequestParts<S> + 'static,
203 B: Send + 'static,
204 T: Service<Request<B>> + Clone,
205 T::Response: IntoResponse,
206 S: Clone + Send + Sync + 'static,
207{
208 type Response = Response;
209 type Error = T::Error;
210 type Future = ResponseFuture<B, T, E, S>;
211
212 #[inline]
213 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
214 self.inner.poll_ready(cx)
215 }
216
217 fn call(&mut self, req: Request<B>) -> Self::Future {
218 let state = self.state.clone();
219 let extract_future = Box::pin(async move {
220 let (mut parts, body) = req.into_parts();
221 let extracted = E::from_request_parts(&mut parts, &state).await;
222 let req = Request::from_parts(parts, body);
223 (req, extracted)
224 });
225
226 ResponseFuture {
227 state: State::Extracting {
228 future: extract_future,
229 },
230 svc: Some(self.inner.clone()),
231 }
232 }
233}
234
235pin_project! {
236 #[allow(missing_debug_implementations)]
238 pub struct ResponseFuture<B, T, E, S>
239 where
240 E: FromRequestParts<S>,
241 T: Service<Request<B>>,
242 {
243 #[pin]
244 state: State<B, T, E, S>,
245 svc: Option<T>,
246 }
247}
248
249pin_project! {
250 #[project = StateProj]
251 enum State<B, T, E, S>
252 where
253 E: FromRequestParts<S>,
254 T: Service<Request<B>>,
255 {
256 Extracting {
257 future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
258 },
259 Call { #[pin] future: T::Future },
260 }
261}
262
263impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
264where
265 E: FromRequestParts<S>,
266 T: Service<Request<B>>,
267 T::Response: IntoResponse,
268{
269 type Output = Result<Response, T::Error>;
270
271 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
272 loop {
273 let mut this = self.as_mut().project();
274
275 let new_state = match this.state.as_mut().project() {
276 StateProj::Extracting { future } => {
277 let (req, extracted) = ready!(future.as_mut().poll(cx));
278
279 match extracted {
280 Ok(_) => {
281 let mut svc = this.svc.take().expect("future polled after completion");
282 let future = svc.call(req);
283 State::Call { future }
284 }
285 Err(err) => {
286 let res = err.into_response();
287 return Poll::Ready(Ok(res));
288 }
289 }
290 }
291 StateProj::Call { future } => {
292 return future
293 .poll(cx)
294 .map(|result| result.map(IntoResponse::into_response));
295 }
296 };
297
298 this.state.set(new_state);
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::{async_trait, handler::Handler, routing::get, test_helpers::*, Router};
307 use axum_core::extract::FromRef;
308 use http::{header, request::Parts, StatusCode};
309 use tower_http::limit::RequestBodyLimitLayer;
310
311 #[crate::test]
312 async fn test_from_extractor() {
313 #[derive(Clone)]
314 struct Secret(&'static str);
315
316 struct RequireAuth;
317
318 #[async_trait::async_trait]
319 impl<S> FromRequestParts<S> for RequireAuth
320 where
321 S: Send + Sync,
322 Secret: FromRef<S>,
323 {
324 type Rejection = StatusCode;
325
326 async fn from_request_parts(
327 parts: &mut Parts,
328 state: &S,
329 ) -> Result<Self, Self::Rejection> {
330 let Secret(secret) = Secret::from_ref(state);
331 if let Some(auth) = parts
332 .headers
333 .get(header::AUTHORIZATION)
334 .and_then(|v| v.to_str().ok())
335 {
336 if auth == secret {
337 return Ok(Self);
338 }
339 }
340
341 Err(StatusCode::UNAUTHORIZED)
342 }
343 }
344
345 async fn handler() {}
346
347 let state = Secret("secret");
348 let app = Router::new().route(
349 "/",
350 get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
351 );
352
353 let client = TestClient::new(app);
354
355 let res = client.get("/").await;
356 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
357
358 let res = client
359 .get("/")
360 .header(http::header::AUTHORIZATION, "secret")
361 .await;
362 assert_eq!(res.status(), StatusCode::OK);
363 }
364
365 #[allow(dead_code)]
367 fn works_with_request_body_limit() {
368 struct MyExtractor;
369
370 #[async_trait]
371 impl<S> FromRequestParts<S> for MyExtractor
372 where
373 S: Send + Sync,
374 {
375 type Rejection = std::convert::Infallible;
376
377 async fn from_request_parts(
378 _parts: &mut Parts,
379 _state: &S,
380 ) -> Result<Self, Self::Rejection> {
381 unimplemented!()
382 }
383 }
384
385 let _: Router = Router::new()
386 .layer(from_extractor::<MyExtractor>())
387 .layer(RequestBodyLimitLayer::new(1));
388 }
389}