postage/channels/
mpsc.rs

1//! A fixed-capacity multi-producer, single-consumer channel.  
2//!
3//! The producer can be cloned, and the sender task is suspended if the channel becomes full.
4
5use std::fmt;
6
7use super::SendMessage;
8use crate::{
9    sink::{PollSend, Sink},
10    stream::{PollRecv, Stream},
11    sync::{shared, ReceiverShared, SenderShared},
12};
13use crossbeam_queue::ArrayQueue;
14use static_assertions::{assert_impl_all, assert_not_impl_all};
15
16pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
17    #[cfg(feature = "debug")]
18    log::error!("Creating mpsc channel with capacity {}", capacity);
19    let (tx_shared, rx_shared) = shared(StateExtension::new(capacity));
20    let sender = Sender { shared: tx_shared };
21
22    let receiver = Receiver { shared: rx_shared };
23
24    (sender, receiver)
25}
26
27/// The sender half of an mpsc channel.  Can send messages with the postage::Sink trait.
28///
29/// Can be cloned.
30pub struct Sender<T> {
31    pub(in crate::channels::mpsc) shared: SenderShared<StateExtension<T>>,
32}
33
34assert_impl_all!(Sender<String>: Clone, Send, Sync, fmt::Debug);
35
36impl<T> Clone for Sender<T> {
37    fn clone(&self) -> Self {
38        Self {
39            shared: self.shared.clone(),
40        }
41    }
42}
43
44impl<T> Sink for Sender<T> {
45    type Item = T;
46
47    fn poll_send(
48        self: std::pin::Pin<&mut Self>,
49        cx: &mut crate::Context<'_>,
50        mut value: Self::Item,
51    ) -> PollSend<Self::Item> {
52        loop {
53            if self.shared.is_closed() {
54                return PollSend::Rejected(value);
55            }
56
57            let guard = self.shared.recv_guard();
58            let queue = &self.shared.extension().queue;
59            match queue.push(value) {
60                Ok(_) => {
61                    self.shared.notify_receivers();
62                    return PollSend::Ready;
63                }
64                Err(v) => {
65                    self.shared.subscribe_recv(cx);
66
67                    if guard.is_expired() {
68                        value = v;
69                        continue;
70                    }
71
72                    return PollSend::Pending(v);
73                }
74            }
75        }
76    }
77}
78
79impl<T> fmt::Debug for Sender<T> {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        f.debug_struct("Sender").finish()
82    }
83}
84
85#[cfg(feature = "futures-traits")]
86mod impl_futures {
87    use crate::sink::SendError;
88    use std::task::Poll;
89
90    impl<T> futures::sink::Sink<T> for super::Sender<T> {
91        type Error = SendError<T>;
92
93        fn poll_ready(
94            self: std::pin::Pin<&mut Self>,
95            cx: &mut std::task::Context<'_>,
96        ) -> Poll<Result<(), Self::Error>> {
97            loop {
98                if self.shared.is_closed() {
99                    return Poll::Ready(Ok(()));
100                }
101
102                let queue = &self.shared.extension().queue;
103                let guard = self.shared.recv_guard();
104
105                if queue.is_full() {
106                    let mut cx = cx.into();
107                    self.shared.subscribe_recv(&mut cx);
108
109                    if guard.is_expired() {
110                        continue;
111                    }
112
113                    return Poll::Pending;
114                } else {
115                    return Poll::Ready(Ok(()));
116                }
117            }
118        }
119
120        fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
121            if self.shared.is_closed() {
122                return Err(SendError(item));
123            }
124
125            let result = self
126                .shared
127                .extension()
128                .queue
129                .push(item)
130                .map_err(|item| SendError(item));
131
132            if result.is_ok() {
133                self.shared.notify_receivers();
134            }
135
136            result
137        }
138
139        fn poll_flush(
140            self: std::pin::Pin<&mut Self>,
141            _cx: &mut std::task::Context<'_>,
142        ) -> Poll<Result<(), Self::Error>> {
143            Poll::Ready(Ok(()))
144        }
145
146        fn poll_close(
147            self: std::pin::Pin<&mut Self>,
148            _cx: &mut std::task::Context<'_>,
149        ) -> Poll<Result<(), Self::Error>> {
150            Poll::Ready(Ok(()))
151        }
152    }
153}
154
155/// The receiver half of an mpsc channel.  Cannot be cloned.
156///
157/// Can receive messages with the postage::Stream trait.
158pub struct Receiver<T> {
159    pub(in crate::channels::mpsc) shared: ReceiverShared<StateExtension<T>>,
160}
161
162assert_impl_all!(Receiver<SendMessage>: Send, Sync, fmt::Debug);
163assert_not_impl_all!(Receiver<SendMessage>: Clone);
164
165impl<T> Stream for Receiver<T> {
166    type Item = T;
167
168    fn poll_recv(
169        self: std::pin::Pin<&mut Self>,
170        cx: &mut crate::Context<'_>,
171    ) -> PollRecv<Self::Item> {
172        loop {
173            let guard = self.shared.send_guard();
174            match self.shared.extension().queue.pop() {
175                Some(v) => {
176                    self.shared.notify_senders();
177                    return PollRecv::Ready(v);
178                }
179                None => {
180                    if self.shared.is_closed() {
181                        return PollRecv::Closed;
182                    }
183
184                    self.shared.subscribe_send(cx);
185
186                    if guard.is_expired() {
187                        continue;
188                    }
189
190                    return PollRecv::Pending;
191                }
192            }
193        }
194    }
195}
196
197impl<T> fmt::Debug for Receiver<T> {
198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199        f.debug_struct("Receiver").finish()
200    }
201}
202
203struct StateExtension<T> {
204    queue: ArrayQueue<T>,
205}
206
207impl<T> StateExtension<T> {
208    pub fn new(capacity: usize) -> Self {
209        Self {
210            queue: ArrayQueue::new(capacity),
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use std::{pin::Pin, task::Context};
218
219    use crate::{
220        sink::{PollSend, Sink},
221        stream::{PollRecv, Stream},
222        test::{noop_context, panic_context},
223    };
224    use futures_test::task::new_count_waker;
225
226    use super::{channel, Receiver, Sender};
227
228    fn pin<'a, 'b>(
229        chan: &mut (Sender<Message>, Receiver<Message>),
230    ) -> (Pin<&mut Sender<Message>>, Pin<&mut Receiver<Message>>) {
231        let tx = Pin::new(&mut chan.0);
232        let rx = Pin::new(&mut chan.1);
233
234        (tx, rx)
235    }
236
237    #[derive(Debug, PartialEq, Eq)]
238    struct Message(usize);
239
240    #[test]
241    fn send_accepted() {
242        let mut cx = panic_context();
243        let mut chan = channel(2);
244        let (tx, _) = pin(&mut chan);
245
246        assert_eq!(PollSend::Ready, tx.poll_send(&mut cx, Message(1)));
247    }
248
249    #[test]
250    fn send_blocks() {
251        let mut cx = panic_context();
252        let (mut tx, _rx) = channel(2);
253
254        assert_eq!(
255            PollSend::Ready,
256            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
257        );
258        assert_eq!(
259            PollSend::Ready,
260            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
261        );
262    }
263
264    #[test]
265    fn send_recv() {
266        let mut cx = panic_context();
267        let (mut tx, mut rx) = channel(2);
268
269        assert_eq!(
270            PollSend::Ready,
271            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
272        );
273        assert_eq!(
274            PollSend::Ready,
275            Pin::new(&mut tx).poll_send(&mut cx, Message(2))
276        );
277        assert_eq!(
278            PollSend::Pending(Message(3)),
279            Pin::new(&mut tx).poll_send(&mut noop_context(), Message(3))
280        );
281
282        assert_eq!(
283            PollRecv::Ready(Message(1)),
284            Pin::new(&mut rx).poll_recv(&mut cx)
285        );
286
287        assert_eq!(
288            PollRecv::Ready(Message(2)),
289            Pin::new(&mut rx).poll_recv(&mut cx)
290        );
291
292        assert_eq!(
293            PollRecv::Pending,
294            Pin::new(&mut rx).poll_recv(&mut noop_context())
295        );
296    }
297
298    #[test]
299    fn sender_disconnect() {
300        let mut cx = panic_context();
301        let (mut tx, mut rx) = channel(100);
302        let mut tx2 = tx.clone();
303
304        assert_eq!(
305            PollSend::Ready,
306            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
307        );
308
309        assert_eq!(
310            PollSend::Ready,
311            Pin::new(&mut tx2).poll_send(&mut cx, Message(2))
312        );
313
314        drop(tx);
315        drop(tx2);
316
317        assert_eq!(
318            PollRecv::Ready(Message(1)),
319            Pin::new(&mut rx).poll_recv(&mut cx)
320        );
321
322        assert_eq!(
323            PollRecv::Ready(Message(2)),
324            Pin::new(&mut rx).poll_recv(&mut cx)
325        );
326
327        assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
328    }
329
330    #[test]
331    fn receiver_disconnect() {
332        let mut cx = panic_context();
333        let (mut tx, rx) = channel(100);
334        let mut tx2 = tx.clone();
335
336        assert_eq!(
337            PollSend::Ready,
338            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
339        );
340
341        assert_eq!(
342            PollSend::Ready,
343            Pin::new(&mut tx2).poll_send(&mut cx, Message(2))
344        );
345
346        drop(rx);
347
348        assert_eq!(
349            PollSend::Rejected(Message(3)),
350            Pin::new(&mut tx).poll_send(&mut cx, Message(3))
351        );
352
353        assert_eq!(
354            PollSend::Rejected(Message(4)),
355            Pin::new(&mut tx2).poll_send(&mut cx, Message(4))
356        );
357    }
358
359    #[test]
360    fn wake_sender() {
361        let mut cx = panic_context();
362        let (mut tx, mut rx) = channel(1);
363
364        assert_eq!(
365            PollSend::Ready,
366            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
367        );
368
369        let (w2, w2_count) = new_count_waker();
370        let w2_context = Context::from_waker(&w2);
371        assert_eq!(
372            PollSend::Pending(Message(2)),
373            Pin::new(&mut tx).poll_send(&mut w2_context.into(), Message(2))
374        );
375
376        assert_eq!(0, w2_count.get());
377
378        assert_eq!(
379            PollRecv::Ready(Message(1)),
380            Pin::new(&mut rx).poll_recv(&mut cx)
381        );
382
383        assert_eq!(1, w2_count.get());
384        assert_eq!(
385            PollRecv::Pending,
386            Pin::new(&mut rx).poll_recv(&mut noop_context())
387        );
388
389        assert_eq!(1, w2_count.get());
390    }
391
392    #[test]
393    fn wake_receiver() {
394        let mut cx = panic_context();
395        let (mut tx, mut rx) = channel(100);
396
397        let (w1, w1_count) = new_count_waker();
398        let w1_context = Context::from_waker(&w1);
399
400        assert_eq!(
401            PollRecv::Pending,
402            Pin::new(&mut rx).poll_recv(&mut w1_context.into())
403        );
404
405        assert_eq!(0, w1_count.get());
406
407        assert_eq!(
408            PollSend::Ready,
409            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
410        );
411
412        assert_eq!(1, w1_count.get());
413
414        assert_eq!(
415            PollSend::Ready,
416            Pin::new(&mut tx).poll_send(&mut cx, Message(2))
417        );
418
419        assert_eq!(1, w1_count.get());
420    }
421
422    #[test]
423    fn wake_sender_on_disconnect() {
424        let (mut tx, rx) = channel(1);
425
426        let (w1, w1_count) = new_count_waker();
427        let w1_context = Context::from_waker(&w1);
428        let mut w1_context: crate::Context<'_> = w1_context.into();
429
430        assert_eq!(
431            PollSend::Ready,
432            Pin::new(&mut tx).poll_send(&mut w1_context, Message(1))
433        );
434
435        assert_eq!(
436            PollSend::Pending(Message(2)),
437            Pin::new(&mut tx).poll_send(&mut w1_context, Message(2))
438        );
439
440        assert_eq!(0, w1_count.get());
441
442        drop(rx);
443
444        assert_eq!(1, w1_count.get());
445    }
446
447    #[test]
448    fn wake_receiver_on_disconnect() {
449        let (tx, mut rx) = channel::<()>(100);
450
451        let (w1, w1_count) = new_count_waker();
452        let w1_context = Context::from_waker(&w1);
453
454        assert_eq!(
455            PollRecv::Pending,
456            Pin::new(&mut rx).poll_recv(&mut w1_context.into())
457        );
458
459        assert_eq!(0, w1_count.get());
460
461        drop(tx);
462
463        assert_eq!(1, w1_count.get());
464    }
465}
466
467#[cfg(test)]
468mod tokio_tests {
469    use std::time::Duration;
470
471    use tokio::{task::spawn, time::timeout};
472
473    use crate::{
474        sink::Sink,
475        stream::Stream,
476        test::{capacity_iter, Channel, Channels, Message, CHANNEL_TEST_SENDERS, TEST_TIMEOUT},
477    };
478
479    #[tokio::test(flavor = "multi_thread")]
480    async fn simple() {
481        // crate::logging::enable_log();
482
483        for cap in capacity_iter() {
484            let (mut tx, mut rx) = super::channel(cap);
485
486            let join = spawn(async move {
487                for message in Message::new_iter(0) {
488                    tx.send(message).await.expect("send failed");
489                }
490            });
491
492            let rx_handle = spawn(async move {
493                let mut channel = Channel::new(0);
494                while let Some(message) = rx.recv().await {
495                    channel.assert_message(&message);
496                }
497                join.await.expect("Join failed");
498            });
499
500            timeout(TEST_TIMEOUT, rx_handle)
501                .await
502                .expect("test timeout")
503                .expect("join error");
504        }
505    }
506
507    #[tokio::test(flavor = "multi_thread")]
508    async fn multi_sender() {
509        for cap in capacity_iter() {
510            let (tx, mut rx) = super::channel(cap);
511
512            for i in 0..CHANNEL_TEST_SENDERS {
513                let mut tx2 = tx.clone();
514                spawn(async move {
515                    for message in Message::new_multi_sender(i) {
516                        tx2.send(message).await.expect("send failed");
517                    }
518                });
519            }
520
521            drop(tx);
522
523            let rx_handle = spawn(async move {
524                let mut channel = Channels::new(CHANNEL_TEST_SENDERS);
525                while let Some(message) = rx.recv().await {
526                    channel.assert_message(&message);
527                }
528            });
529
530            timeout(TEST_TIMEOUT, rx_handle)
531                .await
532                .expect("test timeout")
533                .expect("join error");
534        }
535    }
536
537    #[tokio::test(flavor = "multi_thread")]
538    async fn clone_monster() {
539        for cap in capacity_iter() {
540            // SimpleLogger::new()
541            //     .with_level(LevelFilter::Warn)
542            //     .init()
543            //     .unwrap();
544
545            let (tx, mut rx) = super::channel(cap);
546            let (mut barrier, mut sender_quit) = crate::barrier::channel();
547
548            let mut tx2 = tx.clone();
549            spawn(async move {
550                for message in Message::new_iter(0) {
551                    tx2.send(message).await.expect("send failed");
552                }
553
554                barrier.send(()).await.expect("clone task shutdown failed");
555            });
556
557            spawn(async move {
558                loop {
559                    if let Ok(_) = sender_quit.try_recv() {
560                        break;
561                    }
562
563                    let tx2 = tx.clone();
564                    tokio::time::sleep(Duration::from_micros(100)).await;
565                    drop(tx2);
566
567                    tokio::time::sleep(Duration::from_micros(50)).await;
568                }
569            });
570
571            let rx_handle = spawn(async move {
572                let mut channel = Channel::new(0);
573
574                while let Some(message) = rx.recv().await {
575                    channel.assert_message(&message);
576                }
577            });
578
579            timeout(TEST_TIMEOUT, rx_handle)
580                .await
581                .expect("test timeout")
582                .expect("join failed");
583        }
584    }
585}
586
587#[cfg(test)]
588mod async_std_tests {
589    use std::time::Duration;
590
591    use async_std::{
592        future::timeout,
593        task::{self, spawn},
594    };
595
596    use crate::{
597        sink::Sink,
598        stream::Stream,
599        test::{capacity_iter, Channel, Channels, Message, CHANNEL_TEST_SENDERS, TEST_TIMEOUT},
600    };
601
602    #[async_std::test]
603    async fn simple() {
604        for cap in capacity_iter() {
605            let (mut tx, mut rx) = super::channel(cap);
606
607            spawn(async move {
608                for message in Message::new_iter(0) {
609                    tx.send(message).await.expect("send failed");
610                }
611            });
612
613            let rx_handle = spawn(async move {
614                let mut channel = Channel::new(0);
615                while let Some(message) = rx.recv().await {
616                    channel.assert_message(&message);
617                }
618            });
619
620            timeout(TEST_TIMEOUT, rx_handle)
621                .await
622                .expect("test timeout");
623        }
624    }
625
626    #[async_std::test]
627    async fn multi_sender() {
628        for cap in capacity_iter() {
629            let (tx, mut rx) = super::channel(cap);
630
631            for i in 0..CHANNEL_TEST_SENDERS {
632                let mut tx2 = tx.clone();
633                spawn(async move {
634                    for message in Message::new_multi_sender(i) {
635                        tx2.send(message).await.expect("send failed");
636                    }
637                });
638            }
639
640            drop(tx);
641
642            let rx_handle = spawn(async move {
643                let mut channel = Channels::new(CHANNEL_TEST_SENDERS);
644                while let Some(message) = rx.recv().await {
645                    channel.assert_message(&message);
646                }
647            });
648
649            timeout(TEST_TIMEOUT, rx_handle)
650                .await
651                .expect("test timeout");
652        }
653    }
654
655    #[tokio::test(flavor = "multi_thread")]
656    async fn clone_monster() {
657        // crate::logging::enable_log();
658
659        for cap in capacity_iter() {
660            let (tx, mut rx) = super::channel(cap);
661            let (mut barrier, mut sender_quit) = crate::barrier::channel();
662
663            let mut tx2 = tx.clone();
664            spawn(async move {
665                for message in Message::new_iter(0) {
666                    tx2.send(message).await.expect("send failed");
667                }
668
669                barrier.send(()).await.expect("clone task shutdown failed");
670            });
671
672            spawn(async move {
673                loop {
674                    if let Ok(_) = sender_quit.try_recv() {
675                        break;
676                    }
677
678                    let tx2 = tx.clone();
679                    task::sleep(Duration::from_micros(100)).await;
680                    drop(tx2);
681                    task::sleep(Duration::from_micros(50)).await;
682                }
683            });
684
685            let rx_handle = spawn(async move {
686                let mut channel = Channel::new(0);
687
688                while let Some(message) = rx.recv().await {
689                    channel.assert_message(&message);
690                }
691            });
692
693            timeout(TEST_TIMEOUT, rx_handle)
694                .await
695                .expect("test timeout");
696        }
697    }
698}