use std::{
collections::HashMap,
hash::Hash,
marker::PhantomData,
pin::Pin,
task::{ready, Context, Poll},
time::Instant,
};
use futures::{future::BoxFuture, FutureExt, TryFutureExt, TryStream};
use rand::{distributions::Bernoulli, prelude::*, thread_rng};
use tower::Service;
use crate::{
traits::{DiffuseRequest, StemRequest},
DandelionConfig,
};
#[derive(thiserror::Error, Debug)]
pub enum DandelionRouterError {
#[error("Peer chosen to route stem txs to had an err: {0}.")]
PeerError(tower::BoxError),
#[error("Broadcast service returned an err: {0}.")]
BroadcastError(tower::BoxError),
#[error("The outbound peer stream returned an err: {0}.")]
OutboundPeerStreamError(tower::BoxError),
#[error("The outbound peer discoverer exited.")]
OutboundPeerDiscoverExited,
}
pub enum OutboundPeer<Id, T> {
Peer(Id, T),
Exhausted,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum State {
Fluff,
Stem,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TxState<Id> {
Fluff,
Stem {
from: Id,
},
Local,
}
impl<Id> TxState<Id> {
pub const fn is_stem_stage(&self) -> bool {
matches!(self, Self::Local | Self::Stem { .. })
}
}
pub struct DandelionRouteReq<Tx, Id> {
pub tx: Tx,
pub state: TxState<Id>,
}
pub struct DandelionRouter<P, B, Id, S, Tx> {
outbound_peer_discover: Pin<Box<P>>,
broadcast_svc: B,
current_state: State,
epoch_start: Instant,
local_route: Option<Id>,
stem_routes: HashMap<Id, Id>,
pub(crate) stem_peers: HashMap<Id, S>,
state_dist: Bernoulli,
config: DandelionConfig,
span: tracing::Span,
_tx: PhantomData<Tx>,
}
impl<Tx, Id, P, B, S> DandelionRouter<P, B, Id, S, Tx>
where
Id: Hash + Eq + Clone,
P: TryStream<Ok = OutboundPeer<Id, S>, Error = tower::BoxError>,
B: Service<DiffuseRequest<Tx>, Error = tower::BoxError>,
B::Future: Send + 'static,
S: Service<StemRequest<Tx>, Error = tower::BoxError>,
S::Future: Send + 'static,
{
pub fn new(broadcast_svc: B, outbound_peer_discover: P, config: DandelionConfig) -> Self {
let state_dist = Bernoulli::new(config.fluff_probability)
.expect("Fluff probability was not between 0 and 1");
let current_state = if state_dist.sample(&mut thread_rng()) {
State::Fluff
} else {
State::Stem
};
Self {
outbound_peer_discover: Box::pin(outbound_peer_discover),
broadcast_svc,
current_state,
epoch_start: Instant::now(),
local_route: None,
stem_routes: HashMap::new(),
stem_peers: HashMap::new(),
state_dist,
config,
span: tracing::debug_span!("dandelion_router", state = ?current_state),
_tx: PhantomData,
}
}
fn poll_prepare_graph(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), DandelionRouterError>> {
let peers_needed = match self.current_state {
State::Stem => self.config.number_of_stems(),
State::Fluff => 1,
};
while self.stem_peers.len() < peers_needed {
match ready!(self
.outbound_peer_discover
.as_mut()
.try_poll_next(cx)
.map_err(DandelionRouterError::OutboundPeerStreamError))
.ok_or(DandelionRouterError::OutboundPeerDiscoverExited)??
{
OutboundPeer::Peer(key, svc) => {
self.stem_peers.insert(key, svc);
}
OutboundPeer::Exhausted => {
tracing::warn!("Failed to retrieve enough outbound peers for optimal dandelion++, privacy may be degraded.");
return Poll::Ready(Ok(()));
}
}
}
Poll::Ready(Ok(()))
}
fn fluff_tx(&mut self, tx: Tx) -> BoxFuture<'static, Result<State, DandelionRouterError>> {
self.broadcast_svc
.call(DiffuseRequest(tx))
.map_ok(|_| State::Fluff)
.map_err(DandelionRouterError::BroadcastError)
.boxed()
}
fn stem_tx(
&mut self,
tx: Tx,
from: &Id,
) -> BoxFuture<'static, Result<State, DandelionRouterError>> {
if self.stem_peers.is_empty() {
tracing::debug!("Stem peers are empty, fluffing stem transaction.");
return self.fluff_tx(tx);
}
loop {
let stem_route = self.stem_routes.entry(from.clone()).or_insert_with(|| {
self.stem_peers
.iter()
.choose(&mut thread_rng())
.expect("No peers in `stem_peers` was poll_ready called?")
.0
.clone()
});
let Some(peer) = self.stem_peers.get_mut(stem_route) else {
self.stem_routes.remove(from);
continue;
};
return peer
.call(StemRequest(tx))
.map_ok(|_| State::Stem)
.map_err(DandelionRouterError::PeerError)
.boxed();
}
}
fn stem_local_tx(&mut self, tx: Tx) -> BoxFuture<'static, Result<State, DandelionRouterError>> {
if self.stem_peers.is_empty() {
tracing::warn!("Stem peers are empty, no outbound connections to stem local tx to, fluffing instead, privacy will be degraded.");
return self.fluff_tx(tx);
}
loop {
let stem_route = self.local_route.get_or_insert_with(|| {
self.stem_peers
.iter()
.choose(&mut thread_rng())
.expect("No peers in `stem_peers` was poll_ready called?")
.0
.clone()
});
let Some(peer) = self.stem_peers.get_mut(stem_route) else {
self.local_route.take();
continue;
};
return peer
.call(StemRequest(tx))
.map_ok(|_| State::Stem)
.map_err(DandelionRouterError::PeerError)
.boxed();
}
}
}
impl<Tx, Id, P, B, S> Service<DandelionRouteReq<Tx, Id>> for DandelionRouter<P, B, Id, S, Tx>
where
Id: Hash + Eq + Clone,
P: TryStream<Ok = OutboundPeer<Id, S>, Error = tower::BoxError>,
B: Service<DiffuseRequest<Tx>, Error = tower::BoxError>,
B::Future: Send + 'static,
S: Service<StemRequest<Tx>, Error = tower::BoxError>,
S::Future: Send + 'static,
{
type Response = State;
type Error = DandelionRouterError;
type Future = BoxFuture<'static, Result<State, DandelionRouterError>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.epoch_start.elapsed() > self.config.epoch_duration {
self.stem_peers.clear();
self.stem_routes.clear();
self.local_route.take();
self.current_state = if self.state_dist.sample(&mut thread_rng()) {
State::Fluff
} else {
State::Stem
};
self.span
.record("state", format!("{:?}", self.current_state));
tracing::debug!(parent: &self.span, "Starting new d++ epoch",);
self.epoch_start = Instant::now();
}
let mut peers_pending = false;
let span = &self.span;
self.stem_peers
.retain(|_, peer_svc| match peer_svc.poll_ready(cx) {
Poll::Ready(res) => res
.inspect_err(|e| {
tracing::debug!(
parent: span,
"Peer returned an error on `poll_ready`: {e}, removing from router.",
);
})
.is_ok(),
Poll::Pending => {
peers_pending = true;
true
}
});
if peers_pending {
return Poll::Pending;
}
ready!(self.poll_prepare_graph(cx)?);
ready!(self
.broadcast_svc
.poll_ready(cx)
.map_err(DandelionRouterError::BroadcastError)?);
Poll::Ready(Ok(()))
}
fn call(&mut self, req: DandelionRouteReq<Tx, Id>) -> Self::Future {
tracing::trace!(parent: &self.span, "Handling route request.");
match req.state {
TxState::Fluff => self.fluff_tx(req.tx),
TxState::Stem { from } => match self.current_state {
State::Fluff => {
tracing::debug!(parent: &self.span, "Fluffing stem tx.");
self.fluff_tx(req.tx)
}
State::Stem => {
tracing::trace!(parent: &self.span, "Steming transaction");
self.stem_tx(req.tx, &from)
}
},
TxState::Local => {
tracing::debug!(parent: &self.span, "Steming local tx.");
self.stem_local_tx(req.tx)
}
}
}
}