tokio/sync/mpsc/
chan.rs

1use crate::loom::cell::UnsafeCell;
2use crate::loom::future::AtomicWaker;
3use crate::loom::sync::atomic::AtomicUsize;
4use crate::loom::sync::Arc;
5use crate::runtime::park::CachedParkThread;
6use crate::sync::mpsc::error::TryRecvError;
7use crate::sync::mpsc::{bounded, list, unbounded};
8use crate::sync::notify::Notify;
9use crate::util::cacheline::CachePadded;
10
11use std::fmt;
12use std::panic;
13use std::process;
14use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
15use std::task::Poll::{Pending, Ready};
16use std::task::{ready, Context, Poll};
17
18/// Channel sender.
19pub(crate) struct Tx<T, S> {
20    inner: Arc<Chan<T, S>>,
21}
22
23impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
24    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
25        fmt.debug_struct("Tx").field("inner", &self.inner).finish()
26    }
27}
28
29/// Channel receiver.
30pub(crate) struct Rx<T, S: Semaphore> {
31    inner: Arc<Chan<T, S>>,
32}
33
34impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> {
35    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
36        fmt.debug_struct("Rx").field("inner", &self.inner).finish()
37    }
38}
39
40pub(crate) trait Semaphore {
41    fn is_idle(&self) -> bool;
42
43    fn add_permit(&self);
44
45    fn add_permits(&self, n: usize);
46
47    fn close(&self);
48
49    fn is_closed(&self) -> bool;
50}
51
52pub(super) struct Chan<T, S> {
53    /// Handle to the push half of the lock-free list.
54    tx: CachePadded<list::Tx<T>>,
55
56    /// Receiver waker. Notified when a value is pushed into the channel.
57    rx_waker: CachePadded<AtomicWaker>,
58
59    /// Notifies all tasks listening for the receiver being dropped.
60    notify_rx_closed: Notify,
61
62    /// Coordinates access to channel's capacity.
63    semaphore: S,
64
65    /// Tracks the number of outstanding sender handles.
66    ///
67    /// When this drops to zero, the send half of the channel is closed.
68    tx_count: AtomicUsize,
69
70    /// Tracks the number of outstanding weak sender handles.
71    tx_weak_count: AtomicUsize,
72
73    /// Only accessed by `Rx` handle.
74    rx_fields: UnsafeCell<RxFields<T>>,
75}
76
77impl<T, S> fmt::Debug for Chan<T, S>
78where
79    S: fmt::Debug,
80{
81    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
82        fmt.debug_struct("Chan")
83            .field("tx", &*self.tx)
84            .field("semaphore", &self.semaphore)
85            .field("rx_waker", &*self.rx_waker)
86            .field("tx_count", &self.tx_count)
87            .field("rx_fields", &"...")
88            .finish()
89    }
90}
91
92/// Fields only accessed by `Rx` handle.
93struct RxFields<T> {
94    /// Channel receiver. This field is only accessed by the `Receiver` type.
95    list: list::Rx<T>,
96
97    /// `true` if `Rx::close` is called.
98    rx_closed: bool,
99}
100
101impl<T> fmt::Debug for RxFields<T> {
102    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
103        fmt.debug_struct("RxFields")
104            .field("list", &self.list)
105            .field("rx_closed", &self.rx_closed)
106            .finish()
107    }
108}
109
110unsafe impl<T: Send, S: Send> Send for Chan<T, S> {}
111unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {}
112impl<T, S> panic::RefUnwindSafe for Chan<T, S> {}
113impl<T, S> panic::UnwindSafe for Chan<T, S> {}
114
115pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) {
116    let (tx, rx) = list::channel();
117
118    let chan = Arc::new(Chan {
119        notify_rx_closed: Notify::new(),
120        tx: CachePadded::new(tx),
121        semaphore,
122        rx_waker: CachePadded::new(AtomicWaker::new()),
123        tx_count: AtomicUsize::new(1),
124        tx_weak_count: AtomicUsize::new(0),
125        rx_fields: UnsafeCell::new(RxFields {
126            list: rx,
127            rx_closed: false,
128        }),
129    });
130
131    (Tx::new(chan.clone()), Rx::new(chan))
132}
133
134// ===== impl Tx =====
135
136impl<T, S> Tx<T, S> {
137    fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> {
138        Tx { inner: chan }
139    }
140
141    pub(super) fn strong_count(&self) -> usize {
142        self.inner.tx_count.load(Acquire)
143    }
144
145    pub(super) fn weak_count(&self) -> usize {
146        self.inner.tx_weak_count.load(Relaxed)
147    }
148
149    pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> {
150        self.inner.increment_weak_count();
151
152        self.inner.clone()
153    }
154
155    // Returns the upgraded channel or None if the upgrade failed.
156    pub(super) fn upgrade(chan: Arc<Chan<T, S>>) -> Option<Self> {
157        let mut tx_count = chan.tx_count.load(Acquire);
158
159        loop {
160            if tx_count == 0 {
161                // channel is closed
162                return None;
163            }
164
165            match chan
166                .tx_count
167                .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire)
168            {
169                Ok(_) => return Some(Tx { inner: chan }),
170                Err(prev_count) => tx_count = prev_count,
171            }
172        }
173    }
174
175    pub(super) fn semaphore(&self) -> &S {
176        &self.inner.semaphore
177    }
178
179    /// Send a message and notify the receiver.
180    pub(crate) fn send(&self, value: T) {
181        self.inner.send(value);
182    }
183
184    /// Wake the receive half
185    pub(crate) fn wake_rx(&self) {
186        self.inner.rx_waker.wake();
187    }
188
189    /// Returns `true` if senders belong to the same channel.
190    pub(crate) fn same_channel(&self, other: &Self) -> bool {
191        Arc::ptr_eq(&self.inner, &other.inner)
192    }
193}
194
195impl<T, S: Semaphore> Tx<T, S> {
196    pub(crate) fn is_closed(&self) -> bool {
197        self.inner.semaphore.is_closed()
198    }
199
200    pub(crate) async fn closed(&self) {
201        // In order to avoid a race condition, we first request a notification,
202        // **then** check whether the semaphore is closed. If the semaphore is
203        // closed the notification request is dropped.
204        let notified = self.inner.notify_rx_closed.notified();
205
206        if self.inner.semaphore.is_closed() {
207            return;
208        }
209        notified.await;
210    }
211}
212
213impl<T, S> Clone for Tx<T, S> {
214    fn clone(&self) -> Tx<T, S> {
215        // Using a Relaxed ordering here is sufficient as the caller holds a
216        // strong ref to `self`, preventing a concurrent decrement to zero.
217        self.inner.tx_count.fetch_add(1, Relaxed);
218
219        Tx {
220            inner: self.inner.clone(),
221        }
222    }
223}
224
225impl<T, S> Drop for Tx<T, S> {
226    fn drop(&mut self) {
227        if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 {
228            return;
229        }
230
231        // Close the list, which sends a `Close` message
232        self.inner.tx.close();
233
234        // Notify the receiver
235        self.wake_rx();
236    }
237}
238
239// ===== impl Rx =====
240
241impl<T, S: Semaphore> Rx<T, S> {
242    fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> {
243        Rx { inner: chan }
244    }
245
246    pub(crate) fn close(&mut self) {
247        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
248            let rx_fields = unsafe { &mut *rx_fields_ptr };
249
250            if rx_fields.rx_closed {
251                return;
252            }
253
254            rx_fields.rx_closed = true;
255        });
256
257        self.inner.semaphore.close();
258        self.inner.notify_rx_closed.notify_waiters();
259    }
260
261    pub(crate) fn is_closed(&self) -> bool {
262        // There two internal states that can represent a closed channel
263        //
264        //  1. When `close` is called.
265        //  In this case, the inner semaphore will be closed.
266        //
267        //  2. When all senders are dropped.
268        //  In this case, the semaphore remains unclosed, and the `index` in the list won't
269        //  reach the tail position. It is necessary to check the list if the last block is
270        //  `closed`.
271        self.inner.semaphore.is_closed() || self.inner.tx_count.load(Acquire) == 0
272    }
273
274    pub(crate) fn is_empty(&self) -> bool {
275        self.inner.rx_fields.with(|rx_fields_ptr| {
276            let rx_fields = unsafe { &*rx_fields_ptr };
277            rx_fields.list.is_empty(&self.inner.tx)
278        })
279    }
280
281    pub(crate) fn len(&self) -> usize {
282        self.inner.rx_fields.with(|rx_fields_ptr| {
283            let rx_fields = unsafe { &*rx_fields_ptr };
284            rx_fields.list.len(&self.inner.tx)
285        })
286    }
287
288    /// Receive the next value
289    pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
290        use super::block::Read;
291
292        ready!(crate::trace::trace_leaf(cx));
293
294        // Keep track of task budget
295        let coop = ready!(crate::runtime::coop::poll_proceed(cx));
296
297        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
298            let rx_fields = unsafe { &mut *rx_fields_ptr };
299
300            macro_rules! try_recv {
301                () => {
302                    match rx_fields.list.pop(&self.inner.tx) {
303                        Some(Read::Value(value)) => {
304                            self.inner.semaphore.add_permit();
305                            coop.made_progress();
306                            return Ready(Some(value));
307                        }
308                        Some(Read::Closed) => {
309                            // TODO: This check may not be required as it most
310                            // likely can only return `true` at this point. A
311                            // channel is closed when all tx handles are
312                            // dropped. Dropping a tx handle releases memory,
313                            // which ensures that if dropping the tx handle is
314                            // visible, then all messages sent are also visible.
315                            assert!(self.inner.semaphore.is_idle());
316                            coop.made_progress();
317                            return Ready(None);
318                        }
319                        None => {} // fall through
320                    }
321                };
322            }
323
324            try_recv!();
325
326            self.inner.rx_waker.register_by_ref(cx.waker());
327
328            // It is possible that a value was pushed between attempting to read
329            // and registering the task, so we have to check the channel a
330            // second time here.
331            try_recv!();
332
333            if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
334                coop.made_progress();
335                Ready(None)
336            } else {
337                Pending
338            }
339        })
340    }
341
342    /// Receives up to `limit` values into `buffer`
343    ///
344    /// For `limit > 0`, receives up to limit values into `buffer`.
345    /// For `limit == 0`, immediately returns Ready(0).
346    pub(crate) fn recv_many(
347        &mut self,
348        cx: &mut Context<'_>,
349        buffer: &mut Vec<T>,
350        limit: usize,
351    ) -> Poll<usize> {
352        use super::block::Read;
353
354        ready!(crate::trace::trace_leaf(cx));
355
356        // Keep track of task budget
357        let coop = ready!(crate::runtime::coop::poll_proceed(cx));
358
359        if limit == 0 {
360            coop.made_progress();
361            return Ready(0usize);
362        }
363
364        let mut remaining = limit;
365        let initial_length = buffer.len();
366
367        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
368            let rx_fields = unsafe { &mut *rx_fields_ptr };
369            macro_rules! try_recv {
370                () => {
371                    while remaining > 0 {
372                        match rx_fields.list.pop(&self.inner.tx) {
373                            Some(Read::Value(value)) => {
374                                remaining -= 1;
375                                buffer.push(value);
376                            }
377
378                            Some(Read::Closed) => {
379                                let number_added = buffer.len() - initial_length;
380                                if number_added > 0 {
381                                    self.inner.semaphore.add_permits(number_added);
382                                }
383                                // TODO: This check may not be required as it most
384                                // likely can only return `true` at this point. A
385                                // channel is closed when all tx handles are
386                                // dropped. Dropping a tx handle releases memory,
387                                // which ensures that if dropping the tx handle is
388                                // visible, then all messages sent are also visible.
389                                assert!(self.inner.semaphore.is_idle());
390                                coop.made_progress();
391                                return Ready(number_added);
392                            }
393
394                            None => {
395                                break; // fall through
396                            }
397                        }
398                    }
399                    let number_added = buffer.len() - initial_length;
400                    if number_added > 0 {
401                        self.inner.semaphore.add_permits(number_added);
402                        coop.made_progress();
403                        return Ready(number_added);
404                    }
405                };
406            }
407
408            try_recv!();
409
410            self.inner.rx_waker.register_by_ref(cx.waker());
411
412            // It is possible that a value was pushed between attempting to read
413            // and registering the task, so we have to check the channel a
414            // second time here.
415            try_recv!();
416
417            if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
418                assert!(buffer.is_empty());
419                coop.made_progress();
420                Ready(0usize)
421            } else {
422                Pending
423            }
424        })
425    }
426
427    /// Try to receive the next value.
428    pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
429        use super::list::TryPopResult;
430
431        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
432            let rx_fields = unsafe { &mut *rx_fields_ptr };
433
434            macro_rules! try_recv {
435                () => {
436                    match rx_fields.list.try_pop(&self.inner.tx) {
437                        TryPopResult::Ok(value) => {
438                            self.inner.semaphore.add_permit();
439                            return Ok(value);
440                        }
441                        TryPopResult::Closed => return Err(TryRecvError::Disconnected),
442                        TryPopResult::Empty => return Err(TryRecvError::Empty),
443                        TryPopResult::Busy => {} // fall through
444                    }
445                };
446            }
447
448            try_recv!();
449
450            // If a previous `poll_recv` call has set a waker, we wake it here.
451            // This allows us to put our own CachedParkThread waker in the
452            // AtomicWaker slot instead.
453            //
454            // This is not a spurious wakeup to `poll_recv` since we just got a
455            // Busy from `try_pop`, which only happens if there are messages in
456            // the queue.
457            self.inner.rx_waker.wake();
458
459            // Park the thread until the problematic send has completed.
460            let mut park = CachedParkThread::new();
461            let waker = park.waker().unwrap();
462            loop {
463                self.inner.rx_waker.register_by_ref(&waker);
464                // It is possible that the problematic send has now completed,
465                // so we have to check for messages again.
466                try_recv!();
467                park.park();
468            }
469        })
470    }
471
472    pub(super) fn semaphore(&self) -> &S {
473        &self.inner.semaphore
474    }
475
476    pub(super) fn sender_strong_count(&self) -> usize {
477        self.inner.tx_count.load(Acquire)
478    }
479
480    pub(super) fn sender_weak_count(&self) -> usize {
481        self.inner.tx_weak_count.load(Relaxed)
482    }
483}
484
485impl<T, S: Semaphore> Drop for Rx<T, S> {
486    fn drop(&mut self) {
487        use super::block::Read::Value;
488
489        self.close();
490
491        self.inner.rx_fields.with_mut(|rx_fields_ptr| {
492            let rx_fields = unsafe { &mut *rx_fields_ptr };
493
494            while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) {
495                self.inner.semaphore.add_permit();
496            }
497        });
498    }
499}
500
501// ===== impl Chan =====
502
503impl<T, S> Chan<T, S> {
504    fn send(&self, value: T) {
505        // Push the value
506        self.tx.push(value);
507
508        // Notify the rx task
509        self.rx_waker.wake();
510    }
511
512    pub(super) fn decrement_weak_count(&self) {
513        self.tx_weak_count.fetch_sub(1, Relaxed);
514    }
515
516    pub(super) fn increment_weak_count(&self) {
517        self.tx_weak_count.fetch_add(1, Relaxed);
518    }
519
520    pub(super) fn strong_count(&self) -> usize {
521        self.tx_count.load(Acquire)
522    }
523
524    pub(super) fn weak_count(&self) -> usize {
525        self.tx_weak_count.load(Relaxed)
526    }
527}
528
529impl<T, S> Drop for Chan<T, S> {
530    fn drop(&mut self) {
531        use super::block::Read::Value;
532
533        // Safety: the only owner of the rx fields is Chan, and being
534        // inside its own Drop means we're the last ones to touch it.
535        self.rx_fields.with_mut(|rx_fields_ptr| {
536            let rx_fields = unsafe { &mut *rx_fields_ptr };
537
538            while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {}
539            unsafe { rx_fields.list.free_blocks() };
540        });
541    }
542}
543
544// ===== impl Semaphore for (::Semaphore, capacity) =====
545
546impl Semaphore for bounded::Semaphore {
547    fn add_permit(&self) {
548        self.semaphore.release(1);
549    }
550
551    fn add_permits(&self, n: usize) {
552        self.semaphore.release(n)
553    }
554
555    fn is_idle(&self) -> bool {
556        self.semaphore.available_permits() == self.bound
557    }
558
559    fn close(&self) {
560        self.semaphore.close();
561    }
562
563    fn is_closed(&self) -> bool {
564        self.semaphore.is_closed()
565    }
566}
567
568// ===== impl Semaphore for AtomicUsize =====
569
570impl Semaphore for unbounded::Semaphore {
571    fn add_permit(&self) {
572        let prev = self.0.fetch_sub(2, Release);
573
574        if prev >> 1 == 0 {
575            // Something went wrong
576            process::abort();
577        }
578    }
579
580    fn add_permits(&self, n: usize) {
581        let prev = self.0.fetch_sub(n << 1, Release);
582
583        if (prev >> 1) < n {
584            // Something went wrong
585            process::abort();
586        }
587    }
588
589    fn is_idle(&self) -> bool {
590        self.0.load(Acquire) >> 1 == 0
591    }
592
593    fn close(&self) {
594        self.0.fetch_or(1, Release);
595    }
596
597    fn is_closed(&self) -> bool {
598        self.0.load(Acquire) & 1 == 1
599    }
600}