1use crate::{extract::rejection::*, response::IntoResponseParts};
2use async_trait::async_trait;
3use axum_core::{
4 extract::FromRequestParts,
5 response::{IntoResponse, Response, ResponseParts},
6};
7use http::{request::Parts, Request};
8use std::{
9 convert::Infallible,
10 task::{Context, Poll},
11};
12use tower_service::Service;
13
14#[derive(Debug, Clone, Copy, Default)]
70#[must_use]
71pub struct Extension<T>(pub T);
72
73#[async_trait]
74impl<T, S> FromRequestParts<S> for Extension<T>
75where
76 T: Clone + Send + Sync + 'static,
77 S: Send + Sync,
78{
79 type Rejection = ExtensionRejection;
80
81 async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
82 let value = req
83 .extensions
84 .get::<T>()
85 .ok_or_else(|| {
86 MissingExtension::from_err(format!(
87 "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
88 std::any::type_name::<T>()
89 ))
90 }).cloned()?;
91
92 Ok(Extension(value))
93 }
94}
95
96axum_core::__impl_deref!(Extension);
97
98impl<T> IntoResponseParts for Extension<T>
99where
100 T: Clone + Send + Sync + 'static,
101{
102 type Error = Infallible;
103
104 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
105 res.extensions_mut().insert(self.0);
106 Ok(res)
107 }
108}
109
110impl<T> IntoResponse for Extension<T>
111where
112 T: Clone + Send + Sync + 'static,
113{
114 fn into_response(self) -> Response {
115 let mut res = ().into_response();
116 res.extensions_mut().insert(self.0);
117 res
118 }
119}
120
121impl<S, T> tower_layer::Layer<S> for Extension<T>
122where
123 T: Clone + Send + Sync + 'static,
124{
125 type Service = AddExtension<S, T>;
126
127 fn layer(&self, inner: S) -> Self::Service {
128 AddExtension {
129 inner,
130 value: self.0.clone(),
131 }
132 }
133}
134
135#[derive(Clone, Copy, Debug)]
142pub struct AddExtension<S, T> {
143 pub(crate) inner: S,
144 pub(crate) value: T,
145}
146
147impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
148where
149 S: Service<Request<ResBody>>,
150 T: Clone + Send + Sync + 'static,
151{
152 type Response = S::Response;
153 type Error = S::Error;
154 type Future = S::Future;
155
156 #[inline]
157 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158 self.inner.poll_ready(cx)
159 }
160
161 fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
162 req.extensions_mut().insert(self.value.clone());
163 self.inner.call(req)
164 }
165}