1use std::{
8 future::Future,
9 marker::PhantomData,
10 pin::Pin,
11 sync::{Arc, Mutex},
12 task::{Context, Poll},
13};
14
15use futures::{FutureExt, SinkExt, Stream, StreamExt};
16use tokio::{
17 sync::{mpsc, OwnedSemaphorePermit, Semaphore},
18 time::{error::Elapsed, timeout},
19};
20use tower::{Service, ServiceExt};
21use tracing::{info_span, Instrument, Span};
22
23use cuprate_pruning::{PruningError, PruningSeed};
24use cuprate_wire::{
25 admin::{
26 HandshakeRequest, HandshakeResponse, PingResponse, SupportFlagsResponse,
27 PING_OK_RESPONSE_STATUS_TEXT,
28 },
29 common::PeerSupportFlags,
30 AdminRequestMessage, AdminResponseMessage, BasicNodeData, BucketError, LevinCommand, Message,
31};
32
33use crate::{
34 client::{
35 connection::Connection, request_handler::PeerRequestHandler,
36 timeout_monitor::connection_timeout_monitor_task, Client, InternalPeerID, PeerInformation,
37 },
38 constants::{
39 CLIENT_QUEUE_SIZE, HANDSHAKE_TIMEOUT, MAX_EAGER_PROTOCOL_MESSAGES,
40 MAX_PEERS_IN_PEER_LIST_MESSAGE, PING_TIMEOUT,
41 },
42 handles::HandleBuilder,
43 AddressBook, AddressBookRequest, AddressBookResponse, BroadcastMessage, ConnectionDirection,
44 CoreSyncDataRequest, CoreSyncDataResponse, CoreSyncSvc, NetZoneAddress, NetworkZone,
45 ProtocolRequestHandlerMaker, SharedError, Transport,
46};
47
48pub mod builder;
49pub use builder::HandshakerBuilder;
50
51#[derive(Debug, thiserror::Error)]
52pub enum HandshakeError {
53 #[error("The handshake timed out")]
54 TimedOut(#[from] Elapsed),
55 #[error("Peer has the same node ID as us")]
56 PeerHasSameNodeID,
57 #[error("Peer is on a different network")]
58 IncorrectNetwork,
59 #[error("Peer sent a peer list with peers from different zones")]
60 PeerSentIncorrectPeerList(#[from] crate::services::PeerListConversionError),
61 #[error("Peer sent invalid message: {0}")]
62 PeerSentInvalidMessage(&'static str),
63 #[error("The peers pruning seed is invalid.")]
64 InvalidPruningSeed(#[from] PruningError),
65 #[error("Levin bucket error: {0}")]
66 LevinBucketError(#[from] BucketError),
67 #[error("Internal service error: {0}")]
68 InternalSvcErr(#[from] tower::BoxError),
69 #[error("I/O error: {0}")]
70 IO(#[from] std::io::Error),
71}
72
73pub struct DoHandshakeRequest<Z: NetworkZone, T: Transport<Z>> {
75 pub addr: InternalPeerID<Z::Addr>,
77 pub peer_stream: T::Stream,
79 pub peer_sink: T::Sink,
81 pub direction: ConnectionDirection,
83 pub permit: Option<OwnedSemaphorePermit>,
85}
86
87#[derive(Debug, Clone)]
89pub struct HandShaker<Z: NetworkZone, T: Transport<Z>, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
90{
91 address_book: AdrBook,
93 core_sync_svc: CSync,
95 protocol_request_svc_maker: ProtoHdlrMkr,
97
98 our_basic_node_data: BasicNodeData,
100
101 broadcast_stream_maker: BrdcstStrmMkr,
103
104 connection_parent_span: Span,
105
106 transport_client_config: T::ClientConfig,
108
109 _zone: PhantomData<Z>,
111}
112
113impl<Z: NetworkZone, T: Transport<Z>, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
114 HandShaker<Z, T, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
115{
116 const fn new(
118 address_book: AdrBook,
119 core_sync_svc: CSync,
120 protocol_request_svc_maker: ProtoHdlrMkr,
121 broadcast_stream_maker: BrdcstStrmMkr,
122 our_basic_node_data: BasicNodeData,
123 connection_parent_span: Span,
124 transport_client_config: T::ClientConfig,
125 ) -> Self {
126 Self {
127 address_book,
128 core_sync_svc,
129 protocol_request_svc_maker,
130 broadcast_stream_maker,
131 our_basic_node_data,
132 connection_parent_span,
133 transport_client_config,
134 _zone: PhantomData,
135 }
136 }
137
138 #[inline]
140 pub const fn transport_config(&self) -> &T::ClientConfig {
141 &self.transport_client_config
142 }
143}
144
145impl<Z: NetworkZone, T: Transport<Z>, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr, BrdcstStrm>
146 Service<DoHandshakeRequest<Z, T>>
147 for HandShaker<Z, T, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
148where
149 AdrBook: AddressBook<Z> + Clone,
150 CSync: CoreSyncSvc + Clone,
151 ProtoHdlrMkr: ProtocolRequestHandlerMaker<Z> + Clone,
152 BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
153 BrdcstStrmMkr: Fn(InternalPeerID<Z::Addr>) -> BrdcstStrm + Clone + Send + 'static,
154{
155 type Response = Client<Z>;
156 type Error = HandshakeError;
157 type Future =
158 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
159
160 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161 Poll::Ready(Ok(()))
162 }
163
164 fn call(&mut self, req: DoHandshakeRequest<Z, T>) -> Self::Future {
165 let broadcast_stream_maker = self.broadcast_stream_maker.clone();
166
167 let address_book = self.address_book.clone();
168 let protocol_request_svc_maker = self.protocol_request_svc_maker.clone();
169 let core_sync_svc = self.core_sync_svc.clone();
170 let our_basic_node_data = self.our_basic_node_data.clone();
171
172 let connection_parent_span = self.connection_parent_span.clone();
173
174 let transport_client_config = self.transport_client_config.clone();
175
176 let span = info_span!(parent: &Span::current(), "handshaker", addr=%req.addr);
177
178 async move {
179 timeout(
180 HANDSHAKE_TIMEOUT,
181 handshake(
182 req,
183 transport_client_config,
184 broadcast_stream_maker,
185 address_book,
186 core_sync_svc,
187 protocol_request_svc_maker,
188 our_basic_node_data,
189 connection_parent_span,
190 ),
191 )
192 .await?
193 }
194 .instrument(span)
195 .boxed()
196 }
197}
198
199pub async fn ping<N, T>(addr: N::Addr, config: &T::ClientConfig) -> Result<u64, HandshakeError>
203where
204 N: NetworkZone,
205 T: Transport<N>,
206{
207 tracing::debug!("Sending Ping to peer");
208
209 let (mut peer_stream, mut peer_sink) = T::connect_to_peer(addr, config).await?;
210
211 tracing::debug!("Made outbound connection to peer, sending ping.");
212
213 peer_sink
214 .send(Message::Request(AdminRequestMessage::Ping).into())
215 .await?;
216
217 if let Some(res) = peer_stream.next().await {
218 if let Message::Response(AdminResponseMessage::Ping(ping)) = res? {
219 if ping.status == PING_OK_RESPONSE_STATUS_TEXT {
220 tracing::debug!("Ping successful.");
221 return Ok(ping.peer_id);
222 }
223
224 tracing::debug!("Peer's ping response was not `OK`.");
225 return Err(HandshakeError::PeerSentInvalidMessage(
226 "Ping response was not `OK`",
227 ));
228 }
229
230 tracing::debug!("Peer sent invalid response to ping.");
231 return Err(HandshakeError::PeerSentInvalidMessage(
232 "Peer did not send correct response for ping.",
233 ));
234 }
235
236 tracing::debug!("Connection closed before ping response.");
237 Err(BucketError::IO(std::io::Error::new(
238 std::io::ErrorKind::ConnectionAborted,
239 "The peer stream returned None",
240 ))
241 .into())
242}
243
244#[expect(clippy::too_many_arguments)]
246async fn handshake<
247 Z: NetworkZone,
248 T: Transport<Z>,
249 AdrBook,
250 CSync,
251 ProtoHdlrMkr,
252 BrdcstStrmMkr,
253 BrdcstStrm,
254>(
255 req: DoHandshakeRequest<Z, T>,
256 transport_client_config: T::ClientConfig,
257
258 broadcast_stream_maker: BrdcstStrmMkr,
259
260 mut address_book: AdrBook,
261 mut core_sync_svc: CSync,
262 mut protocol_request_svc_maker: ProtoHdlrMkr,
263 our_basic_node_data: BasicNodeData,
264 connection_parent_span: Span,
265) -> Result<Client<Z>, HandshakeError>
266where
267 AdrBook: AddressBook<Z> + Clone,
268 CSync: CoreSyncSvc + Clone,
269 ProtoHdlrMkr: ProtocolRequestHandlerMaker<Z>,
270 BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
271 BrdcstStrmMkr: Fn(InternalPeerID<Z::Addr>) -> BrdcstStrm + Send + 'static,
272{
273 let DoHandshakeRequest {
274 addr,
275 mut peer_stream,
276 mut peer_sink,
277 direction,
278 permit,
279 } = req;
280
281 let mut eager_protocol_messages = Vec::new();
284
285 let (peer_core_sync, peer_node_data) = match direction {
286 ConnectionDirection::Inbound => {
287 tracing::debug!("waiting for handshake request.");
289
290 let Message::Request(AdminRequestMessage::Handshake(handshake_req)) =
291 wait_for_message::<Z, T>(
292 LevinCommand::Handshake,
293 true,
294 &mut peer_sink,
295 &mut peer_stream,
296 &mut eager_protocol_messages,
297 &our_basic_node_data,
298 )
299 .await?
300 else {
301 panic!("wait_for_message returned ok with wrong message.");
302 };
303
304 tracing::debug!("Received handshake request.");
305 (handshake_req.payload_data, handshake_req.node_data)
307 }
308 ConnectionDirection::Outbound => {
309 send_hs_request::<Z, T, _>(
311 &mut peer_sink,
312 &mut core_sync_svc,
313 our_basic_node_data.clone(),
314 )
315 .await?;
316
317 let Message::Response(AdminResponseMessage::Handshake(handshake_res)) =
319 wait_for_message::<Z, T>(
320 LevinCommand::Handshake,
321 false,
322 &mut peer_sink,
323 &mut peer_stream,
324 &mut eager_protocol_messages,
325 &our_basic_node_data,
326 )
327 .await?
328 else {
329 panic!("wait_for_message returned ok with wrong message.");
330 };
331
332 if handshake_res.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE {
333 tracing::debug!("peer sent too many peers in response, cancelling handshake");
334
335 return Err(HandshakeError::PeerSentInvalidMessage(
336 "Too many peers in peer list message (>250)",
337 ));
338 }
339
340 tracing::debug!(
341 "Telling address book about new peers, len: {}",
342 handshake_res.local_peerlist_new.len()
343 );
344
345 address_book
347 .ready()
348 .await?
349 .call(AddressBookRequest::IncomingPeerList(
350 addr,
351 handshake_res
352 .local_peerlist_new
353 .into_iter()
354 .map(TryInto::try_into)
355 .collect::<Result<_, _>>()?,
356 ))
357 .await?;
358
359 (handshake_res.payload_data, handshake_res.node_data)
360 }
361 };
362
363 if peer_node_data.network_id != our_basic_node_data.network_id {
364 return Err(HandshakeError::IncorrectNetwork);
365 }
366
367 if Z::CHECK_NODE_ID && peer_node_data.peer_id == our_basic_node_data.peer_id {
368 return Err(HandshakeError::PeerHasSameNodeID);
369 }
370
371 let pruning_seed = PruningSeed::decompress_p2p_rules(peer_core_sync.pruning_seed)?;
405
406 let public_address = 'check_out_addr: {
408 match direction {
409 ConnectionDirection::Inbound => {
410 send_hs_response::<Z, T, _, _>(
412 &mut peer_sink,
413 &mut core_sync_svc,
414 &mut address_book,
415 our_basic_node_data.clone(),
416 )
417 .await?;
418
419 if peer_node_data.my_port != 0 {
421 let InternalPeerID::KnownAddr(mut outbound_address) = addr else {
422 break 'check_out_addr None;
424 };
425
426 #[expect(
427 clippy::cast_possible_truncation,
428 reason = "u32 does not make sense as a port so just truncate it."
429 )]
430 outbound_address.set_port(peer_node_data.my_port as u16);
431
432 let Ok(Ok(ping_peer_id)) = timeout(
433 PING_TIMEOUT,
434 ping::<Z, T>(outbound_address, &transport_client_config)
435 .instrument(info_span!("ping")),
436 )
437 .await
438 else {
439 break 'check_out_addr None;
441 };
442
443 if ping_peer_id == peer_node_data.peer_id {
445 break 'check_out_addr Some(outbound_address);
446 }
447 }
448 None
450 }
451 ConnectionDirection::Outbound => {
452 let InternalPeerID::KnownAddr(outbound_addr) = addr else {
453 unreachable!("How could we make an outbound connection to an unknown address");
454 };
455
456 Some(outbound_addr)
458 }
459 }
460 };
461
462 tracing::debug!("Handshake complete.");
463
464 let (connection_guard, handle) = HandleBuilder::new().with_permit(permit).build();
465
466 address_book
468 .ready()
469 .await?
470 .call(AddressBookRequest::NewConnection {
471 internal_peer_id: addr,
472 public_address,
473 handle: handle.clone(),
474 id: peer_node_data.peer_id,
475 pruning_seed,
476 rpc_port: peer_node_data.rpc_port,
477 rpc_credits_per_hash: peer_node_data.rpc_credits_per_hash,
478 })
479 .await?;
480
481 let error_slot = SharedError::new();
483 let (connection_tx, client_rx) = mpsc::channel(CLIENT_QUEUE_SIZE);
484
485 let info = PeerInformation {
486 id: addr,
487 handle,
488 direction,
489 pruning_seed,
490 basic_node_data: peer_node_data,
491 core_sync_data: Arc::new(Mutex::new(peer_core_sync)),
492 };
493
494 let protocol_request_handler = protocol_request_svc_maker
495 .ready()
496 .await?
497 .call(info.clone())
498 .await?;
499
500 let request_handler = PeerRequestHandler {
501 address_book_svc: address_book.clone(),
502 our_sync_svc: core_sync_svc.clone(),
503 protocol_request_handler,
504 our_basic_node_data,
505 peer_info: info.clone(),
506 };
507
508 let connection = Connection::<Z, T, _, _, _, _>::new(
509 peer_sink,
510 client_rx,
511 broadcast_stream_maker(addr),
512 request_handler,
513 connection_guard,
514 error_slot.clone(),
515 );
516
517 let connection_span =
518 tracing::error_span!(parent: &connection_parent_span, "connection", %addr);
519 let connection_handle = tokio::spawn(
520 connection
521 .run(peer_stream.fuse(), eager_protocol_messages)
522 .instrument(connection_span),
523 );
524
525 let semaphore = Arc::new(Semaphore::new(1));
526
527 let timeout_handle = tokio::spawn(connection_timeout_monitor_task(
528 info.clone(),
529 connection_tx.clone(),
530 Arc::clone(&semaphore),
531 address_book,
532 core_sync_svc,
533 ));
534
535 let client = Client::<Z>::new(
536 info,
537 connection_tx,
538 connection_handle,
539 timeout_handle,
540 semaphore,
541 error_slot,
542 );
543
544 Ok(client)
545}
546
547async fn send_hs_request<Z, T, CSync>(
549 peer_sink: &mut T::Sink,
550 core_sync_svc: &mut CSync,
551 our_basic_node_data: BasicNodeData,
552) -> Result<(), HandshakeError>
553where
554 Z: NetworkZone,
555 T: Transport<Z>,
556 CSync: CoreSyncSvc,
557{
558 let CoreSyncDataResponse(our_core_sync_data) = core_sync_svc
559 .ready()
560 .await?
561 .call(CoreSyncDataRequest)
562 .await?;
563
564 let req = HandshakeRequest {
565 node_data: our_basic_node_data,
566 payload_data: our_core_sync_data,
567 };
568
569 tracing::debug!("Sending handshake request.");
570
571 peer_sink
572 .send(Message::Request(AdminRequestMessage::Handshake(req)).into())
573 .await?;
574
575 Ok(())
576}
577
578async fn send_hs_response<Z, T, CSync, AdrBook>(
580 peer_sink: &mut T::Sink,
581 core_sync_svc: &mut CSync,
582 address_book: &mut AdrBook,
583 our_basic_node_data: BasicNodeData,
584) -> Result<(), HandshakeError>
585where
586 Z: NetworkZone,
587 T: Transport<Z>,
588 AdrBook: AddressBook<Z>,
589 CSync: CoreSyncSvc,
590{
591 let CoreSyncDataResponse(our_core_sync_data) = core_sync_svc
592 .ready()
593 .await?
594 .call(CoreSyncDataRequest)
595 .await?;
596
597 let AddressBookResponse::Peers(our_peer_list) = address_book
598 .ready()
599 .await?
600 .call(AddressBookRequest::GetWhitePeers(
601 MAX_PEERS_IN_PEER_LIST_MESSAGE,
602 ))
603 .await?
604 else {
605 panic!("Address book sent incorrect response");
606 };
607
608 let res = HandshakeResponse {
609 node_data: our_basic_node_data,
610 payload_data: our_core_sync_data,
611 local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(),
612 };
613
614 tracing::debug!("Sending handshake response.");
615
616 peer_sink
617 .send(Message::Response(AdminResponseMessage::Handshake(res)).into())
618 .await?;
619
620 Ok(())
621}
622
623async fn wait_for_message<Z, T>(
629 levin_command: LevinCommand,
630 request: bool,
631
632 peer_sink: &mut T::Sink,
633 peer_stream: &mut T::Stream,
634
635 eager_protocol_messages: &mut Vec<cuprate_wire::ProtocolMessage>,
636
637 our_basic_node_data: &BasicNodeData,
638) -> Result<Message, HandshakeError>
639where
640 Z: NetworkZone,
641 T: Transport<Z>,
642{
643 let mut allow_support_flag_req = true;
644 let mut allow_ping = true;
645
646 while let Some(message) = peer_stream.next().await {
647 let message = message?;
648
649 match message {
650 Message::Protocol(protocol_message) => {
651 tracing::debug!(
652 "Received eager protocol message with ID: {}, adding to queue",
653 protocol_message.command()
654 );
655 eager_protocol_messages.push(protocol_message);
656 if eager_protocol_messages.len() > MAX_EAGER_PROTOCOL_MESSAGES {
657 tracing::debug!(
658 "Peer sent too many protocol messages before a handshake response."
659 );
660 return Err(HandshakeError::PeerSentInvalidMessage(
661 "Peer sent too many protocol messages",
662 ));
663 }
664 continue;
665 }
666 Message::Request(req_message) => {
667 if req_message.command() == levin_command && request {
668 return Ok(Message::Request(req_message));
669 }
670
671 match req_message {
672 AdminRequestMessage::SupportFlags => {
673 if !allow_support_flag_req {
674 return Err(HandshakeError::PeerSentInvalidMessage(
675 "Peer sent 2 support flag requests",
676 ));
677 }
678 send_support_flags::<Z, T>(peer_sink, our_basic_node_data.support_flags)
679 .await?;
680 allow_support_flag_req = false;
682 continue;
683 }
684 AdminRequestMessage::Ping => {
685 if !allow_ping {
686 return Err(HandshakeError::PeerSentInvalidMessage(
687 "Peer sent 2 ping requests",
688 ));
689 }
690
691 send_ping_response::<Z, T>(peer_sink, our_basic_node_data.peer_id).await?;
692
693 allow_ping = false;
695 continue;
696 }
697 AdminRequestMessage::Handshake(_) | AdminRequestMessage::TimedSync(_) => {
698 return Err(HandshakeError::PeerSentInvalidMessage(
699 "Peer sent an admin request before responding to the handshake",
700 ));
701 }
702 }
703 }
704 Message::Response(res_message) if !request => {
705 if res_message.command() == levin_command {
706 return Ok(Message::Response(res_message));
707 }
708
709 tracing::debug!("Received unexpected response: {}", res_message.command());
710 return Err(HandshakeError::PeerSentInvalidMessage(
711 "Peer sent an incorrect response",
712 ));
713 }
714
715 Message::Response(_) => Err(HandshakeError::PeerSentInvalidMessage(
716 "Peer sent an incorrect message",
717 )),
718 }?;
719 }
720
721 Err(BucketError::IO(std::io::Error::new(
722 std::io::ErrorKind::ConnectionAborted,
723 "The peer stream returned None",
724 ))
725 .into())
726}
727
728async fn send_support_flags<Z, T>(
730 peer_sink: &mut T::Sink,
731 support_flags: PeerSupportFlags,
732) -> Result<(), HandshakeError>
733where
734 Z: NetworkZone,
735 T: Transport<Z>,
736{
737 tracing::debug!("Sending support flag response.");
738 Ok(peer_sink
739 .send(
740 Message::Response(AdminResponseMessage::SupportFlags(SupportFlagsResponse {
741 support_flags,
742 }))
743 .into(),
744 )
745 .await?)
746}
747
748async fn send_ping_response<Z, T>(
750 peer_sink: &mut T::Sink,
751 peer_id: u64,
752) -> Result<(), HandshakeError>
753where
754 Z: NetworkZone,
755 T: Transport<Z>,
756{
757 tracing::debug!("Sending ping response.");
758 Ok(peer_sink
759 .send(
760 Message::Response(AdminResponseMessage::Ping(PingResponse {
761 status: PING_OK_RESPONSE_STATUS_TEXT,
762 peer_id,
763 }))
764 .into(),
765 )
766 .await?)
767}