tor_chanmgr/transport/
default.rs1use std::{net::SocketAddr, sync::Arc, time::Duration};
5
6use async_trait::async_trait;
7use futures::{stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt};
8use safelog::sensitive as sv;
9use tor_error::bad_api_usage;
10use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11use tor_rtcompat::{NetStreamProvider, Runtime};
12use tracing::trace;
13
14use crate::Error;
15
16#[derive(Clone, Debug)]
22pub(crate) struct DefaultTransport<R: Runtime> {
23 runtime: R,
25}
26
27impl<R: Runtime> DefaultTransport<R> {
28 pub(crate) fn new(runtime: R) -> Self {
30 Self { runtime }
31 }
32}
33
34#[async_trait]
35impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
36 type Stream = <R as NetStreamProvider>::Stream;
37
38 async fn connect(
41 &self,
42 target: &OwnedChanTarget,
43 ) -> crate::Result<(OwnedChanTarget, Self::Stream)> {
44 let direct_addrs: Vec<_> = match target.chan_method() {
45 ChannelMethod::Direct(addrs) => addrs,
46 #[allow(unreachable_patterns)]
47 _ => {
48 return Err(Error::UnusableTarget(bad_api_usage!(
49 "Used default transport implementation for an unsupported transport."
50 )))
51 }
52 };
53
54 trace!("Launching direct connection for {}", target);
55
56 let (stream, addr) = connect_to_one(&self.runtime, &direct_addrs).await?;
57 let mut using_target = target.clone();
58 let _ignore = using_target.chan_method_mut().retain_addrs(|a| a == &addr);
59
60 Ok((using_target, stream))
61 }
62}
63
64static CONNECTION_DELAY: Duration = Duration::from_millis(150);
66
67async fn connect_to_one<R: Runtime>(
71 rt: &R,
72 addrs: &[SocketAddr],
73) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
74 if addrs.is_empty() {
76 return Err(Error::UnusableTarget(bad_api_usage!(
77 "No addresses for chosen relay"
78 )));
79 }
80
81 let mut connections = addrs
89 .iter()
90 .enumerate()
91 .map(|(i, a)| {
92 let delay = rt.sleep(CONNECTION_DELAY * i as u32);
93 delay.then(move |_| {
94 tracing::debug!("Connecting to {}", a);
95 rt.connect(a)
96 .map_ok(move |stream| (stream, *a))
97 .map_err(move |e| (e, *a))
98 })
99 })
100 .collect::<FuturesUnordered<_>>();
101
102 let mut ret = None;
103 let mut errors = vec![];
104
105 while let Some(result) = connections.next().await {
106 match result {
107 Ok(s) => {
108 ret = Some(s);
110 break;
111 }
112 Err((e, a)) => {
113 tor_error::warn_report!(e, "Connection to {} failed", sv(a));
116 errors.push((e, a));
117 }
118 }
119 }
120
121 drop(connections);
123
124 ret.ok_or_else(|| Error::ChannelBuild {
125 addresses: errors
126 .into_iter()
127 .map(|(e, a)| (sv(a), Arc::new(e)))
128 .collect(),
129 })
130}
131
132#[cfg(test)]
133mod test {
134 #![allow(clippy::bool_assert_comparison)]
136 #![allow(clippy::clone_on_copy)]
137 #![allow(clippy::dbg_macro)]
138 #![allow(clippy::mixed_attributes_style)]
139 #![allow(clippy::print_stderr)]
140 #![allow(clippy::print_stdout)]
141 #![allow(clippy::single_char_pattern)]
142 #![allow(clippy::unwrap_used)]
143 #![allow(clippy::unchecked_duration_subtraction)]
144 #![allow(clippy::useless_vec)]
145 #![allow(clippy::needless_pass_by_value)]
146 use std::str::FromStr;
149
150 use tor_rtcompat::{test_with_one_runtime, SleepProviderExt};
151 use tor_rtmock::net::MockNetwork;
152
153 use super::*;
154
155 #[test]
156 fn test_connect_one() {
157 let client_addr = "192.0.1.16".parse().unwrap();
158 let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
160 let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
162 let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
164 let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
166
167 test_with_one_runtime!(|rt| async move {
168 let network = MockNetwork::new();
170
171 let client_rt = network
173 .builder()
174 .add_address(client_addr)
175 .runtime(rt.clone());
176 let server_rt = network
177 .builder()
178 .add_address(addr1.ip())
179 .add_address(addr4.ip())
180 .runtime(rt.clone());
181 let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
182 let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
183 network.add_blackhole(addr3).unwrap();
188
189 let failure = connect_to_one(&client_rt, &[]).await;
191 assert!(failure.is_err());
192
193 for addresses in [
195 &[addr1][..],
196 &[addr1, addr2][..],
197 &[addr2, addr1][..],
198 &[addr1, addr3][..],
199 &[addr3, addr1][..],
200 &[addr1, addr2, addr3][..],
201 &[addr3, addr2, addr1][..],
202 ] {
203 let (_conn, addr) = connect_to_one(&client_rt, addresses).await.unwrap();
204 assert_eq!(addr, addr1);
205 }
206
207 for addresses in [
210 &[addr2][..],
211 &[addr2, addr3][..],
212 &[addr3, addr2][..],
213 &[addr3][..],
214 ] {
215 let expect_timeout = addresses.contains(&addr3);
216 let failure = rt
217 .timeout(
218 Duration::from_millis(300),
219 connect_to_one(&client_rt, addresses),
220 )
221 .await;
222 if expect_timeout {
223 assert!(failure.is_err());
224 } else {
225 assert!(failure.unwrap().is_err());
226 }
227 }
228
229 let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4]).await.unwrap();
231 assert_eq!(addr, addr1);
232 let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1]).await.unwrap();
233 assert_eq!(addr, addr4);
234 });
235 }
236}