1use std::{cmp::max, sync::atomic::AtomicUsize};
2
3use crate::Context;
4use atomic::Ordering;
5use parking_lot::{Mutex, RwLock};
6
7use super::notifier::Notifier;
8use std::fmt::Debug;
9
10pub struct MpmcCircularBuffer<T> {
15 buffer: Box<[Slot<T>]>,
16 head: AtomicUsize,
17 maintenance: Mutex<()>,
18 readers: AtomicUsize,
19}
20
21impl<T> Debug for MpmcCircularBuffer<T> {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("MpmcCircularBuffer")
24 .field("buffer", &self.buffer)
25 .field("head", &self.head)
26 .field("readers", &self.readers)
27 .finish()
28 }
29}
30
31impl<T> MpmcCircularBuffer<T>
32where
33 T: Clone,
34{
35 pub fn new(capacity: usize) -> (Self, BufferReader) {
36 let capacity = max(2, capacity);
38 let mut vec = Vec::with_capacity(capacity);
39
40 for _ in 0..capacity {
41 vec.push(Slot::new(0));
42 }
43
44 let this = Self {
45 buffer: vec.into_boxed_slice(),
46 head: AtomicUsize::new(1),
47 readers: AtomicUsize::new(1),
48 maintenance: Mutex::new(()),
49 };
50
51 let reader = BufferReader { index: 1 };
52
53 (this, reader)
54 }
55}
56
57pub enum TryWrite<T> {
58 Pending(T),
59 Ready,
60}
61
62pub enum SlotTryWrite<T> {
63 Pending(T),
64 Ready,
65 Written(T),
66}
67
68impl<T> MpmcCircularBuffer<T> {
69 pub fn len(&self) -> usize {
70 self.buffer.len()
71 }
72
73 pub fn try_write(&self, mut value: T, cx: &Context<'_>) -> TryWrite<T> {
74 loop {
75 let head_id = self.head.load(Ordering::Acquire);
76 let head_slot = self.get_slot(head_id);
77
78 #[cfg(feature = "debug")]
79 log::debug!(
80 "[{}] Attempting write with required readers {:?}, slot index {:?} with {:?} readers of {:?} required",
81 head_id,
82 &self.readers,
83 head_slot.index,
84 head_slot.reads,
85 &self.readers
86 );
87
88 let try_write = head_slot.try_write(head_id, value, &self.readers, cx, || {
92 if let Err(_e) = self.head.compare_exchange(
93 head_id,
94 head_id + 1,
95 Ordering::SeqCst,
96 Ordering::Relaxed,
97 ) {
98 #[cfg(feature = "debug")]
99 log::warn!(
100 "[{}] Expected {} head value, found {}",
101 head_id,
102 head_id + 1,
103 _e
104 );
105 }
106 });
107
108 match try_write {
109 SlotTryWrite::Pending(v) => {
110 return TryWrite::Pending(v);
111 }
112 SlotTryWrite::Ready => {
113 #[cfg(feature = "debug")]
114 let slot_index = head_id % self.len();
115
116 #[cfg(feature = "debug")]
117 log::info!(
118 "[{}] Write complete in slot {}, head incremented from {} to {}",
119 head_id,
120 slot_index,
121 head_id,
122 head_id + 1
123 );
124
125 return TryWrite::Ready;
126 }
127 SlotTryWrite::Written(v) => {
128 value = v;
129 continue;
130 }
131 }
132 }
133 }
134
135 pub fn new_reader(&self) -> BufferReader {
136 let _maint = self.maintenance.lock();
137 let index = self.head.load(Ordering::Acquire);
138 self.readers.fetch_add(1, Ordering::AcqRel);
139
140 self.mark_read_in_range(0, index);
141
142 #[cfg(feature = "debug")]
143 log::info!("[{}] New reader", index);
144
145 BufferReader { index }
146 }
147
148 fn mark_read_in_range(&self, min: usize, max: usize) {
149 for slot in self.buffer.iter() {
150 let readers = self.readers.load(Ordering::Acquire);
151 slot.mark_read_in_range(min, max, readers);
152 }
153 }
154
155 pub(in crate::sync::mpmc_circular_buffer) fn get_slot(&self, id: usize) -> &Slot<T> {
156 let index = id % self.len();
157 &self.buffer[index]
158 }
159}
160
161#[derive(Debug)]
162pub struct BufferReader {
163 index: usize,
164}
165
166pub enum TryRead<T> {
167 Ready(T),
169 Pending,
171}
172
173impl BufferReader {
174 pub fn try_read<T>(&mut self, buffer: &MpmcCircularBuffer<T>, cx: &Context<'_>) -> TryRead<T>
175 where
176 T: Clone,
177 {
178 let index = self.index;
179 let slot = buffer.get_slot(index);
180
181 let try_read = slot.try_read(index, &buffer.readers, cx);
182
183 match &try_read {
184 TryRead::Ready(_) => {
185 self.index += 1;
186
187 #[cfg(feature = "debug")]
188 log::debug!(
189 "[{}] Read complete in slot {} with {:?} reads of {:?} required",
190 index,
191 index % buffer.len(),
192 slot.reads,
193 &buffer.readers,
194 );
195 }
196 TryRead::Pending => {
197 #[cfg(feature = "debug")]
198 log::debug!("[{}] Read pending, slot: {:?}", index, slot);
199 }
200 }
201
202 try_read
203 }
204
205 pub fn clone_with<T>(&self, buffer: &MpmcCircularBuffer<T>) -> Self {
207 let _maint = buffer.maintenance.lock();
208 buffer.readers.fetch_add(1, Ordering::AcqRel);
209
210 let index = self.index;
211 buffer.mark_read_in_range(0, index);
212
213 #[cfg(feature = "debug")]
214 log::error!("[{}] Cloned reader", index);
215
216 BufferReader { index }
217 }
218
219 pub fn drop_with<T>(&mut self, buffer: &MpmcCircularBuffer<T>) {
220 let _maint = buffer.maintenance.lock();
221
222 buffer
224 .buffer
225 .iter()
226 .for_each(|slot| slot.decrement_read_in_range(0, self.index));
227
228 buffer.readers.fetch_sub(1, Ordering::AcqRel);
230
231 for (_id, slot) in buffer.buffer.iter().enumerate() {
233 #[cfg(feature = "debug")]
234 log::debug!(
235 "[{}] Dropping reader, notifying slot {} with reads {:?} of new reader count {:?}",
236 self.index,
237 _id,
238 slot.reads,
239 buffer.readers,
240 );
241
242 slot.notify_readers_decreased(&buffer.readers);
243 }
244
245 #[cfg(feature = "debug")]
246 log::error!(
247 "[{}] Dropped reader, readers reduced to {:?}",
248 self.index,
249 buffer.readers
250 );
251 }
252}
253
254pub struct Slot<T> {
255 data: RwLock<Option<T>>,
256 reads: AtomicUsize,
257 index: AtomicUsize,
258 on_write: Notifier,
259 on_release: Notifier,
260}
261
262impl<T> Slot<T> {
263 pub fn new(index: usize) -> Self {
264 Self {
265 data: RwLock::new(None),
266 reads: AtomicUsize::new(0),
267 index: AtomicUsize::new(index),
268 on_write: Notifier::new(),
269 on_release: Notifier::new(),
270 }
271 }
272
273 pub fn try_write<OnWrite>(
274 &self,
275 index: usize,
276 value: T,
277 readers: &AtomicUsize,
278 cx: &Context<'_>,
279 on_write: OnWrite,
280 ) -> SlotTryWrite<T>
281 where
282 OnWrite: FnOnce(),
283 {
284 loop {
285 let prev_index = self.index.load(Ordering::Acquire);
286
287 if prev_index >= index {
288 return SlotTryWrite::Written(value);
289 } else if prev_index != 0
290 && self.reads.load(Ordering::Acquire) < readers.load(Ordering::Acquire)
291 {
292 self.on_release.subscribe(cx);
293
294 if prev_index < self.index.load(Ordering::Acquire) {
295 #[cfg(feature = "debug")]
296 log::warn!(
297 "[{}] Slot index advanced during write, invalidating subscription",
298 index
299 );
300 continue;
301 }
302
303 if self.reads.load(Ordering::Acquire) >= readers.load(Ordering::Acquire) {
304 #[cfg(feature = "debug")]
305 log::warn!(
306 "[{}] Reads incremented during write, invalidating subscription",
307 index
308 );
309 continue;
310 }
311
312 return SlotTryWrite::Pending(value);
313 }
314
315 let mut data = self.data.write();
317 if prev_index != 0
318 && self.reads.load(Ordering::Acquire) < readers.load(Ordering::Acquire)
319 {
320 #[cfg(feature = "debug")]
321 log::warn!(
322 "[{}] Reads decreased during write (upgrading index {})",
323 index,
324 prev_index
325 );
326 continue;
327 }
328
329 if self
330 .index
331 .compare_exchange(prev_index, index, Ordering::AcqRel, Ordering::Relaxed)
332 .is_err()
333 {
334 continue;
335 }
336
337 on_write();
338 *data = Some(value);
339 self.reads.store(0, Ordering::Release);
340 self.on_write.notify();
341 return SlotTryWrite::Ready;
342 }
343 }
344
345 fn mark_read_in_range(&self, min: usize, max: usize, readers: usize) {
346 let _read = self.data.read();
348 let index = self.index.load(Ordering::Acquire);
349 if index >= min && index < max {
350 let reads = 1 + self.reads.fetch_add(1, Ordering::AcqRel);
351
352 #[cfg(feature = "debug")]
353 log::debug!(
354 "[{}] Mark read in range occurred. Increased reads to {} of required readers {}",
355 index,
356 reads,
357 readers
358 );
359
360 if reads >= readers {
361 self.on_release.notify();
362 }
363 }
364 }
365
366 fn decrement_read_in_range(&self, min: usize, max: usize) {
367 let _read = self.data.read();
369 let index = self.index.load(Ordering::Acquire);
370 if index >= min && index < max {
371 loop {
372 let reads = self.reads.load(Ordering::Acquire);
373 if reads == 0 {
374 return;
375 }
376
377 if self
378 .reads
379 .compare_exchange(reads, reads - 1, Ordering::Acquire, Ordering::Relaxed)
380 .is_ok()
381 {
382 #[cfg(feature = "debug")]
383 log::debug!(
384 "[{}] Mark decrement in range occurred. Decreased reads to {}",
385 index,
386 reads - 1
387 );
388
389 return;
390 }
391 }
392 }
393 }
394
395 fn notify_readers_decreased(&self, readers: &AtomicUsize) {
396 if self.reads.load(Ordering::Acquire) >= readers.load(Ordering::Acquire) {
397 self.on_release.notify();
398 }
399 }
400}
401
402impl<T> Slot<T>
403where
404 T: Clone,
405{
406 #[allow(clippy::comparison_chain)]
407 pub fn try_read(&self, index: usize, readers: &AtomicUsize, cx: &Context<'_>) -> TryRead<T> {
408 loop {
409 let slot_index = self.index.load(Ordering::Acquire);
410 if slot_index < index {
411 self.on_write.subscribe(cx);
412
413 if self.index.load(Ordering::Acquire) >= index {
415 continue;
416 }
417
418 return TryRead::Pending;
419 } else if slot_index > index {
420 #[cfg(feature = "debug")]
421 log::error!(
422 "Slot index {} has advanced past reader position {}",
423 slot_index,
424 index
425 );
426 return TryRead::Pending;
427 }
428
429 let data_lock = self.data.read();
430
431 let reads = 1 + self.reads.fetch_add(1, Ordering::AcqRel);
432 #[cfg(feature = "debug")]
433 log::debug!(
434 "[{}] Read action occurred. Increased reads to {}",
435 index,
436 reads
437 );
438
439 let data_ref = data_lock.as_ref().unwrap();
443 let data_cloned = data_ref.clone();
444
445 if reads >= readers.load(Ordering::Acquire) {
446 self.on_release.notify();
447 }
448
449 break TryRead::Ready(data_cloned);
450 }
451 }
452}
453
454impl<T> Debug for Slot<T> {
455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456 f.debug_struct("Slot")
457 .field("reads", &self.reads)
458 .field("index", &self.index)
459 .finish()
460 }
461}