postage/channels/
oneshot.rs

1//! Oneshot channels transmit a single value between a sender and a reciever.  
2//!
3//! Neither can be cloned.  If the sender drops, the receiver recieves a `None` value.
4use std::fmt;
5use std::sync::Arc;
6
7use super::SendMessage;
8use crate::{
9    sink::{PollSend, Sink},
10    stream::{PollRecv, Stream},
11    sync::transfer::Transfer,
12};
13use static_assertions::{assert_impl_all, assert_not_impl_all};
14
15/// Constructs a pair of oneshot endpoints
16pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
17    #[cfg(feature = "debug")]
18    log::error!("Creating oneshot channel");
19
20    let shared = Arc::new(Transfer::new());
21    let sender = Sender {
22        shared: shared.clone(),
23    };
24
25    let receiver = Receiver { shared };
26
27    (sender, receiver)
28}
29
30/// The sender half of a oneshot channel.  Can transmit a single message with the postage::Sink trait.
31pub struct Sender<T> {
32    pub(in crate::channels::oneshot) shared: Arc<Transfer<T>>,
33}
34
35assert_impl_all!(Sender<SendMessage>: Send, Sync, fmt::Debug);
36assert_not_impl_all!(Sender<SendMessage>: Clone);
37
38impl<T> Sink for Sender<T> {
39    type Item = T;
40
41    fn poll_send(
42        self: std::pin::Pin<&mut Self>,
43        _cx: &mut crate::Context<'_>,
44        value: Self::Item,
45    ) -> PollSend<Self::Item> {
46        match self.shared.send(value) {
47            Ok(_) => PollSend::Ready,
48            Err(v) => PollSend::Rejected(v),
49        }
50    }
51}
52
53impl<T> fmt::Debug for Sender<T> {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        f.debug_struct("Sender").finish()
56    }
57}
58
59#[cfg(feature = "futures-traits")]
60mod impl_futures {
61    use crate::sink::SendError;
62    use std::task::Poll;
63
64    impl<T> futures::sink::Sink<T> for super::Sender<T> {
65        type Error = crate::sink::SendError<T>;
66
67        fn poll_ready(
68            self: std::pin::Pin<&mut Self>,
69            _cx: &mut std::task::Context<'_>,
70        ) -> Poll<Result<(), Self::Error>> {
71            Poll::Ready(Ok(()))
72        }
73
74        fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
75            self.shared.send(item).map_err(|t| SendError(t))
76        }
77
78        fn poll_flush(
79            self: std::pin::Pin<&mut Self>,
80            _cx: &mut std::task::Context<'_>,
81        ) -> Poll<Result<(), Self::Error>> {
82            Poll::Ready(Ok(()))
83        }
84
85        fn poll_close(
86            self: std::pin::Pin<&mut Self>,
87            _cx: &mut std::task::Context<'_>,
88        ) -> Poll<Result<(), Self::Error>> {
89            Poll::Ready(Ok(()))
90        }
91    }
92}
93
94impl<T> Drop for Sender<T> {
95    fn drop(&mut self) {
96        self.shared.sender_disconnect();
97    }
98}
99
100/// The receiver half of a oneshot channel.  Can recieve a single message (or none if the sender drops) with the postage::Stream trait.
101pub struct Receiver<T> {
102    pub(in crate::channels::oneshot) shared: Arc<Transfer<T>>,
103}
104
105assert_impl_all!(Sender<SendMessage>: Send, Sync, fmt::Debug);
106assert_not_impl_all!(Sender<SendMessage>: Clone);
107
108impl<T> Stream for Receiver<T> {
109    type Item = T;
110
111    fn poll_recv(
112        self: std::pin::Pin<&mut Self>,
113        cx: &mut crate::Context<'_>,
114    ) -> PollRecv<Self::Item> {
115        self.shared.recv(cx)
116    }
117}
118
119impl<T> Drop for Receiver<T> {
120    fn drop(&mut self) {
121        self.shared.receiver_disconnect();
122    }
123}
124
125impl<T> fmt::Debug for Receiver<T> {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("Receiver").finish()
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::pin::Pin;
134
135    use crate::{
136        sink::{PollSend, Sink},
137        stream::{PollRecv, Stream},
138        test::{noop_context, panic_context},
139        Context,
140    };
141    use futures_test::task::new_count_waker;
142
143    use super::channel;
144
145    #[derive(Clone, Debug, PartialEq, Eq)]
146    struct Message(usize);
147
148    #[test]
149    fn send_accepted() {
150        let mut cx = noop_context();
151        let (mut tx, _rx) = channel();
152
153        assert_eq!(
154            PollSend::Ready,
155            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
156        );
157        assert_eq!(
158            PollSend::Rejected(Message(2)),
159            Pin::new(&mut tx).poll_send(&mut cx, Message(2))
160        );
161    }
162
163    #[test]
164    fn send_recv() {
165        let mut cx = noop_context();
166        let (mut tx, mut rx) = channel();
167
168        assert_eq!(
169            PollSend::Ready,
170            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
171        );
172
173        assert_eq!(
174            PollRecv::Ready(Message(1)),
175            Pin::new(&mut rx).poll_recv(&mut cx)
176        );
177        assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
178    }
179
180    #[test]
181    fn sender_disconnect() {
182        let mut cx = noop_context();
183        let (tx, mut rx) = channel::<Message>();
184
185        drop(tx);
186
187        assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
188    }
189
190    #[test]
191    fn sender_disconnect_after_poll() {
192        let mut cx = noop_context();
193        let (tx, mut rx) = channel::<Message>();
194
195        assert_eq!(PollRecv::Pending, Pin::new(&mut rx).poll_recv(&mut cx));
196
197        drop(tx);
198        assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
199    }
200
201    #[test]
202    fn send_then_disconnect() {
203        let mut cx = noop_context();
204        let (mut tx, mut rx) = channel();
205
206        assert_eq!(
207            PollSend::Ready,
208            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
209        );
210
211        drop(tx);
212
213        assert_eq!(
214            PollRecv::Ready(Message(1)),
215            Pin::new(&mut rx).poll_recv(&mut cx)
216        );
217
218        assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
219    }
220
221    #[test]
222    fn receiver_disconnect() {
223        let mut cx = noop_context();
224        let (mut tx, rx) = channel();
225
226        drop(rx);
227
228        assert_eq!(
229            PollSend::Rejected(Message(1)),
230            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
231        );
232    }
233
234    #[test]
235    fn wake_receiver() {
236        let mut cx = panic_context();
237        let (mut tx, mut rx) = channel();
238
239        let (w1, w1_count) = new_count_waker();
240        let mut w1_context = Context::from_waker(&w1);
241
242        assert_eq!(
243            PollRecv::Pending,
244            Pin::new(&mut rx).poll_recv(&mut w1_context)
245        );
246
247        assert_eq!(0, w1_count.get());
248
249        assert_eq!(
250            PollSend::Ready,
251            Pin::new(&mut tx).poll_send(&mut cx, Message(1))
252        );
253
254        assert_eq!(1, w1_count.get());
255
256        assert_eq!(
257            PollSend::Rejected(Message(2)),
258            Pin::new(&mut tx).poll_send(&mut cx, Message(2))
259        );
260
261        assert_eq!(1, w1_count.get());
262    }
263
264    #[test]
265    fn sender_disconnect_wakes_receiver() {
266        let (tx, mut rx) = channel::<usize>();
267
268        let (w1, w1_count) = new_count_waker();
269        let mut w1_context = Context::from_waker(&w1);
270
271        assert_eq!(
272            PollRecv::Pending,
273            Pin::new(&mut rx).poll_recv(&mut w1_context)
274        );
275
276        assert_eq!(0, w1_count.get());
277
278        drop(tx);
279
280        assert_eq!(1, w1_count.get());
281
282        assert_eq!(
283            PollRecv::Closed,
284            Pin::new(&mut rx).poll_recv(&mut w1_context)
285        );
286    }
287}
288
289#[cfg(test)]
290mod tokio_tests {
291    use std::time::Duration;
292
293    use tokio::{task::spawn, time::timeout};
294
295    use crate::{
296        sink::Sink,
297        stream::Stream,
298        test::{CHANNEL_TEST_ITERATIONS, TEST_TIMEOUT},
299    };
300
301    use super::channel;
302
303    #[tokio::test]
304    async fn simple() {
305        for _ in 0..CHANNEL_TEST_ITERATIONS {
306            let (mut tx, mut rx) = channel();
307
308            spawn(async move { tx.send(100usize).await });
309
310            let msg = timeout(TEST_TIMEOUT, async move { rx.recv().await })
311                .await
312                .expect("test timeout");
313
314            assert_eq!(Some(100usize), msg);
315        }
316    }
317
318    #[tokio::test]
319    async fn sender_disconnect() {
320        for _ in 0..CHANNEL_TEST_ITERATIONS {
321            let (tx, mut rx) = channel::<usize>();
322
323            spawn(async move { drop(tx) });
324
325            let msg = timeout(Duration::from_millis(100), async move { rx.recv().await })
326                .await
327                .expect("test timeout");
328
329            assert_eq!(None, msg);
330        }
331    }
332}
333
334#[cfg(test)]
335mod async_std_tests {
336    use std::time::Duration;
337
338    use async_std::{future::timeout, task::spawn};
339
340    use crate::{
341        sink::Sink,
342        stream::Stream,
343        test::{CHANNEL_TEST_ITERATIONS, TEST_TIMEOUT},
344    };
345
346    use super::channel;
347
348    #[async_std::test]
349    async fn simple() {
350        for i in 0..CHANNEL_TEST_ITERATIONS {
351            let (mut tx, mut rx) = channel();
352
353            spawn(async move { tx.send(i).await });
354
355            let msg = timeout(TEST_TIMEOUT, async move { rx.recv().await })
356                .await
357                .expect("test timeout");
358
359            assert_eq!(Some(i), msg);
360        }
361    }
362
363    #[async_std::test]
364    async fn sender_disconnect() {
365        for _ in 0..CHANNEL_TEST_ITERATIONS {
366            let (tx, mut rx) = channel::<usize>();
367
368            spawn(async move { drop(tx) });
369
370            let msg = timeout(Duration::from_millis(100), async move { rx.recv().await })
371                .await
372                .expect("test timeout");
373
374            assert_eq!(None, msg);
375        }
376    }
377}