cuprate_p2p_core/
client.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    sync::{Arc, Mutex},
4    task::{ready, Context, Poll},
5};
6
7use futures::channel::oneshot;
8use tokio::{
9    sync::{mpsc, OwnedSemaphorePermit, Semaphore},
10    task::JoinHandle,
11};
12use tokio_util::sync::{PollSemaphore, PollSender};
13use tower::{Service, ServiceExt};
14use tracing::Instrument;
15
16use cuprate_helper::asynch::InfallibleOneshotReceiver;
17use cuprate_pruning::PruningSeed;
18use cuprate_wire::{BasicNodeData, CoreSyncData};
19
20use crate::{
21    handles::{ConnectionGuard, ConnectionHandle},
22    ConnectionDirection, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
23};
24
25mod connection;
26mod connector;
27pub mod handshaker;
28mod request_handler;
29mod timeout_monitor;
30mod weak;
31
32pub use connector::{ConnectRequest, Connector};
33pub use handshaker::{DoHandshakeRequest, HandshakeError, HandshakerBuilder};
34pub use weak::{WeakBroadcastClient, WeakClient};
35
36/// An internal identifier for a given peer, will be their address if known
37/// or a random u128 if not.
38#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
39pub enum InternalPeerID<A> {
40    /// A known address.
41    KnownAddr(A),
42    /// An unknown address (probably an inbound anonymity network connection).
43    Unknown([u8; 16]),
44}
45
46impl<A: Display> Display for InternalPeerID<A> {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::KnownAddr(addr) => addr.fmt(f),
50            Self::Unknown(id) => f.write_str(&format!("Unknown, ID: {}", hex::encode(id))),
51        }
52    }
53}
54
55/// Information on a connected peer.
56#[derive(Debug, Clone)]
57pub struct PeerInformation<A> {
58    /// The internal peer ID of this peer.
59    pub id: InternalPeerID<A>,
60    /// The [`ConnectionHandle`] for this peer, allows banning this peer and checking if it is still
61    /// alive.
62    pub handle: ConnectionHandle,
63    /// The direction of this connection (inbound|outbound).
64    pub direction: ConnectionDirection,
65    /// The peer's [`PruningSeed`].
66    pub pruning_seed: PruningSeed,
67    /// The peer's [`BasicNodeData`].
68    pub basic_node_data: BasicNodeData,
69    /// The [`CoreSyncData`] of this peer.
70    ///
71    /// Data across fields are not necessarily related, so [`CoreSyncData::top_id`] is not always the
72    /// block hash for the block at height one below [`CoreSyncData::current_height`].
73    ///
74    /// This value is behind a [`Mutex`] and is updated whenever the peer sends new information related
75    /// to their sync state. It is publicly accessible to anyone who has a peers [`Client`] handle. You
76    /// probably should not mutate this value unless you are creating a custom [`ProtocolRequestHandler`](crate::ProtocolRequestHandler).
77    pub core_sync_data: Arc<Mutex<CoreSyncData>>,
78}
79
80/// This represents a connection to a peer.
81///
82/// It allows sending requests to the peer, but does only does minimal checks that the data returned
83/// is the data asked for, i.e. for a certain request the only thing checked will be that the response
84/// is the correct response for that request, not that the response contains the correct data.
85pub struct Client<Z: NetworkZone> {
86    /// Information on the connected peer.
87    pub info: PeerInformation<Z::Addr>,
88
89    /// The channel to the [`Connection`](connection::Connection) task.
90    connection_tx: PollSender<connection::ConnectionTaskRequest>,
91    /// The [`JoinHandle`] of the spawned connection task.
92    connection_handle: JoinHandle<()>,
93    /// The [`JoinHandle`] of the spawned timeout monitor task.
94    timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
95
96    /// The semaphore that limits the requests sent to the peer.
97    semaphore: PollSemaphore,
98    /// A permit for the semaphore, will be [`Some`] after `poll_ready` returns ready.
99    permit: Option<OwnedSemaphorePermit>,
100
101    /// The error slot shared between the [`Client`] and [`Connection`](connection::Connection).
102    error: SharedError<PeerError>,
103}
104
105impl<Z: NetworkZone> Drop for Client<Z> {
106    fn drop(&mut self) {
107        self.info.handle.send_close_signal();
108    }
109}
110
111impl<Z: NetworkZone> Client<Z> {
112    /// Creates a new [`Client`].
113    pub(crate) fn new(
114        info: PeerInformation<Z::Addr>,
115        connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
116        connection_handle: JoinHandle<()>,
117        timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
118        semaphore: Arc<Semaphore>,
119        error: SharedError<PeerError>,
120    ) -> Self {
121        Self {
122            info,
123            connection_tx: PollSender::new(connection_tx),
124            timeout_handle,
125            semaphore: PollSemaphore::new(semaphore),
126            permit: None,
127            connection_handle,
128            error,
129        }
130    }
131
132    /// Internal function to set an error on the [`SharedError`].
133    fn set_err(&self, err: PeerError) -> tower::BoxError {
134        let err_str = err.to_string();
135        match self.error.try_insert_err(err) {
136            Ok(()) => err_str,
137            Err(e) => e.to_string(),
138        }
139        .into()
140    }
141
142    /// Create a [`WeakClient`] for this [`Client`].
143    pub fn downgrade(&self) -> WeakClient<Z> {
144        WeakClient {
145            info: self.info.clone(),
146            connection_tx: self.connection_tx.clone(),
147            semaphore: self.semaphore.clone(),
148            permit: None,
149            error: self.error.clone(),
150        }
151    }
152}
153
154impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
155    type Response = PeerResponse;
156    type Error = tower::BoxError;
157    type Future = InfallibleOneshotReceiver<Result<Self::Response, Self::Error>>;
158
159    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160        if let Some(err) = self.error.try_get_err() {
161            return Poll::Ready(Err(err.to_string().into()));
162        }
163
164        if self.connection_handle.is_finished() || self.timeout_handle.is_finished() {
165            let err = self.set_err(PeerError::ClientChannelClosed);
166            return Poll::Ready(Err(err));
167        }
168
169        if self.permit.is_none() {
170            let permit = ready!(self.semaphore.poll_acquire(cx))
171                .expect("Client semaphore should not be closed!");
172
173            self.permit = Some(permit);
174        }
175
176        if ready!(self.connection_tx.poll_reserve(cx)).is_err() {
177            let err = self.set_err(PeerError::ClientChannelClosed);
178            return Poll::Ready(Err(err));
179        }
180
181        Poll::Ready(Ok(()))
182    }
183
184    fn call(&mut self, request: PeerRequest) -> Self::Future {
185        let permit = self
186            .permit
187            .take()
188            .expect("poll_ready did not return ready before call to call");
189
190        let (tx, rx) = oneshot::channel();
191        let req = connection::ConnectionTaskRequest {
192            response_channel: tx,
193            request,
194            permit: Some(permit),
195        };
196
197        if let Err(req) = self.connection_tx.send_item(req) {
198            // The connection task could have closed between a call to `poll_ready` and the call to
199            // `call`, which means if we don't handle the error here the receiver would panic.
200            self.set_err(PeerError::ClientChannelClosed);
201
202            let resp = Err(PeerError::ClientChannelClosed.into());
203            drop(req.into_inner().unwrap().response_channel.send(resp));
204        }
205
206        rx.into()
207    }
208}
209
210/// Creates a mock [`Client`] for testing purposes.
211///
212/// `request_handler` will be used to handle requests sent to the [`Client`]
213pub fn mock_client<Z: NetworkZone, S>(
214    info: PeerInformation<Z::Addr>,
215    connection_guard: ConnectionGuard,
216    mut request_handler: S,
217) -> Client<Z>
218where
219    S: Service<PeerRequest, Response = PeerResponse, Error = tower::BoxError> + Send + 'static,
220    S::Future: Send + 'static,
221{
222    let (tx, mut rx) = mpsc::channel(1);
223
224    let task_span = tracing::error_span!("mock_connection", addr = %info.id);
225
226    let task_handle = tokio::spawn(
227        async move {
228            let _guard = connection_guard;
229            loop {
230                let Some(req): Option<connection::ConnectionTaskRequest> = rx.recv().await else {
231                    tracing::debug!("Channel closed, closing mock connection");
232                    return;
233                };
234
235                tracing::debug!("Received new request: {:?}", req.request.id());
236                let res = request_handler
237                    .ready()
238                    .await
239                    .unwrap()
240                    .call(req.request)
241                    .await
242                    .unwrap();
243
244                tracing::debug!("Sending back response");
245
246                drop(req.response_channel.send(Ok(res)));
247            }
248        }
249        .instrument(task_span),
250    );
251
252    let timeout_task = tokio::spawn(futures::future::pending());
253    let semaphore = Arc::new(Semaphore::new(1));
254    let error_slot = SharedError::new();
255
256    Client::new(info, tx, task_handle, timeout_task, semaphore, error_slot)
257}