1use std::{
8 future::Future,
9 marker::PhantomData,
10 pin::Pin,
11 sync::{Arc, Mutex},
12 task::{Context, Poll},
13};
14
15use futures::{FutureExt, SinkExt, Stream, StreamExt};
16use tokio::{
17 sync::{mpsc, OwnedSemaphorePermit, Semaphore},
18 time::{error::Elapsed, timeout},
19};
20use tower::{Service, ServiceExt};
21use tracing::{info_span, Instrument, Span};
22
23use cuprate_pruning::{PruningError, PruningSeed};
24use cuprate_wire::{
25 admin::{
26 HandshakeRequest, HandshakeResponse, PingResponse, SupportFlagsResponse,
27 PING_OK_RESPONSE_STATUS_TEXT,
28 },
29 common::PeerSupportFlags,
30 AdminRequestMessage, AdminResponseMessage, BasicNodeData, BucketError, LevinCommand, Message,
31};
32
33use crate::{
34 client::{
35 connection::Connection, request_handler::PeerRequestHandler,
36 timeout_monitor::connection_timeout_monitor_task, Client, InternalPeerID, PeerInformation,
37 },
38 constants::{
39 CLIENT_QUEUE_SIZE, HANDSHAKE_TIMEOUT, MAX_EAGER_PROTOCOL_MESSAGES,
40 MAX_PEERS_IN_PEER_LIST_MESSAGE, PING_TIMEOUT,
41 },
42 handles::HandleBuilder,
43 AddressBook, AddressBookRequest, AddressBookResponse, BroadcastMessage, ConnectionDirection,
44 CoreSyncDataRequest, CoreSyncDataResponse, CoreSyncSvc, NetZoneAddress, NetworkZone,
45 ProtocolRequestHandlerMaker, SharedError, Transport,
46};
47
48pub mod builder;
49pub use builder::HandshakerBuilder;
50
51#[derive(Debug, thiserror::Error)]
52pub enum HandshakeError {
53 #[error("The handshake timed out")]
54 TimedOut(#[from] Elapsed),
55 #[error("Peer has the same node ID as us")]
56 PeerHasSameNodeID,
57 #[error("Peer is on a different network")]
58 IncorrectNetwork,
59 #[error("Peer sent a peer list with peers from different zones")]
60 PeerSentIncorrectPeerList(#[from] crate::services::PeerListConversionError),
61 #[error("Peer sent invalid message: {0}")]
62 PeerSentInvalidMessage(&'static str),
63 #[error("The peers pruning seed is invalid.")]
64 InvalidPruningSeed(#[from] PruningError),
65 #[error("Levin bucket error: {0}")]
66 LevinBucketError(#[from] BucketError),
67 #[error("Internal service error: {0}")]
68 InternalSvcErr(#[from] tower::BoxError),
69 #[error("I/O error: {0}")]
70 IO(#[from] std::io::Error),
71}
72
73pub struct DoHandshakeRequest<Z: NetworkZone, T: Transport<Z>> {
75 pub addr: InternalPeerID<Z::Addr>,
77 pub peer_stream: T::Stream,
79 pub peer_sink: T::Sink,
81 pub direction: ConnectionDirection,
83 pub permit: Option<OwnedSemaphorePermit>,
85}
86
87#[derive(Debug, Clone)]
89pub struct HandShaker<Z: NetworkZone, T: Transport<Z>, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
90{
91 address_book: AdrBook,
93 core_sync_svc: CSync,
95 protocol_request_svc_maker: ProtoHdlrMkr,
97
98 our_basic_node_data: BasicNodeData,
100
101 broadcast_stream_maker: BrdcstStrmMkr,
103
104 connection_parent_span: Span,
105
106 transport_client_config: T::ClientConfig,
108
109 _zone: PhantomData<Z>,
111}
112
113impl<Z: NetworkZone, T: Transport<Z>, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
114 HandShaker<Z, T, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
115{
116 const fn new(
118 address_book: AdrBook,
119 core_sync_svc: CSync,
120 protocol_request_svc_maker: ProtoHdlrMkr,
121 broadcast_stream_maker: BrdcstStrmMkr,
122 our_basic_node_data: BasicNodeData,
123 connection_parent_span: Span,
124 transport_client_config: T::ClientConfig,
125 ) -> Self {
126 Self {
127 address_book,
128 core_sync_svc,
129 protocol_request_svc_maker,
130 broadcast_stream_maker,
131 our_basic_node_data,
132 connection_parent_span,
133 transport_client_config,
134 _zone: PhantomData,
135 }
136 }
137
138 #[inline]
140 pub const fn transport_config(&self) -> &T::ClientConfig {
141 &self.transport_client_config
142 }
143}
144
145impl<Z: NetworkZone, T: Transport<Z>, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr, BrdcstStrm>
146 Service<DoHandshakeRequest<Z, T>>
147 for HandShaker<Z, T, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
148where
149 AdrBook: AddressBook<Z> + Clone,
150 CSync: CoreSyncSvc + Clone,
151 ProtoHdlrMkr: ProtocolRequestHandlerMaker<Z> + Clone,
152 BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
153 BrdcstStrmMkr: Fn(InternalPeerID<Z::Addr>) -> BrdcstStrm + Clone + Send + 'static,
154{
155 type Response = Client<Z>;
156 type Error = HandshakeError;
157 type Future =
158 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
159
160 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161 Poll::Ready(Ok(()))
162 }
163
164 fn call(&mut self, req: DoHandshakeRequest<Z, T>) -> Self::Future {
165 let broadcast_stream_maker = self.broadcast_stream_maker.clone();
166
167 let address_book = self.address_book.clone();
168 let protocol_request_svc_maker = self.protocol_request_svc_maker.clone();
169 let core_sync_svc = self.core_sync_svc.clone();
170 let our_basic_node_data = self.our_basic_node_data.clone();
171
172 let connection_parent_span = self.connection_parent_span.clone();
173
174 let transport_client_config = self.transport_client_config.clone();
175
176 let span = info_span!(parent: &Span::current(), "handshaker", addr=%req.addr);
177
178 async move {
179 timeout(
180 HANDSHAKE_TIMEOUT,
181 handshake(
182 req,
183 transport_client_config,
184 broadcast_stream_maker,
185 address_book,
186 core_sync_svc,
187 protocol_request_svc_maker,
188 our_basic_node_data,
189 connection_parent_span,
190 ),
191 )
192 .await?
193 }
194 .instrument(span)
195 .boxed()
196 }
197}
198
199pub async fn ping<N, T>(addr: N::Addr, config: &T::ClientConfig) -> Result<u64, HandshakeError>
203where
204 N: NetworkZone,
205 T: Transport<N>,
206{
207 tracing::debug!("Sending Ping to peer");
208
209 let (mut peer_stream, mut peer_sink) = T::connect_to_peer(addr, config).await?;
210
211 tracing::debug!("Made outbound connection to peer, sending ping.");
212
213 peer_sink
214 .send(Message::Request(AdminRequestMessage::Ping).into())
215 .await?;
216
217 if let Some(res) = peer_stream.next().await {
218 if let Message::Response(AdminResponseMessage::Ping(ping)) = res? {
219 if ping.status == PING_OK_RESPONSE_STATUS_TEXT {
220 tracing::debug!("Ping successful.");
221 return Ok(ping.peer_id);
222 }
223
224 tracing::debug!("Peer's ping response was not `OK`.");
225 return Err(HandshakeError::PeerSentInvalidMessage(
226 "Ping response was not `OK`",
227 ));
228 }
229
230 tracing::debug!("Peer sent invalid response to ping.");
231 return Err(HandshakeError::PeerSentInvalidMessage(
232 "Peer did not send correct response for ping.",
233 ));
234 }
235
236 tracing::debug!("Connection closed before ping response.");
237 Err(BucketError::IO(std::io::Error::new(
238 std::io::ErrorKind::ConnectionAborted,
239 "The peer stream returned None",
240 ))
241 .into())
242}
243
244#[expect(clippy::too_many_arguments)]
246async fn handshake<
247 Z: NetworkZone,
248 T: Transport<Z>,
249 AdrBook,
250 CSync,
251 ProtoHdlrMkr,
252 BrdcstStrmMkr,
253 BrdcstStrm,
254>(
255 req: DoHandshakeRequest<Z, T>,
256 transport_client_config: T::ClientConfig,
257
258 broadcast_stream_maker: BrdcstStrmMkr,
259
260 mut address_book: AdrBook,
261 mut core_sync_svc: CSync,
262 mut protocol_request_svc_maker: ProtoHdlrMkr,
263 our_basic_node_data: BasicNodeData,
264 connection_parent_span: Span,
265) -> Result<Client<Z>, HandshakeError>
266where
267 AdrBook: AddressBook<Z> + Clone,
268 CSync: CoreSyncSvc + Clone,
269 ProtoHdlrMkr: ProtocolRequestHandlerMaker<Z>,
270 BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
271 BrdcstStrmMkr: Fn(InternalPeerID<Z::Addr>) -> BrdcstStrm + Send + 'static,
272{
273 let DoHandshakeRequest {
274 addr,
275 mut peer_stream,
276 mut peer_sink,
277 direction,
278 permit,
279 } = req;
280
281 let mut eager_protocol_messages = Vec::new();
284
285 let (peer_core_sync, peer_node_data) = match direction {
286 ConnectionDirection::Inbound => {
287 tracing::debug!("waiting for handshake request.");
289
290 let Message::Request(AdminRequestMessage::Handshake(handshake_req)) =
291 wait_for_message::<Z, T>(
292 LevinCommand::Handshake,
293 true,
294 &mut peer_sink,
295 &mut peer_stream,
296 &mut eager_protocol_messages,
297 &our_basic_node_data,
298 )
299 .await?
300 else {
301 panic!("wait_for_message returned ok with wrong message.");
302 };
303
304 tracing::debug!("Received handshake request.");
305 (handshake_req.payload_data, handshake_req.node_data)
307 }
308 ConnectionDirection::Outbound => {
309 send_hs_request::<Z, T, _>(
311 &mut peer_sink,
312 &mut core_sync_svc,
313 our_basic_node_data.clone(),
314 )
315 .await?;
316
317 let Message::Response(AdminResponseMessage::Handshake(handshake_res)) =
319 wait_for_message::<Z, T>(
320 LevinCommand::Handshake,
321 false,
322 &mut peer_sink,
323 &mut peer_stream,
324 &mut eager_protocol_messages,
325 &our_basic_node_data,
326 )
327 .await?
328 else {
329 panic!("wait_for_message returned ok with wrong message.");
330 };
331
332 if handshake_res.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE {
333 tracing::debug!("peer sent too many peers in response, cancelling handshake");
334
335 return Err(HandshakeError::PeerSentInvalidMessage(
336 "Too many peers in peer list message (>250)",
337 ));
338 }
339
340 tracing::debug!(
341 "Telling address book about new peers, len: {}",
342 handshake_res.local_peerlist_new.len()
343 );
344
345 address_book
347 .ready()
348 .await?
349 .call(AddressBookRequest::IncomingPeerList(
350 handshake_res
351 .local_peerlist_new
352 .into_iter()
353 .map(TryInto::try_into)
354 .collect::<Result<_, _>>()?,
355 ))
356 .await?;
357
358 (handshake_res.payload_data, handshake_res.node_data)
359 }
360 };
361
362 if peer_node_data.network_id != our_basic_node_data.network_id {
363 return Err(HandshakeError::IncorrectNetwork);
364 }
365
366 if Z::CHECK_NODE_ID && peer_node_data.peer_id == our_basic_node_data.peer_id {
367 return Err(HandshakeError::PeerHasSameNodeID);
368 }
369
370 let pruning_seed = PruningSeed::decompress_p2p_rules(peer_core_sync.pruning_seed)?;
404
405 let public_address = 'check_out_addr: {
407 match direction {
408 ConnectionDirection::Inbound => {
409 send_hs_response::<Z, T, _, _>(
411 &mut peer_sink,
412 &mut core_sync_svc,
413 &mut address_book,
414 our_basic_node_data.clone(),
415 )
416 .await?;
417
418 if peer_node_data.my_port != 0 {
420 let InternalPeerID::KnownAddr(mut outbound_address) = addr else {
421 break 'check_out_addr None;
423 };
424
425 #[expect(
426 clippy::cast_possible_truncation,
427 reason = "u32 does not make sense as a port so just truncate it."
428 )]
429 outbound_address.set_port(peer_node_data.my_port as u16);
430
431 let Ok(Ok(ping_peer_id)) = timeout(
432 PING_TIMEOUT,
433 ping::<Z, T>(outbound_address, &transport_client_config)
434 .instrument(info_span!("ping")),
435 )
436 .await
437 else {
438 break 'check_out_addr None;
440 };
441
442 if ping_peer_id == peer_node_data.peer_id {
444 break 'check_out_addr Some(outbound_address);
445 }
446 }
447 None
449 }
450 ConnectionDirection::Outbound => {
451 let InternalPeerID::KnownAddr(outbound_addr) = addr else {
452 unreachable!("How could we make an outbound connection to an unknown address");
453 };
454
455 Some(outbound_addr)
457 }
458 }
459 };
460
461 tracing::debug!("Handshake complete.");
462
463 let (connection_guard, handle) = HandleBuilder::new().with_permit(permit).build();
464
465 address_book
467 .ready()
468 .await?
469 .call(AddressBookRequest::NewConnection {
470 internal_peer_id: addr,
471 public_address,
472 handle: handle.clone(),
473 id: peer_node_data.peer_id,
474 pruning_seed,
475 rpc_port: peer_node_data.rpc_port,
476 rpc_credits_per_hash: peer_node_data.rpc_credits_per_hash,
477 })
478 .await?;
479
480 let error_slot = SharedError::new();
482 let (connection_tx, client_rx) = mpsc::channel(CLIENT_QUEUE_SIZE);
483
484 let info = PeerInformation {
485 id: addr,
486 handle,
487 direction,
488 pruning_seed,
489 core_sync_data: Arc::new(Mutex::new(peer_core_sync)),
490 };
491
492 let protocol_request_handler = protocol_request_svc_maker
493 .as_service()
494 .ready()
495 .await?
496 .call(info.clone())
497 .await?;
498
499 let request_handler = PeerRequestHandler {
500 address_book_svc: address_book.clone(),
501 our_sync_svc: core_sync_svc.clone(),
502 protocol_request_handler,
503 our_basic_node_data,
504 peer_info: info.clone(),
505 };
506
507 let connection = Connection::<Z, T, _, _, _, _>::new(
508 peer_sink,
509 client_rx,
510 broadcast_stream_maker(addr),
511 request_handler,
512 connection_guard,
513 error_slot.clone(),
514 );
515
516 let connection_span =
517 tracing::error_span!(parent: &connection_parent_span, "connection", %addr);
518 let connection_handle = tokio::spawn(
519 connection
520 .run(peer_stream.fuse(), eager_protocol_messages)
521 .instrument(connection_span),
522 );
523
524 let semaphore = Arc::new(Semaphore::new(1));
525
526 let timeout_handle = tokio::spawn(connection_timeout_monitor_task(
527 info.clone(),
528 connection_tx.clone(),
529 Arc::clone(&semaphore),
530 address_book,
531 core_sync_svc,
532 ));
533
534 let client = Client::<Z>::new(
535 info,
536 connection_tx,
537 connection_handle,
538 timeout_handle,
539 semaphore,
540 error_slot,
541 );
542
543 Ok(client)
544}
545
546async fn send_hs_request<Z, T, CSync>(
548 peer_sink: &mut T::Sink,
549 core_sync_svc: &mut CSync,
550 our_basic_node_data: BasicNodeData,
551) -> Result<(), HandshakeError>
552where
553 Z: NetworkZone,
554 T: Transport<Z>,
555 CSync: CoreSyncSvc,
556{
557 let CoreSyncDataResponse(our_core_sync_data) = core_sync_svc
558 .ready()
559 .await?
560 .call(CoreSyncDataRequest)
561 .await?;
562
563 let req = HandshakeRequest {
564 node_data: our_basic_node_data,
565 payload_data: our_core_sync_data,
566 };
567
568 tracing::debug!("Sending handshake request.");
569
570 peer_sink
571 .send(Message::Request(AdminRequestMessage::Handshake(req)).into())
572 .await?;
573
574 Ok(())
575}
576
577async fn send_hs_response<Z, T, CSync, AdrBook>(
579 peer_sink: &mut T::Sink,
580 core_sync_svc: &mut CSync,
581 address_book: &mut AdrBook,
582 our_basic_node_data: BasicNodeData,
583) -> Result<(), HandshakeError>
584where
585 Z: NetworkZone,
586 T: Transport<Z>,
587 AdrBook: AddressBook<Z>,
588 CSync: CoreSyncSvc,
589{
590 let CoreSyncDataResponse(our_core_sync_data) = core_sync_svc
591 .ready()
592 .await?
593 .call(CoreSyncDataRequest)
594 .await?;
595
596 let AddressBookResponse::Peers(our_peer_list) = address_book
597 .ready()
598 .await?
599 .call(AddressBookRequest::GetWhitePeers(
600 MAX_PEERS_IN_PEER_LIST_MESSAGE,
601 ))
602 .await?
603 else {
604 panic!("Address book sent incorrect response");
605 };
606
607 let res = HandshakeResponse {
608 node_data: our_basic_node_data,
609 payload_data: our_core_sync_data,
610 local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(),
611 };
612
613 tracing::debug!("Sending handshake response.");
614
615 peer_sink
616 .send(Message::Response(AdminResponseMessage::Handshake(res)).into())
617 .await?;
618
619 Ok(())
620}
621
622async fn wait_for_message<Z, T>(
628 levin_command: LevinCommand,
629 request: bool,
630
631 peer_sink: &mut T::Sink,
632 peer_stream: &mut T::Stream,
633
634 eager_protocol_messages: &mut Vec<cuprate_wire::ProtocolMessage>,
635
636 our_basic_node_data: &BasicNodeData,
637) -> Result<Message, HandshakeError>
638where
639 Z: NetworkZone,
640 T: Transport<Z>,
641{
642 let mut allow_support_flag_req = true;
643 let mut allow_ping = true;
644
645 while let Some(message) = peer_stream.next().await {
646 let message = message?;
647
648 match message {
649 Message::Protocol(protocol_message) => {
650 tracing::debug!(
651 "Received eager protocol message with ID: {}, adding to queue",
652 protocol_message.command()
653 );
654 eager_protocol_messages.push(protocol_message);
655 if eager_protocol_messages.len() > MAX_EAGER_PROTOCOL_MESSAGES {
656 tracing::debug!(
657 "Peer sent too many protocol messages before a handshake response."
658 );
659 return Err(HandshakeError::PeerSentInvalidMessage(
660 "Peer sent too many protocol messages",
661 ));
662 }
663 continue;
664 }
665 Message::Request(req_message) => {
666 if req_message.command() == levin_command && request {
667 return Ok(Message::Request(req_message));
668 }
669
670 match req_message {
671 AdminRequestMessage::SupportFlags => {
672 if !allow_support_flag_req {
673 return Err(HandshakeError::PeerSentInvalidMessage(
674 "Peer sent 2 support flag requests",
675 ));
676 }
677 send_support_flags::<Z, T>(peer_sink, our_basic_node_data.support_flags)
678 .await?;
679 allow_support_flag_req = false;
681 continue;
682 }
683 AdminRequestMessage::Ping => {
684 if !allow_ping {
685 return Err(HandshakeError::PeerSentInvalidMessage(
686 "Peer sent 2 ping requests",
687 ));
688 }
689
690 send_ping_response::<Z, T>(peer_sink, our_basic_node_data.peer_id).await?;
691
692 allow_ping = false;
694 continue;
695 }
696 _ => {
697 return Err(HandshakeError::PeerSentInvalidMessage(
698 "Peer sent an admin request before responding to the handshake",
699 ));
700 }
701 }
702 }
703 Message::Response(res_message) if !request => {
704 if res_message.command() == levin_command {
705 return Ok(Message::Response(res_message));
706 }
707
708 tracing::debug!("Received unexpected response: {}", res_message.command());
709 return Err(HandshakeError::PeerSentInvalidMessage(
710 "Peer sent an incorrect response",
711 ));
712 }
713
714 Message::Response(_) => Err(HandshakeError::PeerSentInvalidMessage(
715 "Peer sent an incorrect message",
716 )),
717 }?;
718 }
719
720 Err(BucketError::IO(std::io::Error::new(
721 std::io::ErrorKind::ConnectionAborted,
722 "The peer stream returned None",
723 ))
724 .into())
725}
726
727async fn send_support_flags<Z, T>(
729 peer_sink: &mut T::Sink,
730 support_flags: PeerSupportFlags,
731) -> Result<(), HandshakeError>
732where
733 Z: NetworkZone,
734 T: Transport<Z>,
735{
736 tracing::debug!("Sending support flag response.");
737 Ok(peer_sink
738 .send(
739 Message::Response(AdminResponseMessage::SupportFlags(SupportFlagsResponse {
740 support_flags,
741 }))
742 .into(),
743 )
744 .await?)
745}
746
747async fn send_ping_response<Z, T>(
749 peer_sink: &mut T::Sink,
750 peer_id: u64,
751) -> Result<(), HandshakeError>
752where
753 Z: NetworkZone,
754 T: Transport<Z>,
755{
756 tracing::debug!("Sending ping response.");
757 Ok(peer_sink
758 .send(
759 Message::Response(AdminResponseMessage::Ping(PingResponse {
760 status: PING_OK_RESPONSE_STATUS_TEXT,
761 peer_id,
762 }))
763 .into(),
764 )
765 .await?)
766}