cuprate_p2p/block_downloader/
request_chain.rs1use std::mem;
2
3use tokio::{task::JoinSet, time::timeout};
4use tower::{util::BoxCloneService, Service, ServiceExt};
5use tracing::{instrument, Instrument, Span};
6
7use cuprate_p2p_core::{
8 client::InternalPeerID, handles::ConnectionHandle, NetworkZone, PeerRequest, PeerResponse,
9 ProtocolRequest, ProtocolResponse,
10};
11use cuprate_wire::protocol::{ChainRequest, ChainResponse};
12
13use crate::{
14 block_downloader::{
15 chain_tracker::{ChainEntry, ChainTracker},
16 BlockDownloadError, ChainSvcRequest, ChainSvcResponse,
17 },
18 constants::{
19 BLOCK_DOWNLOADER_REQUEST_TIMEOUT, INITIAL_CHAIN_REQUESTS_TO_SEND,
20 MAX_BLOCKS_IDS_IN_CHAIN_ENTRY, MEDIUM_BAN,
21 },
22 peer_set::{ClientDropGuard, PeerSetRequest, PeerSetResponse},
23};
24
25pub(crate) async fn request_chain_entry_from_peer<N: NetworkZone>(
30 mut client: ClientDropGuard<N>,
31 short_history: [[u8; 32]; 2],
32) -> Result<(ClientDropGuard<N>, ChainEntry<N>), BlockDownloadError> {
33 let PeerResponse::Protocol(ProtocolResponse::GetChain(chain_res)) = client
34 .ready()
35 .await?
36 .call(PeerRequest::Protocol(ProtocolRequest::GetChain(
37 ChainRequest {
38 block_ids: short_history.into(),
39 prune: true,
40 },
41 )))
42 .await?
43 else {
44 panic!("Connection task returned wrong response!");
45 };
46
47 if chain_res.m_block_ids.is_empty()
48 || chain_res.m_block_ids.len() > MAX_BLOCKS_IDS_IN_CHAIN_ENTRY
49 {
50 client.info.handle.ban_peer(MEDIUM_BAN);
51 return Err(BlockDownloadError::PeersResponseWasInvalid);
52 }
53
54 if !(chain_res.m_block_ids[0] == short_history[0]
56 || chain_res.m_block_ids[0] == short_history[1])
57 {
58 client.info.handle.ban_peer(MEDIUM_BAN);
59 return Err(BlockDownloadError::PeersResponseWasInvalid);
60 }
61
62 if chain_res.m_block_ids[0] == short_history[1] {
65 return Err(BlockDownloadError::PeerDidNotHaveRequestedData);
66 }
67
68 let entry = ChainEntry {
69 ids: (&chain_res.m_block_ids).into(),
70 peer: client.info.id,
71 handle: client.info.handle.clone(),
72 };
73
74 Ok((client, entry))
75}
76
77#[instrument(level = "error", skip_all)]
82pub async fn initial_chain_search<N: NetworkZone, C>(
83 peer_set: &mut BoxCloneService<PeerSetRequest, PeerSetResponse<N>, tower::BoxError>,
84 mut our_chain_svc: C,
85) -> Result<ChainTracker<N>, BlockDownloadError>
86where
87 C: Service<ChainSvcRequest<N>, Response = ChainSvcResponse<N>, Error = tower::BoxError>,
88{
89 tracing::debug!("Getting our chain history");
90 let ChainSvcResponse::CompactHistory {
92 block_ids,
93 cumulative_difficulty,
94 } = our_chain_svc
95 .ready()
96 .await?
97 .call(ChainSvcRequest::CompactHistory)
98 .await?
99 else {
100 panic!("chain service sent wrong response.");
101 };
102
103 let our_genesis = *block_ids.last().expect("Blockchain had no genesis block.");
104
105 let PeerSetResponse::PeersWithMorePoW(clients) = peer_set
106 .ready()
107 .await?
108 .call(PeerSetRequest::PeersWithMorePoW(cumulative_difficulty))
109 .await?
110 else {
111 unreachable!();
112 };
113 let mut peers = clients.into_iter();
114
115 let mut futs = JoinSet::new();
116
117 let req = PeerRequest::Protocol(ProtocolRequest::GetChain(ChainRequest {
118 block_ids: block_ids.into(),
119 prune: false,
120 }));
121
122 tracing::debug!("Sending requests for chain entries.");
123
124 while futs.len() < INITIAL_CHAIN_REQUESTS_TO_SEND {
126 let Some(mut next_peer) = peers.next() else {
127 break;
128 };
129
130 let cloned_req = req.clone();
131 futs.spawn(timeout(
132 BLOCK_DOWNLOADER_REQUEST_TIMEOUT,
133 async move {
134 let PeerResponse::Protocol(ProtocolResponse::GetChain(chain_res)) =
135 next_peer.ready().await?.call(cloned_req).await?
136 else {
137 panic!("connection task returned wrong response!");
138 };
139
140 Ok::<_, tower::BoxError>((
141 chain_res,
142 next_peer.info.id,
143 next_peer.info.handle.clone(),
144 ))
145 }
146 .instrument(Span::current()),
147 ));
148 }
149
150 let mut res: Option<(ChainResponse, InternalPeerID<_>, ConnectionHandle)> = None;
151
152 while let Some(task_res) = futs.join_next().await {
154 let Ok(Ok(task_res)) = task_res.unwrap() else {
155 continue;
156 };
157
158 match &mut res {
159 Some(res) => {
160 if res.0.cumulative_difficulty() < task_res.0.cumulative_difficulty() {
162 drop(mem::replace(res, task_res));
163 }
164 }
165 None => {
166 res = Some(task_res);
168 }
169 }
170 }
171
172 let Some((chain_res, peer_id, peer_handle)) = res else {
173 return Err(BlockDownloadError::FailedToFindAChainToFollow);
174 };
175
176 let hashes: Vec<[u8; 32]> = (&chain_res.m_block_ids).into();
177 drop(chain_res);
179
180 tracing::debug!("Highest chin entry contained {} block Ids", hashes.len());
181
182 let ChainSvcResponse::FindFirstUnknown(first_unknown_ret) = our_chain_svc
184 .ready()
185 .await?
186 .call(ChainSvcRequest::FindFirstUnknown(hashes.clone()))
187 .await?
188 else {
189 panic!("chain service sent wrong response.");
190 };
191
192 let Some((first_unknown, expected_height)) = first_unknown_ret else {
195 return Err(BlockDownloadError::FailedToFindAChainToFollow);
196 };
197
198 if first_unknown == 0 {
200 peer_handle.ban_peer(MEDIUM_BAN);
201 return Err(BlockDownloadError::PeerSentNoOverlappingBlocks);
202 }
203
204 let previous_id = hashes[first_unknown - 1];
205
206 let first_entry = ChainEntry {
207 ids: hashes[first_unknown..].to_vec(),
208 peer: peer_id,
209 handle: peer_handle,
210 };
211
212 tracing::debug!(
213 "Creating chain tracker with {} new block Ids",
214 first_entry.ids.len()
215 );
216
217 let tracker = ChainTracker::new(
218 first_entry,
219 expected_height,
220 our_genesis,
221 previous_id,
222 &mut our_chain_svc,
223 )
224 .await
225 .map_err(|_| BlockDownloadError::ChainInvalid)?;
226
227 Ok(tracker)
228}