axum/extract/
host.rs

1use super::{
2    rejection::{FailedToResolveHost, HostRejection},
3    FromRequestParts,
4};
5use async_trait::async_trait;
6use http::{
7    header::{HeaderMap, FORWARDED},
8    request::Parts,
9};
10
11const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
12
13/// Extractor that resolves the hostname of the request.
14///
15/// Hostname is resolved through the following, in order:
16/// - `Forwarded` header
17/// - `X-Forwarded-Host` header
18/// - `Host` header
19/// - request target / URI
20///
21/// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
22/// sure to validate them to avoid security issues.
23#[derive(Debug, Clone)]
24pub struct Host(pub String);
25
26#[async_trait]
27impl<S> FromRequestParts<S> for Host
28where
29    S: Send + Sync,
30{
31    type Rejection = HostRejection;
32
33    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34        if let Some(host) = parse_forwarded(&parts.headers) {
35            return Ok(Host(host.to_owned()));
36        }
37
38        if let Some(host) = parts
39            .headers
40            .get(X_FORWARDED_HOST_HEADER_KEY)
41            .and_then(|host| host.to_str().ok())
42        {
43            return Ok(Host(host.to_owned()));
44        }
45
46        if let Some(host) = parts
47            .headers
48            .get(http::header::HOST)
49            .and_then(|host| host.to_str().ok())
50        {
51            return Ok(Host(host.to_owned()));
52        }
53
54        if let Some(host) = parts.uri.host() {
55            return Ok(Host(host.to_owned()));
56        }
57
58        Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
59    }
60}
61
62#[allow(warnings)]
63fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
64    // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
65    let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
66
67    // get the first set of values
68    let first_value = forwarded_values.split(',').nth(0)?;
69
70    // find the value of the `host` field
71    first_value.split(';').find_map(|pair| {
72        let (key, value) = pair.split_once('=')?;
73        key.trim()
74            .eq_ignore_ascii_case("host")
75            .then(|| value.trim().trim_matches('"'))
76    })
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use crate::{routing::get, test_helpers::TestClient, Router};
83    use http::header::HeaderName;
84
85    fn test_client() -> TestClient {
86        async fn host_as_body(Host(host): Host) -> String {
87            host
88        }
89
90        TestClient::new(Router::new().route("/", get(host_as_body)))
91    }
92
93    #[crate::test]
94    async fn host_header() {
95        let original_host = "some-domain:123";
96        let host = test_client()
97            .get("/")
98            .header(http::header::HOST, original_host)
99            .await
100            .text()
101            .await;
102        assert_eq!(host, original_host);
103    }
104
105    #[crate::test]
106    async fn x_forwarded_host_header() {
107        let original_host = "some-domain:456";
108        let host = test_client()
109            .get("/")
110            .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
111            .await
112            .text()
113            .await;
114        assert_eq!(host, original_host);
115    }
116
117    #[crate::test]
118    async fn x_forwarded_host_precedence_over_host_header() {
119        let x_forwarded_host_header = "some-domain:456";
120        let host_header = "some-domain:123";
121        let host = test_client()
122            .get("/")
123            .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
124            .header(http::header::HOST, host_header)
125            .await
126            .text()
127            .await;
128        assert_eq!(host, x_forwarded_host_header);
129    }
130
131    #[crate::test]
132    async fn uri_host() {
133        let host = test_client().get("/").await.text().await;
134        assert!(host.contains("127.0.0.1"));
135    }
136
137    #[test]
138    fn forwarded_parsing() {
139        // the basic case
140        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
141        let value = parse_forwarded(&headers).unwrap();
142        assert_eq!(value, "192.0.2.60");
143
144        // is case insensitive
145        let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
146        let value = parse_forwarded(&headers).unwrap();
147        assert_eq!(value, "192.0.2.60");
148
149        // ipv6
150        let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
151        let value = parse_forwarded(&headers).unwrap();
152        assert_eq!(value, "[2001:db8:cafe::17]:4711");
153
154        // multiple values in one header
155        let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
156        let value = parse_forwarded(&headers).unwrap();
157        assert_eq!(value, "192.0.2.60");
158
159        // multiple header values
160        let headers = header_map(&[
161            (FORWARDED, "host=192.0.2.60"),
162            (FORWARDED, "host=127.0.0.1"),
163        ]);
164        let value = parse_forwarded(&headers).unwrap();
165        assert_eq!(value, "192.0.2.60");
166    }
167
168    fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
169        let mut headers = HeaderMap::new();
170        for (key, value) in values {
171            headers.append(key, value.parse().unwrap());
172        }
173        headers
174    }
175}