cuprate_async_buffer/
lib.rs1use std::{
8 cmp::min,
9 future::Future,
10 pin::Pin,
11 sync::{
12 atomic::{AtomicUsize, Ordering},
13 Arc,
14 },
15 task::{Context, Poll},
16};
17
18use futures::{
19 channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
20 ready,
21 task::AtomicWaker,
22 Stream, StreamExt,
23};
24
25#[derive(thiserror::Error, Debug, Copy, Clone, Eq, PartialEq)]
26pub enum BufferError {
27 #[error("The buffer did not have enough capacity.")]
28 NotEnoughCapacity,
29 #[error("The other end of the buffer disconnected.")]
30 Disconnected,
31}
32
33pub fn new_buffer<T>(max_item_weight: usize) -> (BufferAppender<T>, BufferStream<T>) {
42 let (tx, rx) = unbounded();
43 let sink_waker = Arc::new(AtomicWaker::new());
44 let capacity_atomic = Arc::new(AtomicUsize::new(max_item_weight));
45
46 (
47 BufferAppender {
48 queue: tx,
49 sink_waker: sink_waker.clone(),
50 capacity: capacity_atomic.clone(),
51 max_item_weight,
52 },
53 BufferStream {
54 queue: rx,
55 sink_waker,
56 capacity: capacity_atomic,
57 },
58 )
59}
60
61pub struct BufferStream<T> {
63 queue: UnboundedReceiver<(T, usize)>,
65 sink_waker: Arc<AtomicWaker>,
67 capacity: Arc<AtomicUsize>,
69}
70
71impl<T> Stream for BufferStream<T> {
72 type Item = T;
73
74 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
75 let Some((item, size)) = ready!(self.queue.poll_next_unpin(cx)) else {
76 return Poll::Ready(None);
77 };
78
79 self.capacity.fetch_add(size, Ordering::AcqRel);
81 self.sink_waker.wake();
83
84 Poll::Ready(Some(item))
85 }
86}
87
88pub struct BufferAppender<T> {
90 queue: UnboundedSender<(T, usize)>,
92 sink_waker: Arc<AtomicWaker>,
94 capacity: Arc<AtomicUsize>,
96 max_item_weight: usize,
98}
99
100impl<T> BufferAppender<T> {
101 pub fn ready(&mut self, size_needed: usize) -> BufferSinkReady<'_, T> {
108 let size_needed = min(self.max_item_weight, size_needed);
109
110 BufferSinkReady {
111 sink: self,
112 size_needed,
113 }
114 }
115
116 pub fn try_send(&mut self, item: T, size_needed: usize) -> Result<(), BufferError> {
121 let size_needed = min(self.max_item_weight, size_needed);
122
123 if self.capacity.load(Ordering::Acquire) < size_needed {
124 return Err(BufferError::NotEnoughCapacity);
125 }
126
127 let prev_size = self.capacity.fetch_sub(size_needed, Ordering::AcqRel);
128
129 assert!(prev_size >= size_needed);
131
132 self.queue
133 .unbounded_send((item, size_needed))
134 .map_err(|_| BufferError::Disconnected)?;
135
136 Ok(())
137 }
138
139 pub fn send(&mut self, item: T, size_needed: usize) -> BufferSinkSend<'_, T> {
141 BufferSinkSend {
142 ready: self.ready(size_needed),
143 item: Some(item),
144 }
145 }
146}
147
148#[pin_project::pin_project]
150pub struct BufferSinkSend<'a, T> {
151 #[pin]
153 ready: BufferSinkReady<'a, T>,
154 item: Option<T>,
158}
159
160impl<T> Future for BufferSinkSend<'_, T> {
161 type Output = Result<(), BufferError>;
162
163 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 let mut this = self.project();
165
166 let size_needed = this.ready.size_needed;
167
168 this.ready.as_mut().poll(cx).map(|_| {
169 this.ready
170 .sink
171 .try_send(this.item.take().unwrap(), size_needed)
172 })
173 }
174}
175
176pub struct BufferSinkReady<'a, T> {
178 sink: &'a mut BufferAppender<T>,
180 size_needed: usize,
184}
185
186impl<T> Future for BufferSinkReady<'_, T> {
187 type Output = ();
188
189 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
190 if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
192 return Poll::Ready(());
193 }
194
195 self.sink.sink_waker.register(cx.waker());
197
198 if self.sink.capacity.load(Ordering::Acquire) >= self.size_needed {
200 Poll::Ready(())
201 } else {
202 Poll::Pending
203 }
204 }
205}