1use std::pin::Pin;
6
7use futures::{
8 channel::oneshot,
9 stream::{Fuse, FusedStream},
10 SinkExt, Stream, StreamExt,
11};
12use tokio::{
13 sync::{mpsc, OwnedSemaphorePermit},
14 time::{sleep, timeout, Sleep},
15};
16use tokio_stream::wrappers::ReceiverStream;
17
18use cuprate_wire::{LevinCommand, Message, ProtocolMessage};
19
20use crate::{
21 client::request_handler::PeerRequestHandler,
22 constants::{REQUEST_HANDLER_TIMEOUT, REQUEST_TIMEOUT, SENDING_TIMEOUT},
23 handles::ConnectionGuard,
24 AddressBook, BroadcastMessage, CoreSyncSvc, MessageID, NetworkZone, PeerError, PeerRequest,
25 PeerResponse, ProtocolRequestHandler, ProtocolResponse, SharedError,
26};
27
28pub(crate) struct ConnectionTaskRequest {
30 pub request: PeerRequest,
32 pub response_channel: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
34 pub permit: Option<OwnedSemaphorePermit>,
36}
37
38pub(crate) enum State {
40 WaitingForRequest,
42 WaitingForResponse {
44 request_id: MessageID,
46 tx: oneshot::Sender<Result<PeerResponse, tower::BoxError>>,
48 _req_permit: OwnedSemaphorePermit,
50 },
51}
52
53const fn levin_command_response(message_id: MessageID, command: LevinCommand) -> bool {
57 matches!(
58 (message_id, command),
59 (MessageID::Handshake, LevinCommand::Handshake)
60 | (MessageID::TimedSync, LevinCommand::TimedSync)
61 | (MessageID::Ping, LevinCommand::Ping)
62 | (MessageID::SupportFlags, LevinCommand::SupportFlags)
63 | (MessageID::GetObjects, LevinCommand::GetObjectsResponse)
64 | (MessageID::GetChain, LevinCommand::ChainResponse)
65 | (MessageID::FluffyMissingTxs, LevinCommand::NewFluffyBlock)
66 | (
67 MessageID::GetTxPoolCompliment,
68 LevinCommand::NewTransactions
69 )
70 )
71}
72
73pub(crate) struct Connection<Z: NetworkZone, A, CS, PR, BrdcstStrm> {
75 peer_sink: Z::Sink,
77
78 state: State,
80 request_timeout: Option<Pin<Box<Sleep>>>,
82
83 client_rx: Fuse<ReceiverStream<ConnectionTaskRequest>>,
85 broadcast_stream: Pin<Box<BrdcstStrm>>,
87
88 peer_request_handler: PeerRequestHandler<Z, A, CS, PR>,
90
91 connection_guard: ConnectionGuard,
93 error: SharedError<PeerError>,
95}
96
97impl<Z, A, CS, PR, BrdcstStrm> Connection<Z, A, CS, PR, BrdcstStrm>
98where
99 Z: NetworkZone,
100 A: AddressBook<Z>,
101 CS: CoreSyncSvc,
102 PR: ProtocolRequestHandler,
103 BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
104{
105 pub(crate) fn new(
107 peer_sink: Z::Sink,
108 client_rx: mpsc::Receiver<ConnectionTaskRequest>,
109 broadcast_stream: BrdcstStrm,
110 peer_request_handler: PeerRequestHandler<Z, A, CS, PR>,
111 connection_guard: ConnectionGuard,
112 error: SharedError<PeerError>,
113 ) -> Self {
114 Self {
115 peer_sink,
116 state: State::WaitingForRequest,
117 request_timeout: None,
118 client_rx: ReceiverStream::new(client_rx).fuse(),
119 broadcast_stream: Box::pin(broadcast_stream),
120 peer_request_handler,
121 connection_guard,
122 error,
123 }
124 }
125
126 async fn send_message_to_peer(&mut self, mes: Message) -> Result<(), PeerError> {
129 tracing::debug!("Sending message: [{}] to peer", mes.command());
130
131 timeout(SENDING_TIMEOUT, self.peer_sink.send(mes.into()))
132 .await
133 .map_err(|_| PeerError::TimedOut)
134 .and_then(|res| res.map_err(PeerError::BucketError))
135 }
136
137 async fn handle_client_broadcast(&mut self, mes: BroadcastMessage) -> Result<(), PeerError> {
139 match mes {
140 BroadcastMessage::NewFluffyBlock(block) => {
141 self.send_message_to_peer(Message::Protocol(ProtocolMessage::NewFluffyBlock(block)))
142 .await
143 }
144 BroadcastMessage::NewTransactions(txs) => {
145 self.send_message_to_peer(Message::Protocol(ProtocolMessage::NewTransactions(txs)))
146 .await
147 }
148 }
149 }
150
151 async fn handle_client_request(&mut self, req: ConnectionTaskRequest) -> Result<(), PeerError> {
153 tracing::debug!("handling client request, id: {:?}", req.request.id());
154
155 if req.request.needs_response() {
156 assert!(
157 !matches!(self.state, State::WaitingForResponse { .. }),
158 "cannot handle more than 1 request at the same time"
159 );
160
161 self.state = State::WaitingForResponse {
162 request_id: req.request.id(),
163 tx: req.response_channel,
164 _req_permit: req
165 .permit
166 .expect("Client request should have a permit if a response is needed"),
167 };
168
169 self.send_message_to_peer(req.request.into()).await?;
170 self.request_timeout = Some(Box::pin(sleep(REQUEST_TIMEOUT)));
172 return Ok(());
173 }
174
175 let res = self.send_message_to_peer(req.request.into()).await;
178
179 if let Err(e) = res {
181 let err_str = e.to_string();
183 drop(req.response_channel.send(Err(err_str.into())));
184 return Err(e);
185 }
186
187 let resp = Ok(PeerResponse::Protocol(ProtocolResponse::NA));
189 drop(req.response_channel.send(resp));
190
191 Ok(())
192 }
193
194 async fn handle_peer_request(&mut self, req: PeerRequest) -> Result<(), PeerError> {
196 tracing::debug!("Received peer request: {:?}", req.id());
197
198 let res = timeout(
199 REQUEST_HANDLER_TIMEOUT,
200 self.peer_request_handler.handle_peer_request(req),
201 )
202 .await
203 .map_err(|_| {
204 tracing::warn!("Timed-out handling peer request, closing connection.");
205 PeerError::TimedOut
206 })??;
207
208 if let Ok(res) = res.try_into() {
210 self.send_message_to_peer(res).await?;
211 }
212
213 Ok(())
214 }
215
216 async fn handle_potential_response(&mut self, mes: Message) -> Result<(), PeerError> {
218 tracing::debug!("Received peer message, command: {:?}", mes.command());
219
220 if mes.is_request() {
223 return self.handle_peer_request(mes.try_into().unwrap()).await;
224 }
225
226 let State::WaitingForResponse { request_id, .. } = &self.state else {
227 panic!("Not in correct state, can't receive response!")
228 };
229
230 if levin_command_response(*request_id, mes.command()) {
232 let State::WaitingForResponse { tx, .. } =
235 std::mem::replace(&mut self.state, State::WaitingForRequest)
236 else {
237 panic!("Not in correct state, can't receive response!")
238 };
239
240 let resp = Ok(mes
241 .try_into()
242 .map_err(|_| PeerError::PeerSentInvalidMessage)?);
243
244 drop(tx.send(resp));
245
246 self.request_timeout = None;
247
248 Ok(())
249 } else {
250 self.handle_peer_request(
251 mes.try_into()
252 .map_err(|_| PeerError::PeerSentInvalidMessage)?,
253 )
254 .await
255 }
256 }
257
258 async fn state_waiting_for_request<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError>
260 where
261 Str: FusedStream<Item = Result<Message, cuprate_wire::BucketError>> + Unpin,
262 {
263 tracing::debug!("waiting for peer/client request.");
264
265 tokio::select! {
266 biased;
267 () = self.connection_guard.should_shutdown() => {
268 tracing::debug!("connection guard has shutdown, shutting down connection.");
269 Err(PeerError::ConnectionClosed)
270 }
271 broadcast_req = self.broadcast_stream.next() => {
272 if let Some(broadcast_req) = broadcast_req {
273 self.handle_client_broadcast(broadcast_req).await
274 } else {
275 Err(PeerError::ClientChannelClosed)
276 }
277 }
278 client_req = self.client_rx.next() => {
279 if let Some(client_req) = client_req {
280 self.handle_client_request(client_req).await
281 } else {
282 Err(PeerError::ClientChannelClosed)
283 }
284 },
285 peer_message = stream.next() => {
286 if let Some(peer_message) = peer_message {
287 self.handle_peer_request(peer_message?.try_into().map_err(|_| PeerError::PeerSentInvalidMessage)?).await
288 }else {
289 Err(PeerError::ClientChannelClosed)
290 }
291 },
292 }
293 }
294
295 async fn state_waiting_for_response<Str>(&mut self, stream: &mut Str) -> Result<(), PeerError>
297 where
298 Str: FusedStream<Item = Result<Message, cuprate_wire::BucketError>> + Unpin,
299 {
300 tracing::debug!("waiting for peer response.");
301
302 tokio::select! {
303 biased;
304 () = self.connection_guard.should_shutdown() => {
305 tracing::debug!("connection guard has shutdown, shutting down connection.");
306 Err(PeerError::ConnectionClosed)
307 }
308 () = self.request_timeout.as_mut().expect("Request timeout was not set!") => {
309 Err(PeerError::ClientChannelClosed)
310 }
311 broadcast_req = self.broadcast_stream.next() => {
312 if let Some(broadcast_req) = broadcast_req {
313 self.handle_client_broadcast(broadcast_req).await
314 } else {
315 Err(PeerError::ClientChannelClosed)
316 }
317 }
318 client_req = self.client_rx.next() => {
319 if let Some(client_req) = client_req {
322 self.handle_client_request(client_req).await
323 } else {
324 Err(PeerError::ClientChannelClosed)
325 }
326 },
327 peer_message = stream.next() => {
328 if let Some(peer_message) = peer_message {
329 self.handle_potential_response(peer_message?).await
330 } else {
331 Err(PeerError::ClientChannelClosed)
332 }
333 },
334 }
335 }
336
337 pub(crate) async fn run<Str>(
341 mut self,
342 mut stream: Str,
343 eager_protocol_messages: Vec<ProtocolMessage>,
344 ) where
345 Str: FusedStream<Item = Result<Message, cuprate_wire::BucketError>> + Unpin,
346 {
347 tracing::debug!(
348 "Handling eager messages len: {}",
349 eager_protocol_messages.len()
350 );
351 for message in eager_protocol_messages {
352 let message = Message::Protocol(message).try_into();
353
354 let res = match message {
355 Ok(mes) => self.handle_peer_request(mes).await,
356 Err(_) => Err(PeerError::PeerSentInvalidMessage),
357 };
358
359 if let Err(err) = res {
360 return self.shutdown(err);
361 }
362 }
363
364 loop {
365 let res = match self.state {
366 State::WaitingForRequest => self.state_waiting_for_request(&mut stream).await,
367 State::WaitingForResponse { .. } => {
368 self.state_waiting_for_response(&mut stream).await
369 }
370 };
371
372 if let Err(err) = res {
373 return self.shutdown(err);
374 }
375 }
376 }
377
378 #[expect(clippy::significant_drop_tightening)]
381 fn shutdown(mut self, err: PeerError) {
382 tracing::debug!("Connection task shutting down: {}", err);
383
384 let mut client_rx = self.client_rx.into_inner().into_inner();
385 client_rx.close();
386
387 let err_str = err.to_string();
388 if let Err(err) = self.error.try_insert_err(err) {
389 tracing::debug!("Shared error already contains an error: {}", err);
390 }
391
392 if let State::WaitingForResponse { tx, .. } =
393 std::mem::replace(&mut self.state, State::WaitingForRequest)
394 {
395 drop(tx.send(Err(err_str.clone().into())));
396 }
397
398 while let Ok(req) = client_rx.try_recv() {
399 drop(req.response_channel.send(Err(err_str.clone().into())));
400 }
401
402 self.connection_guard.connection_closed();
403 }
404}