1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_util::future::Either;
13use http::uri::{Scheme, Uri};
14use pin_project_lite::pin_project;
15use socket2::TcpKeepalive;
16use tokio::net::{TcpSocket, TcpStream};
17use tokio::time::Sleep;
18use tracing::{debug, trace, warn};
19
20use super::dns::{self, resolve, GaiResolver, Resolve};
21use super::{Connected, Connection};
22use crate::rt::TokioIo;
23
24#[derive(Clone)]
33pub struct HttpConnector<R = GaiResolver> {
34 config: Arc<Config>,
35 resolver: R,
36}
37
38#[derive(Clone, Debug)]
62pub struct HttpInfo {
63 remote_addr: SocketAddr,
64 local_addr: SocketAddr,
65}
66
67#[derive(Clone)]
68struct Config {
69 connect_timeout: Option<Duration>,
70 enforce_http: bool,
71 happy_eyeballs_timeout: Option<Duration>,
72 tcp_keepalive_config: TcpKeepaliveConfig,
73 local_address_ipv4: Option<Ipv4Addr>,
74 local_address_ipv6: Option<Ipv6Addr>,
75 nodelay: bool,
76 reuse_address: bool,
77 send_buffer_size: Option<usize>,
78 recv_buffer_size: Option<usize>,
79 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
80 interface: Option<String>,
81 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
82 tcp_user_timeout: Option<Duration>,
83}
84
85#[derive(Default, Debug, Clone, Copy)]
86struct TcpKeepaliveConfig {
87 time: Option<Duration>,
88 interval: Option<Duration>,
89 retries: Option<u32>,
90}
91
92impl TcpKeepaliveConfig {
93 fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
95 let mut dirty = false;
96 let mut ka = TcpKeepalive::new();
97 if let Some(time) = self.time {
98 ka = ka.with_time(time);
99 dirty = true
100 }
101 if let Some(interval) = self.interval {
102 ka = Self::ka_with_interval(ka, interval, &mut dirty)
103 };
104 if let Some(retries) = self.retries {
105 ka = Self::ka_with_retries(ka, retries, &mut dirty)
106 };
107 if dirty {
108 Some(ka)
109 } else {
110 None
111 }
112 }
113
114 #[cfg(not(any(
115 target_os = "aix",
116 target_os = "openbsd",
117 target_os = "redox",
118 target_os = "solaris"
119 )))]
120 fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
121 *dirty = true;
122 ka.with_interval(interval)
123 }
124
125 #[cfg(any(
126 target_os = "aix",
127 target_os = "openbsd",
128 target_os = "redox",
129 target_os = "solaris"
130 ))]
131 fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
132 ka }
134
135 #[cfg(not(any(
136 target_os = "aix",
137 target_os = "openbsd",
138 target_os = "redox",
139 target_os = "solaris",
140 target_os = "windows"
141 )))]
142 fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
143 *dirty = true;
144 ka.with_retries(retries)
145 }
146
147 #[cfg(any(
148 target_os = "aix",
149 target_os = "openbsd",
150 target_os = "redox",
151 target_os = "solaris",
152 target_os = "windows"
153 ))]
154 fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
155 ka }
157}
158
159impl HttpConnector {
162 pub fn new() -> HttpConnector {
164 HttpConnector::new_with_resolver(GaiResolver::new())
165 }
166}
167
168impl<R> HttpConnector<R> {
169 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
173 HttpConnector {
174 config: Arc::new(Config {
175 connect_timeout: None,
176 enforce_http: true,
177 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
178 tcp_keepalive_config: TcpKeepaliveConfig::default(),
179 local_address_ipv4: None,
180 local_address_ipv6: None,
181 nodelay: false,
182 reuse_address: false,
183 send_buffer_size: None,
184 recv_buffer_size: None,
185 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
186 interface: None,
187 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
188 tcp_user_timeout: None,
189 }),
190 resolver,
191 }
192 }
193
194 #[inline]
198 pub fn enforce_http(&mut self, is_enforced: bool) {
199 self.config_mut().enforce_http = is_enforced;
200 }
201
202 #[inline]
209 pub fn set_keepalive(&mut self, time: Option<Duration>) {
210 self.config_mut().tcp_keepalive_config.time = time;
211 }
212
213 #[inline]
216 pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
217 self.config_mut().tcp_keepalive_config.interval = interval;
218 }
219
220 #[inline]
222 pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
223 self.config_mut().tcp_keepalive_config.retries = retries;
224 }
225
226 #[inline]
230 pub fn set_nodelay(&mut self, nodelay: bool) {
231 self.config_mut().nodelay = nodelay;
232 }
233
234 #[inline]
236 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
237 self.config_mut().send_buffer_size = size;
238 }
239
240 #[inline]
242 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
243 self.config_mut().recv_buffer_size = size;
244 }
245
246 #[inline]
252 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
253 let (v4, v6) = match addr {
254 Some(IpAddr::V4(a)) => (Some(a), None),
255 Some(IpAddr::V6(a)) => (None, Some(a)),
256 _ => (None, None),
257 };
258
259 let cfg = self.config_mut();
260
261 cfg.local_address_ipv4 = v4;
262 cfg.local_address_ipv6 = v6;
263 }
264
265 #[inline]
268 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
269 let cfg = self.config_mut();
270
271 cfg.local_address_ipv4 = Some(addr_ipv4);
272 cfg.local_address_ipv6 = Some(addr_ipv6);
273 }
274
275 #[inline]
282 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
283 self.config_mut().connect_timeout = dur;
284 }
285
286 #[inline]
299 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
300 self.config_mut().happy_eyeballs_timeout = dur;
301 }
302
303 #[inline]
307 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
308 self.config_mut().reuse_address = reuse_address;
309 self
310 }
311
312 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
325 #[inline]
326 pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
327 self.config_mut().interface = Some(interface.into());
328 self
329 }
330
331 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
333 #[inline]
334 pub fn set_tcp_user_timeout(&mut self, time: Option<Duration>) {
335 self.config_mut().tcp_user_timeout = time;
336 }
337
338 fn config_mut(&mut self) -> &mut Config {
341 Arc::make_mut(&mut self.config)
345 }
346}
347
348static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
349static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
350static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
351
352impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
354 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 f.debug_struct("HttpConnector").finish()
356 }
357}
358
359impl<R> tower_service::Service<Uri> for HttpConnector<R>
360where
361 R: Resolve + Clone + Send + Sync + 'static,
362 R::Future: Send,
363{
364 type Response = TokioIo<TcpStream>;
365 type Error = ConnectError;
366 type Future = HttpConnecting<R>;
367
368 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
369 futures_util::ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
370 Poll::Ready(Ok(()))
371 }
372
373 fn call(&mut self, dst: Uri) -> Self::Future {
374 let mut self_ = self.clone();
375 HttpConnecting {
376 fut: Box::pin(async move { self_.call_async(dst).await }),
377 _marker: PhantomData,
378 }
379 }
380}
381
382fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
383 trace!(
384 "Http::connect; scheme={:?}, host={:?}, port={:?}",
385 dst.scheme(),
386 dst.host(),
387 dst.port(),
388 );
389
390 if config.enforce_http {
391 if dst.scheme() != Some(&Scheme::HTTP) {
392 return Err(ConnectError {
393 msg: INVALID_NOT_HTTP.into(),
394 cause: None,
395 });
396 }
397 } else if dst.scheme().is_none() {
398 return Err(ConnectError {
399 msg: INVALID_MISSING_SCHEME.into(),
400 cause: None,
401 });
402 }
403
404 let host = match dst.host() {
405 Some(s) => s,
406 None => {
407 return Err(ConnectError {
408 msg: INVALID_MISSING_HOST.into(),
409 cause: None,
410 })
411 }
412 };
413 let port = match dst.port() {
414 Some(port) => port.as_u16(),
415 None => {
416 if dst.scheme() == Some(&Scheme::HTTPS) {
417 443
418 } else {
419 80
420 }
421 }
422 };
423
424 Ok((host, port))
425}
426
427impl<R> HttpConnector<R>
428where
429 R: Resolve,
430{
431 async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
432 let config = &self.config;
433
434 let (host, port) = get_host_port(config, &dst)?;
435 let host = host.trim_start_matches('[').trim_end_matches(']');
436
437 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
440 addrs
441 } else {
442 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
443 .await
444 .map_err(ConnectError::dns)?;
445 let addrs = addrs
446 .map(|mut addr| {
447 set_port(&mut addr, port, dst.port().is_some());
448
449 addr
450 })
451 .collect();
452 dns::SocketAddrs::new(addrs)
453 };
454
455 let c = ConnectingTcp::new(addrs, config);
456
457 let sock = c.connect().await?;
458
459 if let Err(e) = sock.set_nodelay(config.nodelay) {
460 warn!("tcp set_nodelay error: {}", e);
461 }
462
463 Ok(TokioIo::new(sock))
464 }
465}
466
467impl Connection for TcpStream {
468 fn connected(&self) -> Connected {
469 let connected = Connected::new();
470 if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
471 connected.extra(HttpInfo {
472 remote_addr,
473 local_addr,
474 })
475 } else {
476 connected
477 }
478 }
479}
480
481impl<T> Connection for TokioIo<T>
484where
485 T: Connection,
486{
487 fn connected(&self) -> Connected {
488 self.inner().connected()
489 }
490}
491
492impl HttpInfo {
493 pub fn remote_addr(&self) -> SocketAddr {
495 self.remote_addr
496 }
497
498 pub fn local_addr(&self) -> SocketAddr {
500 self.local_addr
501 }
502}
503
504pin_project! {
505 #[must_use = "futures do nothing unless polled"]
511 #[allow(missing_debug_implementations)]
512 pub struct HttpConnecting<R> {
513 #[pin]
514 fut: BoxConnecting,
515 _marker: PhantomData<R>,
516 }
517}
518
519type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
520type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
521
522impl<R: Resolve> Future for HttpConnecting<R> {
523 type Output = ConnectResult;
524
525 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
526 self.project().fut.poll(cx)
527 }
528}
529
530pub struct ConnectError {
532 msg: Box<str>,
533 cause: Option<Box<dyn StdError + Send + Sync>>,
534}
535
536impl ConnectError {
537 fn new<S, E>(msg: S, cause: E) -> ConnectError
538 where
539 S: Into<Box<str>>,
540 E: Into<Box<dyn StdError + Send + Sync>>,
541 {
542 ConnectError {
543 msg: msg.into(),
544 cause: Some(cause.into()),
545 }
546 }
547
548 fn dns<E>(cause: E) -> ConnectError
549 where
550 E: Into<Box<dyn StdError + Send + Sync>>,
551 {
552 ConnectError::new("dns error", cause)
553 }
554
555 fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
556 where
557 S: Into<Box<str>>,
558 E: Into<Box<dyn StdError + Send + Sync>>,
559 {
560 move |cause| ConnectError::new(msg, cause)
561 }
562}
563
564impl fmt::Debug for ConnectError {
565 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566 if let Some(ref cause) = self.cause {
567 f.debug_tuple("ConnectError")
568 .field(&self.msg)
569 .field(cause)
570 .finish()
571 } else {
572 self.msg.fmt(f)
573 }
574 }
575}
576
577impl fmt::Display for ConnectError {
578 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
579 f.write_str(&self.msg)?;
580
581 if let Some(ref cause) = self.cause {
582 write!(f, ": {}", cause)?;
583 }
584
585 Ok(())
586 }
587}
588
589impl StdError for ConnectError {
590 fn source(&self) -> Option<&(dyn StdError + 'static)> {
591 self.cause.as_ref().map(|e| &**e as _)
592 }
593}
594
595struct ConnectingTcp<'a> {
596 preferred: ConnectingTcpRemote,
597 fallback: Option<ConnectingTcpFallback>,
598 config: &'a Config,
599}
600
601impl<'a> ConnectingTcp<'a> {
602 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
603 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
604 let (preferred_addrs, fallback_addrs) = remote_addrs
605 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
606 if fallback_addrs.is_empty() {
607 return ConnectingTcp {
608 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
609 fallback: None,
610 config,
611 };
612 }
613
614 ConnectingTcp {
615 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
616 fallback: Some(ConnectingTcpFallback {
617 delay: tokio::time::sleep(fallback_timeout),
618 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
619 }),
620 config,
621 }
622 } else {
623 ConnectingTcp {
624 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
625 fallback: None,
626 config,
627 }
628 }
629 }
630}
631
632struct ConnectingTcpFallback {
633 delay: Sleep,
634 remote: ConnectingTcpRemote,
635}
636
637struct ConnectingTcpRemote {
638 addrs: dns::SocketAddrs,
639 connect_timeout: Option<Duration>,
640}
641
642impl ConnectingTcpRemote {
643 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
644 let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
645
646 Self {
647 addrs,
648 connect_timeout,
649 }
650 }
651}
652
653impl ConnectingTcpRemote {
654 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
655 let mut err = None;
656 for addr in &mut self.addrs {
657 debug!("connecting to {}", addr);
658 match connect(&addr, config, self.connect_timeout)?.await {
659 Ok(tcp) => {
660 debug!("connected to {}", addr);
661 return Ok(tcp);
662 }
663 Err(e) => {
664 trace!("connect error for {}: {:?}", addr, e);
665 err = Some(e);
666 }
667 }
668 }
669
670 match err {
671 Some(e) => Err(e),
672 None => Err(ConnectError::new(
673 "tcp connect error",
674 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
675 )),
676 }
677 }
678}
679
680fn bind_local_address(
681 socket: &socket2::Socket,
682 dst_addr: &SocketAddr,
683 local_addr_ipv4: &Option<Ipv4Addr>,
684 local_addr_ipv6: &Option<Ipv6Addr>,
685) -> io::Result<()> {
686 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
687 (SocketAddr::V4(_), Some(addr), _) => {
688 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
689 }
690 (SocketAddr::V6(_), _, Some(addr)) => {
691 socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
692 }
693 _ => {
694 if cfg!(windows) {
695 let any: SocketAddr = match *dst_addr {
697 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
698 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
699 };
700 socket.bind(&any.into())?;
701 }
702 }
703 }
704
705 Ok(())
706}
707
708fn connect(
709 addr: &SocketAddr,
710 config: &Config,
711 connect_timeout: Option<Duration>,
712) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
713 use socket2::{Domain, Protocol, Socket, Type};
717
718 let domain = Domain::for_address(*addr);
719 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
720 .map_err(ConnectError::m("tcp open error"))?;
721
722 socket
725 .set_nonblocking(true)
726 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
727
728 if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
729 if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
730 warn!("tcp set_keepalive error: {}", e);
731 }
732 }
733
734 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
735 if let Some(interface) = &config.interface {
737 socket
738 .bind_device(Some(interface.as_bytes()))
739 .map_err(ConnectError::m("tcp bind interface error"))?;
740 }
741
742 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
743 if let Some(tcp_user_timeout) = &config.tcp_user_timeout {
744 if let Err(e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) {
745 warn!("tcp set_tcp_user_timeout error: {}", e);
746 }
747 }
748
749 bind_local_address(
750 &socket,
751 addr,
752 &config.local_address_ipv4,
753 &config.local_address_ipv6,
754 )
755 .map_err(ConnectError::m("tcp bind local error"))?;
756
757 #[cfg(unix)]
758 let socket = unsafe {
759 use std::os::unix::io::{FromRawFd, IntoRawFd};
764 TcpSocket::from_raw_fd(socket.into_raw_fd())
765 };
766 #[cfg(windows)]
767 let socket = unsafe {
768 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
773 TcpSocket::from_raw_socket(socket.into_raw_socket())
774 };
775
776 if config.reuse_address {
777 if let Err(e) = socket.set_reuseaddr(true) {
778 warn!("tcp set_reuse_address error: {}", e);
779 }
780 }
781
782 if let Some(size) = config.send_buffer_size {
783 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
784 warn!("tcp set_buffer_size error: {}", e);
785 }
786 }
787
788 if let Some(size) = config.recv_buffer_size {
789 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
790 warn!("tcp set_recv_buffer_size error: {}", e);
791 }
792 }
793
794 let connect = socket.connect(*addr);
795 Ok(async move {
796 match connect_timeout {
797 Some(dur) => match tokio::time::timeout(dur, connect).await {
798 Ok(Ok(s)) => Ok(s),
799 Ok(Err(e)) => Err(e),
800 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
801 },
802 None => connect.await,
803 }
804 .map_err(ConnectError::m("tcp connect error"))
805 })
806}
807
808impl ConnectingTcp<'_> {
809 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
810 match self.fallback {
811 None => self.preferred.connect(self.config).await,
812 Some(mut fallback) => {
813 let preferred_fut = self.preferred.connect(self.config);
814 futures_util::pin_mut!(preferred_fut);
815
816 let fallback_fut = fallback.remote.connect(self.config);
817 futures_util::pin_mut!(fallback_fut);
818
819 let fallback_delay = fallback.delay;
820 futures_util::pin_mut!(fallback_delay);
821
822 let (result, future) =
823 match futures_util::future::select(preferred_fut, fallback_delay).await {
824 Either::Left((result, _fallback_delay)) => {
825 (result, Either::Right(fallback_fut))
826 }
827 Either::Right(((), preferred_fut)) => {
828 futures_util::future::select(preferred_fut, fallback_fut)
830 .await
831 .factor_first()
832 }
833 };
834
835 if result.is_err() {
836 future.await
839 } else {
840 result
841 }
842 }
843 }
844 }
845}
846
847fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) {
851 if explicit || addr.port() == 0 {
852 addr.set_port(host_port)
853 };
854}
855
856#[cfg(test)]
857mod tests {
858 use std::io;
859 use std::net::SocketAddr;
860
861 use ::http::Uri;
862
863 use crate::client::legacy::connect::http::TcpKeepaliveConfig;
864
865 use super::super::sealed::{Connect, ConnectSvc};
866 use super::{Config, ConnectError, HttpConnector};
867
868 use super::set_port;
869
870 async fn connect<C>(
871 connector: C,
872 dst: Uri,
873 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
874 where
875 C: Connect,
876 {
877 connector.connect(super::super::sealed::Internal, dst).await
878 }
879
880 #[tokio::test]
881 #[cfg_attr(miri, ignore)]
882 async fn test_errors_enforce_http() {
883 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
884 let connector = HttpConnector::new();
885
886 let err = connect(connector, dst).await.unwrap_err();
887 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
888 }
889
890 #[cfg(any(target_os = "linux", target_os = "macos"))]
891 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
892 use std::net::{IpAddr, TcpListener};
893
894 let mut ip_v4 = None;
895 let mut ip_v6 = None;
896
897 let ips = pnet_datalink::interfaces()
898 .into_iter()
899 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
900
901 for ip in ips {
902 match ip {
903 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
904 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
905 _ => (),
906 }
907
908 if ip_v4.is_some() && ip_v6.is_some() {
909 break;
910 }
911 }
912
913 (ip_v4, ip_v6)
914 }
915
916 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
917 fn default_interface() -> Option<String> {
918 pnet_datalink::interfaces()
919 .iter()
920 .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
921 .map(|e| e.name.clone())
922 }
923
924 #[tokio::test]
925 #[cfg_attr(miri, ignore)]
926 async fn test_errors_missing_scheme() {
927 let dst = "example.domain".parse().unwrap();
928 let mut connector = HttpConnector::new();
929 connector.enforce_http(false);
930
931 let err = connect(connector, dst).await.unwrap_err();
932 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
933 }
934
935 #[cfg(any(target_os = "linux", target_os = "macos"))]
937 #[cfg_attr(miri, ignore)]
938 #[tokio::test]
939 async fn local_address() {
940 use std::net::{IpAddr, TcpListener};
941
942 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
943 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
944 let port = server4.local_addr().unwrap().port();
945 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
946
947 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
948 let mut connector = HttpConnector::new();
949
950 match (bind_ip_v4, bind_ip_v6) {
951 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
952 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
953 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
954 _ => unreachable!(),
955 }
956
957 connect(connector, dst.parse().unwrap()).await.unwrap();
958
959 let (_, client_addr) = server.accept().unwrap();
960
961 assert_eq!(client_addr.ip(), expected_ip);
962 };
963
964 if let Some(ip) = bind_ip_v4 {
965 assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
966 }
967
968 if let Some(ip) = bind_ip_v6 {
969 assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
970 }
971 }
972
973 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
975 #[tokio::test]
976 #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
977 async fn interface() {
978 use socket2::{Domain, Protocol, Socket, Type};
979 use std::net::TcpListener;
980
981 let interface: Option<String> = default_interface();
982
983 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
984 let port = server4.local_addr().unwrap().port();
985
986 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
987
988 let assert_interface_name =
989 |dst: String,
990 server: TcpListener,
991 bind_iface: Option<String>,
992 expected_interface: Option<String>| async move {
993 let mut connector = HttpConnector::new();
994 if let Some(iface) = bind_iface {
995 connector.set_interface(iface);
996 }
997
998 connect(connector, dst.parse().unwrap()).await.unwrap();
999 let domain = Domain::for_address(server.local_addr().unwrap());
1000 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
1001
1002 assert_eq!(
1003 socket.device().unwrap().as_deref(),
1004 expected_interface.as_deref().map(|val| val.as_bytes())
1005 );
1006 };
1007
1008 assert_interface_name(
1009 format!("http://127.0.0.1:{}", port),
1010 server4,
1011 interface.clone(),
1012 interface.clone(),
1013 )
1014 .await;
1015 assert_interface_name(
1016 format!("http://[::1]:{}", port),
1017 server6,
1018 interface.clone(),
1019 interface.clone(),
1020 )
1021 .await;
1022 }
1023
1024 #[test]
1025 #[ignore] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
1027 fn client_happy_eyeballs() {
1028 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
1029 use std::time::{Duration, Instant};
1030
1031 use super::dns;
1032 use super::ConnectingTcp;
1033
1034 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1035 let addr = server4.local_addr().unwrap();
1036 let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
1037 let rt = tokio::runtime::Builder::new_current_thread()
1038 .enable_all()
1039 .build()
1040 .unwrap();
1041
1042 let local_timeout = Duration::default();
1043 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1044 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1045 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1046 + Duration::from_millis(250);
1047
1048 let scenarios = &[
1049 (&[local_ipv4_addr()][..], 4, local_timeout, false),
1051 (&[local_ipv6_addr()][..], 6, local_timeout, false),
1052 (
1054 &[local_ipv4_addr(), local_ipv6_addr()][..],
1055 4,
1056 local_timeout,
1057 false,
1058 ),
1059 (
1060 &[local_ipv6_addr(), local_ipv4_addr()][..],
1061 6,
1062 local_timeout,
1063 false,
1064 ),
1065 (
1067 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1068 4,
1069 unreachable_v4_timeout,
1070 false,
1071 ),
1072 (
1073 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1074 6,
1075 unreachable_v6_timeout,
1076 false,
1077 ),
1078 (
1080 &[
1081 unreachable_ipv4_addr(),
1082 local_ipv4_addr(),
1083 local_ipv6_addr(),
1084 ][..],
1085 4,
1086 unreachable_v4_timeout,
1087 false,
1088 ),
1089 (
1090 &[
1091 unreachable_ipv6_addr(),
1092 local_ipv6_addr(),
1093 local_ipv4_addr(),
1094 ][..],
1095 6,
1096 unreachable_v6_timeout,
1097 true,
1098 ),
1099 (
1101 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1102 6,
1103 fallback_timeout,
1104 false,
1105 ),
1106 (
1107 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1108 4,
1109 fallback_timeout,
1110 true,
1111 ),
1112 (
1114 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1115 6,
1116 fallback_timeout + unreachable_v6_timeout,
1117 false,
1118 ),
1119 (
1120 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1121 4,
1122 fallback_timeout + unreachable_v4_timeout,
1123 true,
1124 ),
1125 ];
1126
1127 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1130
1131 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1132 if needs_ipv6_access && !ipv6_accessible {
1133 continue;
1134 }
1135
1136 let (start, stream) = rt
1137 .block_on(async move {
1138 let addrs = hosts
1139 .iter()
1140 .map(|host| (host.clone(), addr.port()).into())
1141 .collect();
1142 let cfg = Config {
1143 local_address_ipv4: None,
1144 local_address_ipv6: None,
1145 connect_timeout: None,
1146 tcp_keepalive_config: TcpKeepaliveConfig::default(),
1147 happy_eyeballs_timeout: Some(fallback_timeout),
1148 nodelay: false,
1149 reuse_address: false,
1150 enforce_http: false,
1151 send_buffer_size: None,
1152 recv_buffer_size: None,
1153 #[cfg(any(
1154 target_os = "android",
1155 target_os = "fuchsia",
1156 target_os = "linux"
1157 ))]
1158 interface: None,
1159 #[cfg(any(
1160 target_os = "android",
1161 target_os = "fuchsia",
1162 target_os = "linux"
1163 ))]
1164 tcp_user_timeout: None,
1165 };
1166 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1167 let start = Instant::now();
1168 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1169 })
1170 .unwrap();
1171 let res = if stream.peer_addr().unwrap().is_ipv4() {
1172 4
1173 } else {
1174 6
1175 };
1176 let duration = start.elapsed();
1177
1178 let min_duration = if timeout >= Duration::from_millis(150) {
1180 timeout - Duration::from_millis(150)
1181 } else {
1182 Duration::default()
1183 };
1184 let max_duration = timeout + Duration::from_millis(150);
1185
1186 assert_eq!(res, family);
1187 assert!(duration >= min_duration);
1188 assert!(duration <= max_duration);
1189 }
1190
1191 fn local_ipv4_addr() -> IpAddr {
1192 Ipv4Addr::new(127, 0, 0, 1).into()
1193 }
1194
1195 fn local_ipv6_addr() -> IpAddr {
1196 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1197 }
1198
1199 fn unreachable_ipv4_addr() -> IpAddr {
1200 Ipv4Addr::new(127, 0, 0, 2).into()
1201 }
1202
1203 fn unreachable_ipv6_addr() -> IpAddr {
1204 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1205 }
1206
1207 fn slow_ipv4_addr() -> IpAddr {
1208 Ipv4Addr::new(198, 18, 0, 25).into()
1210 }
1211
1212 fn slow_ipv6_addr() -> IpAddr {
1213 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1215 }
1216
1217 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1218 let start = Instant::now();
1219 let result =
1220 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1221
1222 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1223 let duration = start.elapsed();
1224 (reachable, duration)
1225 }
1226 }
1227
1228 use std::time::Duration;
1229
1230 #[test]
1231 fn no_tcp_keepalive_config() {
1232 assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1233 }
1234
1235 #[test]
1236 fn tcp_keepalive_time_config() {
1237 let mut kac = TcpKeepaliveConfig::default();
1238 kac.time = Some(Duration::from_secs(60));
1239 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1240 assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1241 } else {
1242 panic!("test failed");
1243 }
1244 }
1245
1246 #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1247 #[test]
1248 fn tcp_keepalive_interval_config() {
1249 let mut kac = TcpKeepaliveConfig::default();
1250 kac.interval = Some(Duration::from_secs(1));
1251 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1252 assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1253 } else {
1254 panic!("test failed");
1255 }
1256 }
1257
1258 #[cfg(not(any(
1259 target_os = "openbsd",
1260 target_os = "redox",
1261 target_os = "solaris",
1262 target_os = "windows"
1263 )))]
1264 #[test]
1265 fn tcp_keepalive_retries_config() {
1266 let mut kac = TcpKeepaliveConfig::default();
1267 kac.retries = Some(3);
1268 if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1269 assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1270 } else {
1271 panic!("test failed");
1272 }
1273 }
1274
1275 #[test]
1276 fn test_set_port() {
1277 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1279 set_port(&mut addr, 42, true);
1280 assert_eq!(addr.port(), 42);
1281
1282 let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1284 set_port(&mut addr, 443, false);
1285 assert_eq!(addr.port(), 6881);
1286
1287 let mut addr = SocketAddr::from(([0, 0, 0, 0], 0));
1289 set_port(&mut addr, 443, false);
1290 assert_eq!(addr.port(), 443);
1291 }
1292}