1use super::{
2 rejection::{FailedToResolveHost, HostRejection},
3 FromRequestParts,
4};
5use async_trait::async_trait;
6use http::{
7 header::{HeaderMap, FORWARDED},
8 request::Parts,
9};
10
11const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
12
13#[derive(Debug, Clone)]
24pub struct Host(pub String);
25
26#[async_trait]
27impl<S> FromRequestParts<S> for Host
28where
29 S: Send + Sync,
30{
31 type Rejection = HostRejection;
32
33 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34 if let Some(host) = parse_forwarded(&parts.headers) {
35 return Ok(Host(host.to_owned()));
36 }
37
38 if let Some(host) = parts
39 .headers
40 .get(X_FORWARDED_HOST_HEADER_KEY)
41 .and_then(|host| host.to_str().ok())
42 {
43 return Ok(Host(host.to_owned()));
44 }
45
46 if let Some(host) = parts
47 .headers
48 .get(http::header::HOST)
49 .and_then(|host| host.to_str().ok())
50 {
51 return Ok(Host(host.to_owned()));
52 }
53
54 if let Some(host) = parts.uri.host() {
55 return Ok(Host(host.to_owned()));
56 }
57
58 Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
59 }
60}
61
62#[allow(warnings)]
63fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
64 let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
66
67 let first_value = forwarded_values.split(',').nth(0)?;
69
70 first_value.split(';').find_map(|pair| {
72 let (key, value) = pair.split_once('=')?;
73 key.trim()
74 .eq_ignore_ascii_case("host")
75 .then(|| value.trim().trim_matches('"'))
76 })
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use crate::{routing::get, test_helpers::TestClient, Router};
83 use http::header::HeaderName;
84
85 fn test_client() -> TestClient {
86 async fn host_as_body(Host(host): Host) -> String {
87 host
88 }
89
90 TestClient::new(Router::new().route("/", get(host_as_body)))
91 }
92
93 #[crate::test]
94 async fn host_header() {
95 let original_host = "some-domain:123";
96 let host = test_client()
97 .get("/")
98 .header(http::header::HOST, original_host)
99 .await
100 .text()
101 .await;
102 assert_eq!(host, original_host);
103 }
104
105 #[crate::test]
106 async fn x_forwarded_host_header() {
107 let original_host = "some-domain:456";
108 let host = test_client()
109 .get("/")
110 .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
111 .await
112 .text()
113 .await;
114 assert_eq!(host, original_host);
115 }
116
117 #[crate::test]
118 async fn x_forwarded_host_precedence_over_host_header() {
119 let x_forwarded_host_header = "some-domain:456";
120 let host_header = "some-domain:123";
121 let host = test_client()
122 .get("/")
123 .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
124 .header(http::header::HOST, host_header)
125 .await
126 .text()
127 .await;
128 assert_eq!(host, x_forwarded_host_header);
129 }
130
131 #[crate::test]
132 async fn uri_host() {
133 let host = test_client().get("/").await.text().await;
134 assert!(host.contains("127.0.0.1"));
135 }
136
137 #[test]
138 fn forwarded_parsing() {
139 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
141 let value = parse_forwarded(&headers).unwrap();
142 assert_eq!(value, "192.0.2.60");
143
144 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
146 let value = parse_forwarded(&headers).unwrap();
147 assert_eq!(value, "192.0.2.60");
148
149 let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
151 let value = parse_forwarded(&headers).unwrap();
152 assert_eq!(value, "[2001:db8:cafe::17]:4711");
153
154 let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
156 let value = parse_forwarded(&headers).unwrap();
157 assert_eq!(value, "192.0.2.60");
158
159 let headers = header_map(&[
161 (FORWARDED, "host=192.0.2.60"),
162 (FORWARDED, "host=127.0.0.1"),
163 ]);
164 let value = parse_forwarded(&headers).unwrap();
165 assert_eq!(value, "192.0.2.60");
166 }
167
168 fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
169 let mut headers = HeaderMap::new();
170 for (key, value) in values {
171 headers.append(key, value.parse().unwrap());
172 }
173 headers
174 }
175}