cuprate_dandelion_tower/pool/
manager.rs

1use std::{
2    collections::{HashMap, HashSet},
3    hash::Hash,
4    marker::PhantomData,
5    time::Duration,
6};
7
8use futures::{FutureExt, StreamExt};
9use rand::prelude::*;
10use rand_distr::Exp;
11use tokio::{
12    sync::{mpsc, oneshot},
13    task::JoinSet,
14};
15use tokio_util::time::{delay_queue, DelayQueue};
16use tower::{Service, ServiceExt};
17use tracing::Instrument;
18
19use crate::{
20    pool::IncomingTx,
21    traits::{TxStoreRequest, TxStoreResponse},
22    DandelionConfig, DandelionRouteReq, DandelionRouterError, State, TxState,
23};
24
25#[derive(Copy, Clone, Debug, thiserror::Error)]
26#[error("The dandelion pool was shutdown")]
27pub struct DandelionPoolShutDown;
28
29/// The dandelion++ pool manager.
30///
31/// See the [module docs](super) for more.
32pub struct DandelionPoolManager<P, R, Tx, TxId, PeerId> {
33    /// The dandelion++ router
34    pub(crate) dandelion_router: R,
35    /// The backing tx storage.
36    pub(crate) backing_pool: P,
37    /// The set of tasks that are running the future returned from `dandelion_router`.
38    pub(crate) routing_set: JoinSet<(TxId, Result<State, TxState<PeerId>>)>,
39
40    /// The origin of stem transactions.
41    pub(crate) stem_origins: HashMap<TxId, HashSet<PeerId>>,
42
43    /// Current stem pool embargo timers.
44    pub(crate) embargo_timers: DelayQueue<TxId>,
45    pub(crate) embargo_timer_keys: HashMap<TxId, delay_queue::Key>,
46    /// The distrobution to sample to get embargo timers.
47    pub(crate) embargo_dist: Exp<f64>,
48
49    /// The d++ config.
50    pub(crate) config: DandelionConfig,
51
52    pub(crate) _tx: PhantomData<Tx>,
53}
54
55impl<P, R, Tx, TxId, PeerId> DandelionPoolManager<P, R, Tx, TxId, PeerId>
56where
57    Tx: Clone + Send,
58    TxId: Hash + Eq + Clone + Send + 'static,
59    PeerId: Hash + Eq + Clone + Send + 'static,
60    P: Service<TxStoreRequest<TxId>, Response = TxStoreResponse<Tx>, Error = tower::BoxError>,
61    P::Future: Send + 'static,
62    R: Service<DandelionRouteReq<Tx, PeerId>, Response = State, Error = DandelionRouterError>,
63    R::Future: Send + 'static,
64{
65    /// Adds a new embargo timer to the running timers, with a duration pulled from [`Self::embargo_dist`]
66    fn add_embargo_timer_for_tx(&mut self, tx_id: TxId) {
67        let embargo_timer = self.embargo_dist.sample(&mut thread_rng());
68        tracing::debug!(
69            "Setting embargo timer for stem tx: {} seconds.",
70            embargo_timer
71        );
72
73        let key = self
74            .embargo_timers
75            .insert(tx_id.clone(), Duration::from_secs_f64(embargo_timer));
76        self.embargo_timer_keys.insert(tx_id, key);
77    }
78
79    /// Stems the tx, setting the stem origin, if it wasn't already set.
80    ///
81    /// This function does not add the tx to the backing pool.
82    async fn stem_tx(
83        &mut self,
84        tx: Tx,
85        tx_id: TxId,
86        from: Option<PeerId>,
87    ) -> Result<(), tower::BoxError> {
88        if let Some(peer) = &from {
89            self.stem_origins
90                .entry(tx_id.clone())
91                .or_default()
92                .insert(peer.clone());
93        }
94
95        let state = from.map_or(TxState::Local, |from| TxState::Stem { from });
96
97        let fut = self
98            .dandelion_router
99            .ready()
100            .await?
101            .call(DandelionRouteReq {
102                tx,
103                state: state.clone(),
104            });
105
106        self.routing_set
107            .spawn(fut.map(|res| (tx_id, res.map_err(|_| state))));
108        Ok(())
109    }
110
111    /// Fluffs a tx, does not add the tx to the tx pool.
112    async fn fluff_tx(&mut self, tx: Tx, tx_id: TxId) -> Result<(), tower::BoxError> {
113        let fut = self
114            .dandelion_router
115            .ready()
116            .await?
117            .call(DandelionRouteReq {
118                tx,
119                state: TxState::Fluff,
120            });
121
122        self.routing_set
123            .spawn(fut.map(|res| (tx_id, res.map_err(|_| TxState::Fluff))));
124        Ok(())
125    }
126
127    /// Function to handle an [`IncomingTx`].
128    async fn handle_incoming_tx(
129        &mut self,
130        tx: Tx,
131        tx_state: TxState<PeerId>,
132        tx_id: TxId,
133    ) -> Result<(), tower::BoxError> {
134        match tx_state {
135            TxState::Stem { from } => {
136                if self
137                    .stem_origins
138                    .get(&tx_id)
139                    .is_some_and(|peers| peers.contains(&from))
140                {
141                    tracing::debug!("Received stem tx twice from same peer, fluffing it");
142                    // The same peer sent us a tx twice, fluff it.
143                    self.promote_and_fluff_tx(tx_id).await?;
144                } else {
145                    // This could be a new tx or it could have already been stemed, but we still stem it again
146                    // unless the same peer sends us a tx twice.
147                    tracing::debug!("Steming incoming tx");
148                    self.stem_tx(tx, tx_id.clone(), Some(from)).await?;
149                    self.add_embargo_timer_for_tx(tx_id);
150                }
151            }
152            TxState::Fluff => {
153                tracing::debug!("Fluffing incoming tx");
154                self.fluff_tx(tx, tx_id).await?;
155            }
156            TxState::Local => {
157                tracing::debug!("Steming local transaction");
158                self.stem_tx(tx, tx_id.clone(), None).await?;
159                self.add_embargo_timer_for_tx(tx_id);
160            }
161        }
162
163        Ok(())
164    }
165
166    /// Promotes a tx to the clear pool.
167    async fn promote_tx(&mut self, tx_id: TxId) -> Result<(), tower::BoxError> {
168        // Remove the tx from the maps used during the stem phase.
169        self.stem_origins.remove(&tx_id);
170
171        if let Some(key) = self.embargo_timer_keys.remove(&tx_id) {
172            self.embargo_timers.try_remove(&key);
173        }
174
175        self.backing_pool
176            .ready()
177            .await?
178            .call(TxStoreRequest::Promote(tx_id))
179            .await?;
180
181        Ok(())
182    }
183
184    /// Promotes a tx to the public fluff pool and fluffs the tx.
185    async fn promote_and_fluff_tx(&mut self, tx_id: TxId) -> Result<(), tower::BoxError> {
186        tracing::debug!("Promoting transaction to public pool and fluffing it.");
187
188        let TxStoreResponse::Transaction(tx) = self
189            .backing_pool
190            .ready()
191            .await?
192            .call(TxStoreRequest::Get(tx_id.clone()))
193            .await?
194        else {
195            panic!("Backing tx pool responded with wrong response for request.");
196        };
197
198        let Some((tx, state)) = tx else {
199            tracing::debug!("Could not find tx, skipping.");
200            return Ok(());
201        };
202
203        if state == State::Fluff {
204            tracing::debug!("Transaction already fluffed, skipping.");
205            return Ok(());
206        }
207
208        self.promote_tx(tx_id.clone()).await?;
209        self.fluff_tx(tx, tx_id).await
210    }
211
212    /// Returns a tx stored in the fluff _OR_ stem pool.
213    async fn get_tx_from_pool(&mut self, tx_id: TxId) -> Result<Option<Tx>, tower::BoxError> {
214        let TxStoreResponse::Transaction(tx) = self
215            .backing_pool
216            .ready()
217            .await?
218            .call(TxStoreRequest::Get(tx_id))
219            .await?
220        else {
221            panic!("Backing tx pool responded with wrong response for request.");
222        };
223
224        Ok(tx.map(|tx| tx.0))
225    }
226
227    #[expect(clippy::type_complexity)]
228    /// Starts the [`DandelionPoolManager`].
229    pub(crate) async fn run(
230        mut self,
231        mut rx: mpsc::Receiver<(
232            (IncomingTx<Tx, TxId, PeerId>, tracing::Span),
233            oneshot::Sender<()>,
234        )>,
235    ) {
236        tracing::debug!("Starting dandelion++ tx-pool, config: {:?}", self.config);
237
238        loop {
239            tokio::select! {
240                // biased to handle current txs before routing new ones.
241                biased;
242                Some(fired) = self.embargo_timers.next() => {
243                    let span = tracing::debug_span!("embargo_timer_fired");
244                    tracing::debug!(parent: &span,"Embargo timer fired, did not see stem tx in time.");
245
246                    let tx_id = fired.into_inner();
247                    if let Err(e) = self.promote_and_fluff_tx(tx_id).instrument(span.clone()).await {
248                        tracing::error!(parent: &span, "Error handling fired embargo timer: {e}");
249                        return;
250                    }
251                }
252                Some(Ok((tx_id, res))) = self.routing_set.join_next() => {
253                    let span = tracing::debug_span!("dandelion_routing_result");
254
255                    let res = match res {
256                        Ok(State::Fluff) => {
257                            tracing::debug!(parent: &span, "Transaction was fluffed upgrading it to the public pool.");
258                            self.promote_tx(tx_id).instrument(span.clone()).await
259                        }
260                        Err(tx_state) => {
261                            tracing::debug!(parent: &span, "Error routing transaction, trying again.");
262
263                            match self.get_tx_from_pool(tx_id.clone()).instrument(span.clone()).await {
264                                Ok(Some(tx)) => match tx_state {
265                                    TxState::Fluff => self.fluff_tx(tx, tx_id).instrument(span.clone()).await,
266                                    TxState::Stem { from } => self.stem_tx(tx, tx_id, Some(from)).instrument(span.clone()).await,
267                                    TxState::Local => self.stem_tx(tx, tx_id, None).instrument(span.clone()).await,
268                                }
269                                Err(e) => Err(e),
270                                _ => continue,
271                            }
272                        }
273                        Ok(State::Stem) => continue,
274                    };
275
276                    if let Err(e) = res {
277                        tracing::error!(parent: &span, "Error handling transaction routing return: {e}");
278                        return;
279                    }
280                }
281                req = rx.recv() => {
282                    let Some(((IncomingTx { tx, tx_id, routing_state }, span), res_tx)) = req else {
283                        return;
284                    };
285
286                    let span = tracing::debug_span!(parent: &span, "dandelion_pool_manager");
287
288                    tracing::debug!(parent: &span, "Received new tx to route.");
289
290                    if let Err(e) = self.handle_incoming_tx(tx, routing_state, tx_id).instrument(span.clone()).await {
291                        #[expect(clippy::let_underscore_must_use, reason = "dropped receivers can be ignored")]
292                        let _ = res_tx.send(());
293
294                        tracing::error!(parent: &span, "Error handling transaction in dandelion pool: {e}");
295                        return;
296                    }
297
298                    #[expect(clippy::let_underscore_must_use)]
299                    let _ = res_tx.send(());
300                }
301            }
302        }
303    }
304}