postage/sync/
mpmc_circular_buffer.rs

1use std::{cmp::max, sync::atomic::AtomicUsize};
2
3use crate::Context;
4use atomic::Ordering;
5use parking_lot::{Mutex, RwLock};
6
7use super::notifier::Notifier;
8use std::fmt::Debug;
9
10// A lock-free multi-producer, multi-consumer circular buffer
11// Each reader will see each value created exactly once.
12// Cloned readers inherit the read location of the reader that was cloned.
13
14pub struct MpmcCircularBuffer<T> {
15    buffer: Box<[Slot<T>]>,
16    head: AtomicUsize,
17    maintenance: Mutex<()>,
18    readers: AtomicUsize,
19}
20
21impl<T> Debug for MpmcCircularBuffer<T> {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("MpmcCircularBuffer")
24            .field("buffer", &self.buffer)
25            .field("head", &self.head)
26            .field("readers", &self.readers)
27            .finish()
28    }
29}
30
31impl<T> MpmcCircularBuffer<T>
32where
33    T: Clone,
34{
35    pub fn new(capacity: usize) -> (Self, BufferReader) {
36        // we require two readers, so that unique slots can be acquired and released
37        let capacity = max(2, capacity);
38        let mut vec = Vec::with_capacity(capacity);
39
40        for _ in 0..capacity {
41            vec.push(Slot::new(0));
42        }
43
44        let this = Self {
45            buffer: vec.into_boxed_slice(),
46            head: AtomicUsize::new(1),
47            readers: AtomicUsize::new(1),
48            maintenance: Mutex::new(()),
49        };
50
51        let reader = BufferReader { index: 1 };
52
53        (this, reader)
54    }
55}
56
57pub enum TryWrite<T> {
58    Pending(T),
59    Ready,
60}
61
62pub enum SlotTryWrite<T> {
63    Pending(T),
64    Ready,
65    Written(T),
66}
67
68impl<T> MpmcCircularBuffer<T> {
69    pub fn len(&self) -> usize {
70        self.buffer.len()
71    }
72
73    pub fn try_write(&self, mut value: T, cx: &Context<'_>) -> TryWrite<T> {
74        loop {
75            let head_id = self.head.load(Ordering::Acquire);
76            let head_slot = self.get_slot(head_id);
77
78            #[cfg(feature = "debug")]
79            log::debug!(
80                "[{}] Attempting write with required readers {:?}, slot index {:?} with {:?} readers of {:?} required",
81                head_id,
82                &self.readers,
83                head_slot.index,
84                head_slot.reads,
85                &self.readers
86            );
87
88            // try to write a value
89            // if the write is accepted, release the head lock in the closure
90            // this minimizes the time head is locked, and allows the move of value to occur after the lock is released
91            let try_write = head_slot.try_write(head_id, value, &self.readers, cx, || {
92                if let Err(_e) = self.head.compare_exchange(
93                    head_id,
94                    head_id + 1,
95                    Ordering::SeqCst,
96                    Ordering::Relaxed,
97                ) {
98                    #[cfg(feature = "debug")]
99                    log::warn!(
100                        "[{}] Expected {} head value, found {}",
101                        head_id,
102                        head_id + 1,
103                        _e
104                    );
105                }
106            });
107
108            match try_write {
109                SlotTryWrite::Pending(v) => {
110                    return TryWrite::Pending(v);
111                }
112                SlotTryWrite::Ready => {
113                    #[cfg(feature = "debug")]
114                    let slot_index = head_id % self.len();
115
116                    #[cfg(feature = "debug")]
117                    log::info!(
118                        "[{}] Write complete in slot {}, head incremented from {} to {}",
119                        head_id,
120                        slot_index,
121                        head_id,
122                        head_id + 1
123                    );
124
125                    return TryWrite::Ready;
126                }
127                SlotTryWrite::Written(v) => {
128                    value = v;
129                    continue;
130                }
131            }
132        }
133    }
134
135    pub fn new_reader(&self) -> BufferReader {
136        let _maint = self.maintenance.lock();
137        let index = self.head.load(Ordering::Acquire);
138        self.readers.fetch_add(1, Ordering::AcqRel);
139
140        self.mark_read_in_range(0, index);
141
142        #[cfg(feature = "debug")]
143        log::info!("[{}] New reader", index);
144
145        BufferReader { index }
146    }
147
148    fn mark_read_in_range(&self, min: usize, max: usize) {
149        for slot in self.buffer.iter() {
150            let readers = self.readers.load(Ordering::Acquire);
151            slot.mark_read_in_range(min, max, readers);
152        }
153    }
154
155    pub(in crate::sync::mpmc_circular_buffer) fn get_slot(&self, id: usize) -> &Slot<T> {
156        let index = id % self.len();
157        &self.buffer[index]
158    }
159}
160
161#[derive(Debug)]
162pub struct BufferReader {
163    index: usize,
164}
165
166pub enum TryRead<T> {
167    /// A value is ready
168    Ready(T),
169    /// A value is pending in this slot
170    Pending,
171}
172
173impl BufferReader {
174    pub fn try_read<T>(&mut self, buffer: &MpmcCircularBuffer<T>, cx: &Context<'_>) -> TryRead<T>
175    where
176        T: Clone,
177    {
178        let index = self.index;
179        let slot = buffer.get_slot(index);
180
181        let try_read = slot.try_read(index, &buffer.readers, cx);
182
183        match &try_read {
184            TryRead::Ready(_) => {
185                self.index += 1;
186
187                #[cfg(feature = "debug")]
188                log::debug!(
189                    "[{}] Read complete in slot {} with {:?} reads of {:?} required",
190                    index,
191                    index % buffer.len(),
192                    slot.reads,
193                    &buffer.readers,
194                );
195            }
196            TryRead::Pending => {
197                #[cfg(feature = "debug")]
198                log::debug!("[{}] Read pending, slot: {:?}", index, slot);
199            }
200        }
201
202        try_read
203    }
204
205    // To avoid the need for shared Arc references, clone and drop are written as methods instead of using std traits
206    pub fn clone_with<T>(&self, buffer: &MpmcCircularBuffer<T>) -> Self {
207        let _maint = buffer.maintenance.lock();
208        buffer.readers.fetch_add(1, Ordering::AcqRel);
209
210        let index = self.index;
211        buffer.mark_read_in_range(0, index);
212
213        #[cfg(feature = "debug")]
214        log::error!("[{}] Cloned reader", index);
215
216        BufferReader { index }
217    }
218
219    pub fn drop_with<T>(&mut self, buffer: &MpmcCircularBuffer<T>) {
220        let _maint = buffer.maintenance.lock();
221
222        // first, cancel all reads that this reader has committed
223        buffer
224            .buffer
225            .iter()
226            .for_each(|slot| slot.decrement_read_in_range(0, self.index));
227
228        // then decrement the reader count
229        buffer.readers.fetch_sub(1, Ordering::AcqRel);
230
231        // then go through the buffer, and release any slots that should be released
232        for (_id, slot) in buffer.buffer.iter().enumerate() {
233            #[cfg(feature = "debug")]
234            log::debug!(
235                "[{}] Dropping reader, notifying slot {} with reads {:?} of new reader count {:?}",
236                self.index,
237                _id,
238                slot.reads,
239                buffer.readers,
240            );
241
242            slot.notify_readers_decreased(&buffer.readers);
243        }
244
245        #[cfg(feature = "debug")]
246        log::error!(
247            "[{}] Dropped reader, readers reduced to {:?}",
248            self.index,
249            buffer.readers
250        );
251    }
252}
253
254pub struct Slot<T> {
255    data: RwLock<Option<T>>,
256    reads: AtomicUsize,
257    index: AtomicUsize,
258    on_write: Notifier,
259    on_release: Notifier,
260}
261
262impl<T> Slot<T> {
263    pub fn new(index: usize) -> Self {
264        Self {
265            data: RwLock::new(None),
266            reads: AtomicUsize::new(0),
267            index: AtomicUsize::new(index),
268            on_write: Notifier::new(),
269            on_release: Notifier::new(),
270        }
271    }
272
273    pub fn try_write<OnWrite>(
274        &self,
275        index: usize,
276        value: T,
277        readers: &AtomicUsize,
278        cx: &Context<'_>,
279        on_write: OnWrite,
280    ) -> SlotTryWrite<T>
281    where
282        OnWrite: FnOnce(),
283    {
284        loop {
285            let prev_index = self.index.load(Ordering::Acquire);
286
287            if prev_index >= index {
288                return SlotTryWrite::Written(value);
289            } else if prev_index != 0
290                && self.reads.load(Ordering::Acquire) < readers.load(Ordering::Acquire)
291            {
292                self.on_release.subscribe(cx);
293
294                if prev_index < self.index.load(Ordering::Acquire) {
295                    #[cfg(feature = "debug")]
296                    log::warn!(
297                        "[{}] Slot index advanced during write, invalidating subscription",
298                        index
299                    );
300                    continue;
301                }
302
303                if self.reads.load(Ordering::Acquire) >= readers.load(Ordering::Acquire) {
304                    #[cfg(feature = "debug")]
305                    log::warn!(
306                        "[{}] Reads incremented during write, invalidating subscription",
307                        index
308                    );
309                    continue;
310                }
311
312                return SlotTryWrite::Pending(value);
313            }
314
315            // lock the data, then update the index
316            let mut data = self.data.write();
317            if prev_index != 0
318                && self.reads.load(Ordering::Acquire) < readers.load(Ordering::Acquire)
319            {
320                #[cfg(feature = "debug")]
321                log::warn!(
322                    "[{}] Reads decreased during write (upgrading index {})",
323                    index,
324                    prev_index
325                );
326                continue;
327            }
328
329            if self
330                .index
331                .compare_exchange(prev_index, index, Ordering::AcqRel, Ordering::Relaxed)
332                .is_err()
333            {
334                continue;
335            }
336
337            on_write();
338            *data = Some(value);
339            self.reads.store(0, Ordering::Release);
340            self.on_write.notify();
341            return SlotTryWrite::Ready;
342        }
343    }
344
345    fn mark_read_in_range(&self, min: usize, max: usize, readers: usize) {
346        // prevent the index from changing while maintenance is performed
347        let _read = self.data.read();
348        let index = self.index.load(Ordering::Acquire);
349        if index >= min && index < max {
350            let reads = 1 + self.reads.fetch_add(1, Ordering::AcqRel);
351
352            #[cfg(feature = "debug")]
353            log::debug!(
354                "[{}] Mark read in range occurred.  Increased reads to {} of required readers {}",
355                index,
356                reads,
357                readers
358            );
359
360            if reads >= readers {
361                self.on_release.notify();
362            }
363        }
364    }
365
366    fn decrement_read_in_range(&self, min: usize, max: usize) {
367        // prevent the index from changing while maintenance is performed
368        let _read = self.data.read();
369        let index = self.index.load(Ordering::Acquire);
370        if index >= min && index < max {
371            loop {
372                let reads = self.reads.load(Ordering::Acquire);
373                if reads == 0 {
374                    return;
375                }
376
377                if self
378                    .reads
379                    .compare_exchange(reads, reads - 1, Ordering::Acquire, Ordering::Relaxed)
380                    .is_ok()
381                {
382                    #[cfg(feature = "debug")]
383                    log::debug!(
384                        "[{}] Mark decrement in range occurred.  Decreased reads to {}",
385                        index,
386                        reads - 1
387                    );
388
389                    return;
390                }
391            }
392        }
393    }
394
395    fn notify_readers_decreased(&self, readers: &AtomicUsize) {
396        if self.reads.load(Ordering::Acquire) >= readers.load(Ordering::Acquire) {
397            self.on_release.notify();
398        }
399    }
400}
401
402impl<T> Slot<T>
403where
404    T: Clone,
405{
406    #[allow(clippy::comparison_chain)]
407    pub fn try_read(&self, index: usize, readers: &AtomicUsize, cx: &Context<'_>) -> TryRead<T> {
408        loop {
409            let slot_index = self.index.load(Ordering::Acquire);
410            if slot_index < index {
411                self.on_write.subscribe(cx);
412
413                // if the index has advanced, continue and attempt to read again
414                if self.index.load(Ordering::Acquire) >= index {
415                    continue;
416                }
417
418                return TryRead::Pending;
419            } else if slot_index > index {
420                #[cfg(feature = "debug")]
421                log::error!(
422                    "Slot index {} has advanced past reader position {}",
423                    slot_index,
424                    index
425                );
426                return TryRead::Pending;
427            }
428
429            let data_lock = self.data.read();
430
431            let reads = 1 + self.reads.fetch_add(1, Ordering::AcqRel);
432            #[cfg(feature = "debug")]
433            log::debug!(
434                "[{}] Read action occurred.  Increased reads to {}",
435                index,
436                reads
437            );
438
439            // the only way the slot could be uninitialized is if `index` is 0,
440            // but readers are initialized with index: 1
441            // if the slot index was 0, then the above code would have returned TryRead::Pending
442            let data_ref = data_lock.as_ref().unwrap();
443            let data_cloned = data_ref.clone();
444
445            if reads >= readers.load(Ordering::Acquire) {
446                self.on_release.notify();
447            }
448
449            break TryRead::Ready(data_cloned);
450        }
451    }
452}
453
454impl<T> Debug for Slot<T> {
455    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456        f.debug_struct("Slot")
457            .field("reads", &self.reads)
458            .field("index", &self.index)
459            .finish()
460    }
461}