1use std::{
2 borrow::Cow,
3 fmt,
4 io::Result as IoResult,
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
6 pin::Pin,
7 task::{Context, Poll},
8 vec,
9};
10
11use either::Either;
12pub use error::Error;
13use futures_util::{
14 future,
15 stream::{self, Once, Stream},
16};
17
18pub type Result<T> = std::result::Result<T, Error>;
19
20pub trait ToProxyAddrs {
27 type Output: Stream<Item = Result<SocketAddr>> + Unpin;
28
29 fn to_proxy_addrs(&self) -> Self::Output;
30}
31
32macro_rules! trivial_impl_to_proxy_addrs {
33 ($t: ty) => {
34 impl ToProxyAddrs for $t {
35 type Output = Once<future::Ready<Result<SocketAddr>>>;
36
37 fn to_proxy_addrs(&self) -> Self::Output {
38 stream::once(future::ready(Ok(SocketAddr::from(*self))))
39 }
40 }
41 };
42}
43
44trivial_impl_to_proxy_addrs!(SocketAddr);
45trivial_impl_to_proxy_addrs!((IpAddr, u16));
46trivial_impl_to_proxy_addrs!((Ipv4Addr, u16));
47trivial_impl_to_proxy_addrs!((Ipv6Addr, u16));
48trivial_impl_to_proxy_addrs!(SocketAddrV4);
49trivial_impl_to_proxy_addrs!(SocketAddrV6);
50
51impl ToProxyAddrs for &[SocketAddr] {
52 type Output = ProxyAddrsStream;
53
54 fn to_proxy_addrs(&self) -> Self::Output {
55 let addrs = self.to_vec();
56 ProxyAddrsStream(Some(IoResult::Ok(addrs.into_iter())))
57 }
58}
59
60impl ToProxyAddrs for str {
61 type Output = ProxyAddrsStream;
62
63 fn to_proxy_addrs(&self) -> Self::Output {
64 ProxyAddrsStream(Some(self.to_socket_addrs()))
65 }
66}
67
68impl ToProxyAddrs for (&str, u16) {
69 type Output = ProxyAddrsStream;
70
71 fn to_proxy_addrs(&self) -> Self::Output {
72 ProxyAddrsStream(Some(self.to_socket_addrs()))
73 }
74}
75
76impl<T: ToProxyAddrs + ?Sized> ToProxyAddrs for &T {
77 type Output = T::Output;
78
79 fn to_proxy_addrs(&self) -> Self::Output {
80 (**self).to_proxy_addrs()
81 }
82}
83
84pub struct ProxyAddrsStream(Option<IoResult<vec::IntoIter<SocketAddr>>>);
85
86impl Stream for ProxyAddrsStream {
87 type Item = Result<SocketAddr>;
88
89 fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
90 match self.0.as_mut() {
91 Some(Ok(iter)) => Poll::Ready(iter.next().map(Result::Ok)),
92 Some(Err(_)) => {
93 let err = self.0.take().unwrap().unwrap_err();
94 Poll::Ready(Some(Err(err.into())))
95 },
96 None => unreachable!(),
97 }
98 }
99}
100
101#[derive(Debug, PartialEq, Eq, Clone)]
103pub enum TargetAddr<'a> {
104 Ip(SocketAddr),
106
107 Domain(Cow<'a, str>, u16),
112}
113
114impl fmt::Display for TargetAddr<'_> {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 match self {
117 TargetAddr::Ip(addr) => write!(f, "{}", addr),
118 TargetAddr::Domain(domain, port) => write!(f, "{}:{}", domain, port),
119 }
120 }
121}
122
123impl TargetAddr<'_> {
124 pub fn to_owned(&self) -> TargetAddr<'static> {
127 match self {
128 TargetAddr::Ip(addr) => TargetAddr::Ip(*addr),
129 TargetAddr::Domain(domain, port) => TargetAddr::Domain(String::from(domain.clone()).into(), *port),
130 }
131 }
132}
133
134impl ToSocketAddrs for TargetAddr<'_> {
135 type Iter = Either<std::option::IntoIter<SocketAddr>, std::vec::IntoIter<SocketAddr>>;
136
137 fn to_socket_addrs(&self) -> IoResult<Self::Iter> {
138 Ok(match self {
139 TargetAddr::Ip(addr) => Either::Left(addr.to_socket_addrs()?),
140 TargetAddr::Domain(domain, port) => Either::Right((&**domain, *port).to_socket_addrs()?),
141 })
142 }
143}
144
145pub trait IntoTargetAddr<'a> {
147 fn into_target_addr(self) -> Result<TargetAddr<'a>>;
149}
150
151macro_rules! trivial_impl_into_target_addr {
152 ($t: ty) => {
153 impl<'a> IntoTargetAddr<'a> for $t {
154 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
155 Ok(TargetAddr::Ip(SocketAddr::from(self)))
156 }
157 }
158 };
159}
160
161trivial_impl_into_target_addr!(SocketAddr);
162trivial_impl_into_target_addr!((IpAddr, u16));
163trivial_impl_into_target_addr!((Ipv4Addr, u16));
164trivial_impl_into_target_addr!((Ipv6Addr, u16));
165trivial_impl_into_target_addr!(SocketAddrV4);
166trivial_impl_into_target_addr!(SocketAddrV6);
167
168impl<'a> IntoTargetAddr<'a> for TargetAddr<'a> {
169 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
170 Ok(self)
171 }
172}
173
174impl<'a> IntoTargetAddr<'a> for (&'a str, u16) {
175 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
176 if let Ok(addr) = self.0.parse::<IpAddr>() {
178 return (addr, self.1).into_target_addr();
179 }
180
181 if self.0.len() > 255 {
183 return Err(Error::InvalidTargetAddress("overlong domain"));
184 }
185 Ok(TargetAddr::Domain(self.0.into(), self.1))
188 }
189}
190
191impl<'a> IntoTargetAddr<'a> for &'a str {
192 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
193 if let Ok(addr) = self.parse::<SocketAddr>() {
195 return addr.into_target_addr();
196 }
197
198 let mut parts_iter = self.rsplitn(2, ':');
199 let port: u16 = parts_iter
200 .next()
201 .and_then(|port_str| port_str.parse().ok())
202 .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
203 let domain = parts_iter
204 .next()
205 .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
206 if domain.len() > 255 {
207 return Err(Error::InvalidTargetAddress("overlong domain"));
208 }
209 Ok(TargetAddr::Domain(domain.into(), port))
210 }
211}
212
213impl IntoTargetAddr<'static> for String {
214 fn into_target_addr(mut self) -> Result<TargetAddr<'static>> {
215 if let Ok(addr) = self.parse::<SocketAddr>() {
217 return addr.into_target_addr();
218 }
219
220 let mut parts_iter = self.rsplitn(2, ':');
221 let port: u16 = parts_iter
222 .next()
223 .and_then(|port_str| port_str.parse().ok())
224 .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
225 let domain_len = parts_iter
226 .next()
227 .ok_or(Error::InvalidTargetAddress("invalid address format"))?
228 .len();
229 if domain_len > 255 {
230 return Err(Error::InvalidTargetAddress("overlong domain"));
231 }
232 self.truncate(domain_len);
233 Ok(TargetAddr::Domain(self.into(), port))
234 }
235}
236
237impl IntoTargetAddr<'static> for (String, u16) {
238 fn into_target_addr(self) -> Result<TargetAddr<'static>> {
239 let addr = (self.0.as_str(), self.1).into_target_addr()?;
240 if let TargetAddr::Ip(addr) = addr {
241 Ok(TargetAddr::Ip(addr))
242 } else {
243 Ok(TargetAddr::Domain(self.0.into(), self.1))
244 }
245 }
246}
247
248impl<'a, T> IntoTargetAddr<'a> for &'a T
249where
250 T: IntoTargetAddr<'a> + Copy,
251{
252 fn into_target_addr(self) -> Result<TargetAddr<'a>> {
253 (*self).into_target_addr()
254 }
255}
256
257#[derive(Debug)]
259enum Authentication<'a> {
260 Password { username: &'a str, password: &'a str },
261 None,
262}
263
264impl Authentication<'_> {
265 fn id(&self) -> u8 {
266 match self {
267 Authentication::Password { .. } => 0x02,
268 Authentication::None => 0x00,
269 }
270 }
271}
272
273mod error;
274pub mod io;
275pub mod tcp;
276
277#[cfg(test)]
278mod tests {
279 use futures_executor::block_on;
280 use futures_util::StreamExt;
281
282 use super::*;
283
284 fn to_proxy_addrs<T: ToProxyAddrs>(t: T) -> Result<Vec<SocketAddr>> {
285 Ok(block_on(t.to_proxy_addrs().map(Result::unwrap).collect()))
286 }
287
288 #[test]
289 fn test_clone_ip() {
290 let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
291 let addr_clone = addr.clone();
292 assert_eq!(addr, addr_clone);
293 assert_eq!(addr.to_string(), addr_clone.to_string());
294 }
295
296 #[test]
297 fn test_clone_domain() {
298 let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
299 let addr_clone = addr.clone();
300 assert_eq!(addr, addr_clone);
301 assert_eq!(addr.to_string(), addr_clone.to_string());
302 }
303
304 #[test]
305 fn test_display_ip() {
306 let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
307 assert_eq!(format!("{}", addr), "127.0.0.1:8080");
308 }
309
310 #[test]
311 fn test_display_domain() {
312 let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
313 assert_eq!(format!("{}", addr), "example.com:80");
314 }
315
316 #[test]
317 fn test_to_string_ip() {
318 let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
319 assert_eq!(addr.to_string(), "127.0.0.1:8080");
320 }
321
322 #[test]
323 fn test_to_string_domain() {
324 let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
325 assert_eq!(addr.to_string(), "example.com:80");
326 }
327
328 #[test]
329 fn converts_socket_addr_to_proxy_addrs() -> Result<()> {
330 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
331 let res = to_proxy_addrs(addr)?;
332 assert_eq!(&res[..], &[addr]);
333 Ok(())
334 }
335
336 #[test]
337 fn converts_socket_addr_ref_to_proxy_addrs() -> Result<()> {
338 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
339 let res = to_proxy_addrs(addr)?;
340 assert_eq!(&res[..], &[addr]);
341 Ok(())
342 }
343
344 #[test]
345 fn converts_socket_addrs_to_proxy_addrs() -> Result<()> {
346 let addrs = [
347 SocketAddr::from(([1, 1, 1, 1], 443)),
348 SocketAddr::from(([8, 8, 8, 8], 53)),
349 ];
350 let res = to_proxy_addrs(&addrs[..])?;
351 assert_eq!(&res[..], &addrs);
352 Ok(())
353 }
354
355 fn into_target_addr<'a, T>(t: T) -> Result<TargetAddr<'a>>
356 where
357 T: IntoTargetAddr<'a>,
358 {
359 t.into_target_addr()
360 }
361
362 #[test]
363 fn converts_socket_addr_to_target_addr() -> Result<()> {
364 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
365 let res = into_target_addr(addr)?;
366 assert_eq!(TargetAddr::Ip(addr), res);
367 Ok(())
368 }
369
370 #[test]
371 fn converts_socket_addr_ref_to_target_addr() -> Result<()> {
372 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
373 let res = into_target_addr(addr)?;
374 assert_eq!(TargetAddr::Ip(addr), res);
375 Ok(())
376 }
377
378 #[test]
379 fn converts_socket_addr_str_to_target_addr() -> Result<()> {
380 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
381 let ip_str = format!("{}", addr);
382 let res = into_target_addr(ip_str.as_str())?;
383 assert_eq!(TargetAddr::Ip(addr), res);
384 Ok(())
385 }
386
387 #[test]
388 fn converts_ip_str_and_port_target_addr() -> Result<()> {
389 let addr = SocketAddr::from(([1, 1, 1, 1], 443));
390 let ip_str = format!("{}", addr.ip());
391 let res = into_target_addr((ip_str.as_str(), addr.port()))?;
392 assert_eq!(TargetAddr::Ip(addr), res);
393 Ok(())
394 }
395
396 #[test]
397 fn converts_domain_to_target_addr() -> Result<()> {
398 let domain = "www.example.com:80";
399 let res = into_target_addr(domain)?;
400 assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
401
402 let res = into_target_addr(domain.to_owned())?;
403 assert_eq!(TargetAddr::Domain(Cow::Owned("www.example.com".to_owned()), 80), res);
404 Ok(())
405 }
406
407 #[test]
408 fn converts_domain_and_port_to_target_addr() -> Result<()> {
409 let domain = "www.example.com";
410 let res = into_target_addr((domain, 80))?;
411 assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
412 Ok(())
413 }
414
415 #[test]
416 fn overlong_domain_to_target_addr_should_fail() {
417 let domain = format!("www.{:a<1$}.com:80", 'a', 300);
418 assert!(into_target_addr(domain.as_str()).is_err());
419 let domain = format!("www.{:a<1$}.com", 'a', 300);
420 assert!(into_target_addr((domain.as_str(), 80)).is_err());
421 }
422
423 #[test]
424 fn addr_with_invalid_port_to_target_addr_should_fail() {
425 let addr = "[ffff::1]:65536";
426 assert!(into_target_addr(addr).is_err());
427 let addr = "www.example.com:65536";
428 assert!(into_target_addr(addr).is_err());
429 }
430}