cuprate_async_buffer/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
//! Async Buffer
//!
//! A bounded SPSC, FIFO, async buffer that supports arbitrary weights for values.
//!
//! Weight is used to bound the channel, on creation you specify a max weight and for each value you
//! specify a weight.
use std::{
cmp::min,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{Context, Poll},
};
use futures::{
channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
ready,
task::AtomicWaker,
Stream, StreamExt,
};
#[derive(thiserror::Error, Debug, Copy, Clone, Eq, PartialEq)]
pub enum BufferError {
#[error("The buffer did not have enough capacity.")]
NotEnoughCapacity,
#[error("The other end of the buffer disconnected.")]
Disconnected,
}
/// Initializes a new buffer with the provided capacity.
///
/// The capacity inputted is not the max number of items, it is the max combined weight of all items
/// in the buffer.
///
/// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted.
/// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be
/// accepted.
pub fn new_buffer<T>(max_item_weight: usize) -> (BufferAppender<T>, BufferStream<T>) {
let (tx, rx) = unbounded();
let sink_waker = Arc::new(AtomicWaker::new());
let capacity_atomic = Arc::new(AtomicUsize::new(max_item_weight));
(
BufferAppender {
queue: tx,
sink_waker: sink_waker.clone(),
capacity: capacity_atomic.clone(),
max_item_weight,
},
BufferStream {
queue: rx,
sink_waker,
capacity: capacity_atomic,
},
)
}
/// The stream side of the buffer.
pub struct BufferStream<T> {
/// The internal queue of items.
queue: UnboundedReceiver<(T, usize)>,
/// The waker for the [`BufferAppender`]
sink_waker: Arc<AtomicWaker>,
/// The current capacity of the buffer.
capacity: Arc<AtomicUsize>,
}
impl<T> Stream for BufferStream<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Some((item, size)) = ready!(self.queue.poll_next_unpin(cx)) else {
return Poll::Ready(None);
};
// add the capacity back to the buffer.
self.capacity.fetch_add(size, Ordering::AcqRel);
// wake the sink.
self.sink_waker.wake();
Poll::Ready(Some(item))
}
}
/// The appender/sink side of the buffer.
pub struct BufferAppender<T> {
/// The internal queue of items.
queue: UnboundedSender<(T, usize)>,
/// Our waker.
sink_waker: Arc<AtomicWaker>,
/// The current capacity of the buffer.
capacity: Arc<AtomicUsize>,
/// The max weight of an item, equal to the total allowed weight of the buffer.
max_item_weight: usize,
}
impl<T> BufferAppender<T> {
/// Returns a future that resolves when the channel has enough capacity for
/// a single message of `size_needed`.
///
/// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted.
/// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be
/// accepted.
pub fn ready(&mut self, size_needed: usize) -> BufferSinkReady<'_, T> {
let size_needed = min(self.max_item_weight, size_needed);
BufferSinkReady {
sink: self,
size_needed,
}
}
/// Attempts to add an item to the buffer.
///
/// # Errors
/// Returns an error if there is not enough capacity or the [`BufferStream`] was dropped.
pub fn try_send(&mut self, item: T, size_needed: usize) -> Result<(), BufferError> {
let size_needed = min(self.max_item_weight, size_needed);
if self.capacity.load(Ordering::Acquire) < size_needed {
return Err(BufferError::NotEnoughCapacity);
}
let prev_size = self.capacity.fetch_sub(size_needed, Ordering::AcqRel);
// make sure we haven't wrapped the capacity around.
assert!(prev_size >= size_needed);
self.queue
.unbounded_send((item, size_needed))
.map_err(|_| BufferError::Disconnected)?;
Ok(())
}
/// Waits for capacity in the buffer and then sends the item.
pub fn send(&mut self, item: T, size_needed: usize) -> BufferSinkSend<'_, T> {
BufferSinkSend {
ready: self.ready(size_needed),
item: Some(item),
}
}
}
/// A [`Future`] for adding an item to the buffer.
#[pin_project::pin_project]
pub struct BufferSinkSend<'a, T> {
/// A future that resolves when the channel has capacity.
#[pin]
ready: BufferSinkReady<'a, T>,
/// The item to send.
///
/// This is [`take`](Option::take)n and added to the buffer when there is enough capacity.
item: Option<T>,
}
impl<T> Future for BufferSinkSend<'_, T> {
type Output = Result<(), BufferError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
let size_needed = this.ready.size_needed;
this.ready.as_mut().poll(cx).map(|_| {
this.ready
.sink
.try_send(this.item.take().unwrap(), size_needed)
})
}
}
/// A [`Future`] for waiting for capacity in the buffer.
pub struct BufferSinkReady<'a, T> {
/// The sink side of the buffer.
sink: &'a mut BufferAppender<T>,
/// The capacity needed.
///
/// This future will wait forever if this is higher than the total availability of the buffer.
size_needed: usize,
}
impl<T> Future for BufferSinkReady<'_, T> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Check before setting the waker just in case it has capacity now,
if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
return Poll::Ready(());
}
// set the waker
self.sink.sink_waker.register(cx.waker());
// check the capacity again to avoid a race condition that would result in lost notifications.
if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
Poll::Ready(())
} else {
Poll::Pending
}
}
}