cuprate_dandelion_tower/pool/
manager.rs

1use std::{
2    collections::{HashMap, HashSet},
3    hash::Hash,
4    marker::PhantomData,
5    time::Duration,
6};
7
8use futures::{FutureExt, StreamExt};
9use rand::prelude::*;
10use rand_distr::Exp;
11use tokio::{
12    sync::{mpsc, oneshot},
13    task::JoinSet,
14};
15use tokio_util::time::DelayQueue;
16use tower::{Service, ServiceExt};
17
18use crate::{
19    pool::IncomingTx,
20    traits::{TxStoreRequest, TxStoreResponse},
21    DandelionConfig, DandelionRouteReq, DandelionRouterError, State, TxState,
22};
23
24#[derive(Copy, Clone, Debug, thiserror::Error)]
25#[error("The dandelion pool was shutdown")]
26pub struct DandelionPoolShutDown;
27
28/// The dandelion++ pool manager.
29///
30/// See the [module docs](super) for more.
31pub struct DandelionPoolManager<P, R, Tx, TxId, PeerId> {
32    /// The dandelion++ router
33    pub(crate) dandelion_router: R,
34    /// The backing tx storage.
35    pub(crate) backing_pool: P,
36    /// The set of tasks that are running the future returned from `dandelion_router`.
37    pub(crate) routing_set: JoinSet<(TxId, Result<State, TxState<PeerId>>)>,
38
39    /// The origin of stem transactions.
40    pub(crate) stem_origins: HashMap<TxId, HashSet<PeerId>>,
41
42    /// Current stem pool embargo timers.
43    pub(crate) embargo_timers: DelayQueue<TxId>,
44    /// The distrobution to sample to get embargo timers.
45    pub(crate) embargo_dist: Exp<f64>,
46
47    /// The d++ config.
48    pub(crate) config: DandelionConfig,
49
50    pub(crate) _tx: PhantomData<Tx>,
51}
52
53impl<P, R, Tx, TxId, PeerId> DandelionPoolManager<P, R, Tx, TxId, PeerId>
54where
55    Tx: Clone + Send,
56    TxId: Hash + Eq + Clone + Send + 'static,
57    PeerId: Hash + Eq + Clone + Send + 'static,
58    P: Service<TxStoreRequest<TxId>, Response = TxStoreResponse<Tx>, Error = tower::BoxError>,
59    P::Future: Send + 'static,
60    R: Service<DandelionRouteReq<Tx, PeerId>, Response = State, Error = DandelionRouterError>,
61    R::Future: Send + 'static,
62{
63    /// Adds a new embargo timer to the running timers, with a duration pulled from [`Self::embargo_dist`]
64    fn add_embargo_timer_for_tx(&mut self, tx_id: TxId) {
65        let embargo_timer = self.embargo_dist.sample(&mut thread_rng());
66        tracing::debug!(
67            "Setting embargo timer for stem tx: {} seconds.",
68            embargo_timer
69        );
70
71        self.embargo_timers
72            .insert(tx_id, Duration::from_secs_f64(embargo_timer));
73    }
74
75    /// Stems the tx, setting the stem origin, if it wasn't already set.
76    ///
77    /// This function does not add the tx to the backing pool.
78    async fn stem_tx(
79        &mut self,
80        tx: Tx,
81        tx_id: TxId,
82        from: Option<PeerId>,
83    ) -> Result<(), tower::BoxError> {
84        if let Some(peer) = &from {
85            self.stem_origins
86                .entry(tx_id.clone())
87                .or_default()
88                .insert(peer.clone());
89        }
90
91        let state = from.map_or(TxState::Local, |from| TxState::Stem { from });
92
93        let fut = self
94            .dandelion_router
95            .ready()
96            .await?
97            .call(DandelionRouteReq {
98                tx,
99                state: state.clone(),
100            });
101
102        self.routing_set
103            .spawn(fut.map(|res| (tx_id, res.map_err(|_| state))));
104        Ok(())
105    }
106
107    /// Fluffs a tx, does not add the tx to the tx pool.
108    async fn fluff_tx(&mut self, tx: Tx, tx_id: TxId) -> Result<(), tower::BoxError> {
109        let fut = self
110            .dandelion_router
111            .ready()
112            .await?
113            .call(DandelionRouteReq {
114                tx,
115                state: TxState::Fluff,
116            });
117
118        self.routing_set
119            .spawn(fut.map(|res| (tx_id, res.map_err(|_| TxState::Fluff))));
120        Ok(())
121    }
122
123    /// Function to handle an [`IncomingTx`].
124    async fn handle_incoming_tx(
125        &mut self,
126        tx: Tx,
127        tx_state: TxState<PeerId>,
128        tx_id: TxId,
129    ) -> Result<(), tower::BoxError> {
130        match tx_state {
131            TxState::Stem { from } => {
132                if self
133                    .stem_origins
134                    .get(&tx_id)
135                    .is_some_and(|peers| peers.contains(&from))
136                {
137                    tracing::debug!("Received stem tx twice from same peer, fluffing it");
138                    // The same peer sent us a tx twice, fluff it.
139                    self.promote_and_fluff_tx(tx_id).await?;
140                } else {
141                    // This could be a new tx or it could have already been stemed, but we still stem it again
142                    // unless the same peer sends us a tx twice.
143                    tracing::debug!("Steming incoming tx");
144                    self.stem_tx(tx, tx_id.clone(), Some(from)).await?;
145                    self.add_embargo_timer_for_tx(tx_id);
146                }
147            }
148            TxState::Fluff => {
149                tracing::debug!("Fluffing incoming tx");
150                self.fluff_tx(tx, tx_id).await?;
151            }
152            TxState::Local => {
153                tracing::debug!("Steming local transaction");
154                self.stem_tx(tx, tx_id.clone(), None).await?;
155                self.add_embargo_timer_for_tx(tx_id);
156            }
157        }
158
159        Ok(())
160    }
161
162    /// Promotes a tx to the clear pool.
163    async fn promote_tx(&mut self, tx_id: TxId) -> Result<(), tower::BoxError> {
164        // Remove the tx from the maps used during the stem phase.
165        self.stem_origins.remove(&tx_id);
166
167        // The key for this is *Not* the tx_id, it is given on insert, so just keep the timer in the
168        // map. These timers should be relatively short, so it shouldn't be a problem.
169        //self.embargo_timers.try_remove(&tx_id);
170
171        self.backing_pool
172            .ready()
173            .await?
174            .call(TxStoreRequest::Promote(tx_id))
175            .await?;
176
177        Ok(())
178    }
179
180    /// Promotes a tx to the public fluff pool and fluffs the tx.
181    async fn promote_and_fluff_tx(&mut self, tx_id: TxId) -> Result<(), tower::BoxError> {
182        tracing::debug!("Promoting transaction to public pool and fluffing it.");
183
184        let TxStoreResponse::Transaction(tx) = self
185            .backing_pool
186            .ready()
187            .await?
188            .call(TxStoreRequest::Get(tx_id.clone()))
189            .await?
190        else {
191            panic!("Backing tx pool responded with wrong response for request.");
192        };
193
194        let Some((tx, state)) = tx else {
195            tracing::debug!("Could not find tx, skipping.");
196            return Ok(());
197        };
198
199        if state == State::Fluff {
200            tracing::debug!("Transaction already fluffed, skipping.");
201            return Ok(());
202        }
203
204        self.promote_tx(tx_id.clone()).await?;
205        self.fluff_tx(tx, tx_id).await
206    }
207
208    /// Returns a tx stored in the fluff _OR_ stem pool.
209    async fn get_tx_from_pool(&mut self, tx_id: TxId) -> Result<Option<Tx>, tower::BoxError> {
210        let TxStoreResponse::Transaction(tx) = self
211            .backing_pool
212            .ready()
213            .await?
214            .call(TxStoreRequest::Get(tx_id))
215            .await?
216        else {
217            panic!("Backing tx pool responded with wrong response for request.");
218        };
219
220        Ok(tx.map(|tx| tx.0))
221    }
222
223    /// Starts the [`DandelionPoolManager`].
224    pub(crate) async fn run(
225        mut self,
226        mut rx: mpsc::Receiver<(IncomingTx<Tx, TxId, PeerId>, oneshot::Sender<()>)>,
227    ) {
228        tracing::debug!("Starting dandelion++ tx-pool, config: {:?}", self.config);
229
230        loop {
231            tracing::trace!("Waiting for next event.");
232            tokio::select! {
233                // biased to handle current txs before routing new ones.
234                biased;
235                Some(fired) = self.embargo_timers.next() => {
236                    tracing::debug!("Embargo timer fired, did not see stem tx in time.");
237
238                    let tx_id = fired.into_inner();
239                    if let Err(e) = self.promote_and_fluff_tx(tx_id).await {
240                        tracing::error!("Error handling fired embargo timer: {e}");
241                        return;
242                    }
243                }
244                Some(Ok((tx_id, res))) = self.routing_set.join_next() => {
245                    tracing::trace!("Received d++ routing result.");
246
247                    let res = match res {
248                        Ok(State::Fluff) => {
249                            tracing::debug!("Transaction was fluffed upgrading it to the public pool.");
250                            self.promote_tx(tx_id).await
251                        }
252                        Err(tx_state) => {
253                            tracing::debug!("Error routing transaction, trying again.");
254
255                            match self.get_tx_from_pool(tx_id.clone()).await {
256                                Ok(Some(tx)) => match tx_state {
257                                    TxState::Fluff => self.fluff_tx(tx, tx_id).await,
258                                    TxState::Stem { from } => self.stem_tx(tx, tx_id, Some(from)).await,
259                                    TxState::Local => self.stem_tx(tx, tx_id, None).await,
260                                }
261                                Err(e) => Err(e),
262                                _ => continue,
263                            }
264                        }
265                        Ok(State::Stem) => continue,
266                    };
267
268                    if let Err(e) = res {
269                        tracing::error!("Error handling transaction routing return: {e}");
270                        return;
271                    }
272                }
273                req = rx.recv() => {
274                    tracing::debug!("Received new tx to route.");
275
276                    let Some((IncomingTx { tx, tx_id, routing_state }, res_tx)) = req else {
277                        return;
278                    };
279
280                    if let Err(e) = self.handle_incoming_tx(tx, routing_state, tx_id).await {
281                        #[expect(clippy::let_underscore_must_use, reason = "dropped receivers can be ignored")]
282                        let _ = res_tx.send(());
283
284                        tracing::error!("Error handling transaction in dandelion pool: {e}");
285                        return;
286                    }
287
288                    #[expect(clippy::let_underscore_must_use)]
289                    let _ = res_tx.send(());
290                }
291            }
292        }
293    }
294}