1use std::{
5 future::{ready, Future, Ready},
6 pin::{pin, Pin},
7 task::{ready, Context, Poll},
8 time::Duration,
9};
10
11use bytes::Bytes;
12use futures::Stream;
13use rand::prelude::*;
14use rand_distr::Exp;
15use tokio::{
16 sync::{
17 broadcast::{self, error::TryRecvError},
18 watch,
19 },
20 time::{sleep_until, Instant, Sleep},
21};
22use tokio_stream::wrappers::WatchStream;
23use tower::Service;
24
25use cuprate_p2p_core::{
26 client::InternalPeerID, BroadcastMessage, ConnectionDirection, NetworkZone,
27};
28use cuprate_types::{BlockCompleteEntry, TransactionBlobs};
29use cuprate_wire::protocol::{NewFluffyBlock, NewTransactions};
30
31use crate::constants::{
32 DIFFUSION_FLUSH_AVERAGE_SECONDS_INBOUND, DIFFUSION_FLUSH_AVERAGE_SECONDS_OUTBOUND,
33 MAX_TXS_IN_BROADCAST_CHANNEL, SOFT_TX_MESSAGE_SIZE_SIZE_LIMIT,
34};
35
36#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
38pub(crate) struct BroadcastConfig {
39 pub diffusion_flush_average_seconds_outbound: Duration,
41 pub diffusion_flush_average_seconds_inbound: Duration,
43}
44
45impl Default for BroadcastConfig {
46 fn default() -> Self {
47 Self {
48 diffusion_flush_average_seconds_inbound: DIFFUSION_FLUSH_AVERAGE_SECONDS_INBOUND,
49 diffusion_flush_average_seconds_outbound: DIFFUSION_FLUSH_AVERAGE_SECONDS_OUTBOUND,
50 }
51 }
52}
53
54#[expect(clippy::type_complexity)]
61pub(crate) fn init_broadcast_channels<N: NetworkZone>(
62 config: BroadcastConfig,
63) -> (
64 BroadcastSvc<N>,
65 impl Fn(InternalPeerID<N::Addr>) -> BroadcastMessageStream<N> + Clone + Send + 'static,
66 impl Fn(InternalPeerID<N::Addr>) -> BroadcastMessageStream<N> + Clone + Send + 'static,
67) {
68 let outbound_dist = Exp::new(
69 1.0 / config
70 .diffusion_flush_average_seconds_outbound
71 .as_secs_f64(),
72 )
73 .unwrap();
74 let inbound_dist =
75 Exp::new(1.0 / config.diffusion_flush_average_seconds_inbound.as_secs_f64()).unwrap();
76
77 let (block_watch_sender, block_watch_receiver) = watch::channel(NewBlockInfo {
80 block_bytes: Default::default(),
81 current_blockchain_height: 0,
82 });
83
84 let (tx_broadcast_channel_outbound_sender, tx_broadcast_channel_outbound_receiver) =
86 broadcast::channel(MAX_TXS_IN_BROADCAST_CHANNEL);
87 let (tx_broadcast_channel_inbound_sender, tx_broadcast_channel_inbound_receiver) =
88 broadcast::channel(MAX_TXS_IN_BROADCAST_CHANNEL);
89
90 let broadcast_svc = BroadcastSvc {
92 new_block_watch: block_watch_sender,
93 tx_broadcast_channel_outbound: tx_broadcast_channel_outbound_sender,
94 tx_broadcast_channel_inbound: tx_broadcast_channel_inbound_sender,
95 };
96
97 let tx_channel_outbound_receiver_wrapped =
99 CloneableBroadcastReceiver(tx_broadcast_channel_outbound_receiver);
100 let tx_channel_inbound_receiver_wrapped =
101 CloneableBroadcastReceiver(tx_broadcast_channel_inbound_receiver);
102
103 let block_watch_receiver_cloned = block_watch_receiver.clone();
106 let outbound_stream_maker = move |addr| {
107 BroadcastMessageStream::new(
108 addr,
109 outbound_dist,
110 block_watch_receiver_cloned.clone(),
111 tx_channel_outbound_receiver_wrapped.clone().0,
112 )
113 };
114
115 let inbound_stream_maker = move |addr| {
116 BroadcastMessageStream::new(
117 addr,
118 inbound_dist,
119 block_watch_receiver.clone(),
120 tx_channel_inbound_receiver_wrapped.clone().0,
121 )
122 };
123
124 (broadcast_svc, outbound_stream_maker, inbound_stream_maker)
125}
126
127pub enum BroadcastRequest<N: NetworkZone> {
135 Block {
137 block_bytes: Bytes,
139 current_blockchain_height: u64,
141 },
142 Transaction {
146 tx_bytes: Bytes,
148 direction: Option<ConnectionDirection>,
150 received_from: Option<InternalPeerID<N::Addr>>,
152 },
153}
154
155#[derive(Clone)]
156pub struct BroadcastSvc<N: NetworkZone> {
157 new_block_watch: watch::Sender<NewBlockInfo>,
158 tx_broadcast_channel_outbound: broadcast::Sender<BroadcastTxInfo<N>>,
159 tx_broadcast_channel_inbound: broadcast::Sender<BroadcastTxInfo<N>>,
160}
161
162impl<N: NetworkZone> BroadcastSvc<N> {
163 pub fn mock() -> Self {
165 init_broadcast_channels(BroadcastConfig::default()).0
166 }
167}
168
169impl<N: NetworkZone> Service<BroadcastRequest<N>> for BroadcastSvc<N> {
170 type Response = ();
171 type Error = std::convert::Infallible;
172 type Future = Ready<Result<(), std::convert::Infallible>>;
173
174 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175 Poll::Ready(Ok(()))
176 }
177
178 fn call(&mut self, req: BroadcastRequest<N>) -> Self::Future {
179 match req {
180 BroadcastRequest::Block {
181 block_bytes,
182 current_blockchain_height,
183 } => {
184 tracing::debug!(
185 "queuing block at chain height {current_blockchain_height} for broadcast"
186 );
187
188 self.new_block_watch.send_replace(NewBlockInfo {
189 block_bytes,
190 current_blockchain_height,
191 });
192 }
193 BroadcastRequest::Transaction {
194 tx_bytes,
195 received_from,
196 direction,
197 } => {
198 let nex_tx_info = BroadcastTxInfo {
199 tx: tx_bytes,
200 received_from,
201 };
202
203 drop(match direction {
205 Some(ConnectionDirection::Inbound) => {
206 self.tx_broadcast_channel_inbound.send(nex_tx_info)
207 }
208 Some(ConnectionDirection::Outbound) => {
209 self.tx_broadcast_channel_outbound.send(nex_tx_info)
210 }
211 None => {
212 drop(self.tx_broadcast_channel_outbound.send(nex_tx_info.clone()));
213 self.tx_broadcast_channel_inbound.send(nex_tx_info)
214 }
215 });
216 }
217 }
218
219 ready(Ok(()))
220 }
221}
222
223struct CloneableBroadcastReceiver<T: Clone>(broadcast::Receiver<T>);
228
229impl<T: Clone> Clone for CloneableBroadcastReceiver<T> {
230 fn clone(&self) -> Self {
231 Self(self.0.resubscribe())
232 }
233}
234
235#[derive(Clone)]
237struct NewBlockInfo {
238 block_bytes: Bytes,
240 current_blockchain_height: u64,
242}
243
244#[derive(Clone)]
246struct BroadcastTxInfo<N: NetworkZone> {
247 tx: Bytes,
249 received_from: Option<InternalPeerID<N::Addr>>,
251}
252
253#[pin_project::pin_project]
257pub(crate) struct BroadcastMessageStream<N: NetworkZone> {
258 addr: InternalPeerID<N::Addr>,
260
261 #[pin]
263 new_block_watch: WatchStream<NewBlockInfo>,
264 tx_broadcast_channel: broadcast::Receiver<BroadcastTxInfo<N>>,
266
267 diffusion_flush_dist: Exp<f64>,
270 #[pin]
272 next_flush: Sleep,
273}
274
275impl<N: NetworkZone> BroadcastMessageStream<N> {
276 fn new(
278 addr: InternalPeerID<N::Addr>,
279 diffusion_flush_dist: Exp<f64>,
280 new_block_watch: watch::Receiver<NewBlockInfo>,
281 tx_broadcast_channel: broadcast::Receiver<BroadcastTxInfo<N>>,
282 ) -> Self {
283 let next_flush = Instant::now()
284 + Duration::from_secs_f64(diffusion_flush_dist.sample(&mut thread_rng()));
285
286 Self {
287 addr,
288 new_block_watch: WatchStream::from_changes(new_block_watch),
290 tx_broadcast_channel,
291 diffusion_flush_dist,
292 next_flush: sleep_until(next_flush),
293 }
294 }
295}
296
297impl<N: NetworkZone> Stream for BroadcastMessageStream<N> {
298 type Item = BroadcastMessage;
299
300 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301 let mut this = self.project();
302
303 if let Poll::Ready(res) = this.new_block_watch.poll_next(cx) {
305 let Some(block) = res else {
306 return Poll::Ready(None);
307 };
308
309 let block_mes = NewFluffyBlock {
310 b: BlockCompleteEntry {
311 pruned: false,
312 block: block.block_bytes,
313 block_weight: 0,
315 txs: TransactionBlobs::None,
316 },
317 current_blockchain_height: block.current_blockchain_height,
318 };
319
320 return Poll::Ready(Some(BroadcastMessage::NewFluffyBlock(block_mes)));
321 }
322
323 ready!(this.next_flush.as_mut().poll(cx));
324
325 let (txs, more_available) = get_txs_to_broadcast::<N>(this.addr, this.tx_broadcast_channel);
326
327 let next_flush = if more_available {
328 Instant::now()
330 } else {
331 Instant::now()
332 + Duration::from_secs_f64(this.diffusion_flush_dist.sample(&mut thread_rng()))
333 };
334
335 let next_flush = sleep_until(next_flush);
336 this.next_flush.set(next_flush);
337
338 if let Some(txs) = txs {
339 tracing::debug!(
340 "Diffusion flush timer expired, diffusing {} txs",
341 txs.txs.len()
342 );
343 Poll::Ready(Some(BroadcastMessage::NewTransactions(txs)))
345 } else {
346 tracing::trace!("Diffusion flush timer expired but no txs to diffuse");
347 #[expect(clippy::let_underscore_must_use)]
350 let _ = this.next_flush.poll(cx);
351 Poll::Pending
352 }
353 }
354}
355
356fn get_txs_to_broadcast<N: NetworkZone>(
359 addr: &InternalPeerID<N::Addr>,
360 broadcast_rx: &mut broadcast::Receiver<BroadcastTxInfo<N>>,
361) -> (Option<NewTransactions>, bool) {
362 let mut new_txs = NewTransactions {
363 txs: vec![],
364 dandelionpp_fluff: true,
365 padding: Bytes::new(),
366 };
367 let mut total_size = 0;
368
369 loop {
370 match broadcast_rx.try_recv() {
371 Ok(txs) => {
372 if txs.received_from.is_some_and(|from| &from == addr) {
373 continue;
375 }
376
377 total_size += txs.tx.len();
378
379 new_txs.txs.push(txs.tx);
380
381 if total_size > SOFT_TX_MESSAGE_SIZE_SIZE_LIMIT {
382 return (Some(new_txs), true);
383 }
384 }
385 Err(e) => match e {
386 TryRecvError::Empty | TryRecvError::Closed => {
387 if new_txs.txs.is_empty() {
388 return (None, false);
389 }
390 return (Some(new_txs), false);
391 }
392 TryRecvError::Lagged(lag) => {
393 tracing::debug!(
394 "{lag} transaction broadcast messages were missed, continuing."
395 );
396 continue;
397 }
398 },
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use std::{pin::pin, time::Duration};
406
407 use bytes::Bytes;
408 use futures::StreamExt;
409 use tokio::time::timeout;
410 use tower::{Service, ServiceExt};
411
412 use cuprate_p2p_core::{client::InternalPeerID, BroadcastMessage, ConnectionDirection};
413 use cuprate_test_utils::test_netzone::TestNetZone;
414
415 use super::{init_broadcast_channels, BroadcastConfig, BroadcastRequest};
416
417 const TEST_CONFIG: BroadcastConfig = BroadcastConfig {
418 diffusion_flush_average_seconds_outbound: Duration::from_millis(100),
419 diffusion_flush_average_seconds_inbound: Duration::from_millis(200),
420 };
421
422 #[tokio::test]
423 async fn tx_broadcast_direction_correct() {
424 let (mut brcst, outbound_mkr, inbound_mkr) =
425 init_broadcast_channels::<TestNetZone<true>>(TEST_CONFIG);
426
427 let mut outbound_stream = pin!(outbound_mkr(InternalPeerID::Unknown(1)));
428 let mut inbound_stream = pin!(inbound_mkr(InternalPeerID::Unknown(1)));
429
430 brcst
433 .ready()
434 .await
435 .unwrap()
436 .call(BroadcastRequest::Transaction {
437 tx_bytes: Bytes::from_static(&[1]),
438 direction: Some(ConnectionDirection::Outbound),
439 received_from: None,
440 })
441 .await
442 .unwrap();
443
444 brcst
445 .ready()
446 .await
447 .unwrap()
448 .call(BroadcastRequest::Transaction {
449 tx_bytes: Bytes::from_static(&[2]),
450 direction: Some(ConnectionDirection::Inbound),
451 received_from: None,
452 })
453 .await
454 .unwrap();
455
456 brcst
457 .ready()
458 .await
459 .unwrap()
460 .call(BroadcastRequest::Transaction {
461 tx_bytes: Bytes::from_static(&[3]),
462 direction: None,
463 received_from: None,
464 })
465 .await
466 .unwrap();
467
468 let match_tx = |mes, txs| match mes {
469 BroadcastMessage::NewTransactions(tx) => assert_eq!(tx.txs.as_slice(), txs),
470 BroadcastMessage::NewFluffyBlock(_) => panic!("Block broadcast?"),
471 };
472
473 let next = outbound_stream.next().await.unwrap();
474 let txs = [Bytes::from_static(&[1]), Bytes::from_static(&[3])];
475 match_tx(next, &txs);
476
477 let next = inbound_stream.next().await.unwrap();
478 match_tx(next, &[Bytes::from_static(&[2]), Bytes::from_static(&[3])]);
479 }
480
481 #[tokio::test]
482 async fn block_broadcast_sent_to_all() {
483 let (mut brcst, outbound_mkr, inbound_mkr) =
484 init_broadcast_channels::<TestNetZone<true>>(TEST_CONFIG);
485
486 let mut outbound_stream = pin!(outbound_mkr(InternalPeerID::Unknown(1)));
487 let mut inbound_stream = pin!(inbound_mkr(InternalPeerID::Unknown(1)));
488
489 brcst
490 .ready()
491 .await
492 .unwrap()
493 .call(BroadcastRequest::Block {
494 block_bytes: Default::default(),
495 current_blockchain_height: 0,
496 })
497 .await
498 .unwrap();
499
500 let next = outbound_stream.next().await.unwrap();
501 assert!(matches!(next, BroadcastMessage::NewFluffyBlock(_)));
502
503 let next = inbound_stream.next().await.unwrap();
504 assert!(matches!(next, BroadcastMessage::NewFluffyBlock(_)));
505 }
506
507 #[tokio::test]
508 async fn tx_broadcast_skipped_for_received_from_peer() {
509 let (mut brcst, outbound_mkr, inbound_mkr) =
510 init_broadcast_channels::<TestNetZone<true>>(TEST_CONFIG);
511
512 let mut outbound_stream = pin!(outbound_mkr(InternalPeerID::Unknown(1)));
513 let mut outbound_stream_from = pin!(outbound_mkr(InternalPeerID::Unknown(0)));
514
515 let mut inbound_stream = pin!(inbound_mkr(InternalPeerID::Unknown(1)));
516 let mut inbound_stream_from = pin!(inbound_mkr(InternalPeerID::Unknown(0)));
517
518 brcst
519 .ready()
520 .await
521 .unwrap()
522 .call(BroadcastRequest::Transaction {
523 tx_bytes: Bytes::from_static(&[1]),
524 direction: None,
525 received_from: Some(InternalPeerID::Unknown(0)),
526 })
527 .await
528 .unwrap();
529
530 let match_tx = |mes, txs| match mes {
531 BroadcastMessage::NewTransactions(tx) => assert_eq!(tx.txs.as_slice(), txs),
532 BroadcastMessage::NewFluffyBlock(_) => panic!("Block broadcast?"),
533 };
534
535 let next = outbound_stream.next().await.unwrap();
536 let txs = [Bytes::from_static(&[1])];
537 match_tx(next, &txs);
538
539 let next = inbound_stream.next().await.unwrap();
540 match_tx(next, &[Bytes::from_static(&[1])]);
541
542 assert!(timeout(
544 Duration::from_secs(2),
545 futures::future::select(inbound_stream_from.next(), outbound_stream_from.next())
546 )
547 .await
548 .is_err());
549 }
550}