1use std::fmt;
6
7use super::SendMessage;
8use crate::{
9 sink::{PollSend, Sink},
10 stream::{PollRecv, Stream},
11 sync::{shared, ReceiverShared, SenderShared},
12};
13use crossbeam_queue::ArrayQueue;
14use static_assertions::{assert_impl_all, assert_not_impl_all};
15
16pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
17 #[cfg(feature = "debug")]
18 log::error!("Creating mpsc channel with capacity {}", capacity);
19 let (tx_shared, rx_shared) = shared(StateExtension::new(capacity));
20 let sender = Sender { shared: tx_shared };
21
22 let receiver = Receiver { shared: rx_shared };
23
24 (sender, receiver)
25}
26
27pub struct Sender<T> {
31 pub(in crate::channels::mpsc) shared: SenderShared<StateExtension<T>>,
32}
33
34assert_impl_all!(Sender<String>: Clone, Send, Sync, fmt::Debug);
35
36impl<T> Clone for Sender<T> {
37 fn clone(&self) -> Self {
38 Self {
39 shared: self.shared.clone(),
40 }
41 }
42}
43
44impl<T> Sink for Sender<T> {
45 type Item = T;
46
47 fn poll_send(
48 self: std::pin::Pin<&mut Self>,
49 cx: &mut crate::Context<'_>,
50 mut value: Self::Item,
51 ) -> PollSend<Self::Item> {
52 loop {
53 if self.shared.is_closed() {
54 return PollSend::Rejected(value);
55 }
56
57 let guard = self.shared.recv_guard();
58 let queue = &self.shared.extension().queue;
59 match queue.push(value) {
60 Ok(_) => {
61 self.shared.notify_receivers();
62 return PollSend::Ready;
63 }
64 Err(v) => {
65 self.shared.subscribe_recv(cx);
66
67 if guard.is_expired() {
68 value = v;
69 continue;
70 }
71
72 return PollSend::Pending(v);
73 }
74 }
75 }
76 }
77}
78
79impl<T> fmt::Debug for Sender<T> {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 f.debug_struct("Sender").finish()
82 }
83}
84
85#[cfg(feature = "futures-traits")]
86mod impl_futures {
87 use crate::sink::SendError;
88 use std::task::Poll;
89
90 impl<T> futures::sink::Sink<T> for super::Sender<T> {
91 type Error = SendError<T>;
92
93 fn poll_ready(
94 self: std::pin::Pin<&mut Self>,
95 cx: &mut std::task::Context<'_>,
96 ) -> Poll<Result<(), Self::Error>> {
97 loop {
98 if self.shared.is_closed() {
99 return Poll::Ready(Ok(()));
100 }
101
102 let queue = &self.shared.extension().queue;
103 let guard = self.shared.recv_guard();
104
105 if queue.is_full() {
106 let mut cx = cx.into();
107 self.shared.subscribe_recv(&mut cx);
108
109 if guard.is_expired() {
110 continue;
111 }
112
113 return Poll::Pending;
114 } else {
115 return Poll::Ready(Ok(()));
116 }
117 }
118 }
119
120 fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
121 if self.shared.is_closed() {
122 return Err(SendError(item));
123 }
124
125 let result = self
126 .shared
127 .extension()
128 .queue
129 .push(item)
130 .map_err(|item| SendError(item));
131
132 if result.is_ok() {
133 self.shared.notify_receivers();
134 }
135
136 result
137 }
138
139 fn poll_flush(
140 self: std::pin::Pin<&mut Self>,
141 _cx: &mut std::task::Context<'_>,
142 ) -> Poll<Result<(), Self::Error>> {
143 Poll::Ready(Ok(()))
144 }
145
146 fn poll_close(
147 self: std::pin::Pin<&mut Self>,
148 _cx: &mut std::task::Context<'_>,
149 ) -> Poll<Result<(), Self::Error>> {
150 Poll::Ready(Ok(()))
151 }
152 }
153}
154
155pub struct Receiver<T> {
159 pub(in crate::channels::mpsc) shared: ReceiverShared<StateExtension<T>>,
160}
161
162assert_impl_all!(Receiver<SendMessage>: Send, Sync, fmt::Debug);
163assert_not_impl_all!(Receiver<SendMessage>: Clone);
164
165impl<T> Stream for Receiver<T> {
166 type Item = T;
167
168 fn poll_recv(
169 self: std::pin::Pin<&mut Self>,
170 cx: &mut crate::Context<'_>,
171 ) -> PollRecv<Self::Item> {
172 loop {
173 let guard = self.shared.send_guard();
174 match self.shared.extension().queue.pop() {
175 Some(v) => {
176 self.shared.notify_senders();
177 return PollRecv::Ready(v);
178 }
179 None => {
180 if self.shared.is_closed() {
181 return PollRecv::Closed;
182 }
183
184 self.shared.subscribe_send(cx);
185
186 if guard.is_expired() {
187 continue;
188 }
189
190 return PollRecv::Pending;
191 }
192 }
193 }
194 }
195}
196
197impl<T> fmt::Debug for Receiver<T> {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 f.debug_struct("Receiver").finish()
200 }
201}
202
203struct StateExtension<T> {
204 queue: ArrayQueue<T>,
205}
206
207impl<T> StateExtension<T> {
208 pub fn new(capacity: usize) -> Self {
209 Self {
210 queue: ArrayQueue::new(capacity),
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use std::{pin::Pin, task::Context};
218
219 use crate::{
220 sink::{PollSend, Sink},
221 stream::{PollRecv, Stream},
222 test::{noop_context, panic_context},
223 };
224 use futures_test::task::new_count_waker;
225
226 use super::{channel, Receiver, Sender};
227
228 fn pin<'a, 'b>(
229 chan: &mut (Sender<Message>, Receiver<Message>),
230 ) -> (Pin<&mut Sender<Message>>, Pin<&mut Receiver<Message>>) {
231 let tx = Pin::new(&mut chan.0);
232 let rx = Pin::new(&mut chan.1);
233
234 (tx, rx)
235 }
236
237 #[derive(Debug, PartialEq, Eq)]
238 struct Message(usize);
239
240 #[test]
241 fn send_accepted() {
242 let mut cx = panic_context();
243 let mut chan = channel(2);
244 let (tx, _) = pin(&mut chan);
245
246 assert_eq!(PollSend::Ready, tx.poll_send(&mut cx, Message(1)));
247 }
248
249 #[test]
250 fn send_blocks() {
251 let mut cx = panic_context();
252 let (mut tx, _rx) = channel(2);
253
254 assert_eq!(
255 PollSend::Ready,
256 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
257 );
258 assert_eq!(
259 PollSend::Ready,
260 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
261 );
262 }
263
264 #[test]
265 fn send_recv() {
266 let mut cx = panic_context();
267 let (mut tx, mut rx) = channel(2);
268
269 assert_eq!(
270 PollSend::Ready,
271 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
272 );
273 assert_eq!(
274 PollSend::Ready,
275 Pin::new(&mut tx).poll_send(&mut cx, Message(2))
276 );
277 assert_eq!(
278 PollSend::Pending(Message(3)),
279 Pin::new(&mut tx).poll_send(&mut noop_context(), Message(3))
280 );
281
282 assert_eq!(
283 PollRecv::Ready(Message(1)),
284 Pin::new(&mut rx).poll_recv(&mut cx)
285 );
286
287 assert_eq!(
288 PollRecv::Ready(Message(2)),
289 Pin::new(&mut rx).poll_recv(&mut cx)
290 );
291
292 assert_eq!(
293 PollRecv::Pending,
294 Pin::new(&mut rx).poll_recv(&mut noop_context())
295 );
296 }
297
298 #[test]
299 fn sender_disconnect() {
300 let mut cx = panic_context();
301 let (mut tx, mut rx) = channel(100);
302 let mut tx2 = tx.clone();
303
304 assert_eq!(
305 PollSend::Ready,
306 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
307 );
308
309 assert_eq!(
310 PollSend::Ready,
311 Pin::new(&mut tx2).poll_send(&mut cx, Message(2))
312 );
313
314 drop(tx);
315 drop(tx2);
316
317 assert_eq!(
318 PollRecv::Ready(Message(1)),
319 Pin::new(&mut rx).poll_recv(&mut cx)
320 );
321
322 assert_eq!(
323 PollRecv::Ready(Message(2)),
324 Pin::new(&mut rx).poll_recv(&mut cx)
325 );
326
327 assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
328 }
329
330 #[test]
331 fn receiver_disconnect() {
332 let mut cx = panic_context();
333 let (mut tx, rx) = channel(100);
334 let mut tx2 = tx.clone();
335
336 assert_eq!(
337 PollSend::Ready,
338 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
339 );
340
341 assert_eq!(
342 PollSend::Ready,
343 Pin::new(&mut tx2).poll_send(&mut cx, Message(2))
344 );
345
346 drop(rx);
347
348 assert_eq!(
349 PollSend::Rejected(Message(3)),
350 Pin::new(&mut tx).poll_send(&mut cx, Message(3))
351 );
352
353 assert_eq!(
354 PollSend::Rejected(Message(4)),
355 Pin::new(&mut tx2).poll_send(&mut cx, Message(4))
356 );
357 }
358
359 #[test]
360 fn wake_sender() {
361 let mut cx = panic_context();
362 let (mut tx, mut rx) = channel(1);
363
364 assert_eq!(
365 PollSend::Ready,
366 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
367 );
368
369 let (w2, w2_count) = new_count_waker();
370 let w2_context = Context::from_waker(&w2);
371 assert_eq!(
372 PollSend::Pending(Message(2)),
373 Pin::new(&mut tx).poll_send(&mut w2_context.into(), Message(2))
374 );
375
376 assert_eq!(0, w2_count.get());
377
378 assert_eq!(
379 PollRecv::Ready(Message(1)),
380 Pin::new(&mut rx).poll_recv(&mut cx)
381 );
382
383 assert_eq!(1, w2_count.get());
384 assert_eq!(
385 PollRecv::Pending,
386 Pin::new(&mut rx).poll_recv(&mut noop_context())
387 );
388
389 assert_eq!(1, w2_count.get());
390 }
391
392 #[test]
393 fn wake_receiver() {
394 let mut cx = panic_context();
395 let (mut tx, mut rx) = channel(100);
396
397 let (w1, w1_count) = new_count_waker();
398 let w1_context = Context::from_waker(&w1);
399
400 assert_eq!(
401 PollRecv::Pending,
402 Pin::new(&mut rx).poll_recv(&mut w1_context.into())
403 );
404
405 assert_eq!(0, w1_count.get());
406
407 assert_eq!(
408 PollSend::Ready,
409 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
410 );
411
412 assert_eq!(1, w1_count.get());
413
414 assert_eq!(
415 PollSend::Ready,
416 Pin::new(&mut tx).poll_send(&mut cx, Message(2))
417 );
418
419 assert_eq!(1, w1_count.get());
420 }
421
422 #[test]
423 fn wake_sender_on_disconnect() {
424 let (mut tx, rx) = channel(1);
425
426 let (w1, w1_count) = new_count_waker();
427 let w1_context = Context::from_waker(&w1);
428 let mut w1_context: crate::Context<'_> = w1_context.into();
429
430 assert_eq!(
431 PollSend::Ready,
432 Pin::new(&mut tx).poll_send(&mut w1_context, Message(1))
433 );
434
435 assert_eq!(
436 PollSend::Pending(Message(2)),
437 Pin::new(&mut tx).poll_send(&mut w1_context, Message(2))
438 );
439
440 assert_eq!(0, w1_count.get());
441
442 drop(rx);
443
444 assert_eq!(1, w1_count.get());
445 }
446
447 #[test]
448 fn wake_receiver_on_disconnect() {
449 let (tx, mut rx) = channel::<()>(100);
450
451 let (w1, w1_count) = new_count_waker();
452 let w1_context = Context::from_waker(&w1);
453
454 assert_eq!(
455 PollRecv::Pending,
456 Pin::new(&mut rx).poll_recv(&mut w1_context.into())
457 );
458
459 assert_eq!(0, w1_count.get());
460
461 drop(tx);
462
463 assert_eq!(1, w1_count.get());
464 }
465}
466
467#[cfg(test)]
468mod tokio_tests {
469 use std::time::Duration;
470
471 use tokio::{task::spawn, time::timeout};
472
473 use crate::{
474 sink::Sink,
475 stream::Stream,
476 test::{capacity_iter, Channel, Channels, Message, CHANNEL_TEST_SENDERS, TEST_TIMEOUT},
477 };
478
479 #[tokio::test(flavor = "multi_thread")]
480 async fn simple() {
481 for cap in capacity_iter() {
484 let (mut tx, mut rx) = super::channel(cap);
485
486 let join = spawn(async move {
487 for message in Message::new_iter(0) {
488 tx.send(message).await.expect("send failed");
489 }
490 });
491
492 let rx_handle = spawn(async move {
493 let mut channel = Channel::new(0);
494 while let Some(message) = rx.recv().await {
495 channel.assert_message(&message);
496 }
497 join.await.expect("Join failed");
498 });
499
500 timeout(TEST_TIMEOUT, rx_handle)
501 .await
502 .expect("test timeout")
503 .expect("join error");
504 }
505 }
506
507 #[tokio::test(flavor = "multi_thread")]
508 async fn multi_sender() {
509 for cap in capacity_iter() {
510 let (tx, mut rx) = super::channel(cap);
511
512 for i in 0..CHANNEL_TEST_SENDERS {
513 let mut tx2 = tx.clone();
514 spawn(async move {
515 for message in Message::new_multi_sender(i) {
516 tx2.send(message).await.expect("send failed");
517 }
518 });
519 }
520
521 drop(tx);
522
523 let rx_handle = spawn(async move {
524 let mut channel = Channels::new(CHANNEL_TEST_SENDERS);
525 while let Some(message) = rx.recv().await {
526 channel.assert_message(&message);
527 }
528 });
529
530 timeout(TEST_TIMEOUT, rx_handle)
531 .await
532 .expect("test timeout")
533 .expect("join error");
534 }
535 }
536
537 #[tokio::test(flavor = "multi_thread")]
538 async fn clone_monster() {
539 for cap in capacity_iter() {
540 let (tx, mut rx) = super::channel(cap);
546 let (mut barrier, mut sender_quit) = crate::barrier::channel();
547
548 let mut tx2 = tx.clone();
549 spawn(async move {
550 for message in Message::new_iter(0) {
551 tx2.send(message).await.expect("send failed");
552 }
553
554 barrier.send(()).await.expect("clone task shutdown failed");
555 });
556
557 spawn(async move {
558 loop {
559 if let Ok(_) = sender_quit.try_recv() {
560 break;
561 }
562
563 let tx2 = tx.clone();
564 tokio::time::sleep(Duration::from_micros(100)).await;
565 drop(tx2);
566
567 tokio::time::sleep(Duration::from_micros(50)).await;
568 }
569 });
570
571 let rx_handle = spawn(async move {
572 let mut channel = Channel::new(0);
573
574 while let Some(message) = rx.recv().await {
575 channel.assert_message(&message);
576 }
577 });
578
579 timeout(TEST_TIMEOUT, rx_handle)
580 .await
581 .expect("test timeout")
582 .expect("join failed");
583 }
584 }
585}
586
587#[cfg(test)]
588mod async_std_tests {
589 use std::time::Duration;
590
591 use async_std::{
592 future::timeout,
593 task::{self, spawn},
594 };
595
596 use crate::{
597 sink::Sink,
598 stream::Stream,
599 test::{capacity_iter, Channel, Channels, Message, CHANNEL_TEST_SENDERS, TEST_TIMEOUT},
600 };
601
602 #[async_std::test]
603 async fn simple() {
604 for cap in capacity_iter() {
605 let (mut tx, mut rx) = super::channel(cap);
606
607 spawn(async move {
608 for message in Message::new_iter(0) {
609 tx.send(message).await.expect("send failed");
610 }
611 });
612
613 let rx_handle = spawn(async move {
614 let mut channel = Channel::new(0);
615 while let Some(message) = rx.recv().await {
616 channel.assert_message(&message);
617 }
618 });
619
620 timeout(TEST_TIMEOUT, rx_handle)
621 .await
622 .expect("test timeout");
623 }
624 }
625
626 #[async_std::test]
627 async fn multi_sender() {
628 for cap in capacity_iter() {
629 let (tx, mut rx) = super::channel(cap);
630
631 for i in 0..CHANNEL_TEST_SENDERS {
632 let mut tx2 = tx.clone();
633 spawn(async move {
634 for message in Message::new_multi_sender(i) {
635 tx2.send(message).await.expect("send failed");
636 }
637 });
638 }
639
640 drop(tx);
641
642 let rx_handle = spawn(async move {
643 let mut channel = Channels::new(CHANNEL_TEST_SENDERS);
644 while let Some(message) = rx.recv().await {
645 channel.assert_message(&message);
646 }
647 });
648
649 timeout(TEST_TIMEOUT, rx_handle)
650 .await
651 .expect("test timeout");
652 }
653 }
654
655 #[tokio::test(flavor = "multi_thread")]
656 async fn clone_monster() {
657 for cap in capacity_iter() {
660 let (tx, mut rx) = super::channel(cap);
661 let (mut barrier, mut sender_quit) = crate::barrier::channel();
662
663 let mut tx2 = tx.clone();
664 spawn(async move {
665 for message in Message::new_iter(0) {
666 tx2.send(message).await.expect("send failed");
667 }
668
669 barrier.send(()).await.expect("clone task shutdown failed");
670 });
671
672 spawn(async move {
673 loop {
674 if let Ok(_) = sender_quit.try_recv() {
675 break;
676 }
677
678 let tx2 = tx.clone();
679 task::sleep(Duration::from_micros(100)).await;
680 drop(tx2);
681 task::sleep(Duration::from_micros(50)).await;
682 }
683 });
684
685 let rx_handle = spawn(async move {
686 let mut channel = Channel::new(0);
687
688 while let Some(message) = rx.recv().await {
689 channel.assert_message(&message);
690 }
691 });
692
693 timeout(TEST_TIMEOUT, rx_handle)
694 .await
695 .expect("test timeout");
696 }
697 }
698}