axum/
extension.rs

1use crate::{extract::rejection::*, response::IntoResponseParts};
2use async_trait::async_trait;
3use axum_core::{
4    extract::FromRequestParts,
5    response::{IntoResponse, Response, ResponseParts},
6};
7use http::{request::Parts, Request};
8use std::{
9    convert::Infallible,
10    task::{Context, Poll},
11};
12use tower_service::Service;
13
14/// Extractor and response for extensions.
15///
16/// # As extractor
17///
18/// This is commonly used to share state across handlers.
19///
20/// ```rust,no_run
21/// use axum::{
22///     Router,
23///     Extension,
24///     routing::get,
25/// };
26/// use std::sync::Arc;
27///
28/// // Some shared state used throughout our application
29/// struct State {
30///     // ...
31/// }
32///
33/// async fn handler(state: Extension<Arc<State>>) {
34///     // ...
35/// }
36///
37/// let state = Arc::new(State { /* ... */ });
38///
39/// let app = Router::new().route("/", get(handler))
40///     // Add middleware that inserts the state into all incoming request's
41///     // extensions.
42///     .layer(Extension(state));
43/// # let _: Router = app;
44/// ```
45///
46/// If the extension is missing it will reject the request with a `500 Internal
47/// Server Error` response.
48///
49/// # As response
50///
51/// Response extensions can be used to share state with middleware.
52///
53/// ```rust
54/// use axum::{
55///     Extension,
56///     response::IntoResponse,
57/// };
58///
59/// async fn handler() -> (Extension<Foo>, &'static str) {
60///     (
61///         Extension(Foo("foo")),
62///         "Hello, World!"
63///     )
64/// }
65///
66/// #[derive(Clone)]
67/// struct Foo(&'static str);
68/// ```
69#[derive(Debug, Clone, Copy, Default)]
70#[must_use]
71pub struct Extension<T>(pub T);
72
73#[async_trait]
74impl<T, S> FromRequestParts<S> for Extension<T>
75where
76    T: Clone + Send + Sync + 'static,
77    S: Send + Sync,
78{
79    type Rejection = ExtensionRejection;
80
81    async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
82        let value = req
83            .extensions
84            .get::<T>()
85            .ok_or_else(|| {
86                MissingExtension::from_err(format!(
87                    "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
88                    std::any::type_name::<T>()
89                ))
90            }).cloned()?;
91
92        Ok(Extension(value))
93    }
94}
95
96axum_core::__impl_deref!(Extension);
97
98impl<T> IntoResponseParts for Extension<T>
99where
100    T: Clone + Send + Sync + 'static,
101{
102    type Error = Infallible;
103
104    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
105        res.extensions_mut().insert(self.0);
106        Ok(res)
107    }
108}
109
110impl<T> IntoResponse for Extension<T>
111where
112    T: Clone + Send + Sync + 'static,
113{
114    fn into_response(self) -> Response {
115        let mut res = ().into_response();
116        res.extensions_mut().insert(self.0);
117        res
118    }
119}
120
121impl<S, T> tower_layer::Layer<S> for Extension<T>
122where
123    T: Clone + Send + Sync + 'static,
124{
125    type Service = AddExtension<S, T>;
126
127    fn layer(&self, inner: S) -> Self::Service {
128        AddExtension {
129            inner,
130            value: self.0.clone(),
131        }
132    }
133}
134
135/// Middleware for adding some shareable value to [request extensions].
136///
137/// See [Sharing state with handlers](index.html#sharing-state-with-handlers)
138/// for more details.
139///
140/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
141#[derive(Clone, Copy, Debug)]
142pub struct AddExtension<S, T> {
143    pub(crate) inner: S,
144    pub(crate) value: T,
145}
146
147impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
148where
149    S: Service<Request<ResBody>>,
150    T: Clone + Send + Sync + 'static,
151{
152    type Response = S::Response;
153    type Error = S::Error;
154    type Future = S::Future;
155
156    #[inline]
157    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158        self.inner.poll_ready(cx)
159    }
160
161    fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
162        req.extensions_mut().insert(self.value.clone());
163        self.inner.call(req)
164    }
165}