cuprate_p2p/block_downloader/
request_chain.rs

1use std::mem;
2
3use tokio::{task::JoinSet, time::timeout};
4use tower::{util::BoxCloneService, Service, ServiceExt};
5use tracing::{instrument, Instrument, Span};
6
7use cuprate_p2p_core::{
8    client::InternalPeerID, handles::ConnectionHandle, NetworkZone, PeerRequest, PeerResponse,
9    ProtocolRequest, ProtocolResponse,
10};
11use cuprate_wire::protocol::{ChainRequest, ChainResponse};
12
13use crate::{
14    block_downloader::{
15        chain_tracker::{ChainEntry, ChainTracker},
16        BlockDownloadError, ChainSvcRequest, ChainSvcResponse,
17    },
18    constants::{
19        BLOCK_DOWNLOADER_REQUEST_TIMEOUT, INITIAL_CHAIN_REQUESTS_TO_SEND,
20        MAX_BLOCKS_IDS_IN_CHAIN_ENTRY, MEDIUM_BAN,
21    },
22    peer_set::{ClientDropGuard, PeerSetRequest, PeerSetResponse},
23};
24
25/// Request a chain entry from a peer.
26///
27/// Because the block downloader only follows and downloads one chain we only have to send the block hash of
28/// top block we have found and the genesis block, this is then called `short_history`.
29pub(crate) async fn request_chain_entry_from_peer<N: NetworkZone>(
30    mut client: ClientDropGuard<N>,
31    short_history: [[u8; 32]; 2],
32) -> Result<(ClientDropGuard<N>, ChainEntry<N>), BlockDownloadError> {
33    let PeerResponse::Protocol(ProtocolResponse::GetChain(chain_res)) = client
34        .ready()
35        .await?
36        .call(PeerRequest::Protocol(ProtocolRequest::GetChain(
37            ChainRequest {
38                block_ids: short_history.into(),
39                prune: true,
40            },
41        )))
42        .await?
43    else {
44        panic!("Connection task returned wrong response!");
45    };
46
47    if chain_res.m_block_ids.is_empty()
48        || chain_res.m_block_ids.len() > MAX_BLOCKS_IDS_IN_CHAIN_ENTRY
49    {
50        client.info.handle.ban_peer(MEDIUM_BAN);
51        return Err(BlockDownloadError::PeersResponseWasInvalid);
52    }
53
54    // We must have at least one overlapping block.
55    if !(chain_res.m_block_ids[0] == short_history[0]
56        || chain_res.m_block_ids[0] == short_history[1])
57    {
58        client.info.handle.ban_peer(MEDIUM_BAN);
59        return Err(BlockDownloadError::PeersResponseWasInvalid);
60    }
61
62    // If the genesis is the overlapping block then this peer does not have our top tracked block in
63    // its chain.
64    if chain_res.m_block_ids[0] == short_history[1] {
65        return Err(BlockDownloadError::PeerDidNotHaveRequestedData);
66    }
67
68    let entry = ChainEntry {
69        ids: (&chain_res.m_block_ids).into(),
70        peer: client.info.id,
71        handle: client.info.handle.clone(),
72    };
73
74    Ok((client, entry))
75}
76
77/// Initial chain search, this function pulls [`INITIAL_CHAIN_REQUESTS_TO_SEND`] peers from the [`ClientPool`]
78/// and sends chain requests to all of them.
79///
80/// We then wait for their response and choose the peer who claims the highest cumulative difficulty.
81#[instrument(level = "error", skip_all)]
82pub async fn initial_chain_search<N: NetworkZone, C>(
83    peer_set: &mut BoxCloneService<PeerSetRequest, PeerSetResponse<N>, tower::BoxError>,
84    mut our_chain_svc: C,
85) -> Result<ChainTracker<N>, BlockDownloadError>
86where
87    C: Service<ChainSvcRequest<N>, Response = ChainSvcResponse<N>, Error = tower::BoxError>,
88{
89    tracing::debug!("Getting our chain history");
90    // Get our history.
91    let ChainSvcResponse::CompactHistory {
92        block_ids,
93        cumulative_difficulty,
94    } = our_chain_svc
95        .ready()
96        .await?
97        .call(ChainSvcRequest::CompactHistory)
98        .await?
99    else {
100        panic!("chain service sent wrong response.");
101    };
102
103    let our_genesis = *block_ids.last().expect("Blockchain had no genesis block.");
104
105    let PeerSetResponse::PeersWithMorePoW(clients) = peer_set
106        .ready()
107        .await?
108        .call(PeerSetRequest::PeersWithMorePoW(cumulative_difficulty))
109        .await?
110    else {
111        unreachable!();
112    };
113    let mut peers = clients.into_iter();
114
115    let mut futs = JoinSet::new();
116
117    let req = PeerRequest::Protocol(ProtocolRequest::GetChain(ChainRequest {
118        block_ids: block_ids.into(),
119        prune: false,
120    }));
121
122    tracing::debug!("Sending requests for chain entries.");
123
124    // Send the requests.
125    while futs.len() < INITIAL_CHAIN_REQUESTS_TO_SEND {
126        let Some(mut next_peer) = peers.next() else {
127            break;
128        };
129
130        let cloned_req = req.clone();
131        futs.spawn(timeout(
132            BLOCK_DOWNLOADER_REQUEST_TIMEOUT,
133            async move {
134                let PeerResponse::Protocol(ProtocolResponse::GetChain(chain_res)) =
135                    next_peer.ready().await?.call(cloned_req).await?
136                else {
137                    panic!("connection task returned wrong response!");
138                };
139
140                Ok::<_, tower::BoxError>((
141                    chain_res,
142                    next_peer.info.id,
143                    next_peer.info.handle.clone(),
144                ))
145            }
146            .instrument(Span::current()),
147        ));
148    }
149
150    let mut res: Option<(ChainResponse, InternalPeerID<_>, ConnectionHandle)> = None;
151
152    // Wait for the peers responses.
153    while let Some(task_res) = futs.join_next().await {
154        let Ok(Ok(task_res)) = task_res.unwrap() else {
155            continue;
156        };
157
158        match &mut res {
159            Some(res) => {
160                // res has already been set, replace it if this peer claims higher cumulative difficulty
161                if res.0.cumulative_difficulty() < task_res.0.cumulative_difficulty() {
162                    drop(mem::replace(res, task_res));
163                }
164            }
165            None => {
166                // res has not been set, set it now;
167                res = Some(task_res);
168            }
169        }
170    }
171
172    let Some((chain_res, peer_id, peer_handle)) = res else {
173        return Err(BlockDownloadError::FailedToFindAChainToFollow);
174    };
175
176    let hashes: Vec<[u8; 32]> = (&chain_res.m_block_ids).into();
177    // drop this to deallocate the [`Bytes`].
178    drop(chain_res);
179
180    tracing::debug!("Highest chin entry contained {} block Ids", hashes.len());
181
182    // Find the first unknown block in the batch.
183    let ChainSvcResponse::FindFirstUnknown(first_unknown_ret) = our_chain_svc
184        .ready()
185        .await?
186        .call(ChainSvcRequest::FindFirstUnknown(hashes.clone()))
187        .await?
188    else {
189        panic!("chain service sent wrong response.");
190    };
191
192    // We know all the blocks already
193    // TODO: The peer could still be on a different chain, however the chain might just be too far split.
194    let Some((first_unknown, expected_height)) = first_unknown_ret else {
195        return Err(BlockDownloadError::FailedToFindAChainToFollow);
196    };
197
198    // The peer must send at least one block we already know.
199    if first_unknown == 0 {
200        peer_handle.ban_peer(MEDIUM_BAN);
201        return Err(BlockDownloadError::PeerSentNoOverlappingBlocks);
202    }
203
204    let previous_id = hashes[first_unknown - 1];
205
206    let first_entry = ChainEntry {
207        ids: hashes[first_unknown..].to_vec(),
208        peer: peer_id,
209        handle: peer_handle,
210    };
211
212    tracing::debug!(
213        "Creating chain tracker with {} new block Ids",
214        first_entry.ids.len()
215    );
216
217    let tracker = ChainTracker::new(
218        first_entry,
219        expected_height,
220        our_genesis,
221        previous_id,
222        &mut our_chain_svc,
223    )
224    .await
225    .map_err(|_| BlockDownloadError::ChainInvalid)?;
226
227    Ok(tracker)
228}