axum_core/response/
into_response_parts.rs

1use super::{IntoResponse, Response};
2use http::{
3    header::{HeaderMap, HeaderName, HeaderValue},
4    Extensions, StatusCode,
5};
6use std::{convert::Infallible, fmt};
7
8/// Trait for adding headers and extensions to a response.
9///
10/// # Example
11///
12/// ```rust
13/// use axum::{
14///     response::{ResponseParts, IntoResponse, IntoResponseParts, Response},
15///     http::{StatusCode, header::{HeaderName, HeaderValue}},
16/// };
17///
18/// // Hypothetical helper type for setting a single header
19/// struct SetHeader<'a>(&'a str, &'a str);
20///
21/// impl<'a> IntoResponseParts for SetHeader<'a> {
22///     type Error = (StatusCode, String);
23///
24///     fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
25///         match (self.0.parse::<HeaderName>(), self.1.parse::<HeaderValue>()) {
26///             (Ok(name), Ok(value)) => {
27///                 res.headers_mut().insert(name, value);
28///             },
29///             (Err(_), _) => {
30///                 return Err((
31///                     StatusCode::INTERNAL_SERVER_ERROR,
32///                     format!("Invalid header name {}", self.0),
33///                 ));
34///             },
35///             (_, Err(_)) => {
36///                 return Err((
37///                     StatusCode::INTERNAL_SERVER_ERROR,
38///                     format!("Invalid header value {}", self.1),
39///                 ));
40///             },
41///         }
42///
43///         Ok(res)
44///     }
45/// }
46///
47/// // It's also recommended to implement `IntoResponse` so `SetHeader` can be used on its own as
48/// // the response
49/// impl<'a> IntoResponse for SetHeader<'a> {
50///     fn into_response(self) -> Response {
51///         // This gives an empty response with the header
52///         (self, ()).into_response()
53///     }
54/// }
55///
56/// // We can now return `SetHeader` in responses
57/// //
58/// // Note that returning `impl IntoResponse` might be easier if the response has many parts to
59/// // it. The return type is written out here for clarity.
60/// async fn handler() -> (SetHeader<'static>, SetHeader<'static>, &'static str) {
61///     (
62///         SetHeader("server", "axum"),
63///         SetHeader("x-foo", "custom"),
64///         "body",
65///     )
66/// }
67///
68/// // Or on its own as the whole response
69/// async fn other_handler() -> SetHeader<'static> {
70///     SetHeader("x-foo", "custom")
71/// }
72/// ```
73pub trait IntoResponseParts {
74    /// The type returned in the event of an error.
75    ///
76    /// This can be used to fallibly convert types into headers or extensions.
77    type Error: IntoResponse;
78
79    /// Set parts of the response
80    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error>;
81}
82
83impl<T> IntoResponseParts for Option<T>
84where
85    T: IntoResponseParts,
86{
87    type Error = T::Error;
88
89    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
90        if let Some(inner) = self {
91            inner.into_response_parts(res)
92        } else {
93            Ok(res)
94        }
95    }
96}
97
98/// Parts of a response.
99///
100/// Used with [`IntoResponseParts`].
101#[derive(Debug)]
102pub struct ResponseParts {
103    pub(crate) res: Response,
104}
105
106impl ResponseParts {
107    /// Gets a reference to the response headers.
108    #[must_use]
109    pub fn headers(&self) -> &HeaderMap {
110        self.res.headers()
111    }
112
113    /// Gets a mutable reference to the response headers.
114    #[must_use]
115    pub fn headers_mut(&mut self) -> &mut HeaderMap {
116        self.res.headers_mut()
117    }
118
119    /// Gets a reference to the response extensions.
120    #[must_use]
121    pub fn extensions(&self) -> &Extensions {
122        self.res.extensions()
123    }
124
125    /// Gets a mutable reference to the response extensions.
126    #[must_use]
127    pub fn extensions_mut(&mut self) -> &mut Extensions {
128        self.res.extensions_mut()
129    }
130}
131
132impl IntoResponseParts for HeaderMap {
133    type Error = Infallible;
134
135    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
136        res.headers_mut().extend(self);
137        Ok(res)
138    }
139}
140
141impl<K, V, const N: usize> IntoResponseParts for [(K, V); N]
142where
143    K: TryInto<HeaderName>,
144    K::Error: fmt::Display,
145    V: TryInto<HeaderValue>,
146    V::Error: fmt::Display,
147{
148    type Error = TryIntoHeaderError<K::Error, V::Error>;
149
150    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
151        for (key, value) in self {
152            let key = key.try_into().map_err(TryIntoHeaderError::key)?;
153            let value = value.try_into().map_err(TryIntoHeaderError::value)?;
154            res.headers_mut().insert(key, value);
155        }
156
157        Ok(res)
158    }
159}
160
161/// Error returned if converting a value to a header fails.
162#[derive(Debug)]
163pub struct TryIntoHeaderError<K, V> {
164    kind: TryIntoHeaderErrorKind<K, V>,
165}
166
167impl<K, V> TryIntoHeaderError<K, V> {
168    pub(super) fn key(err: K) -> Self {
169        Self {
170            kind: TryIntoHeaderErrorKind::Key(err),
171        }
172    }
173
174    pub(super) fn value(err: V) -> Self {
175        Self {
176            kind: TryIntoHeaderErrorKind::Value(err),
177        }
178    }
179}
180
181#[derive(Debug)]
182enum TryIntoHeaderErrorKind<K, V> {
183    Key(K),
184    Value(V),
185}
186
187impl<K, V> IntoResponse for TryIntoHeaderError<K, V>
188where
189    K: fmt::Display,
190    V: fmt::Display,
191{
192    fn into_response(self) -> Response {
193        match self.kind {
194            TryIntoHeaderErrorKind::Key(inner) => {
195                (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
196            }
197            TryIntoHeaderErrorKind::Value(inner) => {
198                (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
199            }
200        }
201    }
202}
203
204impl<K, V> fmt::Display for TryIntoHeaderError<K, V> {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        match self.kind {
207            TryIntoHeaderErrorKind::Key(_) => write!(f, "failed to convert key to a header name"),
208            TryIntoHeaderErrorKind::Value(_) => {
209                write!(f, "failed to convert value to a header value")
210            }
211        }
212    }
213}
214
215impl<K, V> std::error::Error for TryIntoHeaderError<K, V>
216where
217    K: std::error::Error + 'static,
218    V: std::error::Error + 'static,
219{
220    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
221        match &self.kind {
222            TryIntoHeaderErrorKind::Key(inner) => Some(inner),
223            TryIntoHeaderErrorKind::Value(inner) => Some(inner),
224        }
225    }
226}
227
228macro_rules! impl_into_response_parts {
229    ( $($ty:ident),* $(,)? ) => {
230        #[allow(non_snake_case)]
231        impl<$($ty,)*> IntoResponseParts for ($($ty,)*)
232        where
233            $( $ty: IntoResponseParts, )*
234        {
235            type Error = Response;
236
237            fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
238                let ($($ty,)*) = self;
239
240                $(
241                    let res = match $ty.into_response_parts(res) {
242                        Ok(res) => res,
243                        Err(err) => {
244                            return Err(err.into_response());
245                        }
246                    };
247                )*
248
249                Ok(res)
250            }
251        }
252    }
253}
254
255all_the_tuples_no_last_special_case!(impl_into_response_parts);
256
257impl IntoResponseParts for Extensions {
258    type Error = Infallible;
259
260    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
261        res.extensions_mut().extend(self);
262        Ok(res)
263    }
264}
265
266impl IntoResponseParts for () {
267    type Error = Infallible;
268
269    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
270        Ok(res)
271    }
272}