1use std::fmt;
6use std::sync::Arc;
7
8use atomic::{Atomic, Ordering};
9use static_assertions::{assert_impl_all, assert_not_impl_all};
10
11use crate::{
12 sink::{PollSend, Sink},
13 stream::{PollRecv, Stream},
14 sync::notifier::Notifier,
15};
16
17pub fn channel() -> (Sender, Receiver) {
19 #[cfg(feature = "debug")]
20 log::error!("Creating barrier channel");
21 let shared = Arc::new(Shared {
22 state: Atomic::new(State::Pending),
23 notify_rx: Notifier::new(),
24 });
25
26 let sender = Sender {
27 shared: shared.clone(),
28 };
29
30 let receiver = Receiver { shared };
31
32 (sender, receiver)
33}
34
35pub struct Sender {
39 pub(in crate::channels::barrier) shared: Arc<Shared>,
40}
41
42assert_impl_all!(Sender: Send, Sync, fmt::Debug);
43assert_not_impl_all!(Sender: Clone);
44
45impl Sink for Sender {
46 type Item = ();
47
48 fn poll_send(
49 self: std::pin::Pin<&mut Self>,
50 _cx: &mut crate::Context<'_>,
51 _value: (),
52 ) -> PollSend<Self::Item> {
53 match self.shared.state.load(Ordering::Acquire) {
54 State::Pending => {
55 self.shared.close();
56 PollSend::Ready
57 }
58 State::Sent => PollSend::Rejected(()),
59 }
60 }
61}
62
63impl fmt::Debug for Sender {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("Sender").finish()
66 }
67}
68
69#[cfg(feature = "futures-traits")]
70mod impl_futures {
71 use super::State;
72 use crate::sink::SendError;
73 use atomic::Ordering;
74 use std::task::{Context, Poll};
75
76 impl futures::sink::Sink<()> for super::Sender {
77 type Error = SendError<()>;
78
79 fn poll_ready(
80 self: std::pin::Pin<&mut Self>,
81 _cx: &mut Context<'_>,
82 ) -> std::task::Poll<Result<(), Self::Error>> {
83 match self.shared.state.load(Ordering::Acquire) {
84 State::Pending => Poll::Ready(Ok(())),
85 State::Sent => Poll::Ready(Err(SendError(()))),
86 }
87 }
88
89 fn start_send(self: std::pin::Pin<&mut Self>, _item: ()) -> Result<(), Self::Error> {
90 match self.shared.state.load(Ordering::Acquire) {
91 State::Pending => {
92 self.shared.close();
93 Ok(())
94 }
95 State::Sent => Err(SendError(())),
96 }
97 }
98
99 fn poll_flush(
100 self: std::pin::Pin<&mut Self>,
101 _cx: &mut Context<'_>,
102 ) -> Poll<Result<(), Self::Error>> {
103 Poll::Ready(Ok(()))
104 }
105
106 fn poll_close(
107 self: std::pin::Pin<&mut Self>,
108 _cx: &mut Context<'_>,
109 ) -> Poll<Result<(), Self::Error>> {
110 Poll::Ready(Ok(()))
111 }
112 }
113}
114
115impl Drop for Sender {
116 fn drop(&mut self) {
117 self.shared.close();
118 }
119}
120
121#[derive(Clone)]
123pub struct Receiver {
124 pub(in crate::channels::barrier) shared: Arc<Shared>,
125}
126
127assert_impl_all!(Receiver: Clone, Send, Sync, fmt::Debug);
128
129#[derive(Copy, Clone)]
130enum State {
131 Pending,
132 Sent,
133}
134
135struct Shared {
136 state: Atomic<State>,
137 notify_rx: Notifier,
138}
139
140impl Shared {
141 pub fn close(&self) {
142 self.state.store(State::Sent, Ordering::Release);
143 self.notify_rx.notify();
144 }
145}
146
147impl Stream for Receiver {
148 type Item = ();
149
150 fn poll_recv(
151 self: std::pin::Pin<&mut Self>,
152 cx: &mut crate::Context<'_>,
153 ) -> PollRecv<Self::Item> {
154 match self.shared.state.load(Ordering::Acquire) {
155 State::Pending => {
156 self.shared.notify_rx.subscribe(cx);
157
158 if let State::Sent = self.shared.state.load(Ordering::Acquire) {
159 return PollRecv::Ready(());
160 }
161
162 PollRecv::Pending
163 }
164 State::Sent => PollRecv::Ready(()),
165 }
166 }
167}
168
169impl fmt::Debug for Receiver {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("Receiver").finish()
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use std::{pin::Pin, task::Context};
178
179 use crate::{
180 sink::{PollSend, Sink},
181 stream::{PollRecv, Stream},
182 test::{noop_context, panic_context},
183 };
184 use futures_test::task::new_count_waker;
185
186 use super::channel;
187
188 #[test]
189 fn send_accepted() {
190 let mut cx = noop_context();
191 let (mut tx, _rx) = channel();
192
193 assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
194 assert_eq!(
195 PollSend::Rejected(()),
196 Pin::new(&mut tx).poll_send(&mut cx, ())
197 );
198 }
199
200 #[test]
201 fn send_recv() {
202 let mut cx = noop_context();
203 let (mut tx, mut rx) = channel();
204
205 assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
206
207 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
208 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
209 }
210
211 #[test]
212 fn sender_disconnect() {
213 let mut cx = noop_context();
214 let (tx, mut rx) = channel();
215
216 drop(tx);
217
218 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
219 }
220
221 #[test]
222 fn send_then_disconnect() {
223 let mut cx = noop_context();
224 let (mut tx, mut rx) = channel();
225
226 assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
227
228 drop(tx);
229
230 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
231 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
232 }
233
234 #[test]
235 fn receiver_disconnect() {
236 let mut cx = noop_context();
237 let (mut tx, rx) = channel();
238
239 drop(rx);
240
241 assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
242 }
243
244 #[test]
245 fn receiver_clone() {
246 let mut cx = noop_context();
247 let (mut tx, mut rx) = channel();
248 let mut rx2 = rx.clone();
249
250 assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
251
252 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
253 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx2).poll_recv(&mut cx));
254 }
255
256 #[test]
257 fn receiver_send_then_clone() {
258 let mut cx = noop_context();
259 let (mut tx, mut rx) = channel();
260
261 assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
262
263 let mut rx2 = rx.clone();
264
265 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx).poll_recv(&mut cx));
266 assert_eq!(PollRecv::Ready(()), Pin::new(&mut rx2).poll_recv(&mut cx));
267 }
268
269 #[test]
270 fn wake_receiver() {
271 let mut cx = panic_context();
272 let (mut tx, mut rx) = channel();
273
274 let (w, w_count) = new_count_waker();
275 let w_context = Context::from_waker(&w);
276
277 assert_eq!(
278 PollRecv::Pending,
279 Pin::new(&mut rx).poll_recv(&mut w_context.into())
280 );
281
282 assert_eq!(0, w_count.get());
283
284 assert_eq!(PollSend::Ready, Pin::new(&mut tx).poll_send(&mut cx, ()));
285
286 assert_eq!(1, w_count.get());
287
288 assert_eq!(
289 PollSend::Rejected(()),
290 Pin::new(&mut tx).poll_send(&mut cx, ())
291 );
292
293 assert_eq!(1, w_count.get());
294 }
295
296 #[test]
297 fn wake_receiver_on_disconnect() {
298 let (tx, mut rx) = channel();
299
300 let (w1, w1_count) = new_count_waker();
301 let w1_context = Context::from_waker(&w1);
302
303 assert_eq!(
304 PollRecv::Pending,
305 Pin::new(&mut rx).poll_recv(&mut w1_context.into())
306 );
307
308 assert_eq!(0, w1_count.get());
309
310 drop(tx);
311
312 assert_eq!(1, w1_count.get());
313 }
314}
315
316#[cfg(test)]
317mod tokio_tests {
318 use std::time::Duration;
319
320 use tokio::{task::spawn, time::timeout};
321
322 use crate::{
323 sink::Sink,
324 stream::Stream,
325 test::{CHANNEL_TEST_ITERATIONS, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
326 };
327
328 use super::Receiver;
329
330 async fn assert_rx(mut rx: Receiver) {
331 if let Err(_e) = timeout(Duration::from_millis(100), rx.recv()).await {
332 panic!("Timeout waiting for barrier");
333 }
334 }
335
336 #[tokio::test]
337 async fn simple() {
338 for _ in 0..CHANNEL_TEST_ITERATIONS {
339 let (mut tx, rx) = super::channel();
340
341 spawn(async move {
342 tx.send(()).await.expect("Should send message");
343 });
344
345 timeout(TEST_TIMEOUT, async move {
346 assert_rx(rx).await;
347 })
348 .await
349 .expect("test timeout");
350 }
351 }
352
353 #[tokio::test]
354 async fn simple_drop() {
355 for _ in 0..CHANNEL_TEST_ITERATIONS {
356 let (tx, rx) = super::channel();
357
358 spawn(async move {
359 drop(tx);
360 });
361
362 timeout(TEST_TIMEOUT, async move {
363 assert_rx(rx).await;
364 })
365 .await
366 .expect("test timeout");
367 }
368 }
369
370 #[tokio::test]
371 async fn multi_receiver() {
372 for _ in 0..CHANNEL_TEST_ITERATIONS {
373 let (tx, rx) = super::channel();
374
375 let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
376 let rx2 = rx.clone();
377
378 spawn(async move {
379 assert_rx(rx2).await;
380 })
381 });
382
383 spawn(async move {
384 drop(tx);
385 });
386
387 let rx_handle = spawn(async move {
388 for handle in handles {
389 handle.await.expect("Assertion failure");
390 }
391 });
392
393 timeout(TEST_TIMEOUT, rx_handle)
394 .await
395 .expect("test timeout")
396 .expect("join error");
397 }
398 }
399}
400
401#[cfg(test)]
402mod async_std_tests {
403 use std::time::Duration;
404
405 use async_std::{future::timeout, task::spawn};
406
407 use crate::{
408 sink::Sink,
409 stream::Stream,
410 test::{CHANNEL_TEST_ITERATIONS, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
411 };
412
413 use super::Receiver;
414
415 async fn assert_rx(mut rx: Receiver) {
416 if let Err(_e) = timeout(Duration::from_millis(100), rx.recv()).await {
417 panic!("Timeout waiting for barrier");
418 }
419 }
420
421 #[async_std::test]
422 async fn simple() {
423 for _ in 0..CHANNEL_TEST_ITERATIONS {
424 let (mut tx, rx) = super::channel();
425
426 spawn(async move {
427 tx.send(()).await.expect("Should send message");
428 });
429
430 timeout(TEST_TIMEOUT, async move {
431 assert_rx(rx).await;
432 })
433 .await
434 .expect("test timeout");
435 }
436 }
437
438 #[async_std::test]
439 async fn simple_drop() {
440 for _ in 0..CHANNEL_TEST_ITERATIONS {
443 let (tx, rx) = super::channel();
444
445 spawn(async move {
446 drop(tx);
447 });
448
449 timeout(TEST_TIMEOUT, async move {
450 assert_rx(rx).await;
451 })
452 .await
453 .expect("test timeout");
454 }
455 }
456
457 #[async_std::test]
458 async fn multi_receiver() {
459 for _ in 0..CHANNEL_TEST_ITERATIONS {
460 let (tx, rx) = super::channel();
461
462 let handles = (0..CHANNEL_TEST_RECEIVERS).map(|_| {
463 let rx2 = rx.clone();
464
465 spawn(async move {
466 assert_rx(rx2).await;
467 })
468 });
469
470 drop(tx);
471
472 timeout(TEST_TIMEOUT, async move {
473 for handle in handles {
474 handle.await;
475 }
476 })
477 .await
478 .expect("test timeout");
479 }
480 }
481}