cuprate_p2p_core/client/
connection.rs

1//! The Connection Task
2//!
3//! This module handles routing requests from a [`Client`](crate::client::Client) or a broadcast channel to
4//! a peer. This module also handles routing requests from the connected peer to a request handler.
5use std::pin::Pin;
6
7use futures::{
8    channel::oneshot,
9    stream::{Fuse, FusedStream},
10    SinkExt, Stream, StreamExt,
11};
12use tokio::{
13    sync::{mpsc, OwnedSemaphorePermit},
14    time::{sleep, timeout, Sleep},
15};
16use tokio_stream::wrappers::ReceiverStream;
17
18use cuprate_wire::{LevinCommand, Message, ProtocolMessage};
19
20use crate::{
21    client::request_handler::PeerRequestHandler,
22    constants::{REQUEST_HANDLER_TIMEOUT, REQUEST_TIMEOUT, SENDING_TIMEOUT},
23    handles::ConnectionGuard,
24    AddressBook, BroadcastMessage, CoreSyncSvc, MessageID, NetworkZone, PeerError, PeerRequest,
25    PeerResponse, ProtocolRequestHandler, ProtocolResponse, SharedError,
26};
27
28/// A request to the connection task from a [`Client`](crate::client::Client).
29pub(crate) struct ConnectionTaskRequest {
30    /// The request.
31    pub request: PeerRequest,
32    /// The response channel.
33    pub response_channel: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
34    /// A permit for this request
35    pub permit: Option<OwnedSemaphorePermit>,
36}
37
38/// The connection state.
39pub(crate) enum State {
40    /// Waiting for a request from Cuprate or the connected peer.
41    WaitingForRequest,
42    /// Waiting for a response from the peer.
43    WaitingForResponse {
44        /// The requests ID.
45        request_id: MessageID,
46        /// The channel to send the response down.
47        tx: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
48        /// A permit for this request.
49        _req_permit: OwnedSemaphorePermit,
50    },
51}
52
53/// Returns if the [`LevinCommand`] is the correct response message for our request.
54///
55/// e.g. that we didn't get a block for a txs request.
56const fn levin_command_response(message_id: MessageID, command: LevinCommand) -> bool {
57    matches!(
58        (message_id, command),
59        (MessageID::Handshake, LevinCommand::Handshake)
60            | (MessageID::TimedSync, LevinCommand::TimedSync)
61            | (MessageID::Ping, LevinCommand::Ping)
62            | (MessageID::SupportFlags, LevinCommand::SupportFlags)
63            | (MessageID::GetObjects, LevinCommand::GetObjectsResponse)
64            | (MessageID::GetChain, LevinCommand::ChainResponse)
65            | (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock)
66            | (
67                MessageID::GetTxPoolCompliment,
68                LevinCommand::NewTransactions
69            )
70    )
71}
72
73/// This represents a connection to a peer.
74pub(crate) struct Connection<Z: NetworkZone, A, CS, PR, BrdcstStrm> {
75    /// The peer sink - where we send messages to the peer.
76    peer_sink: Z::Sink,
77
78    /// The connections current state.
79    state: State,
80    /// Will be [`Some`] if we are expecting a response from the peer.
81    request_timeout: Option<Pin<Box<Sleep>>>,
82
83    /// The client channel where requests from Cuprate to this peer will come from for us to route.
84    client_rx: Fuse<ReceiverStream<ConnectionTaskRequest>>,
85    /// A stream of messages to broadcast from Cuprate.
86    broadcast_stream: Pin<Box<BrdcstStrm>>,
87
88    /// The inner handler for any requests that come from the requested peer.
89    peer_request_handler: PeerRequestHandler<Z, A, CS, PR>,
90
91    /// The connection guard which will send signals to other parts of Cuprate when this connection is dropped.
92    connection_guard: ConnectionGuard,
93    /// An error slot which is shared with the client.
94    error: SharedError<PeerError>,
95}
96
97impl<Z, A, CS, PR, BrdcstStrm> Connection<Z, A, CS, PR, BrdcstStrm>
98where
99    Z: NetworkZone,
100    A: AddressBook<Z>,
101    CS: CoreSyncSvc,
102    PR: ProtocolRequestHandler,
103    BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
104{
105    /// Create a new connection struct.
106    pub(crate) fn new(
107        peer_sink: Z::Sink,
108        client_rx: mpsc::Receiver<ConnectionTaskRequest>,
109        broadcast_stream: BrdcstStrm,
110        peer_request_handler: PeerRequestHandler<Z, A, CS, PR>,
111        connection_guard: ConnectionGuard,
112        error: SharedError<PeerError>,
113    ) -> Self {
114        Self {
115            peer_sink,
116            state: State::WaitingForRequest,
117            request_timeout: None,
118            client_rx: ReceiverStream::new(client_rx).fuse(),
119            broadcast_stream: Box::pin(broadcast_stream),
120            peer_request_handler,
121            connection_guard,
122            error,
123        }
124    }
125
126    /// Sends a message to the peer, this function implements a timeout, so we don't get stuck sending a message to the
127    /// peer.
128    async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> {
129        tracing::debug!("Sending message: [{}] to peer", mes.command());
130
131        timeout(SENDING_TIMEOUT, self.peer_sink.send(mes.into()))
132            .await
133            .map_err(|_| PeerError::TimedOut)
134            .and_then(|res| res.map_err(PeerError::BucketError))
135    }
136
137    /// Handles a broadcast request from Cuprate.
138    async fn handle_client_broadcast(&mut self, mes: BroadcastMessage) -> Result<(), PeerError> {
139        match mes {
140            BroadcastMessage::NewFluffyBlock(block) => {
141                self.send_message_to_peer(Message::Protocol(ProtocolMessage::NewFluffyBlock(block)))
142                    .await
143            }
144            BroadcastMessage::NewTransactions(txs) => {
145                self.send_message_to_peer(Message::Protocol(ProtocolMessage::NewTransactions(txs)))
146                    .await
147            }
148        }
149    }
150
151    /// Handles a request from Cuprate, unlike a broadcast this request will be directed specifically at this peer.
152    async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> {
153        tracing::debug!("handling client request, id: {:?}", req.request.id());
154
155        if req.request.needs_response() {
156            assert!(
157                !matches!(self.state, State::WaitingForResponse { .. }),
158                "cannot handle more than 1 request at the same time"
159            );
160
161            self.state = State::WaitingForResponse {
162                request_id: req.request.id(),
163                tx: req.response_channel,
164                _req_permit: req
165                    .permit
166                    .expect("Client request should have a permit if a response is needed"),
167            };
168
169            self.send_message_to_peer(req.request.into()).await?;
170            // Set the timeout after sending the message, TODO: Is this a good idea.
171            self.request_timeout = Some(Box::pin(sleep(REQUEST_TIMEOUT)));
172            return Ok(());
173        }
174
175        // INVARIANT: From now this function cannot exit early without sending a response back down the
176        // response channel.
177        let res = self.send_message_to_peer(req.request.into()).await;
178
179        // send the response now, the request does not need a response from the peer.
180        if let Err(e) = res {
181            // can't clone the error so turn it to a string first, hacky but oh well.
182            let err_str = e.to_string();
183            drop(req.response_channel.send(Err(err_str.into())));
184            return Err(e);
185        }
186
187        // We still need to respond even if the response is this.
188        let resp = Ok(PeerResponse::Protocol(ProtocolResponse::NA));
189        drop(req.response_channel.send(resp));
190
191        Ok(())
192    }
193
194    /// Handles a request from the connected peer to this node.
195    async fn handle_peer_request(&mut self, req: PeerRequest) -> Result<(), PeerError> {
196        tracing::debug!("Received peer request: {:?}", req.id());
197
198        let res = timeout(
199            REQUEST_HANDLER_TIMEOUT,
200            self.peer_request_handler.handle_peer_request(req),
201        )
202        .await
203        .map_err(|_| {
204            tracing::warn!("Timed-out handling peer request, closing connection.");
205            PeerError::TimedOut
206        })??;
207
208        // This will be an error if a response does not need to be sent
209        if let Ok(res) = res.try_into() {
210            self.send_message_to_peer(res).await?;
211        }
212
213        Ok(())
214    }
215
216    /// Handles a message from a peer when we are in [`State::WaitingForResponse`].
217    async fn handle_potential_response(&mut self, mes: Message) -> Result<(), PeerError> {
218        tracing::debug!("Received peer message, command: {:?}", mes.command());
219
220        // If the message is defiantly a request then there is no way it can be a response to
221        // our request.
222        if mes.is_request() {
223            return self.handle_peer_request(mes.try_into().unwrap()).await;
224        }
225
226        let State::WaitingForResponse { request_id, .. } = &self.state else {
227            panic!("Not in correct state, can't receive response!")
228        };
229
230        // Check if the message is a response to our request.
231        if levin_command_response(*request_id, mes.command()) {
232            // TODO: Do more checks before returning response.
233
234            let State::WaitingForResponse { tx, .. } =
235                std::mem::replace(&mut self.state, State::WaitingForRequest)
236            else {
237                panic!("Not in correct state, can't receive response!")
238            };
239
240            let resp = Ok(mes
241                .try_into()
242                .map_err(|_| PeerError::PeerSentInvalidMessage)?);
243
244            drop(tx.send(resp));
245
246            self.request_timeout = None;
247
248            Ok(())
249        } else {
250            self.handle_peer_request(
251                mes.try_into()
252                    .map_err(|_| PeerError::PeerSentInvalidMessage)?,
253            )
254            .await
255        }
256    }
257
258    /// The main-loop for when we are in [`State::WaitingForRequest`].
259    async fn state_waiting_for_request<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError>
260    where
261        Str: FusedStream<Item = Result<Message, cuprate_wire::BucketError>> + Unpin,
262    {
263        tracing::debug!("waiting for peer/client request.");
264
265        tokio::select! {
266            biased;
267            () = self.connection_guard.should_shutdown() => {
268                tracing::debug!("connection guard has shutdown, shutting down connection.");
269                Err(PeerError::ConnectionClosed)
270            }
271            broadcast_req = self.broadcast_stream.next() => {
272                if let Some(broadcast_req) = broadcast_req {
273                    self.handle_client_broadcast(broadcast_req).await
274                } else {
275                    Err(PeerError::ClientChannelClosed)
276                }
277            }
278            client_req = self.client_rx.next() => {
279                if let Some(client_req) = client_req {
280                    self.handle_client_request(client_req).await
281                } else {
282                    Err(PeerError::ClientChannelClosed)
283                }
284            },
285            peer_message = stream.next() => {
286                if let Some(peer_message) = peer_message {
287                    self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await
288                }else {
289                    Err(PeerError::ClientChannelClosed)
290                }
291            },
292        }
293    }
294
295    /// The main-loop for when we are in [`State::WaitingForResponse`].
296    async fn state_waiting_for_response<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError>
297    where
298        Str: FusedStream<Item = Result<Message, cuprate_wire::BucketError>> + Unpin,
299    {
300        tracing::debug!("waiting for peer response.");
301
302        tokio::select! {
303            biased;
304            () = self.connection_guard.should_shutdown() => {
305                tracing::debug!("connection guard has shutdown, shutting down connection.");
306                Err(PeerError::ConnectionClosed)
307            }
308            () = self.request_timeout.as_mut().expect("Request timeout was not set!") => {
309                Err(PeerError::ClientChannelClosed)
310            }
311            broadcast_req = self.broadcast_stream.next() => {
312                if let Some(broadcast_req) = broadcast_req {
313                    self.handle_client_broadcast(broadcast_req).await
314                } else {
315                    Err(PeerError::ClientChannelClosed)
316                }
317            }
318            client_req = self.client_rx.next() => {
319                // Although we can only handle 1 request from the client at a time, this channel is also used
320                // for specific broadcasts to this peer so we need to handle those here as well.
321                if let Some(client_req) = client_req {
322                    self.handle_client_request(client_req).await
323                } else {
324                    Err(PeerError::ClientChannelClosed)
325                }
326            },
327            peer_message = stream.next() => {
328                if let Some(peer_message) = peer_message {
329                    self.handle_potential_response(peer_message?).await
330                } else {
331                    Err(PeerError::ClientChannelClosed)
332                }
333            },
334        }
335    }
336
337    /// Runs the Connection handler logic, this should be put in a separate task.
338    ///
339    /// `eager_protocol_messages` are protocol messages that we received during a handshake.
340    pub(crate) async fn run<Str>(
341        mut self,
342        mut stream: Str,
343        eager_protocol_messages: Vec<ProtocolMessage>,
344    ) where
345        Str: FusedStream<Item = Result<Message, cuprate_wire::BucketError>> + Unpin,
346    {
347        tracing::debug!(
348            "Handling eager messages len: {}",
349            eager_protocol_messages.len()
350        );
351        for message in eager_protocol_messages {
352            let message = Message::Protocol(message).try_into();
353
354            let res = match message {
355                Ok(mes) => self.handle_peer_request(mes).await,
356                Err(_) => Err(PeerError::PeerSentInvalidMessage),
357            };
358
359            if let Err(err) = res {
360                return self.shutdown(err);
361            }
362        }
363
364        loop {
365            let res = match self.state {
366                State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await,
367                State::WaitingForResponse { .. } => {
368                    self.state_waiting_for_response(&mut stream).await
369                }
370            };
371
372            if let Err(err) = res {
373                return self.shutdown(err);
374            }
375        }
376    }
377
378    /// Shutdowns the connection, flushing pending requests and setting the error slot, if it hasn't been
379    /// set already.
380    #[expect(clippy::significant_drop_tightening)]
381    fn shutdown(mut self, err: PeerError) {
382        tracing::debug!("Connection task shutting down: {}", err);
383
384        let mut client_rx = self.client_rx.into_inner().into_inner();
385        client_rx.close();
386
387        let err_str = err.to_string();
388        if let Err(err) = self.error.try_insert_err(err) {
389            tracing::debug!("Shared error already contains an error: {}", err);
390        }
391
392        if let State::WaitingForResponse { tx, .. } =
393            std::mem::replace(&mut self.state, State::WaitingForRequest)
394        {
395            drop(tx.send(Err(err_str.clone().into())));
396        }
397
398        while let Ok(req) = client_rx.try_recv() {
399            drop(req.response_channel.send(Err(err_str.clone().into())));
400        }
401
402        self.connection_guard.connection_closed();
403    }
404}