use std::{
fmt::{Debug, Display, Formatter},
sync::{Arc, Mutex},
task::{ready, Context, Poll},
};
use futures::channel::oneshot;
use tokio::{
sync::{mpsc, OwnedSemaphorePermit, Semaphore},
task::JoinHandle,
};
use tokio_util::sync::PollSemaphore;
use tower::{Service, ServiceExt};
use tracing::Instrument;
use cuprate_helper::asynch::InfallibleOneshotReceiver;
use cuprate_pruning::PruningSeed;
use cuprate_wire::CoreSyncData;
use crate::{
handles::{ConnectionGuard, ConnectionHandle},
ConnectionDirection, NetworkZone, PeerError, PeerRequest, PeerResponse, SharedError,
};
mod connection;
mod connector;
pub mod handshaker;
mod request_handler;
mod timeout_monitor;
mod weak;
pub use connector::{ConnectRequest, Connector};
pub use handshaker::{DoHandshakeRequest, HandshakeError, HandshakerBuilder};
pub use weak::WeakClient;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum InternalPeerID<A> {
KnownAddr(A),
Unknown(u128),
}
impl<A: Display> Display for InternalPeerID<A> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::KnownAddr(addr) => addr.fmt(f),
Self::Unknown(id) => f.write_str(&format!("Unknown, ID: {id}")),
}
}
}
#[derive(Debug, Clone)]
pub struct PeerInformation<A> {
pub id: InternalPeerID<A>,
pub handle: ConnectionHandle,
pub direction: ConnectionDirection,
pub pruning_seed: PruningSeed,
pub core_sync_data: Arc<Mutex<CoreSyncData>>,
}
pub struct Client<Z: NetworkZone> {
pub info: PeerInformation<Z::Addr>,
connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
connection_handle: JoinHandle<()>,
timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
semaphore: PollSemaphore,
permit: Option<OwnedSemaphorePermit>,
error: SharedError<PeerError>,
}
impl<Z: NetworkZone> Client<Z> {
pub(crate) fn new(
info: PeerInformation<Z::Addr>,
connection_tx: mpsc::Sender<connection::ConnectionTaskRequest>,
connection_handle: JoinHandle<()>,
timeout_handle: JoinHandle<Result<(), tower::BoxError>>,
semaphore: Arc<Semaphore>,
error: SharedError<PeerError>,
) -> Self {
Self {
info,
connection_tx,
timeout_handle,
semaphore: PollSemaphore::new(semaphore),
permit: None,
connection_handle,
error,
}
}
fn set_err(&self, err: PeerError) -> tower::BoxError {
let err_str = err.to_string();
match self.error.try_insert_err(err) {
Ok(()) => err_str,
Err(e) => e.to_string(),
}
.into()
}
pub fn downgrade(&self) -> WeakClient<Z> {
WeakClient {
info: self.info.clone(),
connection_tx: self.connection_tx.downgrade(),
semaphore: self.semaphore.clone(),
permit: None,
error: self.error.clone(),
}
}
}
impl<Z: NetworkZone> Service<PeerRequest> for Client<Z> {
type Response = PeerResponse;
type Error = tower::BoxError;
type Future = InfallibleOneshotReceiver<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(err) = self.error.try_get_err() {
return Poll::Ready(Err(err.to_string().into()));
}
if self.connection_handle.is_finished() || self.timeout_handle.is_finished() {
let err = self.set_err(PeerError::ClientChannelClosed);
return Poll::Ready(Err(err));
}
if self.permit.is_some() {
return Poll::Ready(Ok(()));
}
let permit = ready!(self.semaphore.poll_acquire(cx))
.expect("Client semaphore should not be closed!");
self.permit = Some(permit);
Poll::Ready(Ok(()))
}
fn call(&mut self, request: PeerRequest) -> Self::Future {
let permit = self
.permit
.take()
.expect("poll_ready did not return ready before call to call");
let (tx, rx) = oneshot::channel();
let req = connection::ConnectionTaskRequest {
response_channel: tx,
request,
permit: Some(permit),
};
if let Err(e) = self.connection_tx.try_send(req) {
use mpsc::error::TrySendError;
match e {
TrySendError::Closed(req) | TrySendError::Full(req) => {
self.set_err(PeerError::ClientChannelClosed);
let resp = Err(PeerError::ClientChannelClosed.into());
drop(req.response_channel.send(resp));
}
}
}
rx.into()
}
}
pub fn mock_client<Z: NetworkZone, S>(
info: PeerInformation<Z::Addr>,
connection_guard: ConnectionGuard,
mut request_handler: S,
) -> Client<Z>
where
S: Service<PeerRequest, Response = PeerResponse, Error = tower::BoxError> + Send + 'static,
S::Future: Send + 'static,
{
let (tx, mut rx) = mpsc::channel(1);
let task_span = tracing::error_span!("mock_connection", addr = %info.id);
let task_handle = tokio::spawn(
async move {
let _guard = connection_guard;
loop {
let Some(req): Option<connection::ConnectionTaskRequest> = rx.recv().await else {
tracing::debug!("Channel closed, closing mock connection");
return;
};
tracing::debug!("Received new request: {:?}", req.request.id());
let res = request_handler
.ready()
.await
.unwrap()
.call(req.request)
.await
.unwrap();
tracing::debug!("Sending back response");
drop(req.response_channel.send(Ok(res)));
}
}
.instrument(task_span),
);
let timeout_task = tokio::spawn(futures::future::pending());
let semaphore = Arc::new(Semaphore::new(1));
let error_slot = SharedError::new();
Client::new(info, tx, task_handle, timeout_task, semaphore, error_slot)
}