1use std::{
2 future::{ready, Future, Ready},
3 pin::{pin, Pin},
4 task::{Context, Poll},
5};
6
7use futures::{stream::FuturesUnordered, StreamExt};
8use indexmap::{IndexMap, IndexSet};
9use rand::{seq::index::sample, thread_rng};
10use tokio::sync::mpsc::Receiver;
11use tokio_util::sync::WaitForCancellationFutureOwned;
12use tower::Service;
13
14use cuprate_helper::cast::u64_to_usize;
15use cuprate_p2p_core::{
16 client::{Client, InternalPeerID},
17 ConnectionDirection, NetworkZone,
18};
19
20mod client_wrappers;
21
22pub use client_wrappers::ClientDropGuard;
23use client_wrappers::StoredClient;
24
25pub enum PeerSetRequest {
27 MostPoWSeen,
29 PeersWithMorePoW(u128),
33 StemPeer,
37}
38
39pub enum PeerSetResponse<N: NetworkZone> {
41 MostPoWSeen {
43 cumulative_difficulty: u128,
45 height: usize,
47 top_hash: [u8; 32],
49 },
50 PeersWithMorePoW(Vec<ClientDropGuard<N>>),
54 StemPeer(Option<ClientDropGuard<N>>),
58}
59
60#[pin_project::pin_project]
62struct ClosedConnectionFuture<N: NetworkZone> {
63 #[pin]
64 fut: WaitForCancellationFutureOwned,
65 id: Option<InternalPeerID<N::Addr>>,
66}
67
68impl<N: NetworkZone> Future for ClosedConnectionFuture<N> {
69 type Output = InternalPeerID<N::Addr>;
70 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
71 let this = self.project();
72
73 this.fut.poll(cx).map(|()| this.id.take().unwrap())
74 }
75}
76
77pub(crate) struct PeerSet<N: NetworkZone> {
79 peers: IndexMap<InternalPeerID<N::Addr>, StoredClient<N>>,
81 closed_connections: FuturesUnordered<ClosedConnectionFuture<N>>,
83 outbound_peers: IndexSet<InternalPeerID<N::Addr>>,
85 new_peers: Receiver<Client<N>>,
87}
88
89impl<N: NetworkZone> PeerSet<N> {
90 pub(crate) fn new(new_peers: Receiver<Client<N>>) -> Self {
91 Self {
92 peers: IndexMap::new(),
93 closed_connections: FuturesUnordered::new(),
94 outbound_peers: IndexSet::new(),
95 new_peers,
96 }
97 }
98
99 fn poll_new_peers(&mut self, cx: &mut Context<'_>) {
101 while let Poll::Ready(Some(new_peer)) = self.new_peers.poll_recv(cx) {
102 if new_peer.info.direction == ConnectionDirection::Outbound {
103 self.outbound_peers.insert(new_peer.info.id);
104 }
105
106 self.closed_connections.push(ClosedConnectionFuture {
107 fut: new_peer.info.handle.closed(),
108 id: Some(new_peer.info.id),
109 });
110
111 self.peers
112 .insert(new_peer.info.id, StoredClient::new(new_peer));
113 }
114 }
115
116 fn remove_dead_peers(&mut self, cx: &mut Context<'_>) {
118 while let Poll::Ready(Some(dead_peer)) = self.closed_connections.poll_next_unpin(cx) {
119 let Some(peer) = self.peers.swap_remove(&dead_peer) else {
120 continue;
121 };
122
123 if peer.client.info.direction == ConnectionDirection::Outbound {
124 self.outbound_peers.swap_remove(&peer.client.info.id);
125 }
126
127 self.peers.swap_remove(&dead_peer);
128 }
129 }
130
131 fn most_pow_seen(&self) -> PeerSetResponse<N> {
133 let most_pow_chain = self
134 .peers
135 .values()
136 .map(|peer| {
137 let core_sync_data = peer.client.info.core_sync_data.lock().unwrap();
138
139 (
140 core_sync_data.cumulative_difficulty(),
141 u64_to_usize(core_sync_data.current_height),
142 core_sync_data.top_id,
143 )
144 })
145 .max_by_key(|(cumulative_difficulty, ..)| *cumulative_difficulty)
146 .unwrap_or_default();
147
148 PeerSetResponse::MostPoWSeen {
149 cumulative_difficulty: most_pow_chain.0,
150 height: most_pow_chain.1,
151 top_hash: most_pow_chain.2,
152 }
153 }
154
155 fn peers_with_more_pow(&self, cumulative_difficulty: u128) -> PeerSetResponse<N> {
157 PeerSetResponse::PeersWithMorePoW(
158 self.peers
159 .values()
160 .filter(|&client| {
161 !client.is_downloading_blocks()
162 && client
163 .client
164 .info
165 .core_sync_data
166 .lock()
167 .unwrap()
168 .cumulative_difficulty()
169 > cumulative_difficulty
170 })
171 .map(StoredClient::downloading_blocks_guard)
172 .collect(),
173 )
174 }
175
176 fn random_peer_for_stem(&self) -> PeerSetResponse<N> {
178 PeerSetResponse::StemPeer(
179 sample(
180 &mut thread_rng(),
181 self.outbound_peers.len(),
182 self.outbound_peers.len(),
183 )
184 .into_iter()
185 .find_map(|i| {
186 let peer = self.outbound_peers.get_index(i).unwrap();
187 let client = self.peers.get(peer).unwrap();
188 (!client.is_a_stem_peer()).then(|| client.stem_peer_guard())
189 }),
190 )
191 }
192}
193
194impl<N: NetworkZone> Service<PeerSetRequest> for PeerSet<N> {
195 type Response = PeerSetResponse<N>;
196 type Error = tower::BoxError;
197 type Future = Ready<Result<Self::Response, Self::Error>>;
198
199 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200 self.poll_new_peers(cx);
201 self.remove_dead_peers(cx);
202
203 Poll::Ready(Ok(()))
206 }
207
208 fn call(&mut self, req: PeerSetRequest) -> Self::Future {
209 ready(match req {
210 PeerSetRequest::MostPoWSeen => Ok(self.most_pow_seen()),
211 PeerSetRequest::PeersWithMorePoW(cumulative_difficulty) => {
212 Ok(self.peers_with_more_pow(cumulative_difficulty))
213 }
214 PeerSetRequest::StemPeer => Ok(self.random_peer_for_stem()),
215 })
216 }
217}