hyper_util/client/legacy/connect/
http.rs

1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_util::future::Either;
13use http::uri::{Scheme, Uri};
14use pin_project_lite::pin_project;
15use socket2::TcpKeepalive;
16use tokio::net::{TcpSocket, TcpStream};
17use tokio::time::Sleep;
18use tracing::{debug, trace, warn};
19
20use super::dns::{self, resolve, GaiResolver, Resolve};
21use super::{Connected, Connection};
22use crate::rt::TokioIo;
23
24/// A connector for the `http` scheme.
25///
26/// Performs DNS resolution in a thread pool, and then connects over TCP.
27///
28/// # Note
29///
30/// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
31/// transport information such as the remote socket address used.
32#[derive(Clone)]
33pub struct HttpConnector<R = GaiResolver> {
34    config: Arc<Config>,
35    resolver: R,
36}
37
38/// Extra information about the transport when an HttpConnector is used.
39///
40/// # Example
41///
42/// ```
43/// # fn doc(res: http::Response<()>) {
44/// use hyper_util::client::legacy::connect::HttpInfo;
45///
46/// // res = http::Response
47/// res
48///     .extensions()
49///     .get::<HttpInfo>()
50///     .map(|info| {
51///         println!("remote addr = {}", info.remote_addr());
52///     });
53/// # }
54/// ```
55///
56/// # Note
57///
58/// If a different connector is used besides [`HttpConnector`](HttpConnector),
59/// this value will not exist in the extensions. Consult that specific
60/// connector to see what "extra" information it might provide to responses.
61#[derive(Clone, Debug)]
62pub struct HttpInfo {
63    remote_addr: SocketAddr,
64    local_addr: SocketAddr,
65}
66
67#[derive(Clone)]
68struct Config {
69    connect_timeout: Option<Duration>,
70    enforce_http: bool,
71    happy_eyeballs_timeout: Option<Duration>,
72    tcp_keepalive_config: TcpKeepaliveConfig,
73    local_address_ipv4: Option<Ipv4Addr>,
74    local_address_ipv6: Option<Ipv6Addr>,
75    nodelay: bool,
76    reuse_address: bool,
77    send_buffer_size: Option<usize>,
78    recv_buffer_size: Option<usize>,
79    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
80    interface: Option<String>,
81    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
82    tcp_user_timeout: Option<Duration>,
83}
84
85#[derive(Default, Debug, Clone, Copy)]
86struct TcpKeepaliveConfig {
87    time: Option<Duration>,
88    interval: Option<Duration>,
89    retries: Option<u32>,
90}
91
92impl TcpKeepaliveConfig {
93    /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
94    fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
95        let mut dirty = false;
96        let mut ka = TcpKeepalive::new();
97        if let Some(time) = self.time {
98            ka = ka.with_time(time);
99            dirty = true
100        }
101        if let Some(interval) = self.interval {
102            ka = Self::ka_with_interval(ka, interval, &mut dirty)
103        };
104        if let Some(retries) = self.retries {
105            ka = Self::ka_with_retries(ka, retries, &mut dirty)
106        };
107        if dirty {
108            Some(ka)
109        } else {
110            None
111        }
112    }
113
114    #[cfg(not(any(
115        target_os = "aix",
116        target_os = "openbsd",
117        target_os = "redox",
118        target_os = "solaris"
119    )))]
120    fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
121        *dirty = true;
122        ka.with_interval(interval)
123    }
124
125    #[cfg(any(
126        target_os = "aix",
127        target_os = "openbsd",
128        target_os = "redox",
129        target_os = "solaris"
130    ))]
131    fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
132        ka // no-op as keepalive interval is not supported on this platform
133    }
134
135    #[cfg(not(any(
136        target_os = "aix",
137        target_os = "openbsd",
138        target_os = "redox",
139        target_os = "solaris",
140        target_os = "windows"
141    )))]
142    fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
143        *dirty = true;
144        ka.with_retries(retries)
145    }
146
147    #[cfg(any(
148        target_os = "aix",
149        target_os = "openbsd",
150        target_os = "redox",
151        target_os = "solaris",
152        target_os = "windows"
153    ))]
154    fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
155        ka // no-op as keepalive retries is not supported on this platform
156    }
157}
158
159// ===== impl HttpConnector =====
160
161impl HttpConnector {
162    /// Construct a new HttpConnector.
163    pub fn new() -> HttpConnector {
164        HttpConnector::new_with_resolver(GaiResolver::new())
165    }
166}
167
168impl<R> HttpConnector<R> {
169    /// Construct a new HttpConnector.
170    ///
171    /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
172    pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
173        HttpConnector {
174            config: Arc::new(Config {
175                connect_timeout: None,
176                enforce_http: true,
177                happy_eyeballs_timeout: Some(Duration::from_millis(300)),
178                tcp_keepalive_config: TcpKeepaliveConfig::default(),
179                local_address_ipv4: None,
180                local_address_ipv6: None,
181                nodelay: false,
182                reuse_address: false,
183                send_buffer_size: None,
184                recv_buffer_size: None,
185                #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
186                interface: None,
187                #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
188                tcp_user_timeout: None,
189            }),
190            resolver,
191        }
192    }
193
194    /// Option to enforce all `Uri`s have the `http` scheme.
195    ///
196    /// Enabled by default.
197    #[inline]
198    pub fn enforce_http(&mut self, is_enforced: bool) {
199        self.config_mut().enforce_http = is_enforced;
200    }
201
202    /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration
203    /// to remain idle before sending TCP keepalive probes.
204    ///
205    /// If `None`, keepalive is disabled.
206    ///
207    /// Default is `None`.
208    #[inline]
209    pub fn set_keepalive(&mut self, time: Option<Duration>) {
210        self.config_mut().tcp_keepalive_config.time = time;
211    }
212
213    /// Set the duration between two successive TCP keepalive retransmissions,
214    /// if acknowledgement to the previous keepalive transmission is not received.
215    #[inline]
216    pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
217        self.config_mut().tcp_keepalive_config.interval = interval;
218    }
219
220    /// Set the number of retransmissions to be carried out before declaring that remote end is not available.
221    #[inline]
222    pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
223        self.config_mut().tcp_keepalive_config.retries = retries;
224    }
225
226    /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
227    ///
228    /// Default is `false`.
229    #[inline]
230    pub fn set_nodelay(&mut self, nodelay: bool) {
231        self.config_mut().nodelay = nodelay;
232    }
233
234    /// Sets the value of the SO_SNDBUF option on the socket.
235    #[inline]
236    pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
237        self.config_mut().send_buffer_size = size;
238    }
239
240    /// Sets the value of the SO_RCVBUF option on the socket.
241    #[inline]
242    pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
243        self.config_mut().recv_buffer_size = size;
244    }
245
246    /// Set that all sockets are bound to the configured address before connection.
247    ///
248    /// If `None`, the sockets will not be bound.
249    ///
250    /// Default is `None`.
251    #[inline]
252    pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
253        let (v4, v6) = match addr {
254            Some(IpAddr::V4(a)) => (Some(a), None),
255            Some(IpAddr::V6(a)) => (None, Some(a)),
256            _ => (None, None),
257        };
258
259        let cfg = self.config_mut();
260
261        cfg.local_address_ipv4 = v4;
262        cfg.local_address_ipv6 = v6;
263    }
264
265    /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
266    /// preferences) before connection.
267    #[inline]
268    pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
269        let cfg = self.config_mut();
270
271        cfg.local_address_ipv4 = Some(addr_ipv4);
272        cfg.local_address_ipv6 = Some(addr_ipv6);
273    }
274
275    /// Set the connect timeout.
276    ///
277    /// If a domain resolves to multiple IP addresses, the timeout will be
278    /// evenly divided across them.
279    ///
280    /// Default is `None`.
281    #[inline]
282    pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
283        self.config_mut().connect_timeout = dur;
284    }
285
286    /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
287    ///
288    /// If hostname resolves to both IPv4 and IPv6 addresses and connection
289    /// cannot be established using preferred address family before timeout
290    /// elapses, then connector will in parallel attempt connection using other
291    /// address family.
292    ///
293    /// If `None`, parallel connection attempts are disabled.
294    ///
295    /// Default is 300 milliseconds.
296    ///
297    /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
298    #[inline]
299    pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
300        self.config_mut().happy_eyeballs_timeout = dur;
301    }
302
303    /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
304    ///
305    /// Default is `false`.
306    #[inline]
307    pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
308        self.config_mut().reuse_address = reuse_address;
309        self
310    }
311
312    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
313    ///
314    /// If a socket is bound to an interface, only packets received from that particular
315    /// interface are processed by the socket. Note that this only works for some socket
316    /// types, particularly AF_INET sockets.
317    ///
318    /// On Linux it can be used to specify a [VRF], but the binary needs
319    /// to either have `CAP_NET_RAW` or to be run as root.
320    ///
321    /// This function is only available on Android、Fuchsia and Linux.
322    ///
323    /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
324    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
325    #[inline]
326    pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
327        self.config_mut().interface = Some(interface.into());
328        self
329    }
330
331    /// Sets the value of the TCP_USER_TIMEOUT option on the socket.
332    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
333    #[inline]
334    pub fn set_tcp_user_timeout(&mut self, time: Option<Duration>) {
335        self.config_mut().tcp_user_timeout = time;
336    }
337
338    // private
339
340    fn config_mut(&mut self) -> &mut Config {
341        // If the are HttpConnector clones, this will clone the inner
342        // config. So mutating the config won't ever affect previous
343        // clones.
344        Arc::make_mut(&mut self.config)
345    }
346}
347
348static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
349static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
350static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
351
352// R: Debug required for now to allow adding it to debug output later...
353impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
354    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355        f.debug_struct("HttpConnector").finish()
356    }
357}
358
359impl<R> tower_service::Service<Uri> for HttpConnector<R>
360where
361    R: Resolve + Clone + Send + Sync + 'static,
362    R::Future: Send,
363{
364    type Response = TokioIo<TcpStream>;
365    type Error = ConnectError;
366    type Future = HttpConnecting<R>;
367
368    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
369        futures_util::ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
370        Poll::Ready(Ok(()))
371    }
372
373    fn call(&mut self, dst: Uri) -> Self::Future {
374        let mut self_ = self.clone();
375        HttpConnecting {
376            fut: Box::pin(async move { self_.call_async(dst).await }),
377            _marker: PhantomData,
378        }
379    }
380}
381
382fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
383    trace!(
384        "Http::connect; scheme={:?}, host={:?}, port={:?}",
385        dst.scheme(),
386        dst.host(),
387        dst.port(),
388    );
389
390    if config.enforce_http {
391        if dst.scheme() != Some(&Scheme::HTTP) {
392            return Err(ConnectError {
393                msg: INVALID_NOT_HTTP.into(),
394                cause: None,
395            });
396        }
397    } else if dst.scheme().is_none() {
398        return Err(ConnectError {
399            msg: INVALID_MISSING_SCHEME.into(),
400            cause: None,
401        });
402    }
403
404    let host = match dst.host() {
405        Some(s) => s,
406        None => {
407            return Err(ConnectError {
408                msg: INVALID_MISSING_HOST.into(),
409                cause: None,
410            })
411        }
412    };
413    let port = match dst.port() {
414        Some(port) => port.as_u16(),
415        None => {
416            if dst.scheme() == Some(&Scheme::HTTPS) {
417                443
418            } else {
419                80
420            }
421        }
422    };
423
424    Ok((host, port))
425}
426
427impl<R> HttpConnector<R>
428where
429    R: Resolve,
430{
431    async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
432        let config = &self.config;
433
434        let (host, port) = get_host_port(config, &dst)?;
435        let host = host.trim_start_matches('[').trim_end_matches(']');
436
437        // If the host is already an IP addr (v4 or v6),
438        // skip resolving the dns and start connecting right away.
439        let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
440            addrs
441        } else {
442            let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
443                .await
444                .map_err(ConnectError::dns)?;
445            let addrs = addrs
446                .map(|mut addr| {
447                    set_port(&mut addr, port, dst.port().is_some());
448
449                    addr
450                })
451                .collect();
452            dns::SocketAddrs::new(addrs)
453        };
454
455        let c = ConnectingTcp::new(addrs, config);
456
457        let sock = c.connect().await?;
458
459        if let Err(e) = sock.set_nodelay(config.nodelay) {
460            warn!("tcp set_nodelay error: {}", e);
461        }
462
463        Ok(TokioIo::new(sock))
464    }
465}
466
467impl Connection for TcpStream {
468    fn connected(&self) -> Connected {
469        let connected = Connected::new();
470        if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
471            connected.extra(HttpInfo {
472                remote_addr,
473                local_addr,
474            })
475        } else {
476            connected
477        }
478    }
479}
480
481// Implement `Connection` for generic `TokioIo<T>` so that external crates can
482// implement their own `HttpConnector` with `TokioIo<CustomTcpStream>`.
483impl<T> Connection for TokioIo<T>
484where
485    T: Connection,
486{
487    fn connected(&self) -> Connected {
488        self.inner().connected()
489    }
490}
491
492impl HttpInfo {
493    /// Get the remote address of the transport used.
494    pub fn remote_addr(&self) -> SocketAddr {
495        self.remote_addr
496    }
497
498    /// Get the local address of the transport used.
499    pub fn local_addr(&self) -> SocketAddr {
500        self.local_addr
501    }
502}
503
504pin_project! {
505    // Not publicly exported (so missing_docs doesn't trigger).
506    //
507    // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
508    // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
509    // (and thus we can change the type in the future).
510    #[must_use = "futures do nothing unless polled"]
511    #[allow(missing_debug_implementations)]
512    pub struct HttpConnecting<R> {
513        #[pin]
514        fut: BoxConnecting,
515        _marker: PhantomData<R>,
516    }
517}
518
519type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
520type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
521
522impl<R: Resolve> Future for HttpConnecting<R> {
523    type Output = ConnectResult;
524
525    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
526        self.project().fut.poll(cx)
527    }
528}
529
530// Not publicly exported (so missing_docs doesn't trigger).
531pub struct ConnectError {
532    msg: Box<str>,
533    cause: Option<Box<dyn StdError + Send + Sync>>,
534}
535
536impl ConnectError {
537    fn new<S, E>(msg: S, cause: E) -> ConnectError
538    where
539        S: Into<Box<str>>,
540        E: Into<Box<dyn StdError + Send + Sync>>,
541    {
542        ConnectError {
543            msg: msg.into(),
544            cause: Some(cause.into()),
545        }
546    }
547
548    fn dns<E>(cause: E) -> ConnectError
549    where
550        E: Into<Box<dyn StdError + Send + Sync>>,
551    {
552        ConnectError::new("dns error", cause)
553    }
554
555    fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
556    where
557        S: Into<Box<str>>,
558        E: Into<Box<dyn StdError + Send + Sync>>,
559    {
560        move |cause| ConnectError::new(msg, cause)
561    }
562}
563
564impl fmt::Debug for ConnectError {
565    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566        if let Some(ref cause) = self.cause {
567            f.debug_tuple("ConnectError")
568                .field(&self.msg)
569                .field(cause)
570                .finish()
571        } else {
572            self.msg.fmt(f)
573        }
574    }
575}
576
577impl fmt::Display for ConnectError {
578    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
579        f.write_str(&self.msg)?;
580
581        if let Some(ref cause) = self.cause {
582            write!(f, ": {}", cause)?;
583        }
584
585        Ok(())
586    }
587}
588
589impl StdError for ConnectError {
590    fn source(&self) -> Option<&(dyn StdError + 'static)> {
591        self.cause.as_ref().map(|e| &**e as _)
592    }
593}
594
595struct ConnectingTcp<'a> {
596    preferred: ConnectingTcpRemote,
597    fallback: Option<ConnectingTcpFallback>,
598    config: &'a Config,
599}
600
601impl<'a> ConnectingTcp<'a> {
602    fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
603        if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
604            let (preferred_addrs, fallback_addrs) = remote_addrs
605                .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
606            if fallback_addrs.is_empty() {
607                return ConnectingTcp {
608                    preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
609                    fallback: None,
610                    config,
611                };
612            }
613
614            ConnectingTcp {
615                preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
616                fallback: Some(ConnectingTcpFallback {
617                    delay: tokio::time::sleep(fallback_timeout),
618                    remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
619                }),
620                config,
621            }
622        } else {
623            ConnectingTcp {
624                preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
625                fallback: None,
626                config,
627            }
628        }
629    }
630}
631
632struct ConnectingTcpFallback {
633    delay: Sleep,
634    remote: ConnectingTcpRemote,
635}
636
637struct ConnectingTcpRemote {
638    addrs: dns::SocketAddrs,
639    connect_timeout: Option<Duration>,
640}
641
642impl ConnectingTcpRemote {
643    fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
644        let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
645
646        Self {
647            addrs,
648            connect_timeout,
649        }
650    }
651}
652
653impl ConnectingTcpRemote {
654    async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
655        let mut err = None;
656        for addr in &mut self.addrs {
657            debug!("connecting to {}", addr);
658            match connect(&addr, config, self.connect_timeout)?.await {
659                Ok(tcp) => {
660                    debug!("connected to {}", addr);
661                    return Ok(tcp);
662                }
663                Err(e) => {
664                    trace!("connect error for {}: {:?}", addr, e);
665                    err = Some(e);
666                }
667            }
668        }
669
670        match err {
671            Some(e) => Err(e),
672            None => Err(ConnectError::new(
673                "tcp connect error",
674                std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
675            )),
676        }
677    }
678}
679
680fn bind_local_address(
681    socket: &socket2::Socket,
682    dst_addr: &SocketAddr,
683    local_addr_ipv4: &Option<Ipv4Addr>,
684    local_addr_ipv6: &Option<Ipv6Addr>,
685) -> io::Result<()> {
686    match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
687        (SocketAddr::V4(_), Some(addr), _) => {
688            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
689        }
690        (SocketAddr::V6(_), _, Some(addr)) => {
691            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
692        }
693        _ => {
694            if cfg!(windows) {
695                // Windows requires a socket be bound before calling connect
696                let any: SocketAddr = match *dst_addr {
697                    SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
698                    SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
699                };
700                socket.bind(&any.into())?;
701            }
702        }
703    }
704
705    Ok(())
706}
707
708fn connect(
709    addr: &SocketAddr,
710    config: &Config,
711    connect_timeout: Option<Duration>,
712) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
713    // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
714    // keepalive timeout, it would be nice to use that instead of socket2,
715    // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
716    use socket2::{Domain, Protocol, Socket, Type};
717
718    let domain = Domain::for_address(*addr);
719    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
720        .map_err(ConnectError::m("tcp open error"))?;
721
722    // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
723    // responsible for ensuring O_NONBLOCK is set.
724    socket
725        .set_nonblocking(true)
726        .map_err(ConnectError::m("tcp set_nonblocking error"))?;
727
728    if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
729        if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
730            warn!("tcp set_keepalive error: {}", e);
731        }
732    }
733
734    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
735    // That this only works for some socket types, particularly AF_INET sockets.
736    if let Some(interface) = &config.interface {
737        socket
738            .bind_device(Some(interface.as_bytes()))
739            .map_err(ConnectError::m("tcp bind interface error"))?;
740    }
741
742    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
743    if let Some(tcp_user_timeout) = &config.tcp_user_timeout {
744        if let Err(e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) {
745            warn!("tcp set_tcp_user_timeout error: {}", e);
746        }
747    }
748
749    bind_local_address(
750        &socket,
751        addr,
752        &config.local_address_ipv4,
753        &config.local_address_ipv6,
754    )
755    .map_err(ConnectError::m("tcp bind local error"))?;
756
757    #[cfg(unix)]
758    let socket = unsafe {
759        // Safety: `from_raw_fd` is only safe to call if ownership of the raw
760        // file descriptor is transferred. Since we call `into_raw_fd` on the
761        // socket2 socket, it gives up ownership of the fd and will not close
762        // it, so this is safe.
763        use std::os::unix::io::{FromRawFd, IntoRawFd};
764        TcpSocket::from_raw_fd(socket.into_raw_fd())
765    };
766    #[cfg(windows)]
767    let socket = unsafe {
768        // Safety: `from_raw_socket` is only safe to call if ownership of the raw
769        // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
770        // socket2 socket, it gives up ownership of the SOCKET and will not close
771        // it, so this is safe.
772        use std::os::windows::io::{FromRawSocket, IntoRawSocket};
773        TcpSocket::from_raw_socket(socket.into_raw_socket())
774    };
775
776    if config.reuse_address {
777        if let Err(e) = socket.set_reuseaddr(true) {
778            warn!("tcp set_reuse_address error: {}", e);
779        }
780    }
781
782    if let Some(size) = config.send_buffer_size {
783        if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
784            warn!("tcp set_buffer_size error: {}", e);
785        }
786    }
787
788    if let Some(size) = config.recv_buffer_size {
789        if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
790            warn!("tcp set_recv_buffer_size error: {}", e);
791        }
792    }
793
794    let connect = socket.connect(*addr);
795    Ok(async move {
796        match connect_timeout {
797            Some(dur) => match tokio::time::timeout(dur, connect).await {
798                Ok(Ok(s)) => Ok(s),
799                Ok(Err(e)) => Err(e),
800                Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
801            },
802            None => connect.await,
803        }
804        .map_err(ConnectError::m("tcp connect error"))
805    })
806}
807
808impl ConnectingTcp<'_> {
809    async fn connect(mut self) -> Result<TcpStream, ConnectError> {
810        match self.fallback {
811            None => self.preferred.connect(self.config).await,
812            Some(mut fallback) => {
813                let preferred_fut = self.preferred.connect(self.config);
814                futures_util::pin_mut!(preferred_fut);
815
816                let fallback_fut = fallback.remote.connect(self.config);
817                futures_util::pin_mut!(fallback_fut);
818
819                let fallback_delay = fallback.delay;
820                futures_util::pin_mut!(fallback_delay);
821
822                let (result, future) =
823                    match futures_util::future::select(preferred_fut, fallback_delay).await {
824                        Either::Left((result, _fallback_delay)) => {
825                            (result, Either::Right(fallback_fut))
826                        }
827                        Either::Right(((), preferred_fut)) => {
828                            // Delay is done, start polling both the preferred and the fallback
829                            futures_util::future::select(preferred_fut, fallback_fut)
830                                .await
831                                .factor_first()
832                        }
833                    };
834
835                if result.is_err() {
836                    // Fallback to the remaining future (could be preferred or fallback)
837                    // if we get an error
838                    future.await
839                } else {
840                    result
841                }
842            }
843        }
844    }
845}
846
847/// Respect explicit ports in the URI, if none, either
848/// keep non `0` ports resolved from a custom dns resolver,
849/// or use the default port for the scheme.
850fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) {
851    if explicit || addr.port() == 0 {
852        addr.set_port(host_port)
853    };
854}
855
856#[cfg(test)]
857mod tests {
858    use std::io;
859    use std::net::SocketAddr;
860
861    use ::http::Uri;
862
863    use crate::client::legacy::connect::http::TcpKeepaliveConfig;
864
865    use super::super::sealed::{Connect, ConnectSvc};
866    use super::{Config, ConnectError, HttpConnector};
867
868    use super::set_port;
869
870    async fn connect<C>(
871        connector: C,
872        dst: Uri,
873    ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
874    where
875        C: Connect,
876    {
877        connector.connect(super::super::sealed::Internal, dst).await
878    }
879
880    #[tokio::test]
881    #[cfg_attr(miri, ignore)]
882    async fn test_errors_enforce_http() {
883        let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
884        let connector = HttpConnector::new();
885
886        let err = connect(connector, dst).await.unwrap_err();
887        assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
888    }
889
890    #[cfg(any(target_os = "linux", target_os = "macos"))]
891    fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
892        use std::net::{IpAddr, TcpListener};
893
894        let mut ip_v4 = None;
895        let mut ip_v6 = None;
896
897        let ips = pnet_datalink::interfaces()
898            .into_iter()
899            .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
900
901        for ip in ips {
902            match ip {
903                IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
904                IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
905                _ => (),
906            }
907
908            if ip_v4.is_some() && ip_v6.is_some() {
909                break;
910            }
911        }
912
913        (ip_v4, ip_v6)
914    }
915
916    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
917    fn default_interface() -> Option<String> {
918        pnet_datalink::interfaces()
919            .iter()
920            .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
921            .map(|e| e.name.clone())
922    }
923
924    #[tokio::test]
925    #[cfg_attr(miri, ignore)]
926    async fn test_errors_missing_scheme() {
927        let dst = "example.domain".parse().unwrap();
928        let mut connector = HttpConnector::new();
929        connector.enforce_http(false);
930
931        let err = connect(connector, dst).await.unwrap_err();
932        assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
933    }
934
935    // NOTE: pnet crate that we use in this test doesn't compile on Windows
936    #[cfg(any(target_os = "linux", target_os = "macos"))]
937    #[cfg_attr(miri, ignore)]
938    #[tokio::test]
939    async fn local_address() {
940        use std::net::{IpAddr, TcpListener};
941
942        let (bind_ip_v4, bind_ip_v6) = get_local_ips();
943        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
944        let port = server4.local_addr().unwrap().port();
945        let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
946
947        let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
948            let mut connector = HttpConnector::new();
949
950            match (bind_ip_v4, bind_ip_v6) {
951                (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
952                (Some(v4), None) => connector.set_local_address(Some(v4.into())),
953                (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
954                _ => unreachable!(),
955            }
956
957            connect(connector, dst.parse().unwrap()).await.unwrap();
958
959            let (_, client_addr) = server.accept().unwrap();
960
961            assert_eq!(client_addr.ip(), expected_ip);
962        };
963
964        if let Some(ip) = bind_ip_v4 {
965            assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
966        }
967
968        if let Some(ip) = bind_ip_v6 {
969            assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
970        }
971    }
972
973    // NOTE: pnet crate that we use in this test doesn't compile on Windows
974    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
975    #[tokio::test]
976    #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
977    async fn interface() {
978        use socket2::{Domain, Protocol, Socket, Type};
979        use std::net::TcpListener;
980
981        let interface: Option<String> = default_interface();
982
983        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
984        let port = server4.local_addr().unwrap().port();
985
986        let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
987
988        let assert_interface_name =
989            |dst: String,
990             server: TcpListener,
991             bind_iface: Option<String>,
992             expected_interface: Option<String>| async move {
993                let mut connector = HttpConnector::new();
994                if let Some(iface) = bind_iface {
995                    connector.set_interface(iface);
996                }
997
998                connect(connector, dst.parse().unwrap()).await.unwrap();
999                let domain = Domain::for_address(server.local_addr().unwrap());
1000                let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
1001
1002                assert_eq!(
1003                    socket.device().unwrap().as_deref(),
1004                    expected_interface.as_deref().map(|val| val.as_bytes())
1005                );
1006            };
1007
1008        assert_interface_name(
1009            format!("http://127.0.0.1:{}", port),
1010            server4,
1011            interface.clone(),
1012            interface.clone(),
1013        )
1014        .await;
1015        assert_interface_name(
1016            format!("http://[::1]:{}", port),
1017            server6,
1018            interface.clone(),
1019            interface.clone(),
1020        )
1021        .await;
1022    }
1023
1024    #[test]
1025    #[ignore] // TODO
1026    #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
1027    fn client_happy_eyeballs() {
1028        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
1029        use std::time::{Duration, Instant};
1030
1031        use super::dns;
1032        use super::ConnectingTcp;
1033
1034        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1035        let addr = server4.local_addr().unwrap();
1036        let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
1037        let rt = tokio::runtime::Builder::new_current_thread()
1038            .enable_all()
1039            .build()
1040            .unwrap();
1041
1042        let local_timeout = Duration::default();
1043        let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1044        let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1045        let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1046            + Duration::from_millis(250);
1047
1048        let scenarios = &[
1049            // Fast primary, without fallback.
1050            (&[local_ipv4_addr()][..], 4, local_timeout, false),
1051            (&[local_ipv6_addr()][..], 6, local_timeout, false),
1052            // Fast primary, with (unused) fallback.
1053            (
1054                &[local_ipv4_addr(), local_ipv6_addr()][..],
1055                4,
1056                local_timeout,
1057                false,
1058            ),
1059            (
1060                &[local_ipv6_addr(), local_ipv4_addr()][..],
1061                6,
1062                local_timeout,
1063                false,
1064            ),
1065            // Unreachable + fast primary, without fallback.
1066            (
1067                &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1068                4,
1069                unreachable_v4_timeout,
1070                false,
1071            ),
1072            (
1073                &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1074                6,
1075                unreachable_v6_timeout,
1076                false,
1077            ),
1078            // Unreachable + fast primary, with (unused) fallback.
1079            (
1080                &[
1081                    unreachable_ipv4_addr(),
1082                    local_ipv4_addr(),
1083                    local_ipv6_addr(),
1084                ][..],
1085                4,
1086                unreachable_v4_timeout,
1087                false,
1088            ),
1089            (
1090                &[
1091                    unreachable_ipv6_addr(),
1092                    local_ipv6_addr(),
1093                    local_ipv4_addr(),
1094                ][..],
1095                6,
1096                unreachable_v6_timeout,
1097                true,
1098            ),
1099            // Slow primary, with (used) fallback.
1100            (
1101                &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1102                6,
1103                fallback_timeout,
1104                false,
1105            ),
1106            (
1107                &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1108                4,
1109                fallback_timeout,
1110                true,
1111            ),
1112            // Slow primary, with (used) unreachable + fast fallback.
1113            (
1114                &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1115                6,
1116                fallback_timeout + unreachable_v6_timeout,
1117                false,
1118            ),
1119            (
1120                &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1121                4,
1122                fallback_timeout + unreachable_v4_timeout,
1123                true,
1124            ),
1125        ];
1126
1127        // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
1128        // Otherwise, connection to "slow" IPv6 address will error-out immediately.
1129        let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1130
1131        for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1132            if needs_ipv6_access && !ipv6_accessible {
1133                continue;
1134            }
1135
1136            let (start, stream) = rt
1137                .block_on(async move {
1138                    let addrs = hosts
1139                        .iter()
1140                        .map(|host| (host.clone(), addr.port()).into())
1141                        .collect();
1142                    let cfg = Config {
1143                        local_address_ipv4: None,
1144                        local_address_ipv6: None,
1145                        connect_timeout: None,
1146                        tcp_keepalive_config: TcpKeepaliveConfig::default(),
1147                        happy_eyeballs_timeout: Some(fallback_timeout),
1148                        nodelay: false,
1149                        reuse_address: false,
1150                        enforce_http: false,
1151                        send_buffer_size: None,
1152                        recv_buffer_size: None,
1153                        #[cfg(any(
1154                            target_os = "android",
1155                            target_os = "fuchsia",
1156                            target_os = "linux"
1157                        ))]
1158                        interface: None,
1159                        #[cfg(any(
1160                            target_os = "android",
1161                            target_os = "fuchsia",
1162                            target_os = "linux"
1163                        ))]
1164                        tcp_user_timeout: None,
1165                    };
1166                    let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1167                    let start = Instant::now();
1168                    Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1169                })
1170                .unwrap();
1171            let res = if stream.peer_addr().unwrap().is_ipv4() {
1172                4
1173            } else {
1174                6
1175            };
1176            let duration = start.elapsed();
1177
1178            // Allow actual duration to be +/- 150ms off.
1179            let min_duration = if timeout >= Duration::from_millis(150) {
1180                timeout - Duration::from_millis(150)
1181            } else {
1182                Duration::default()
1183            };
1184            let max_duration = timeout + Duration::from_millis(150);
1185
1186            assert_eq!(res, family);
1187            assert!(duration >= min_duration);
1188            assert!(duration <= max_duration);
1189        }
1190
1191        fn local_ipv4_addr() -> IpAddr {
1192            Ipv4Addr::new(127, 0, 0, 1).into()
1193        }
1194
1195        fn local_ipv6_addr() -> IpAddr {
1196            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1197        }
1198
1199        fn unreachable_ipv4_addr() -> IpAddr {
1200            Ipv4Addr::new(127, 0, 0, 2).into()
1201        }
1202
1203        fn unreachable_ipv6_addr() -> IpAddr {
1204            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1205        }
1206
1207        fn slow_ipv4_addr() -> IpAddr {
1208            // RFC 6890 reserved IPv4 address.
1209            Ipv4Addr::new(198, 18, 0, 25).into()
1210        }
1211
1212        fn slow_ipv6_addr() -> IpAddr {
1213            // RFC 6890 reserved IPv6 address.
1214            Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1215        }
1216
1217        fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1218            let start = Instant::now();
1219            let result =
1220                std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1221
1222            let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1223            let duration = start.elapsed();
1224            (reachable, duration)
1225        }
1226    }
1227
1228    use std::time::Duration;
1229
1230    #[test]
1231    fn no_tcp_keepalive_config() {
1232        assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1233    }
1234
1235    #[test]
1236    fn tcp_keepalive_time_config() {
1237        let mut kac = TcpKeepaliveConfig::default();
1238        kac.time = Some(Duration::from_secs(60));
1239        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1240            assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1241        } else {
1242            panic!("test failed");
1243        }
1244    }
1245
1246    #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1247    #[test]
1248    fn tcp_keepalive_interval_config() {
1249        let mut kac = TcpKeepaliveConfig::default();
1250        kac.interval = Some(Duration::from_secs(1));
1251        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1252            assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1253        } else {
1254            panic!("test failed");
1255        }
1256    }
1257
1258    #[cfg(not(any(
1259        target_os = "openbsd",
1260        target_os = "redox",
1261        target_os = "solaris",
1262        target_os = "windows"
1263    )))]
1264    #[test]
1265    fn tcp_keepalive_retries_config() {
1266        let mut kac = TcpKeepaliveConfig::default();
1267        kac.retries = Some(3);
1268        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1269            assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1270        } else {
1271            panic!("test failed");
1272        }
1273    }
1274
1275    #[test]
1276    fn test_set_port() {
1277        // Respect explicit ports no matter what the resolved port is.
1278        let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1279        set_port(&mut addr, 42, true);
1280        assert_eq!(addr.port(), 42);
1281
1282        // Ignore default  host port, and use the socket port instead.
1283        let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1284        set_port(&mut addr, 443, false);
1285        assert_eq!(addr.port(), 6881);
1286
1287        // Use the default port if the resolved port is `0`.
1288        let mut addr = SocketAddr::from(([0, 0, 0, 0], 0));
1289        set_port(&mut addr, 443, false);
1290        assert_eq!(addr.port(), 443);
1291    }
1292}