1use std::sync::{Arc, Mutex};
5
6use crate::event::ChanMgrEventSender;
7use async_trait::async_trait;
8use tor_error::{internal, HasKind, HasRetryTime};
9use tor_linkspec::{HasChanMethod, OwnedChanTarget, PtTransportName};
10use tor_proto::channel::Channel;
11use tor_proto::memquota::ChannelAccount;
12use tracing::debug;
13
14#[derive(Clone)]
19pub struct BootstrapReporter(pub(crate) Arc<Mutex<ChanMgrEventSender>>);
20
21impl BootstrapReporter {
22 #[cfg(test)]
23 pub(crate) fn fake() -> Self {
25 let (snd, _rcv) = crate::event::channel();
26 Self(Arc::new(Mutex::new(snd)))
27 }
28}
29
30#[async_trait]
46pub trait ChannelFactory: Send + Sync {
47 async fn connect_via_transport(
57 &self,
58 target: &OwnedChanTarget,
59 reporter: BootstrapReporter,
60 memquota: ChannelAccount,
61 ) -> crate::Result<Arc<Channel>>;
62}
63
64#[async_trait]
69pub trait IncomingChannelFactory: Send + Sync {
70 type Stream: Send + Sync + 'static;
72
73 #[cfg(feature = "relay")]
76 async fn accept_from_transport(
77 &self,
78 peer: std::net::SocketAddr,
79 stream: Self::Stream,
80 memquota: ChannelAccount,
81 ) -> crate::Result<Arc<Channel>>;
82}
83
84#[async_trait]
85impl<CF> crate::mgr::AbstractChannelFactory for CF
86where
87 CF: ChannelFactory + IncomingChannelFactory + Sync,
88{
89 type Channel = tor_proto::channel::Channel;
90 type BuildSpec = OwnedChanTarget;
91 type Stream = CF::Stream;
92
93 async fn build_channel(
94 &self,
95 target: &Self::BuildSpec,
96 reporter: BootstrapReporter,
97 memquota: ChannelAccount,
98 ) -> crate::Result<Arc<Self::Channel>> {
99 debug!("Attempting to open a new channel to {target}");
100 self.connect_via_transport(target, reporter, memquota).await
101 }
102
103 #[cfg(feature = "relay")]
104 async fn build_channel_using_incoming(
105 &self,
106 peer: std::net::SocketAddr,
107 stream: Self::Stream,
108 memquota: ChannelAccount,
109 ) -> crate::Result<Arc<tor_proto::channel::Channel>> {
110 debug!("Attempting to open a new channel from {peer}");
111 self.accept_from_transport(peer, stream, memquota).await
112 }
113}
114
115pub trait AbstractPtError:
117 std::error::Error + HasKind + HasRetryTime + Send + Sync + std::fmt::Debug
118{
119}
120
121#[async_trait]
126pub trait AbstractPtMgr: Send + Sync {
127 async fn factory_for_transport(
129 &self,
130 transport: &PtTransportName,
131 ) -> Result<Option<Arc<dyn ChannelFactory + Send + Sync>>, Arc<dyn AbstractPtError>>;
132}
133
134#[async_trait]
135impl<P> AbstractPtMgr for Option<P>
136where
137 P: AbstractPtMgr,
138{
139 async fn factory_for_transport(
140 &self,
141 transport: &PtTransportName,
142 ) -> Result<Option<Arc<dyn ChannelFactory + Send + Sync>>, Arc<dyn AbstractPtError>> {
143 match self {
144 Some(mgr) => mgr.factory_for_transport(transport).await,
145 None => Ok(None),
146 }
147 }
148}
149
150pub(crate) struct CompoundFactory<CF> {
153 #[cfg(feature = "pt-client")]
154 ptmgr: Option<Arc<dyn AbstractPtMgr + 'static>>,
156 default_factory: Arc<CF>,
158}
159
160impl<CF> Clone for CompoundFactory<CF> {
161 fn clone(&self) -> Self {
162 Self {
163 #[cfg(feature = "pt-client")]
164 ptmgr: self.ptmgr.as_ref().map(Arc::clone),
165 default_factory: Arc::clone(&self.default_factory),
166 }
167 }
168}
169
170#[async_trait]
171impl<CF: ChannelFactory> ChannelFactory for CompoundFactory<CF> {
172 async fn connect_via_transport(
173 &self,
174 target: &OwnedChanTarget,
175 reporter: BootstrapReporter,
176 memquota: ChannelAccount,
177 ) -> crate::Result<Arc<Channel>> {
178 use tor_linkspec::ChannelMethod::*;
179 let factory = match target.chan_method() {
180 Direct(_) => self.default_factory.clone(),
181 #[cfg(feature = "pt-client")]
182 Pluggable(a) => match self.ptmgr.as_ref() {
183 Some(mgr) => mgr
184 .factory_for_transport(a.transport())
185 .await
186 .map_err(crate::Error::Pt)?
187 .ok_or_else(|| crate::Error::NoSuchTransport(a.transport().clone().into()))?,
188 None => return Err(crate::Error::NoSuchTransport(a.transport().clone().into())),
189 },
190 #[allow(unreachable_patterns)]
191 _ => {
192 return Err(crate::Error::Internal(internal!(
193 "No support for channel method"
194 )))
195 }
196 };
197
198 factory
199 .connect_via_transport(target, reporter, memquota)
200 .await
201 }
202}
203
204#[async_trait]
205impl<CF: IncomingChannelFactory> IncomingChannelFactory for CompoundFactory<CF> {
206 type Stream = CF::Stream;
207
208 #[cfg(feature = "relay")]
209 async fn accept_from_transport(
210 &self,
211 peer: std::net::SocketAddr,
212 stream: Self::Stream,
213 memquota: ChannelAccount,
214 ) -> crate::Result<Arc<Channel>> {
215 self.default_factory
216 .accept_from_transport(peer, stream, memquota)
217 .await
218 }
219}
220
221impl<CF: ChannelFactory + 'static> CompoundFactory<CF> {
222 pub(crate) fn new(
225 default_factory: Arc<CF>,
226 #[cfg(feature = "pt-client")] ptmgr: Option<Arc<dyn AbstractPtMgr + 'static>>,
227 ) -> Self {
228 Self {
229 default_factory,
230 #[cfg(feature = "pt-client")]
231 ptmgr,
232 }
233 }
234
235 #[cfg(feature = "pt-client")]
236 pub(crate) fn replace_ptmgr(&mut self, ptmgr: Arc<dyn AbstractPtMgr + 'static>) {
238 self.ptmgr = Some(ptmgr);
239 }
240}