cuprate_async_buffer/
lib.rs

1//! Async Buffer
2//!
3//! A bounded SPSC, FIFO, async buffer that supports arbitrary weights for values.
4//!
5//! Weight is used to bound the channel, on creation you specify a max weight and for each value you
6//! specify a weight.
7use std::{
8    cmp::min,
9    future::Future,
10    pin::Pin,
11    sync::{
12        atomic::{AtomicUsize, Ordering},
13        Arc,
14    },
15    task::{Context, Poll},
16};
17
18use futures::{
19    channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
20    ready,
21    task::AtomicWaker,
22    Stream, StreamExt,
23};
24
25#[derive(thiserror::Error, Debug, Copy, Clone, Eq, PartialEq)]
26pub enum BufferError {
27    #[error("The buffer did not have enough capacity.")]
28    NotEnoughCapacity,
29    #[error("The other end of the buffer disconnected.")]
30    Disconnected,
31}
32
33/// Initializes a new buffer with the provided capacity.
34///
35/// The capacity inputted is not the max number of items, it is the max combined weight of all items
36/// in the buffer.
37///
38/// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted.
39/// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be
40/// accepted.
41pub fn new_buffer<T>(max_item_weight: usize) -> (BufferAppender<T>, BufferStream<T>) {
42    let (tx, rx) = unbounded();
43    let sink_waker = Arc::new(AtomicWaker::new());
44    let capacity_atomic = Arc::new(AtomicUsize::new(max_item_weight));
45
46    (
47        BufferAppender {
48            queue: tx,
49            sink_waker: sink_waker.clone(),
50            capacity: capacity_atomic.clone(),
51            max_item_weight,
52        },
53        BufferStream {
54            queue: rx,
55            sink_waker,
56            capacity: capacity_atomic,
57        },
58    )
59}
60
61/// The stream side of the buffer.
62pub struct BufferStream<T> {
63    /// The internal queue of items.
64    queue: UnboundedReceiver<(T, usize)>,
65    /// The waker for the [`BufferAppender`]
66    sink_waker: Arc<AtomicWaker>,
67    /// The current capacity of the buffer.
68    capacity: Arc<AtomicUsize>,
69}
70
71impl<T> Stream for BufferStream<T> {
72    type Item = T;
73
74    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
75        let Some((item, size)) = ready!(self.queue.poll_next_unpin(cx)) else {
76            return Poll::Ready(None);
77        };
78
79        // add the capacity back to the buffer.
80        self.capacity.fetch_add(size, Ordering::AcqRel);
81        // wake the sink.
82        self.sink_waker.wake();
83
84        Poll::Ready(Some(item))
85    }
86}
87
88/// The appender/sink side of the buffer.
89pub struct BufferAppender<T> {
90    /// The internal queue of items.
91    queue: UnboundedSender<(T, usize)>,
92    /// Our waker.
93    sink_waker: Arc<AtomicWaker>,
94    /// The current capacity of the buffer.
95    capacity: Arc<AtomicUsize>,
96    /// The max weight of an item, equal to the total allowed weight of the buffer.
97    max_item_weight: usize,
98}
99
100impl<T> BufferAppender<T> {
101    /// Returns a future that resolves when the channel has enough capacity for
102    /// a single message of `size_needed`.
103    ///
104    /// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted.
105    /// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be
106    /// accepted.
107    pub fn ready(&mut self, size_needed: usize) -> BufferSinkReady<'_, T> {
108        let size_needed = min(self.max_item_weight, size_needed);
109
110        BufferSinkReady {
111            sink: self,
112            size_needed,
113        }
114    }
115
116    /// Attempts to add an item to the buffer.
117    ///
118    /// # Errors
119    /// Returns an error if there is not enough capacity or the [`BufferStream`] was dropped.
120    pub fn try_send(&mut self, item: T, size_needed: usize) -> Result<(), BufferError> {
121        let size_needed = min(self.max_item_weight, size_needed);
122
123        if self.capacity.load(Ordering::Acquire) < size_needed {
124            return Err(BufferError::NotEnoughCapacity);
125        }
126
127        let prev_size = self.capacity.fetch_sub(size_needed, Ordering::AcqRel);
128
129        // make sure we haven't wrapped the capacity around.
130        assert!(prev_size >= size_needed);
131
132        self.queue
133            .unbounded_send((item, size_needed))
134            .map_err(|_| BufferError::Disconnected)?;
135
136        Ok(())
137    }
138
139    /// Waits for capacity in the buffer and then sends the item.
140    pub fn send(&mut self, item: T, size_needed: usize) -> BufferSinkSend<'_, T> {
141        BufferSinkSend {
142            ready: self.ready(size_needed),
143            item: Some(item),
144        }
145    }
146}
147
148/// A [`Future`] for adding an item to the buffer.
149#[pin_project::pin_project]
150pub struct BufferSinkSend<'a, T> {
151    /// A future that resolves when the channel has capacity.
152    #[pin]
153    ready: BufferSinkReady<'a, T>,
154    /// The item to send.
155    ///
156    /// This is [`take`](Option::take)n and added to the buffer when there is enough capacity.
157    item: Option<T>,
158}
159
160impl<T> Future for BufferSinkSend<'_, T> {
161    type Output = Result<(), BufferError>;
162
163    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        let mut this = self.project();
165
166        let size_needed = this.ready.size_needed;
167
168        this.ready.as_mut().poll(cx).map(|_| {
169            this.ready
170                .sink
171                .try_send(this.item.take().unwrap(), size_needed)
172        })
173    }
174}
175
176/// A [`Future`] for waiting for capacity in the buffer.
177pub struct BufferSinkReady<'a, T> {
178    /// The sink side of the buffer.
179    sink: &'a mut BufferAppender<T>,
180    /// The capacity needed.
181    ///
182    /// This future will wait forever if this is higher than the total availability of the buffer.
183    size_needed: usize,
184}
185
186impl<T> Future for BufferSinkReady<'_, T> {
187    type Output = ();
188
189    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
190        // Check before setting the waker just in case it has capacity now,
191        if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
192            return Poll::Ready(());
193        }
194
195        // set the waker
196        self.sink.sink_waker.register(cx.waker());
197
198        // check the capacity again to avoid a race condition that would result in lost notifications.
199        if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
200            Poll::Ready(())
201        } else {
202            Poll::Pending
203        }
204    }
205}