1use std::{
2 fmt::{Debug, Display, Formatter},
3 sync::{Arc, Mutex},
4 task::{ready, Context, Poll},
5};
6
7use futures::channel::oneshot;
8use tokio::{
9 sync::{mpsc, OwnedSemaphorePermit, Semaphore},
10 task::JoinHandle,
11};
12use tokio_util::sync::{PollSemaphore, PollSender};
13use tower::{Service, ServiceExt};
14use tracing::Instrument;
15
16use cuprate_helper::asynch::InfallibleOneshotReceiver;
17use cuprate_pruning::PruningSeed;
18use cuprate_wire::CoreSyncData;
19
20use crate::{
21 handles::{ConnectionGuard, ConnectionHandle},
22 ConnectionDirection, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
23};
24
25mod connection;
26mod connector;
27pub mod handshaker;
28mod request_handler;
29mod timeout_monitor;
30mod weak;
31
32pub use connector::{ConnectRequest, Connector};
33pub use handshaker::{DoHandshakeRequest, HandshakeError, HandshakerBuilder};
34pub use weak::{WeakBroadcastClient, WeakClient};
35
36#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
39pub enum InternalPeerID<A> {
40 KnownAddr(A),
42 Unknown(u128),
44}
45
46impl<A: Display> Display for InternalPeerID<A> {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Self::KnownAddr(addr) => addr.fmt(f),
50 Self::Unknown(id) => f.write_str(&format!("Unknown, ID: {id}")),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct PeerInformation<A> {
58 pub id: InternalPeerID<A>,
60 pub handle: ConnectionHandle,
63 pub direction: ConnectionDirection,
65 pub pruning_seed: PruningSeed,
67 pub core_sync_data: Arc<Mutex<CoreSyncData>>,
76}
77
78pub struct Client<Z: NetworkZone> {
84 pub info: PeerInformation<Z::Addr>,
86
87 connection_tx: PollSender<connection::ConnectionTaskRequest>,
89 connection_handle: JoinHandle<()>,
91 timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
93
94 semaphore: PollSemaphore,
96 permit: Option<OwnedSemaphorePermit>,
98
99 error: SharedError<PeerError>,
101}
102
103impl<Z: NetworkZone> Drop for Client<Z> {
104 fn drop(&mut self) {
105 self.info.handle.send_close_signal();
106 }
107}
108
109impl<Z: NetworkZone> Client<Z> {
110 pub(crate) fn new(
112 info: PeerInformation<Z::Addr>,
113 connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
114 connection_handle: JoinHandle<()>,
115 timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
116 semaphore: Arc<Semaphore>,
117 error: SharedError<PeerError>,
118 ) -> Self {
119 Self {
120 info,
121 connection_tx: PollSender::new(connection_tx),
122 timeout_handle,
123 semaphore: PollSemaphore::new(semaphore),
124 permit: None,
125 connection_handle,
126 error,
127 }
128 }
129
130 fn set_err(&self, err: PeerError) -> tower::BoxError {
132 let err_str = err.to_string();
133 match self.error.try_insert_err(err) {
134 Ok(()) => err_str,
135 Err(e) => e.to_string(),
136 }
137 .into()
138 }
139
140 pub fn downgrade(&self) -> WeakClient<Z> {
142 WeakClient {
143 info: self.info.clone(),
144 connection_tx: self.connection_tx.clone(),
145 semaphore: self.semaphore.clone(),
146 permit: None,
147 error: self.error.clone(),
148 }
149 }
150}
151
152impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
153 type Response = PeerResponse;
154 type Error = tower::BoxError;
155 type Future = InfallibleOneshotReceiver<Result<Self::Response, Self::Error>>;
156
157 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158 if let Some(err) = self.error.try_get_err() {
159 return Poll::Ready(Err(err.to_string().into()));
160 }
161
162 if self.connection_handle.is_finished() || self.timeout_handle.is_finished() {
163 let err = self.set_err(PeerError::ClientChannelClosed);
164 return Poll::Ready(Err(err));
165 }
166
167 if self.permit.is_none() {
168 let permit = ready!(self.semaphore.poll_acquire(cx))
169 .expect("Client semaphore should not be closed!");
170
171 self.permit = Some(permit);
172 }
173
174 if ready!(self.connection_tx.poll_reserve(cx)).is_err() {
175 let err = self.set_err(PeerError::ClientChannelClosed);
176 return Poll::Ready(Err(err));
177 }
178
179 Poll::Ready(Ok(()))
180 }
181
182 fn call(&mut self, request: PeerRequest) -> Self::Future {
183 let permit = self
184 .permit
185 .take()
186 .expect("poll_ready did not return ready before call to call");
187
188 let (tx, rx) = oneshot::channel();
189 let req = connection::ConnectionTaskRequest {
190 response_channel: tx,
191 request,
192 permit: Some(permit),
193 };
194
195 if let Err(req) = self.connection_tx.send_item(req) {
196 self.set_err(PeerError::ClientChannelClosed);
199
200 let resp = Err(PeerError::ClientChannelClosed.into());
201 drop(req.into_inner().unwrap().response_channel.send(resp));
202 }
203
204 rx.into()
205 }
206}
207
208pub fn mock_client<Z: NetworkZone, S>(
212 info: PeerInformation<Z::Addr>,
213 connection_guard: ConnectionGuard,
214 mut request_handler: S,
215) -> Client<Z>
216where
217 S: Service<PeerRequest, Response = PeerResponse, Error = tower::BoxError> + Send + 'static,
218 S::Future: Send + 'static,
219{
220 let (tx, mut rx) = mpsc::channel(1);
221
222 let task_span = tracing::error_span!("mock_connection", addr = %info.id);
223
224 let task_handle = tokio::spawn(
225 async move {
226 let _guard = connection_guard;
227 loop {
228 let Some(req): Option<connection::ConnectionTaskRequest> = rx.recv().await else {
229 tracing::debug!("Channel closed, closing mock connection");
230 return;
231 };
232
233 tracing::debug!("Received new request: {:?}", req.request.id());
234 let res = request_handler
235 .ready()
236 .await
237 .unwrap()
238 .call(req.request)
239 .await
240 .unwrap();
241
242 tracing::debug!("Sending back response");
243
244 drop(req.response_channel.send(Ok(res)));
245 }
246 }
247 .instrument(task_span),
248 );
249
250 let timeout_task = tokio::spawn(futures::future::pending());
251 let semaphore = Arc::new(Semaphore::new(1));
252 let error_slot = SharedError::new();
253
254 Client::new(info, tx, task_handle, timeout_task, semaphore, error_slot)
255}