postage/channels/
barrier.rs

1//! Barriers transmit when the sender half is dropped, and can synchronize events in async tasks.
2//!
3//! The barrier can also be triggered with `tx.send(())`.
4
5use std::fmt;
6use std::sync::Arc;
7
8use atomic::{Atomic, Ordering};
9use static_assertions::{assert_impl_all, assert_not_impl_all};
10
11use crate::{
12    sink::{PollSend, Sink},
13    stream::{PollRecv, Stream},
14    sync::notifier::Notifier,
15};
16
17/// Constructs a pair of barrier endpoints, which transmits when the sender is dropped.
18pub fn channel() -> (Sender, Receiver) {
19    #[cfg(feature = "debug")]
20    log::error!("Creating barrier channel");
21    let shared = Arc::new(Shared {
22        state: Atomic::new(State::Pending),
23        notify_rx: Notifier::new(),
24    });
25
26    let sender = Sender {
27        shared: shared.clone(),
28    };
29
30    let receiver = Receiver { shared };
31
32    (sender, receiver)
33}
34
35/// The sender half of a barrier channel.  Dropping the sender transmits to the receiver.
36///
37/// Can also be triggered by sending `()` with the postage::Sink trait.
38pub struct Sender {
39    pub(in crate::channels::barrier) shared: Arc<Shared>,
40}
41
42assert_impl_all!(Sender: Send, Sync, fmt::Debug);
43assert_not_impl_all!(Sender: Clone);
44
45impl Sink for Sender {
46    type Item = ();
47
48    fn poll_send(
49        self: std::pin::Pin<&mut Self>,
50        _cx: &mut crate::Context<'_>,
51        _value: (),
52    ) -> PollSend<Self::Item> {
53        match self.shared.state.load(Ordering::Acquire) {
54            State::Pending => {
55                self.shared.close();
56                PollSend::Ready
57            }
58            State::Sent => PollSend::Rejected(()),
59        }
60    }
61}
62
63impl fmt::Debug for Sender {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("Sender").finish()
66    }
67}
68
69#[cfg(feature = "futures-traits")]
70mod impl_futures {
71    use super::State;
72    use crate::sink::SendError;
73    use atomic::Ordering;
74    use std::task::{Context, Poll};
75
76    impl futures::sink::Sink<()> for super::Sender {
77        type Error = SendError<()>;
78
79        fn poll_ready(
80            self: std::pin::Pin<&mut Self>,
81            _cx: &mut Context<'_>,
82        ) -> std::task::Poll<Result<(), Self::Error>> {
83            match self.shared.state.load(Ordering::Acquire) {
84                State::Pending => Poll::Ready(Ok(())),
85                State::Sent => Poll::Ready(Err(SendError(()))),
86            }
87        }
88
89        fn start_send(self: std::pin::Pin<&mut Self>, _item: ()) -> Result<(), Self::Error> {
90            match self.shared.state.load(Ordering::Acquire) {
91                State::Pending => {
92                    self.shared.close();
93                    Ok(())
94                }
95                State::Sent => Err(SendError(())),
96            }
97        }
98
99        fn poll_flush(
100            self: std::pin::Pin<&mut Self>,
101            _cx: &mut Context<'_>,
102        ) -> Poll<Result<(), Self::Error>> {
103            Poll::Ready(Ok(()))
104        }
105
106        fn poll_close(
107            self: std::pin::Pin<&mut Self>,
108            _cx: &mut Context<'_>,
109        ) -> Poll<Result<(), Self::Error>> {
110            Poll::Ready(Ok(()))
111        }
112    }
113}
114
115impl Drop for Sender {
116    fn drop(&mut self) {
117        self.shared.close();
118    }
119}
120
121/// A barrier reciever.  Can be used with the postage::Stream trait to return a `()` value when the Sender is dropped.
122#[derive(Clone)]
123pub struct Receiver {
124    pub(in crate::channels::barrier) shared: Arc<Shared>,
125}
126
127assert_impl_all!(Receiver: Clone, Send, Sync, fmt::Debug);
128
129#[derive(Copy, Clone)]
130enum State {
131    Pending,
132    Sent,
133}
134
135struct Shared {
136    state: Atomic<State>,
137    notify_rx: Notifier,
138}
139
140impl Shared {
141    pub fn close(&self) {
142        self.state.store(State::Sent, Ordering::Release);
143        self.notify_rx.notify();
144    }
145}
146
147impl Stream for Receiver {
148    type Item = ();
149
150    fn poll_recv(
151        self: std::pin::Pin<&mut Self>,
152        cx: &mut crate::Context<'_>,
153    ) -> PollRecv<Self::Item> {
154        match self.shared.state.load(Ordering::Acquire) {
155            State::Pending => {
156                self.shared.notify_rx.subscribe(cx);
157
158                if let State::Sent = self.shared.state.load(Ordering::Acquire) {
159                    return PollRecv::Ready(());
160                }
161
162                PollRecv::Pending
163            }
164            State::Sent => PollRecv::Ready(()),
165        }
166    }
167}
168
169impl fmt::Debug for Receiver {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        f.debug_struct("Receiver").finish()
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use std::{pin::Pin, task::Context};
178
179    use crate::{
180        sink::{PollSend, Sink},
181        stream::{PollRecv, Stream},
182        test::{noop_context, panic_context},
183    };
184    use futures_test::task::new_count_waker;
185
186    use super::channel;
187
188    #[test]
189    fn send_accepted() {
190        let mut cx = noop_context();
191        let (mut tx, _rx) = channel();
192
193        assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
194        assert_eq!(
195            PollSend::Rejected(()),
196            Pin::new(&mut tx).poll_send(&mut cx, ())
197        );
198    }
199
200    #[test]
201    fn send_recv() {
202        let mut cx = noop_context();
203        let (mut tx, mut rx) = channel();
204
205        assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
206
207        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
208        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
209    }
210
211    #[test]
212    fn sender_disconnect() {
213        let mut cx = noop_context();
214        let (tx, mut rx) = channel();
215
216        drop(tx);
217
218        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
219    }
220
221    #[test]
222    fn send_then_disconnect() {
223        let mut cx = noop_context();
224        let (mut tx, mut rx) = channel();
225
226        assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
227
228        drop(tx);
229
230        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
231        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
232    }
233
234    #[test]
235    fn receiver_disconnect() {
236        let mut cx = noop_context();
237        let (mut tx, rx) = channel();
238
239        drop(rx);
240
241        assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
242    }
243
244    #[test]
245    fn receiver_clone() {
246        let mut cx = noop_context();
247        let (mut tx, mut rx) = channel();
248        let mut rx2 = rx.clone();
249
250        assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
251
252        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
253        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx2).poll_recv(&mut cx));
254    }
255
256    #[test]
257    fn receiver_send_then_clone() {
258        let mut cx = noop_context();
259        let (mut tx, mut rx) = channel();
260
261        assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
262
263        let mut rx2 = rx.clone();
264
265        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
266        assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx2).poll_recv(&mut cx));
267    }
268
269    #[test]
270    fn wake_receiver() {
271        let mut cx = panic_context();
272        let (mut tx, mut rx) = channel();
273
274        let (w, w_count) = new_count_waker();
275        let w_context = Context::from_waker(&w);
276
277        assert_eq!(
278            PollRecv::Pending,
279            Pin::new(&mut rx).poll_recv(&mut w_context.into())
280        );
281
282        assert_eq!(0, w_count.get());
283
284        assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
285
286        assert_eq!(1, w_count.get());
287
288        assert_eq!(
289            PollSend::Rejected(()),
290            Pin::new(&mut tx).poll_send(&mut cx, ())
291        );
292
293        assert_eq!(1, w_count.get());
294    }
295
296    #[test]
297    fn wake_receiver_on_disconnect() {
298        let (tx, mut rx) = channel();
299
300        let (w1, w1_count) = new_count_waker();
301        let w1_context = Context::from_waker(&w1);
302
303        assert_eq!(
304            PollRecv::Pending,
305            Pin::new(&mut rx).poll_recv(&mut w1_context.into())
306        );
307
308        assert_eq!(0, w1_count.get());
309
310        drop(tx);
311
312        assert_eq!(1, w1_count.get());
313    }
314}
315
316#[cfg(test)]
317mod tokio_tests {
318    use std::time::Duration;
319
320    use tokio::{task::spawn, time::timeout};
321
322    use crate::{
323        sink::Sink,
324        stream::Stream,
325        test::{CHANNEL_TEST_ITERATIONS, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
326    };
327
328    use super::Receiver;
329
330    async fn assert_rx(mut rx: Receiver) {
331        if let Err(_e) = timeout(Duration::from_millis(100), rx.recv()).await {
332            panic!("Timeout waiting for barrier");
333        }
334    }
335
336    #[tokio::test]
337    async fn simple() {
338        for _ in 0..CHANNEL_TEST_ITERATIONS {
339            let (mut tx, rx) = super::channel();
340
341            spawn(async move {
342                tx.send(()).await.expect("Should send message");
343            });
344
345            timeout(TEST_TIMEOUT, async move {
346                assert_rx(rx).await;
347            })
348            .await
349            .expect("test timeout");
350        }
351    }
352
353    #[tokio::test]
354    async fn simple_drop() {
355        for _ in 0..CHANNEL_TEST_ITERATIONS {
356            let (tx, rx) = super::channel();
357
358            spawn(async move {
359                drop(tx);
360            });
361
362            timeout(TEST_TIMEOUT, async move {
363                assert_rx(rx).await;
364            })
365            .await
366            .expect("test timeout");
367        }
368    }
369
370    #[tokio::test]
371    async fn multi_receiver() {
372        for _ in 0..CHANNEL_TEST_ITERATIONS {
373            let (tx, rx) = super::channel();
374
375            let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
376                let rx2 = rx.clone();
377
378                spawn(async move {
379                    assert_rx(rx2).await;
380                })
381            });
382
383            spawn(async move {
384                drop(tx);
385            });
386
387            let rx_handle = spawn(async move {
388                for handle in handles {
389                    handle.await.expect("Assertion failure");
390                }
391            });
392
393            timeout(TEST_TIMEOUT, rx_handle)
394                .await
395                .expect("test timeout")
396                .expect("join error");
397        }
398    }
399}
400
401#[cfg(test)]
402mod async_std_tests {
403    use std::time::Duration;
404
405    use async_std::{future::timeout, task::spawn};
406
407    use crate::{
408        sink::Sink,
409        stream::Stream,
410        test::{CHANNEL_TEST_ITERATIONS, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
411    };
412
413    use super::Receiver;
414
415    async fn assert_rx(mut rx: Receiver) {
416        if let Err(_e) = timeout(Duration::from_millis(100), rx.recv()).await {
417            panic!("Timeout waiting for barrier");
418        }
419    }
420
421    #[async_std::test]
422    async fn simple() {
423        for _ in 0..CHANNEL_TEST_ITERATIONS {
424            let (mut tx, rx) = super::channel();
425
426            spawn(async move {
427                tx.send(()).await.expect("Should send message");
428            });
429
430            timeout(TEST_TIMEOUT, async move {
431                assert_rx(rx).await;
432            })
433            .await
434            .expect("test timeout");
435        }
436    }
437
438    #[async_std::test]
439    async fn simple_drop() {
440        // crate::logging::enable_log();
441
442        for _ in 0..CHANNEL_TEST_ITERATIONS {
443            let (tx, rx) = super::channel();
444
445            spawn(async move {
446                drop(tx);
447            });
448
449            timeout(TEST_TIMEOUT, async move {
450                assert_rx(rx).await;
451            })
452            .await
453            .expect("test timeout");
454        }
455    }
456
457    #[async_std::test]
458    async fn multi_receiver() {
459        for _ in 0..CHANNEL_TEST_ITERATIONS {
460            let (tx, rx) = super::channel();
461
462            let handles = (0..CHANNEL_TEST_RECEIVERS).map(|_| {
463                let rx2 = rx.clone();
464
465                spawn(async move {
466                    assert_rx(rx2).await;
467                })
468            });
469
470            drop(tx);
471
472            timeout(TEST_TIMEOUT, async move {
473                for handle in handles {
474                    handle.await;
475                }
476            })
477            .await
478            .expect("test timeout");
479        }
480    }
481}