cuprate_p2p_core/
client.rs

1use std::{
2    fmt::{Debug, Display, Formatter},
3    sync::{Arc, Mutex},
4    task::{ready, Context, Poll},
5};
6
7use futures::channel::oneshot;
8use tokio::{
9    sync::{mpsc, OwnedSemaphorePermit, Semaphore},
10    task::JoinHandle,
11};
12use tokio_util::sync::{PollSemaphore, PollSender};
13use tower::{Service, ServiceExt};
14use tracing::Instrument;
15
16use cuprate_helper::asynch::InfallibleOneshotReceiver;
17use cuprate_pruning::PruningSeed;
18use cuprate_wire::CoreSyncData;
19
20use crate::{
21    handles::{ConnectionGuard, ConnectionHandle},
22    ConnectionDirection, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
23};
24
25mod connection;
26mod connector;
27pub mod handshaker;
28mod request_handler;
29mod timeout_monitor;
30mod weak;
31
32pub use connector::{ConnectRequest, Connector};
33pub use handshaker::{DoHandshakeRequest, HandshakeError, HandshakerBuilder};
34pub use weak::{WeakBroadcastClient, WeakClient};
35
36/// An internal identifier for a given peer, will be their address if known
37/// or a random u128 if not.
38#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
39pub enum InternalPeerID<A> {
40    /// A known address.
41    KnownAddr(A),
42    /// An unknown address (probably an inbound anonymity network connection).
43    Unknown(u128),
44}
45
46impl<A: Display> Display for InternalPeerID<A> {
47    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::KnownAddr(addr) => addr.fmt(f),
50            Self::Unknown(id) => f.write_str(&format!("Unknown, ID: {id}")),
51        }
52    }
53}
54
55/// Information on a connected peer.
56#[derive(Debug, Clone)]
57pub struct PeerInformation<A> {
58    /// The internal peer ID of this peer.
59    pub id: InternalPeerID<A>,
60    /// The [`ConnectionHandle`] for this peer, allows banning this peer and checking if it is still
61    /// alive.
62    pub handle: ConnectionHandle,
63    /// The direction of this connection (inbound|outbound).
64    pub direction: ConnectionDirection,
65    /// The peer's [`PruningSeed`].
66    pub pruning_seed: PruningSeed,
67    /// The [`CoreSyncData`] of this peer.
68    ///
69    /// Data across fields are not necessarily related, so [`CoreSyncData::top_id`] is not always the
70    /// block hash for the block at height one below [`CoreSyncData::current_height`].
71    ///
72    /// This value is behind a [`Mutex`] and is updated whenever the peer sends new information related
73    /// to their sync state. It is publicly accessible to anyone who has a peers [`Client`] handle. You
74    /// probably should not mutate this value unless you are creating a custom [`ProtocolRequestHandler`](crate::ProtocolRequestHandler).
75    pub core_sync_data: Arc<Mutex<CoreSyncData>>,
76}
77
78/// This represents a connection to a peer.
79///
80/// It allows sending requests to the peer, but does only does minimal checks that the data returned
81/// is the data asked for, i.e. for a certain request the only thing checked will be that the response
82/// is the correct response for that request, not that the response contains the correct data.
83pub struct Client<Z: NetworkZone> {
84    /// Information on the connected peer.
85    pub info: PeerInformation<Z::Addr>,
86
87    /// The channel to the [`Connection`](connection::Connection) task.
88    connection_tx: PollSender<connection::ConnectionTaskRequest>,
89    /// The [`JoinHandle`] of the spawned connection task.
90    connection_handle: JoinHandle<()>,
91    /// The [`JoinHandle`] of the spawned timeout monitor task.
92    timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
93
94    /// The semaphore that limits the requests sent to the peer.
95    semaphore: PollSemaphore,
96    /// A permit for the semaphore, will be [`Some`] after `poll_ready` returns ready.
97    permit: Option<OwnedSemaphorePermit>,
98
99    /// The error slot shared between the [`Client`] and [`Connection`](connection::Connection).
100    error: SharedError<PeerError>,
101}
102
103impl<Z: NetworkZone> Drop for Client<Z> {
104    fn drop(&mut self) {
105        self.info.handle.send_close_signal();
106    }
107}
108
109impl<Z: NetworkZone> Client<Z> {
110    /// Creates a new [`Client`].
111    pub(crate) fn new(
112        info: PeerInformation<Z::Addr>,
113        connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
114        connection_handle: JoinHandle<()>,
115        timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
116        semaphore: Arc<Semaphore>,
117        error: SharedError<PeerError>,
118    ) -> Self {
119        Self {
120            info,
121            connection_tx: PollSender::new(connection_tx),
122            timeout_handle,
123            semaphore: PollSemaphore::new(semaphore),
124            permit: None,
125            connection_handle,
126            error,
127        }
128    }
129
130    /// Internal function to set an error on the [`SharedError`].
131    fn set_err(&self, err: PeerError) -> tower::BoxError {
132        let err_str = err.to_string();
133        match self.error.try_insert_err(err) {
134            Ok(()) => err_str,
135            Err(e) => e.to_string(),
136        }
137        .into()
138    }
139
140    /// Create a [`WeakClient`] for this [`Client`].
141    pub fn downgrade(&self) -> WeakClient<Z> {
142        WeakClient {
143            info: self.info.clone(),
144            connection_tx: self.connection_tx.clone(),
145            semaphore: self.semaphore.clone(),
146            permit: None,
147            error: self.error.clone(),
148        }
149    }
150}
151
152impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
153    type Response = PeerResponse;
154    type Error = tower::BoxError;
155    type Future = InfallibleOneshotReceiver<Result<Self::Response, Self::Error>>;
156
157    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158        if let Some(err) = self.error.try_get_err() {
159            return Poll::Ready(Err(err.to_string().into()));
160        }
161
162        if self.connection_handle.is_finished() || self.timeout_handle.is_finished() {
163            let err = self.set_err(PeerError::ClientChannelClosed);
164            return Poll::Ready(Err(err));
165        }
166
167        if self.permit.is_none() {
168            let permit = ready!(self.semaphore.poll_acquire(cx))
169                .expect("Client semaphore should not be closed!");
170
171            self.permit = Some(permit);
172        }
173
174        if ready!(self.connection_tx.poll_reserve(cx)).is_err() {
175            let err = self.set_err(PeerError::ClientChannelClosed);
176            return Poll::Ready(Err(err));
177        }
178
179        Poll::Ready(Ok(()))
180    }
181
182    fn call(&mut self, request: PeerRequest) -> Self::Future {
183        let permit = self
184            .permit
185            .take()
186            .expect("poll_ready did not return ready before call to call");
187
188        let (tx, rx) = oneshot::channel();
189        let req = connection::ConnectionTaskRequest {
190            response_channel: tx,
191            request,
192            permit: Some(permit),
193        };
194
195        if let Err(req) = self.connection_tx.send_item(req) {
196            // The connection task could have closed between a call to `poll_ready` and the call to
197            // `call`, which means if we don't handle the error here the receiver would panic.
198            self.set_err(PeerError::ClientChannelClosed);
199
200            let resp = Err(PeerError::ClientChannelClosed.into());
201            drop(req.into_inner().unwrap().response_channel.send(resp));
202        }
203
204        rx.into()
205    }
206}
207
208/// Creates a mock [`Client`] for testing purposes.
209///
210/// `request_handler` will be used to handle requests sent to the [`Client`]
211pub fn mock_client<Z: NetworkZone, S>(
212    info: PeerInformation<Z::Addr>,
213    connection_guard: ConnectionGuard,
214    mut request_handler: S,
215) -> Client<Z>
216where
217    S: Service<PeerRequest, Response = PeerResponse, Error = tower::BoxError> + Send + 'static,
218    S::Future: Send + 'static,
219{
220    let (tx, mut rx) = mpsc::channel(1);
221
222    let task_span = tracing::error_span!("mock_connection", addr = %info.id);
223
224    let task_handle = tokio::spawn(
225        async move {
226            let _guard = connection_guard;
227            loop {
228                let Some(req): Option<connection::ConnectionTaskRequest> = rx.recv().await else {
229                    tracing::debug!("Channel closed, closing mock connection");
230                    return;
231                };
232
233                tracing::debug!("Received new request: {:?}", req.request.id());
234                let res = request_handler
235                    .ready()
236                    .await
237                    .unwrap()
238                    .call(req.request)
239                    .await
240                    .unwrap();
241
242                tracing::debug!("Sending back response");
243
244                drop(req.response_channel.send(Ok(res)));
245            }
246        }
247        .instrument(task_span),
248    );
249
250    let timeout_task = tokio::spawn(futures::future::pending());
251    let semaphore = Arc::new(Semaphore::new(1));
252    let error_slot = SharedError::new();
253
254    Client::new(info, tx, task_handle, timeout_task, semaphore, error_slot)
255}