tokio_socks/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt,
4    io::Result as IoResult,
5    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
6    pin::Pin,
7    task::{Context, Poll},
8    vec,
9};
10
11use either::Either;
12pub use error::Error;
13use futures_util::{
14    future,
15    stream::{self, Once, Stream},
16};
17
18pub type Result<T> = std::result::Result<T, Error>;
19
20/// A trait for objects which can be converted or resolved to one or more
21/// `SocketAddr` values, which are going to be connected as the the proxy
22/// server.
23///
24/// This trait is similar to `std::net::ToSocketAddrs` but allows asynchronous
25/// name resolution.
26pub trait ToProxyAddrs {
27    type Output: Stream<Item = Result<SocketAddr>> + Unpin;
28
29    fn to_proxy_addrs(&self) -> Self::Output;
30}
31
32macro_rules! trivial_impl_to_proxy_addrs {
33    ($t: ty) => {
34        impl ToProxyAddrs for $t {
35            type Output = Once<future::Ready<Result<SocketAddr>>>;
36
37            fn to_proxy_addrs(&self) -> Self::Output {
38                stream::once(future::ready(Ok(SocketAddr::from(*self))))
39            }
40        }
41    };
42}
43
44trivial_impl_to_proxy_addrs!(SocketAddr);
45trivial_impl_to_proxy_addrs!((IpAddr, u16));
46trivial_impl_to_proxy_addrs!((Ipv4Addr, u16));
47trivial_impl_to_proxy_addrs!((Ipv6Addr, u16));
48trivial_impl_to_proxy_addrs!(SocketAddrV4);
49trivial_impl_to_proxy_addrs!(SocketAddrV6);
50
51impl ToProxyAddrs for &[SocketAddr] {
52    type Output = ProxyAddrsStream;
53
54    fn to_proxy_addrs(&self) -> Self::Output {
55        let addrs = self.to_vec();
56        ProxyAddrsStream(Some(IoResult::Ok(addrs.into_iter())))
57    }
58}
59
60impl ToProxyAddrs for str {
61    type Output = ProxyAddrsStream;
62
63    fn to_proxy_addrs(&self) -> Self::Output {
64        ProxyAddrsStream(Some(self.to_socket_addrs()))
65    }
66}
67
68impl ToProxyAddrs for (&str, u16) {
69    type Output = ProxyAddrsStream;
70
71    fn to_proxy_addrs(&self) -> Self::Output {
72        ProxyAddrsStream(Some(self.to_socket_addrs()))
73    }
74}
75
76impl<T: ToProxyAddrs + ?Sized> ToProxyAddrs for &T {
77    type Output = T::Output;
78
79    fn to_proxy_addrs(&self) -> Self::Output {
80        (**self).to_proxy_addrs()
81    }
82}
83
84pub struct ProxyAddrsStream(Option<IoResult<vec::IntoIter<SocketAddr>>>);
85
86impl Stream for ProxyAddrsStream {
87    type Item = Result<SocketAddr>;
88
89    fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
90        match self.0.as_mut() {
91            Some(Ok(iter)) => Poll::Ready(iter.next().map(Result::Ok)),
92            Some(Err(_)) => {
93                let err = self.0.take().unwrap().unwrap_err();
94                Poll::Ready(Some(Err(err.into())))
95            },
96            None => unreachable!(),
97        }
98    }
99}
100
101/// A SOCKS connection target.
102#[derive(Debug, PartialEq, Eq, Clone)]
103pub enum TargetAddr<'a> {
104    /// Connect to an IP address.
105    Ip(SocketAddr),
106
107    /// Connect to a fully-qualified domain name.
108    ///
109    /// The domain name will be passed along to the proxy server and DNS lookup
110    /// will happen there.
111    Domain(Cow<'a, str>, u16),
112}
113
114impl fmt::Display for TargetAddr<'_> {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        match self {
117            TargetAddr::Ip(addr) => write!(f, "{}", addr),
118            TargetAddr::Domain(domain, port) => write!(f, "{}:{}", domain, port),
119        }
120    }
121}
122
123impl TargetAddr<'_> {
124    /// Creates owned `TargetAddr` by cloning. It is usually used to eliminate
125    /// the lifetime bound.
126    pub fn to_owned(&self) -> TargetAddr<'static> {
127        match self {
128            TargetAddr::Ip(addr) => TargetAddr::Ip(*addr),
129            TargetAddr::Domain(domain, port) => TargetAddr::Domain(String::from(domain.clone()).into(), *port),
130        }
131    }
132}
133
134impl ToSocketAddrs for TargetAddr<'_> {
135    type Iter = Either<std::option::IntoIter<SocketAddr>, std::vec::IntoIter<SocketAddr>>;
136
137    fn to_socket_addrs(&self) -> IoResult<Self::Iter> {
138        Ok(match self {
139            TargetAddr::Ip(addr) => Either::Left(addr.to_socket_addrs()?),
140            TargetAddr::Domain(domain, port) => Either::Right((&**domain, *port).to_socket_addrs()?),
141        })
142    }
143}
144
145/// A trait for objects that can be converted to `TargetAddr`.
146pub trait IntoTargetAddr<'a> {
147    /// Converts the value of self to a `TargetAddr`.
148    fn into_target_addr(self) -> Result<TargetAddr<'a>>;
149}
150
151macro_rules! trivial_impl_into_target_addr {
152    ($t: ty) => {
153        impl<'a> IntoTargetAddr<'a> for $t {
154            fn into_target_addr(self) -> Result<TargetAddr<'a>> {
155                Ok(TargetAddr::Ip(SocketAddr::from(self)))
156            }
157        }
158    };
159}
160
161trivial_impl_into_target_addr!(SocketAddr);
162trivial_impl_into_target_addr!((IpAddr, u16));
163trivial_impl_into_target_addr!((Ipv4Addr, u16));
164trivial_impl_into_target_addr!((Ipv6Addr, u16));
165trivial_impl_into_target_addr!(SocketAddrV4);
166trivial_impl_into_target_addr!(SocketAddrV6);
167
168impl<'a> IntoTargetAddr<'a> for TargetAddr<'a> {
169    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
170        Ok(self)
171    }
172}
173
174impl<'a> IntoTargetAddr<'a> for (&'a str, u16) {
175    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
176        // Try IP address first
177        if let Ok(addr) = self.0.parse::<IpAddr>() {
178            return (addr, self.1).into_target_addr();
179        }
180
181        // Treat as domain name
182        if self.0.len() > 255 {
183            return Err(Error::InvalidTargetAddress("overlong domain"));
184        }
185        // TODO: Should we validate the domain format here?
186
187        Ok(TargetAddr::Domain(self.0.into(), self.1))
188    }
189}
190
191impl<'a> IntoTargetAddr<'a> for &'a str {
192    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
193        // Try IP address first
194        if let Ok(addr) = self.parse::<SocketAddr>() {
195            return addr.into_target_addr();
196        }
197
198        let mut parts_iter = self.rsplitn(2, ':');
199        let port: u16 = parts_iter
200            .next()
201            .and_then(|port_str| port_str.parse().ok())
202            .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
203        let domain = parts_iter
204            .next()
205            .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
206        if domain.len() > 255 {
207            return Err(Error::InvalidTargetAddress("overlong domain"));
208        }
209        Ok(TargetAddr::Domain(domain.into(), port))
210    }
211}
212
213impl IntoTargetAddr<'static> for String {
214    fn into_target_addr(mut self) -> Result<TargetAddr<'static>> {
215        // Try IP address first
216        if let Ok(addr) = self.parse::<SocketAddr>() {
217            return addr.into_target_addr();
218        }
219
220        let mut parts_iter = self.rsplitn(2, ':');
221        let port: u16 = parts_iter
222            .next()
223            .and_then(|port_str| port_str.parse().ok())
224            .ok_or(Error::InvalidTargetAddress("invalid address format"))?;
225        let domain_len = parts_iter
226            .next()
227            .ok_or(Error::InvalidTargetAddress("invalid address format"))?
228            .len();
229        if domain_len > 255 {
230            return Err(Error::InvalidTargetAddress("overlong domain"));
231        }
232        self.truncate(domain_len);
233        Ok(TargetAddr::Domain(self.into(), port))
234    }
235}
236
237impl IntoTargetAddr<'static> for (String, u16) {
238    fn into_target_addr(self) -> Result<TargetAddr<'static>> {
239        let addr = (self.0.as_str(), self.1).into_target_addr()?;
240        if let TargetAddr::Ip(addr) = addr {
241            Ok(TargetAddr::Ip(addr))
242        } else {
243            Ok(TargetAddr::Domain(self.0.into(), self.1))
244        }
245    }
246}
247
248impl<'a, T> IntoTargetAddr<'a> for &'a T
249where
250    T: IntoTargetAddr<'a> + Copy,
251{
252    fn into_target_addr(self) -> Result<TargetAddr<'a>> {
253        (*self).into_target_addr()
254    }
255}
256
257/// Authentication methods
258#[derive(Debug)]
259enum Authentication<'a> {
260    Password { username: &'a str, password: &'a str },
261    None,
262}
263
264impl Authentication<'_> {
265    fn id(&self) -> u8 {
266        match self {
267            Authentication::Password { .. } => 0x02,
268            Authentication::None => 0x00,
269        }
270    }
271}
272
273mod error;
274pub mod io;
275pub mod tcp;
276
277#[cfg(test)]
278mod tests {
279    use futures_executor::block_on;
280    use futures_util::StreamExt;
281
282    use super::*;
283
284    fn to_proxy_addrs<T: ToProxyAddrs>(t: T) -> Result<Vec<SocketAddr>> {
285        Ok(block_on(t.to_proxy_addrs().map(Result::unwrap).collect()))
286    }
287
288    #[test]
289    fn test_clone_ip() {
290        let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
291        let addr_clone = addr.clone();
292        assert_eq!(addr, addr_clone);
293        assert_eq!(addr.to_string(), addr_clone.to_string());
294    }
295
296    #[test]
297    fn test_clone_domain() {
298        let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
299        let addr_clone = addr.clone();
300        assert_eq!(addr, addr_clone);
301        assert_eq!(addr.to_string(), addr_clone.to_string());
302    }
303
304    #[test]
305    fn test_display_ip() {
306        let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
307        assert_eq!(format!("{}", addr), "127.0.0.1:8080");
308    }
309
310    #[test]
311    fn test_display_domain() {
312        let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
313        assert_eq!(format!("{}", addr), "example.com:80");
314    }
315
316    #[test]
317    fn test_to_string_ip() {
318        let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
319        assert_eq!(addr.to_string(), "127.0.0.1:8080");
320    }
321
322    #[test]
323    fn test_to_string_domain() {
324        let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
325        assert_eq!(addr.to_string(), "example.com:80");
326    }
327
328    #[test]
329    fn converts_socket_addr_to_proxy_addrs() -> Result<()> {
330        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
331        let res = to_proxy_addrs(addr)?;
332        assert_eq!(&res[..], &[addr]);
333        Ok(())
334    }
335
336    #[test]
337    fn converts_socket_addr_ref_to_proxy_addrs() -> Result<()> {
338        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
339        let res = to_proxy_addrs(addr)?;
340        assert_eq!(&res[..], &[addr]);
341        Ok(())
342    }
343
344    #[test]
345    fn converts_socket_addrs_to_proxy_addrs() -> Result<()> {
346        let addrs = [
347            SocketAddr::from(([1, 1, 1, 1], 443)),
348            SocketAddr::from(([8, 8, 8, 8], 53)),
349        ];
350        let res = to_proxy_addrs(&addrs[..])?;
351        assert_eq!(&res[..], &addrs);
352        Ok(())
353    }
354
355    fn into_target_addr<'a, T>(t: T) -> Result<TargetAddr<'a>>
356    where
357        T: IntoTargetAddr<'a>,
358    {
359        t.into_target_addr()
360    }
361
362    #[test]
363    fn converts_socket_addr_to_target_addr() -> Result<()> {
364        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
365        let res = into_target_addr(addr)?;
366        assert_eq!(TargetAddr::Ip(addr), res);
367        Ok(())
368    }
369
370    #[test]
371    fn converts_socket_addr_ref_to_target_addr() -> Result<()> {
372        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
373        let res = into_target_addr(addr)?;
374        assert_eq!(TargetAddr::Ip(addr), res);
375        Ok(())
376    }
377
378    #[test]
379    fn converts_socket_addr_str_to_target_addr() -> Result<()> {
380        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
381        let ip_str = format!("{}", addr);
382        let res = into_target_addr(ip_str.as_str())?;
383        assert_eq!(TargetAddr::Ip(addr), res);
384        Ok(())
385    }
386
387    #[test]
388    fn converts_ip_str_and_port_target_addr() -> Result<()> {
389        let addr = SocketAddr::from(([1, 1, 1, 1], 443));
390        let ip_str = format!("{}", addr.ip());
391        let res = into_target_addr((ip_str.as_str(), addr.port()))?;
392        assert_eq!(TargetAddr::Ip(addr), res);
393        Ok(())
394    }
395
396    #[test]
397    fn converts_domain_to_target_addr() -> Result<()> {
398        let domain = "www.example.com:80";
399        let res = into_target_addr(domain)?;
400        assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
401
402        let res = into_target_addr(domain.to_owned())?;
403        assert_eq!(TargetAddr::Domain(Cow::Owned("www.example.com".to_owned()), 80), res);
404        Ok(())
405    }
406
407    #[test]
408    fn converts_domain_and_port_to_target_addr() -> Result<()> {
409        let domain = "www.example.com";
410        let res = into_target_addr((domain, 80))?;
411        assert_eq!(TargetAddr::Domain(Cow::Borrowed("www.example.com"), 80), res);
412        Ok(())
413    }
414
415    #[test]
416    fn overlong_domain_to_target_addr_should_fail() {
417        let domain = format!("www.{:a<1$}.com:80", 'a', 300);
418        assert!(into_target_addr(domain.as_str()).is_err());
419        let domain = format!("www.{:a<1$}.com", 'a', 300);
420        assert!(into_target_addr((domain.as_str(), 80)).is_err());
421    }
422
423    #[test]
424    fn addr_with_invalid_port_to_target_addr_should_fail() {
425        let addr = "[ffff::1]:65536";
426        assert!(into_target_addr(addr).is_err());
427        let addr = "www.example.com:65536";
428        assert!(into_target_addr(addr).is_err());
429    }
430}