axum/extract/
state.rs

1use async_trait::async_trait;
2use axum_core::extract::{FromRef, FromRequestParts};
3use http::request::Parts;
4use std::{
5    convert::Infallible,
6    ops::{Deref, DerefMut},
7};
8
9/// Extractor for state.
10///
11/// See ["Accessing state in middleware"][state-from-middleware] for how to
12/// access state in middleware.
13///
14/// State is global and used in every request a router with state receives.
15/// For accessing data derived from requests, such as authorization data, see [`Extension`].
16///
17/// [state-from-middleware]: crate::middleware#accessing-state-in-middleware
18/// [`Extension`]: crate::Extension
19///
20/// # With `Router`
21///
22/// ```
23/// use axum::{Router, routing::get, extract::State};
24///
25/// // the application state
26/// //
27/// // here you can put configuration, database connection pools, or whatever
28/// // state you need
29/// //
30/// // see "When states need to implement `Clone`" for more details on why we need
31/// // `#[derive(Clone)]` here.
32/// #[derive(Clone)]
33/// struct AppState {}
34///
35/// let state = AppState {};
36///
37/// // create a `Router` that holds our state
38/// let app = Router::new()
39///     .route("/", get(handler))
40///     // provide the state so the router can access it
41///     .with_state(state);
42///
43/// async fn handler(
44///     // access the state via the `State` extractor
45///     // extracting a state of the wrong type results in a compile error
46///     State(state): State<AppState>,
47/// ) {
48///     // use `state`...
49/// }
50/// # let _: axum::Router = app;
51/// ```
52///
53/// Note that `State` is an extractor, so be sure to put it before any body
54/// extractors, see ["the order of extractors"][order-of-extractors].
55///
56/// [order-of-extractors]: crate::extract#the-order-of-extractors
57///
58/// ## Combining stateful routers
59///
60/// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`]
61/// When combining [`Router`]s with one of these methods, the [`Router`]s must have
62/// the same state type. Generally, this can be inferred automatically:
63///
64/// ```
65/// use axum::{Router, routing::get, extract::State};
66///
67/// #[derive(Clone)]
68/// struct AppState {}
69///
70/// let state = AppState {};
71///
72/// // create a `Router` that will be nested within another
73/// let api = Router::new()
74///     .route("/posts", get(posts_handler));
75///
76/// let app = Router::new()
77///     .nest("/api", api)
78///     .with_state(state);
79///
80/// async fn posts_handler(State(state): State<AppState>) {
81///     // use `state`...
82/// }
83/// # let _: axum::Router = app;
84/// ```
85///
86/// However, if you are composing [`Router`]s that are defined in separate scopes,
87/// you may need to annotate the [`State`] type explicitly:
88///
89/// ```
90/// use axum::{Router, routing::get, extract::State};
91///
92/// #[derive(Clone)]
93/// struct AppState {}
94///
95/// fn make_app() -> Router {
96///     let state = AppState {};
97///
98///     Router::new()
99///         .nest("/api", make_api())
100///         .with_state(state) // the outer Router's state is inferred
101/// }
102///
103/// // the inner Router must specify its state type to compose with the
104/// // outer router
105/// fn make_api() -> Router<AppState> {
106///     Router::new()
107///         .route("/posts", get(posts_handler))
108/// }
109///
110/// async fn posts_handler(State(state): State<AppState>) {
111///     // use `state`...
112/// }
113/// # let _: axum::Router = make_app();
114/// ```
115///
116/// In short, a [`Router`]'s generic state type defaults to `()`
117/// (no state) unless [`Router::with_state`] is called or the value
118/// of the generic type is given explicitly.
119///
120/// [`Router`]: crate::Router
121/// [`Router::merge`]: crate::Router::merge
122/// [`Router::nest`]: crate::Router::nest
123/// [`Router::with_state`]: crate::Router::with_state
124///
125/// # With `MethodRouter`
126///
127/// ```
128/// use axum::{routing::get, extract::State};
129///
130/// #[derive(Clone)]
131/// struct AppState {}
132///
133/// let state = AppState {};
134///
135/// let method_router_with_state = get(handler)
136///     // provide the state so the handler can access it
137///     .with_state(state);
138/// # let _: axum::routing::MethodRouter = method_router_with_state;
139///
140/// async fn handler(State(state): State<AppState>) {
141///     // use `state`...
142/// }
143/// ```
144///
145/// # With `Handler`
146///
147/// ```
148/// use axum::{routing::get, handler::Handler, extract::State};
149///
150/// #[derive(Clone)]
151/// struct AppState {}
152///
153/// let state = AppState {};
154///
155/// async fn handler(State(state): State<AppState>) {
156///     // use `state`...
157/// }
158///
159/// // provide the state so the handler can access it
160/// let handler_with_state = handler.with_state(state);
161///
162/// # async {
163/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
164/// axum::serve(listener, handler_with_state.into_make_service()).await.unwrap();
165/// # };
166/// ```
167///
168/// # Substates
169///
170/// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates":
171///
172/// ```
173/// use axum::{Router, routing::get, extract::{State, FromRef}};
174///
175/// // the application state
176/// #[derive(Clone)]
177/// struct AppState {
178///     // that holds some api specific state
179///     api_state: ApiState,
180/// }
181///
182/// // the api specific state
183/// #[derive(Clone)]
184/// struct ApiState {}
185///
186/// // support converting an `AppState` in an `ApiState`
187/// impl FromRef<AppState> for ApiState {
188///     fn from_ref(app_state: &AppState) -> ApiState {
189///         app_state.api_state.clone()
190///     }
191/// }
192///
193/// let state = AppState {
194///     api_state: ApiState {},
195/// };
196///
197/// let app = Router::new()
198///     .route("/", get(handler))
199///     .route("/api/users", get(api_users))
200///     .with_state(state);
201///
202/// async fn api_users(
203///     // access the api specific state
204///     State(api_state): State<ApiState>,
205/// ) {
206/// }
207///
208/// async fn handler(
209///     // we can still access to top level state
210///     State(state): State<AppState>,
211/// ) {
212/// }
213/// # let _: axum::Router = app;
214/// ```
215///
216/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.
217///
218/// # For library authors
219///
220/// If you're writing a library that has an extractor that needs state, this is the recommended way
221/// to do it:
222///
223/// ```rust
224/// use axum_core::extract::{FromRequestParts, FromRef};
225/// use http::request::Parts;
226/// use async_trait::async_trait;
227/// use std::convert::Infallible;
228///
229/// // the extractor your library provides
230/// struct MyLibraryExtractor;
231///
232/// #[async_trait]
233/// impl<S> FromRequestParts<S> for MyLibraryExtractor
234/// where
235///     // keep `S` generic but require that it can produce a `MyLibraryState`
236///     // this means users will have to implement `FromRef<UserState> for MyLibraryState`
237///     MyLibraryState: FromRef<S>,
238///     S: Send + Sync,
239/// {
240///     type Rejection = Infallible;
241///
242///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
243///         // get a `MyLibraryState` from a reference to the state
244///         let state = MyLibraryState::from_ref(state);
245///
246///         // ...
247///         # todo!()
248///     }
249/// }
250///
251/// // the state your library needs
252/// struct MyLibraryState {
253///     // ...
254/// }
255/// ```
256///
257/// # When states need to implement `Clone`
258///
259/// Your top level state type must implement `Clone` to be extractable with `State`:
260///
261/// ```
262/// use axum::extract::State;
263///
264/// // no substates, so to extract to `State<AppState>` we must implement `Clone` for `AppState`
265/// #[derive(Clone)]
266/// struct AppState {}
267///
268/// async fn handler(State(state): State<AppState>) {
269///     // ...
270/// }
271/// ```
272///
273/// This works because of [`impl<S> FromRef<S> for S where S: Clone`][`FromRef`].
274///
275/// This is also true if you're extracting substates, unless you _never_ extract the top level
276/// state itself:
277///
278/// ```
279/// use axum::extract::{State, FromRef};
280///
281/// // we never extract `State<AppState>`, just `State<InnerState>`. So `AppState` doesn't need to
282/// // implement `Clone`
283/// struct AppState {
284///     inner: InnerState,
285/// }
286///
287/// #[derive(Clone)]
288/// struct InnerState {}
289///
290/// impl FromRef<AppState> for InnerState {
291///     fn from_ref(app_state: &AppState) -> InnerState {
292///         app_state.inner.clone()
293///     }
294/// }
295///
296/// async fn api_users(State(inner): State<InnerState>) {
297///     // ...
298/// }
299/// ```
300///
301/// In general however we recommend you implement `Clone` for all your state types to avoid
302/// potential type errors.
303///
304/// # Shared mutable state
305///
306/// [As state is global within a `Router`][global] you can't directly get a mutable reference to
307/// the state.
308///
309/// The most basic solution is to use an `Arc<Mutex<_>>`. Which kind of mutex you need depends on
310/// your use case. See [the tokio docs] for more details.
311///
312/// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send`
313/// futures which are incompatible with axum. If you need to hold a mutex across `.await` points,
314/// consider using a `tokio::sync::Mutex` instead.
315///
316/// ## Example
317///
318/// ```
319/// use axum::{Router, routing::get, extract::State};
320/// use std::sync::{Arc, Mutex};
321///
322/// #[derive(Clone)]
323/// struct AppState {
324///     data: Arc<Mutex<String>>,
325/// }
326///
327/// async fn handler(State(state): State<AppState>) {
328///     {
329///         let mut data = state.data.lock().expect("mutex was poisoned");
330///         *data = "updated foo".to_owned();
331///     }
332///
333///     // ...
334/// }
335///
336/// let state = AppState {
337///     data: Arc::new(Mutex::new("foo".to_owned())),
338/// };
339///
340/// let app = Router::new()
341///     .route("/", get(handler))
342///     .with_state(state);
343/// # let _: Router = app;
344/// ```
345///
346/// [global]: crate::Router::with_state
347/// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
348#[derive(Debug, Default, Clone, Copy)]
349pub struct State<S>(pub S);
350
351#[async_trait]
352impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
353where
354    InnerState: FromRef<OuterState>,
355    OuterState: Send + Sync,
356{
357    type Rejection = Infallible;
358
359    async fn from_request_parts(
360        _parts: &mut Parts,
361        state: &OuterState,
362    ) -> Result<Self, Self::Rejection> {
363        let inner_state = InnerState::from_ref(state);
364        Ok(Self(inner_state))
365    }
366}
367
368impl<S> Deref for State<S> {
369    type Target = S;
370
371    fn deref(&self) -> &Self::Target {
372        &self.0
373    }
374}
375
376impl<S> DerefMut for State<S> {
377    fn deref_mut(&mut self) -> &mut Self::Target {
378        &mut self.0
379    }
380}