1use std::{
2 fmt::{Debug, Display, Formatter},
3 sync::{Arc, Mutex},
4 task::{ready, Context, Poll},
5};
6
7use futures::channel::oneshot;
8use tokio::{
9 sync::{mpsc, OwnedSemaphorePermit, Semaphore},
10 task::JoinHandle,
11};
12use tokio_util::sync::{PollSemaphore, PollSender};
13use tower::{Service, ServiceExt};
14use tracing::Instrument;
15
16use cuprate_helper::asynch::InfallibleOneshotReceiver;
17use cuprate_pruning::PruningSeed;
18use cuprate_wire::{BasicNodeData, CoreSyncData};
19
20use crate::{
21 handles::{ConnectionGuard, ConnectionHandle},
22 ConnectionDirection, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
23};
24
25mod connection;
26mod connector;
27pub mod handshaker;
28mod request_handler;
29mod timeout_monitor;
30mod weak;
31
32pub use connector::{ConnectRequest, Connector};
33pub use handshaker::{DoHandshakeRequest, HandshakeError, HandshakerBuilder};
34pub use weak::{WeakBroadcastClient, WeakClient};
35
36#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
39pub enum InternalPeerID<A> {
40 KnownAddr(A),
42 Unknown([u8; 16]),
44}
45
46impl<A: Display> Display for InternalPeerID<A> {
47 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Self::KnownAddr(addr) => addr.fmt(f),
50 Self::Unknown(id) => f.write_str(&format!("Unknown, ID: {}", hex::encode(id))),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct PeerInformation<A> {
58 pub id: InternalPeerID<A>,
60 pub handle: ConnectionHandle,
63 pub direction: ConnectionDirection,
65 pub pruning_seed: PruningSeed,
67 pub basic_node_data: BasicNodeData,
69 pub core_sync_data: Arc<Mutex<CoreSyncData>>,
78}
79
80pub struct Client<Z: NetworkZone> {
86 pub info: PeerInformation<Z::Addr>,
88
89 connection_tx: PollSender<connection::ConnectionTaskRequest>,
91 connection_handle: JoinHandle<()>,
93 timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
95
96 semaphore: PollSemaphore,
98 permit: Option<OwnedSemaphorePermit>,
100
101 error: SharedError<PeerError>,
103}
104
105impl<Z: NetworkZone> Drop for Client<Z> {
106 fn drop(&mut self) {
107 self.info.handle.send_close_signal();
108 }
109}
110
111impl<Z: NetworkZone> Client<Z> {
112 pub(crate) fn new(
114 info: PeerInformation<Z::Addr>,
115 connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
116 connection_handle: JoinHandle<()>,
117 timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
118 semaphore: Arc<Semaphore>,
119 error: SharedError<PeerError>,
120 ) -> Self {
121 Self {
122 info,
123 connection_tx: PollSender::new(connection_tx),
124 timeout_handle,
125 semaphore: PollSemaphore::new(semaphore),
126 permit: None,
127 connection_handle,
128 error,
129 }
130 }
131
132 fn set_err(&self, err: PeerError) -> tower::BoxError {
134 let err_str = err.to_string();
135 match self.error.try_insert_err(err) {
136 Ok(()) => err_str,
137 Err(e) => e.to_string(),
138 }
139 .into()
140 }
141
142 pub fn downgrade(&self) -> WeakClient<Z> {
144 WeakClient {
145 info: self.info.clone(),
146 connection_tx: self.connection_tx.clone(),
147 semaphore: self.semaphore.clone(),
148 permit: None,
149 error: self.error.clone(),
150 }
151 }
152}
153
154impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
155 type Response = PeerResponse;
156 type Error = tower::BoxError;
157 type Future = InfallibleOneshotReceiver<Result<Self::Response, Self::Error>>;
158
159 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160 if let Some(err) = self.error.try_get_err() {
161 return Poll::Ready(Err(err.to_string().into()));
162 }
163
164 if self.connection_handle.is_finished() || self.timeout_handle.is_finished() {
165 let err = self.set_err(PeerError::ClientChannelClosed);
166 return Poll::Ready(Err(err));
167 }
168
169 if self.permit.is_none() {
170 let permit = ready!(self.semaphore.poll_acquire(cx))
171 .expect("Client semaphore should not be closed!");
172
173 self.permit = Some(permit);
174 }
175
176 if ready!(self.connection_tx.poll_reserve(cx)).is_err() {
177 let err = self.set_err(PeerError::ClientChannelClosed);
178 return Poll::Ready(Err(err));
179 }
180
181 Poll::Ready(Ok(()))
182 }
183
184 fn call(&mut self, request: PeerRequest) -> Self::Future {
185 let permit = self
186 .permit
187 .take()
188 .expect("poll_ready did not return ready before call to call");
189
190 let (tx, rx) = oneshot::channel();
191 let req = connection::ConnectionTaskRequest {
192 response_channel: tx,
193 request,
194 permit: Some(permit),
195 };
196
197 if let Err(req) = self.connection_tx.send_item(req) {
198 self.set_err(PeerError::ClientChannelClosed);
201
202 let resp = Err(PeerError::ClientChannelClosed.into());
203 drop(req.into_inner().unwrap().response_channel.send(resp));
204 }
205
206 rx.into()
207 }
208}
209
210pub fn mock_client<Z: NetworkZone, S>(
214 info: PeerInformation<Z::Addr>,
215 connection_guard: ConnectionGuard,
216 mut request_handler: S,
217) -> Client<Z>
218where
219 S: Service<PeerRequest, Response = PeerResponse, Error = tower::BoxError> + Send + 'static,
220 S::Future: Send + 'static,
221{
222 let (tx, mut rx) = mpsc::channel(1);
223
224 let task_span = tracing::error_span!("mock_connection", addr = %info.id);
225
226 let task_handle = tokio::spawn(
227 async move {
228 let _guard = connection_guard;
229 loop {
230 let Some(req): Option<connection::ConnectionTaskRequest> = rx.recv().await else {
231 tracing::debug!("Channel closed, closing mock connection");
232 return;
233 };
234
235 tracing::debug!("Received new request: {:?}", req.request.id());
236 let res = request_handler
237 .ready()
238 .await
239 .unwrap()
240 .call(req.request)
241 .await
242 .unwrap();
243
244 tracing::debug!("Sending back response");
245
246 drop(req.response_channel.send(Ok(res)));
247 }
248 }
249 .instrument(task_span),
250 );
251
252 let timeout_task = tokio::spawn(futures::future::pending());
253 let semaphore = Arc::new(Semaphore::new(1));
254 let error_slot = SharedError::new();
255
256 Client::new(info, tx, task_handle, timeout_task, semaphore, error_slot)
257}