cuprate_async_buffer/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
//! Async Buffer
//!
//! A bounded SPSC, FIFO, async buffer that supports arbitrary weights for values.
//!
//! Weight is used to bound the channel, on creation you specify a max weight and for each value you
//! specify a weight.
use std::{
    cmp::min,
    future::Future,
    pin::Pin,
    sync::{
        atomic::{AtomicUsize, Ordering},
        Arc,
    },
    task::{Context, Poll},
};

use futures::{
    channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
    ready,
    task::AtomicWaker,
    Stream, StreamExt,
};

#[derive(thiserror::Error, Debug, Copy, Clone, Eq, PartialEq)]
pub enum BufferError {
    #[error("The buffer did not have enough capacity.")]
    NotEnoughCapacity,
    #[error("The other end of the buffer disconnected.")]
    Disconnected,
}

/// Initializes a new buffer with the provided capacity.
///
/// The capacity inputted is not the max number of items, it is the max combined weight of all items
/// in the buffer.
///
/// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted.
/// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be
/// accepted.
pub fn new_buffer<T>(max_item_weight: usize) -> (BufferAppender<T>, BufferStream<T>) {
    let (tx, rx) = unbounded();
    let sink_waker = Arc::new(AtomicWaker::new());
    let capacity_atomic = Arc::new(AtomicUsize::new(max_item_weight));

    (
        BufferAppender {
            queue: tx,
            sink_waker: sink_waker.clone(),
            capacity: capacity_atomic.clone(),
            max_item_weight,
        },
        BufferStream {
            queue: rx,
            sink_waker,
            capacity: capacity_atomic,
        },
    )
}

/// The stream side of the buffer.
pub struct BufferStream<T> {
    /// The internal queue of items.
    queue: UnboundedReceiver<(T, usize)>,
    /// The waker for the [`BufferAppender`]
    sink_waker: Arc<AtomicWaker>,
    /// The current capacity of the buffer.
    capacity: Arc<AtomicUsize>,
}

impl<T> Stream for BufferStream<T> {
    type Item = T;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let Some((item, size)) = ready!(self.queue.poll_next_unpin(cx)) else {
            return Poll::Ready(None);
        };

        // add the capacity back to the buffer.
        self.capacity.fetch_add(size, Ordering::AcqRel);
        // wake the sink.
        self.sink_waker.wake();

        Poll::Ready(Some(item))
    }
}

/// The appender/sink side of the buffer.
pub struct BufferAppender<T> {
    /// The internal queue of items.
    queue: UnboundedSender<(T, usize)>,
    /// Our waker.
    sink_waker: Arc<AtomicWaker>,
    /// The current capacity of the buffer.
    capacity: Arc<AtomicUsize>,
    /// The max weight of an item, equal to the total allowed weight of the buffer.
    max_item_weight: usize,
}

impl<T> BufferAppender<T> {
    /// Returns a future that resolves when the channel has enough capacity for
    /// a single message of `size_needed`.
    ///
    /// It should be noted that if there are no items in the buffer then a single item of any capacity is accepted.
    /// i.e. if the capacity is 5 and there are no items in the buffer then any item even if it's weight is >5 will be
    /// accepted.
    pub fn ready(&mut self, size_needed: usize) -> BufferSinkReady<'_, T> {
        let size_needed = min(self.max_item_weight, size_needed);

        BufferSinkReady {
            sink: self,
            size_needed,
        }
    }

    /// Attempts to add an item to the buffer.
    ///
    /// # Errors
    /// Returns an error if there is not enough capacity or the [`BufferStream`] was dropped.
    pub fn try_send(&mut self, item: T, size_needed: usize) -> Result<(), BufferError> {
        let size_needed = min(self.max_item_weight, size_needed);

        if self.capacity.load(Ordering::Acquire) < size_needed {
            return Err(BufferError::NotEnoughCapacity);
        }

        let prev_size = self.capacity.fetch_sub(size_needed, Ordering::AcqRel);

        // make sure we haven't wrapped the capacity around.
        assert!(prev_size >= size_needed);

        self.queue
            .unbounded_send((item, size_needed))
            .map_err(|_| BufferError::Disconnected)?;

        Ok(())
    }

    /// Waits for capacity in the buffer and then sends the item.
    pub fn send(&mut self, item: T, size_needed: usize) -> BufferSinkSend<'_, T> {
        BufferSinkSend {
            ready: self.ready(size_needed),
            item: Some(item),
        }
    }
}

/// A [`Future`] for adding an item to the buffer.
#[pin_project::pin_project]
pub struct BufferSinkSend<'a, T> {
    /// A future that resolves when the channel has capacity.
    #[pin]
    ready: BufferSinkReady<'a, T>,
    /// The item to send.
    ///
    /// This is [`take`](Option::take)n and added to the buffer when there is enough capacity.
    item: Option<T>,
}

impl<T> Future for BufferSinkSend<'_, T> {
    type Output = Result<(), BufferError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut this = self.project();

        let size_needed = this.ready.size_needed;

        this.ready.as_mut().poll(cx).map(|_| {
            this.ready
                .sink
                .try_send(this.item.take().unwrap(), size_needed)
        })
    }
}

/// A [`Future`] for waiting for capacity in the buffer.
pub struct BufferSinkReady<'a, T> {
    /// The sink side of the buffer.
    sink: &'a mut BufferAppender<T>,
    /// The capacity needed.
    ///
    /// This future will wait forever if this is higher than the total availability of the buffer.
    size_needed: usize,
}

impl<T> Future for BufferSinkReady<'_, T> {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // Check before setting the waker just in case it has capacity now,
        if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
            return Poll::Ready(());
        }

        // set the waker
        self.sink.sink_waker.register(cx.waker());

        // check the capacity again to avoid a race condition that would result in lost notifications.
        if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
            Poll::Ready(())
        } else {
            Poll::Pending
        }
    }
}