1use crate::loom::cell::UnsafeCell;
2use crate::loom::future::AtomicWaker;
3use crate::loom::sync::atomic::AtomicUsize;
4use crate::loom::sync::Arc;
5use crate::runtime::park::CachedParkThread;
6use crate::sync::mpsc::error::TryRecvError;
7use crate::sync::mpsc::{bounded, list, unbounded};
8use crate::sync::notify::Notify;
9use crate::util::cacheline::CachePadded;
10
11use std::fmt;
12use std::panic;
13use std::process;
14use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
15use std::task::Poll::{Pending, Ready};
16use std::task::{ready, Context, Poll};
17
18pub(crate) struct Tx<T, S> {
20 inner: Arc<Chan<T, S>>,
21}
22
23impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> {
24 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
25 fmt.debug_struct("Tx").field("inner", &self.inner).finish()
26 }
27}
28
29pub(crate) struct Rx<T, S: Semaphore> {
31 inner: Arc<Chan<T, S>>,
32}
33
34impl<T, S: Semaphore + fmt::Debug> fmt::Debug for Rx<T, S> {
35 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
36 fmt.debug_struct("Rx").field("inner", &self.inner).finish()
37 }
38}
39
40pub(crate) trait Semaphore {
41 fn is_idle(&self) -> bool;
42
43 fn add_permit(&self);
44
45 fn add_permits(&self, n: usize);
46
47 fn close(&self);
48
49 fn is_closed(&self) -> bool;
50}
51
52pub(super) struct Chan<T, S> {
53 tx: CachePadded<list::Tx<T>>,
55
56 rx_waker: CachePadded<AtomicWaker>,
58
59 notify_rx_closed: Notify,
61
62 semaphore: S,
64
65 tx_count: AtomicUsize,
69
70 tx_weak_count: AtomicUsize,
72
73 rx_fields: UnsafeCell<RxFields<T>>,
75}
76
77impl<T, S> fmt::Debug for Chan<T, S>
78where
79 S: fmt::Debug,
80{
81 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
82 fmt.debug_struct("Chan")
83 .field("tx", &*self.tx)
84 .field("semaphore", &self.semaphore)
85 .field("rx_waker", &*self.rx_waker)
86 .field("tx_count", &self.tx_count)
87 .field("rx_fields", &"...")
88 .finish()
89 }
90}
91
92struct RxFields<T> {
94 list: list::Rx<T>,
96
97 rx_closed: bool,
99}
100
101impl<T> fmt::Debug for RxFields<T> {
102 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
103 fmt.debug_struct("RxFields")
104 .field("list", &self.list)
105 .field("rx_closed", &self.rx_closed)
106 .finish()
107 }
108}
109
110unsafe impl<T: Send, S: Send> Send for Chan<T, S> {}
111unsafe impl<T: Send, S: Sync> Sync for Chan<T, S> {}
112impl<T, S> panic::RefUnwindSafe for Chan<T, S> {}
113impl<T, S> panic::UnwindSafe for Chan<T, S> {}
114
115pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) {
116 let (tx, rx) = list::channel();
117
118 let chan = Arc::new(Chan {
119 notify_rx_closed: Notify::new(),
120 tx: CachePadded::new(tx),
121 semaphore,
122 rx_waker: CachePadded::new(AtomicWaker::new()),
123 tx_count: AtomicUsize::new(1),
124 tx_weak_count: AtomicUsize::new(0),
125 rx_fields: UnsafeCell::new(RxFields {
126 list: rx,
127 rx_closed: false,
128 }),
129 });
130
131 (Tx::new(chan.clone()), Rx::new(chan))
132}
133
134impl<T, S> Tx<T, S> {
137 fn new(chan: Arc<Chan<T, S>>) -> Tx<T, S> {
138 Tx { inner: chan }
139 }
140
141 pub(super) fn strong_count(&self) -> usize {
142 self.inner.tx_count.load(Acquire)
143 }
144
145 pub(super) fn weak_count(&self) -> usize {
146 self.inner.tx_weak_count.load(Relaxed)
147 }
148
149 pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> {
150 self.inner.increment_weak_count();
151
152 self.inner.clone()
153 }
154
155 pub(super) fn upgrade(chan: Arc<Chan<T, S>>) -> Option<Self> {
157 let mut tx_count = chan.tx_count.load(Acquire);
158
159 loop {
160 if tx_count == 0 {
161 return None;
163 }
164
165 match chan
166 .tx_count
167 .compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire)
168 {
169 Ok(_) => return Some(Tx { inner: chan }),
170 Err(prev_count) => tx_count = prev_count,
171 }
172 }
173 }
174
175 pub(super) fn semaphore(&self) -> &S {
176 &self.inner.semaphore
177 }
178
179 pub(crate) fn send(&self, value: T) {
181 self.inner.send(value);
182 }
183
184 pub(crate) fn wake_rx(&self) {
186 self.inner.rx_waker.wake();
187 }
188
189 pub(crate) fn same_channel(&self, other: &Self) -> bool {
191 Arc::ptr_eq(&self.inner, &other.inner)
192 }
193}
194
195impl<T, S: Semaphore> Tx<T, S> {
196 pub(crate) fn is_closed(&self) -> bool {
197 self.inner.semaphore.is_closed()
198 }
199
200 pub(crate) async fn closed(&self) {
201 let notified = self.inner.notify_rx_closed.notified();
205
206 if self.inner.semaphore.is_closed() {
207 return;
208 }
209 notified.await;
210 }
211}
212
213impl<T, S> Clone for Tx<T, S> {
214 fn clone(&self) -> Tx<T, S> {
215 self.inner.tx_count.fetch_add(1, Relaxed);
218
219 Tx {
220 inner: self.inner.clone(),
221 }
222 }
223}
224
225impl<T, S> Drop for Tx<T, S> {
226 fn drop(&mut self) {
227 if self.inner.tx_count.fetch_sub(1, AcqRel) != 1 {
228 return;
229 }
230
231 self.inner.tx.close();
233
234 self.wake_rx();
236 }
237}
238
239impl<T, S: Semaphore> Rx<T, S> {
242 fn new(chan: Arc<Chan<T, S>>) -> Rx<T, S> {
243 Rx { inner: chan }
244 }
245
246 pub(crate) fn close(&mut self) {
247 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
248 let rx_fields = unsafe { &mut *rx_fields_ptr };
249
250 if rx_fields.rx_closed {
251 return;
252 }
253
254 rx_fields.rx_closed = true;
255 });
256
257 self.inner.semaphore.close();
258 self.inner.notify_rx_closed.notify_waiters();
259 }
260
261 pub(crate) fn is_closed(&self) -> bool {
262 self.inner.semaphore.is_closed() || self.inner.tx_count.load(Acquire) == 0
272 }
273
274 pub(crate) fn is_empty(&self) -> bool {
275 self.inner.rx_fields.with(|rx_fields_ptr| {
276 let rx_fields = unsafe { &*rx_fields_ptr };
277 rx_fields.list.is_empty(&self.inner.tx)
278 })
279 }
280
281 pub(crate) fn len(&self) -> usize {
282 self.inner.rx_fields.with(|rx_fields_ptr| {
283 let rx_fields = unsafe { &*rx_fields_ptr };
284 rx_fields.list.len(&self.inner.tx)
285 })
286 }
287
288 pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
290 use super::block::Read;
291
292 ready!(crate::trace::trace_leaf(cx));
293
294 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
296
297 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
298 let rx_fields = unsafe { &mut *rx_fields_ptr };
299
300 macro_rules! try_recv {
301 () => {
302 match rx_fields.list.pop(&self.inner.tx) {
303 Some(Read::Value(value)) => {
304 self.inner.semaphore.add_permit();
305 coop.made_progress();
306 return Ready(Some(value));
307 }
308 Some(Read::Closed) => {
309 assert!(self.inner.semaphore.is_idle());
316 coop.made_progress();
317 return Ready(None);
318 }
319 None => {} }
321 };
322 }
323
324 try_recv!();
325
326 self.inner.rx_waker.register_by_ref(cx.waker());
327
328 try_recv!();
332
333 if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
334 coop.made_progress();
335 Ready(None)
336 } else {
337 Pending
338 }
339 })
340 }
341
342 pub(crate) fn recv_many(
347 &mut self,
348 cx: &mut Context<'_>,
349 buffer: &mut Vec<T>,
350 limit: usize,
351 ) -> Poll<usize> {
352 use super::block::Read;
353
354 ready!(crate::trace::trace_leaf(cx));
355
356 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
358
359 if limit == 0 {
360 coop.made_progress();
361 return Ready(0usize);
362 }
363
364 let mut remaining = limit;
365 let initial_length = buffer.len();
366
367 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
368 let rx_fields = unsafe { &mut *rx_fields_ptr };
369 macro_rules! try_recv {
370 () => {
371 while remaining > 0 {
372 match rx_fields.list.pop(&self.inner.tx) {
373 Some(Read::Value(value)) => {
374 remaining -= 1;
375 buffer.push(value);
376 }
377
378 Some(Read::Closed) => {
379 let number_added = buffer.len() - initial_length;
380 if number_added > 0 {
381 self.inner.semaphore.add_permits(number_added);
382 }
383 assert!(self.inner.semaphore.is_idle());
390 coop.made_progress();
391 return Ready(number_added);
392 }
393
394 None => {
395 break; }
397 }
398 }
399 let number_added = buffer.len() - initial_length;
400 if number_added > 0 {
401 self.inner.semaphore.add_permits(number_added);
402 coop.made_progress();
403 return Ready(number_added);
404 }
405 };
406 }
407
408 try_recv!();
409
410 self.inner.rx_waker.register_by_ref(cx.waker());
411
412 try_recv!();
416
417 if rx_fields.rx_closed && self.inner.semaphore.is_idle() {
418 assert!(buffer.is_empty());
419 coop.made_progress();
420 Ready(0usize)
421 } else {
422 Pending
423 }
424 })
425 }
426
427 pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
429 use super::list::TryPopResult;
430
431 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
432 let rx_fields = unsafe { &mut *rx_fields_ptr };
433
434 macro_rules! try_recv {
435 () => {
436 match rx_fields.list.try_pop(&self.inner.tx) {
437 TryPopResult::Ok(value) => {
438 self.inner.semaphore.add_permit();
439 return Ok(value);
440 }
441 TryPopResult::Closed => return Err(TryRecvError::Disconnected),
442 TryPopResult::Empty => return Err(TryRecvError::Empty),
443 TryPopResult::Busy => {} }
445 };
446 }
447
448 try_recv!();
449
450 self.inner.rx_waker.wake();
458
459 let mut park = CachedParkThread::new();
461 let waker = park.waker().unwrap();
462 loop {
463 self.inner.rx_waker.register_by_ref(&waker);
464 try_recv!();
467 park.park();
468 }
469 })
470 }
471
472 pub(super) fn semaphore(&self) -> &S {
473 &self.inner.semaphore
474 }
475
476 pub(super) fn sender_strong_count(&self) -> usize {
477 self.inner.tx_count.load(Acquire)
478 }
479
480 pub(super) fn sender_weak_count(&self) -> usize {
481 self.inner.tx_weak_count.load(Relaxed)
482 }
483}
484
485impl<T, S: Semaphore> Drop for Rx<T, S> {
486 fn drop(&mut self) {
487 use super::block::Read::Value;
488
489 self.close();
490
491 self.inner.rx_fields.with_mut(|rx_fields_ptr| {
492 let rx_fields = unsafe { &mut *rx_fields_ptr };
493
494 while let Some(Value(_)) = rx_fields.list.pop(&self.inner.tx) {
495 self.inner.semaphore.add_permit();
496 }
497 });
498 }
499}
500
501impl<T, S> Chan<T, S> {
504 fn send(&self, value: T) {
505 self.tx.push(value);
507
508 self.rx_waker.wake();
510 }
511
512 pub(super) fn decrement_weak_count(&self) {
513 self.tx_weak_count.fetch_sub(1, Relaxed);
514 }
515
516 pub(super) fn increment_weak_count(&self) {
517 self.tx_weak_count.fetch_add(1, Relaxed);
518 }
519
520 pub(super) fn strong_count(&self) -> usize {
521 self.tx_count.load(Acquire)
522 }
523
524 pub(super) fn weak_count(&self) -> usize {
525 self.tx_weak_count.load(Relaxed)
526 }
527}
528
529impl<T, S> Drop for Chan<T, S> {
530 fn drop(&mut self) {
531 use super::block::Read::Value;
532
533 self.rx_fields.with_mut(|rx_fields_ptr| {
536 let rx_fields = unsafe { &mut *rx_fields_ptr };
537
538 while let Some(Value(_)) = rx_fields.list.pop(&self.tx) {}
539 unsafe { rx_fields.list.free_blocks() };
540 });
541 }
542}
543
544impl Semaphore for bounded::Semaphore {
547 fn add_permit(&self) {
548 self.semaphore.release(1);
549 }
550
551 fn add_permits(&self, n: usize) {
552 self.semaphore.release(n)
553 }
554
555 fn is_idle(&self) -> bool {
556 self.semaphore.available_permits() == self.bound
557 }
558
559 fn close(&self) {
560 self.semaphore.close();
561 }
562
563 fn is_closed(&self) -> bool {
564 self.semaphore.is_closed()
565 }
566}
567
568impl Semaphore for unbounded::Semaphore {
571 fn add_permit(&self) {
572 let prev = self.0.fetch_sub(2, Release);
573
574 if prev >> 1 == 0 {
575 process::abort();
577 }
578 }
579
580 fn add_permits(&self, n: usize) {
581 let prev = self.0.fetch_sub(n << 1, Release);
582
583 if (prev >> 1) < n {
584 process::abort();
586 }
587 }
588
589 fn is_idle(&self) -> bool {
590 self.0.load(Acquire) >> 1 == 0
591 }
592
593 fn close(&self) {
594 self.0.fetch_or(1, Release);
595 }
596
597 fn is_closed(&self) -> bool {
598 self.0.load(Acquire) & 1 == 1
599 }
600}