1use std::{
8 future::Future,
9 marker::PhantomData,
10 pin::Pin,
11 sync::{Arc, Mutex},
12 task::{Context, Poll},
13};
14
15use futures::{FutureExt, SinkExt, Stream, StreamExt};
16use tokio::{
17 sync::{mpsc, OwnedSemaphorePermit, Semaphore},
18 time::{error::Elapsed, timeout},
19};
20use tower::{Service, ServiceExt};
21use tracing::{info_span, Instrument, Span};
22
23use cuprate_pruning::{PruningError, PruningSeed};
24use cuprate_wire::{
25 admin::{
26 HandshakeRequest, HandshakeResponse, PingResponse, SupportFlagsResponse,
27 PING_OK_RESPONSE_STATUS_TEXT,
28 },
29 common::PeerSupportFlags,
30 AdminRequestMessage, AdminResponseMessage, BasicNodeData, BucketError, LevinCommand, Message,
31};
32
33use crate::{
34 client::{
35 connection::Connection, request_handler::PeerRequestHandler,
36 timeout_monitor::connection_timeout_monitor_task, Client, InternalPeerID, PeerInformation,
37 },
38 constants::{
39 CLIENT_QUEUE_SIZE, HANDSHAKE_TIMEOUT, MAX_EAGER_PROTOCOL_MESSAGES,
40 MAX_PEERS_IN_PEER_LIST_MESSAGE, PING_TIMEOUT,
41 },
42 handles::HandleBuilder,
43 AddressBook, AddressBookRequest, AddressBookResponse, BroadcastMessage, ConnectionDirection,
44 CoreSyncDataRequest, CoreSyncDataResponse, CoreSyncSvc, NetZoneAddress, NetworkZone,
45 ProtocolRequestHandlerMaker, SharedError,
46};
47
48pub mod builder;
49pub use builder::HandshakerBuilder;
50
51#[derive(Debug, thiserror::Error)]
52pub enum HandshakeError {
53 #[error("The handshake timed out")]
54 TimedOut(#[from] Elapsed),
55 #[error("Peer has the same node ID as us")]
56 PeerHasSameNodeID,
57 #[error("Peer is on a different network")]
58 IncorrectNetwork,
59 #[error("Peer sent a peer list with peers from different zones")]
60 PeerSentIncorrectPeerList(#[from] crate::services::PeerListConversionError),
61 #[error("Peer sent invalid message: {0}")]
62 PeerSentInvalidMessage(&'static str),
63 #[error("The peers pruning seed is invalid.")]
64 InvalidPruningSeed(#[from] PruningError),
65 #[error("Levin bucket error: {0}")]
66 LevinBucketError(#[from] BucketError),
67 #[error("Internal service error: {0}")]
68 InternalSvcErr(#[from] tower::BoxError),
69 #[error("I/O error: {0}")]
70 IO(#[from] std::io::Error),
71}
72
73pub struct DoHandshakeRequest<Z: NetworkZone> {
75 pub addr: InternalPeerID<Z::Addr>,
77 pub peer_stream: Z::Stream,
79 pub peer_sink: Z::Sink,
81 pub direction: ConnectionDirection,
83 pub permit: Option<OwnedSemaphorePermit>,
85}
86
87#[derive(Debug, Clone)]
89pub struct HandShaker<Z: NetworkZone, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr> {
90 address_book: AdrBook,
92 core_sync_svc: CSync,
94 protocol_request_svc_maker: ProtoHdlrMkr,
96
97 our_basic_node_data: BasicNodeData,
99
100 broadcast_stream_maker: BrdcstStrmMkr,
102
103 connection_parent_span: Span,
104
105 _zone: PhantomData<Z>,
107}
108
109impl<Z: NetworkZone, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
110 HandShaker<Z, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
111{
112 const fn new(
114 address_book: AdrBook,
115 core_sync_svc: CSync,
116 protocol_request_svc_maker: ProtoHdlrMkr,
117 broadcast_stream_maker: BrdcstStrmMkr,
118 our_basic_node_data: BasicNodeData,
119 connection_parent_span: Span,
120 ) -> Self {
121 Self {
122 address_book,
123 core_sync_svc,
124 protocol_request_svc_maker,
125 broadcast_stream_maker,
126 our_basic_node_data,
127 connection_parent_span,
128 _zone: PhantomData,
129 }
130 }
131}
132
133impl<Z: NetworkZone, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr, BrdcstStrm>
134 Service<DoHandshakeRequest<Z>> for HandShaker<Z, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr>
135where
136 AdrBook: AddressBook<Z> + Clone,
137 CSync: CoreSyncSvc + Clone,
138 ProtoHdlrMkr: ProtocolRequestHandlerMaker<Z> + Clone,
139 BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
140 BrdcstStrmMkr: Fn(InternalPeerID<Z::Addr>) -> BrdcstStrm + Clone + Send + 'static,
141{
142 type Response = Client<Z>;
143 type Error = HandshakeError;
144 type Future =
145 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
146
147 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148 Poll::Ready(Ok(()))
149 }
150
151 fn call(&mut self, req: DoHandshakeRequest<Z>) -> Self::Future {
152 let broadcast_stream_maker = self.broadcast_stream_maker.clone();
153
154 let address_book = self.address_book.clone();
155 let protocol_request_svc_maker = self.protocol_request_svc_maker.clone();
156 let core_sync_svc = self.core_sync_svc.clone();
157 let our_basic_node_data = self.our_basic_node_data.clone();
158
159 let connection_parent_span = self.connection_parent_span.clone();
160
161 let span = info_span!(parent: &Span::current(), "handshaker", addr=%req.addr);
162
163 async move {
164 timeout(
165 HANDSHAKE_TIMEOUT,
166 handshake(
167 req,
168 broadcast_stream_maker,
169 address_book,
170 core_sync_svc,
171 protocol_request_svc_maker,
172 our_basic_node_data,
173 connection_parent_span,
174 ),
175 )
176 .await?
177 }
178 .instrument(span)
179 .boxed()
180 }
181}
182
183pub async fn ping<N: NetworkZone>(addr: N::Addr) -> Result<u64, HandshakeError> {
187 tracing::debug!("Sending Ping to peer");
188
189 let (mut peer_stream, mut peer_sink) = N::connect_to_peer(addr).await?;
190
191 tracing::debug!("Made outbound connection to peer, sending ping.");
192
193 peer_sink
194 .send(Message::Request(AdminRequestMessage::Ping).into())
195 .await?;
196
197 if let Some(res) = peer_stream.next().await {
198 if let Message::Response(AdminResponseMessage::Ping(ping)) = res? {
199 if ping.status == PING_OK_RESPONSE_STATUS_TEXT {
200 tracing::debug!("Ping successful.");
201 return Ok(ping.peer_id);
202 }
203
204 tracing::debug!("Peer's ping response was not `OK`.");
205 return Err(HandshakeError::PeerSentInvalidMessage(
206 "Ping response was not `OK`",
207 ));
208 }
209
210 tracing::debug!("Peer sent invalid response to ping.");
211 return Err(HandshakeError::PeerSentInvalidMessage(
212 "Peer did not send correct response for ping.",
213 ));
214 }
215
216 tracing::debug!("Connection closed before ping response.");
217 Err(BucketError::IO(std::io::Error::new(
218 std::io::ErrorKind::ConnectionAborted,
219 "The peer stream returned None",
220 ))
221 .into())
222}
223
224async fn handshake<Z: NetworkZone, AdrBook, CSync, ProtoHdlrMkr, BrdcstStrmMkr, BrdcstStrm>(
226 req: DoHandshakeRequest<Z>,
227
228 broadcast_stream_maker: BrdcstStrmMkr,
229
230 mut address_book: AdrBook,
231 mut core_sync_svc: CSync,
232 mut protocol_request_svc_maker: ProtoHdlrMkr,
233 our_basic_node_data: BasicNodeData,
234 connection_parent_span: Span,
235) -> Result<Client<Z>, HandshakeError>
236where
237 AdrBook: AddressBook<Z> + Clone,
238 CSync: CoreSyncSvc + Clone,
239 ProtoHdlrMkr: ProtocolRequestHandlerMaker<Z>,
240 BrdcstStrm: Stream<Item = BroadcastMessage> + Send + 'static,
241 BrdcstStrmMkr: Fn(InternalPeerID<Z::Addr>) -> BrdcstStrm + Send + 'static,
242{
243 let DoHandshakeRequest {
244 addr,
245 mut peer_stream,
246 mut peer_sink,
247 direction,
248 permit,
249 } = req;
250
251 let mut eager_protocol_messages = Vec::new();
254
255 let (peer_core_sync, peer_node_data) = match direction {
256 ConnectionDirection::Inbound => {
257 tracing::debug!("waiting for handshake request.");
259
260 let Message::Request(AdminRequestMessage::Handshake(handshake_req)) =
261 wait_for_message::<Z>(
262 LevinCommand::Handshake,
263 true,
264 &mut peer_sink,
265 &mut peer_stream,
266 &mut eager_protocol_messages,
267 &our_basic_node_data,
268 )
269 .await?
270 else {
271 panic!("wait_for_message returned ok with wrong message.");
272 };
273
274 tracing::debug!("Received handshake request.");
275 (handshake_req.payload_data, handshake_req.node_data)
277 }
278 ConnectionDirection::Outbound => {
279 send_hs_request::<Z, _>(
281 &mut peer_sink,
282 &mut core_sync_svc,
283 our_basic_node_data.clone(),
284 )
285 .await?;
286
287 let Message::Response(AdminResponseMessage::Handshake(handshake_res)) =
289 wait_for_message::<Z>(
290 LevinCommand::Handshake,
291 false,
292 &mut peer_sink,
293 &mut peer_stream,
294 &mut eager_protocol_messages,
295 &our_basic_node_data,
296 )
297 .await?
298 else {
299 panic!("wait_for_message returned ok with wrong message.");
300 };
301
302 if handshake_res.local_peerlist_new.len() > MAX_PEERS_IN_PEER_LIST_MESSAGE {
303 tracing::debug!("peer sent too many peers in response, cancelling handshake");
304
305 return Err(HandshakeError::PeerSentInvalidMessage(
306 "Too many peers in peer list message (>250)",
307 ));
308 }
309
310 tracing::debug!(
311 "Telling address book about new peers, len: {}",
312 handshake_res.local_peerlist_new.len()
313 );
314
315 address_book
317 .ready()
318 .await?
319 .call(AddressBookRequest::IncomingPeerList(
320 handshake_res
321 .local_peerlist_new
322 .into_iter()
323 .map(TryInto::try_into)
324 .collect::<Result<_, _>>()?,
325 ))
326 .await?;
327
328 (handshake_res.payload_data, handshake_res.node_data)
329 }
330 };
331
332 if peer_node_data.network_id != our_basic_node_data.network_id {
333 return Err(HandshakeError::IncorrectNetwork);
334 }
335
336 if Z::CHECK_NODE_ID && peer_node_data.peer_id == our_basic_node_data.peer_id {
337 return Err(HandshakeError::PeerHasSameNodeID);
338 }
339
340 let pruning_seed = PruningSeed::decompress_p2p_rules(peer_core_sync.pruning_seed)?;
374
375 let public_address = 'check_out_addr: {
377 match direction {
378 ConnectionDirection::Inbound => {
379 send_hs_response::<Z, _, _>(
381 &mut peer_sink,
382 &mut core_sync_svc,
383 &mut address_book,
384 our_basic_node_data.clone(),
385 )
386 .await?;
387
388 if peer_node_data.my_port != 0 {
390 let InternalPeerID::KnownAddr(mut outbound_address) = addr else {
391 break 'check_out_addr None;
393 };
394
395 #[expect(
396 clippy::cast_possible_truncation,
397 reason = "u32 does not make sense as a port so just truncate it."
398 )]
399 outbound_address.set_port(peer_node_data.my_port as u16);
400
401 let Ok(Ok(ping_peer_id)) = timeout(
402 PING_TIMEOUT,
403 ping::<Z>(outbound_address).instrument(info_span!("ping")),
404 )
405 .await
406 else {
407 break 'check_out_addr None;
409 };
410
411 if ping_peer_id == peer_node_data.peer_id {
413 break 'check_out_addr Some(outbound_address);
414 }
415 }
416 None
418 }
419 ConnectionDirection::Outbound => {
420 let InternalPeerID::KnownAddr(outbound_addr) = addr else {
421 unreachable!("How could we make an outbound connection to an unknown address");
422 };
423
424 Some(outbound_addr)
426 }
427 }
428 };
429
430 tracing::debug!("Handshake complete.");
431
432 let (connection_guard, handle) = HandleBuilder::new().with_permit(permit).build();
433
434 address_book
436 .ready()
437 .await?
438 .call(AddressBookRequest::NewConnection {
439 internal_peer_id: addr,
440 public_address,
441 handle: handle.clone(),
442 id: peer_node_data.peer_id,
443 pruning_seed,
444 rpc_port: peer_node_data.rpc_port,
445 rpc_credits_per_hash: peer_node_data.rpc_credits_per_hash,
446 })
447 .await?;
448
449 let error_slot = SharedError::new();
451 let (connection_tx, client_rx) = mpsc::channel(CLIENT_QUEUE_SIZE);
452
453 let info = PeerInformation {
454 id: addr,
455 handle,
456 direction,
457 pruning_seed,
458 core_sync_data: Arc::new(Mutex::new(peer_core_sync)),
459 };
460
461 let protocol_request_handler = protocol_request_svc_maker
462 .as_service()
463 .ready()
464 .await?
465 .call(info.clone())
466 .await?;
467
468 let request_handler = PeerRequestHandler {
469 address_book_svc: address_book.clone(),
470 our_sync_svc: core_sync_svc.clone(),
471 protocol_request_handler,
472 our_basic_node_data,
473 peer_info: info.clone(),
474 };
475
476 let connection = Connection::<Z, _, _, _, _>::new(
477 peer_sink,
478 client_rx,
479 broadcast_stream_maker(addr),
480 request_handler,
481 connection_guard,
482 error_slot.clone(),
483 );
484
485 let connection_span =
486 tracing::error_span!(parent: &connection_parent_span, "connection", %addr);
487 let connection_handle = tokio::spawn(
488 connection
489 .run(peer_stream.fuse(), eager_protocol_messages)
490 .instrument(connection_span),
491 );
492
493 let semaphore = Arc::new(Semaphore::new(1));
494
495 let timeout_handle = tokio::spawn(connection_timeout_monitor_task(
496 info.clone(),
497 connection_tx.clone(),
498 Arc::clone(&semaphore),
499 address_book,
500 core_sync_svc,
501 ));
502
503 let client = Client::<Z>::new(
504 info,
505 connection_tx,
506 connection_handle,
507 timeout_handle,
508 semaphore,
509 error_slot,
510 );
511
512 Ok(client)
513}
514
515async fn send_hs_request<Z: NetworkZone, CSync>(
517 peer_sink: &mut Z::Sink,
518 core_sync_svc: &mut CSync,
519 our_basic_node_data: BasicNodeData,
520) -> Result<(), HandshakeError>
521where
522 CSync: CoreSyncSvc,
523{
524 let CoreSyncDataResponse(our_core_sync_data) = core_sync_svc
525 .ready()
526 .await?
527 .call(CoreSyncDataRequest)
528 .await?;
529
530 let req = HandshakeRequest {
531 node_data: our_basic_node_data,
532 payload_data: our_core_sync_data,
533 };
534
535 tracing::debug!("Sending handshake request.");
536
537 peer_sink
538 .send(Message::Request(AdminRequestMessage::Handshake(req)).into())
539 .await?;
540
541 Ok(())
542}
543
544async fn send_hs_response<Z: NetworkZone, CSync, AdrBook>(
546 peer_sink: &mut Z::Sink,
547 core_sync_svc: &mut CSync,
548 address_book: &mut AdrBook,
549 our_basic_node_data: BasicNodeData,
550) -> Result<(), HandshakeError>
551where
552 AdrBook: AddressBook<Z>,
553 CSync: CoreSyncSvc,
554{
555 let CoreSyncDataResponse(our_core_sync_data) = core_sync_svc
556 .ready()
557 .await?
558 .call(CoreSyncDataRequest)
559 .await?;
560
561 let AddressBookResponse::Peers(our_peer_list) = address_book
562 .ready()
563 .await?
564 .call(AddressBookRequest::GetWhitePeers(
565 MAX_PEERS_IN_PEER_LIST_MESSAGE,
566 ))
567 .await?
568 else {
569 panic!("Address book sent incorrect response");
570 };
571
572 let res = HandshakeResponse {
573 node_data: our_basic_node_data,
574 payload_data: our_core_sync_data,
575 local_peerlist_new: our_peer_list.into_iter().map(Into::into).collect(),
576 };
577
578 tracing::debug!("Sending handshake response.");
579
580 peer_sink
581 .send(Message::Response(AdminResponseMessage::Handshake(res)).into())
582 .await?;
583
584 Ok(())
585}
586
587async fn wait_for_message<Z: NetworkZone>(
593 levin_command: LevinCommand,
594 request: bool,
595
596 peer_sink: &mut Z::Sink,
597 peer_stream: &mut Z::Stream,
598
599 eager_protocol_messages: &mut Vec<cuprate_wire::ProtocolMessage>,
600
601 our_basic_node_data: &BasicNodeData,
602) -> Result<Message, HandshakeError> {
603 let mut allow_support_flag_req = true;
604 let mut allow_ping = true;
605
606 while let Some(message) = peer_stream.next().await {
607 let message = message?;
608
609 match message {
610 Message::Protocol(protocol_message) => {
611 tracing::debug!(
612 "Received eager protocol message with ID: {}, adding to queue",
613 protocol_message.command()
614 );
615 eager_protocol_messages.push(protocol_message);
616 if eager_protocol_messages.len() > MAX_EAGER_PROTOCOL_MESSAGES {
617 tracing::debug!(
618 "Peer sent too many protocol messages before a handshake response."
619 );
620 return Err(HandshakeError::PeerSentInvalidMessage(
621 "Peer sent too many protocol messages",
622 ));
623 }
624 continue;
625 }
626 Message::Request(req_message) => {
627 if req_message.command() == levin_command && request {
628 return Ok(Message::Request(req_message));
629 }
630
631 match req_message {
632 AdminRequestMessage::SupportFlags => {
633 if !allow_support_flag_req {
634 return Err(HandshakeError::PeerSentInvalidMessage(
635 "Peer sent 2 support flag requests",
636 ));
637 }
638 send_support_flags::<Z>(peer_sink, our_basic_node_data.support_flags)
639 .await?;
640 allow_support_flag_req = false;
642 continue;
643 }
644 AdminRequestMessage::Ping => {
645 if !allow_ping {
646 return Err(HandshakeError::PeerSentInvalidMessage(
647 "Peer sent 2 ping requests",
648 ));
649 }
650
651 send_ping_response::<Z>(peer_sink, our_basic_node_data.peer_id).await?;
652
653 allow_ping = false;
655 continue;
656 }
657 _ => {
658 return Err(HandshakeError::PeerSentInvalidMessage(
659 "Peer sent an admin request before responding to the handshake",
660 ));
661 }
662 }
663 }
664 Message::Response(res_message) if !request => {
665 if res_message.command() == levin_command {
666 return Ok(Message::Response(res_message));
667 }
668
669 tracing::debug!("Received unexpected response: {}", res_message.command());
670 return Err(HandshakeError::PeerSentInvalidMessage(
671 "Peer sent an incorrect response",
672 ));
673 }
674
675 Message::Response(_) => Err(HandshakeError::PeerSentInvalidMessage(
676 "Peer sent an incorrect message",
677 )),
678 }?;
679 }
680
681 Err(BucketError::IO(std::io::Error::new(
682 std::io::ErrorKind::ConnectionAborted,
683 "The peer stream returned None",
684 ))
685 .into())
686}
687
688async fn send_support_flags<Z: NetworkZone>(
690 peer_sink: &mut Z::Sink,
691 support_flags: PeerSupportFlags,
692) -> Result<(), HandshakeError> {
693 tracing::debug!("Sending support flag response.");
694 Ok(peer_sink
695 .send(
696 Message::Response(AdminResponseMessage::SupportFlags(SupportFlagsResponse {
697 support_flags,
698 }))
699 .into(),
700 )
701 .await?)
702}
703
704async fn send_ping_response<Z: NetworkZone>(
706 peer_sink: &mut Z::Sink,
707 peer_id: u64,
708) -> Result<(), HandshakeError> {
709 tracing::debug!("Sending ping response.");
710 Ok(peer_sink
711 .send(
712 Message::Response(AdminResponseMessage::Ping(PingResponse {
713 status: PING_OK_RESPONSE_STATUS_TEXT,
714 peer_id,
715 }))
716 .into(),
717 )
718 .await?)
719}