1use super::SendSyncMessage;
8use std::{
9 fmt,
10 ops::{Deref, DerefMut},
11 sync::atomic::{AtomicUsize, Ordering},
12};
13
14use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
15use static_assertions::{assert_impl_all, assert_not_impl_all};
16
17use crate::{
18 sink::{PollSend, Sink},
19 stream::{PollRecv, Stream},
20 sync::{shared, ReceiverShared, SenderShared},
21};
22
23pub fn channel<T: Clone + Default>() -> (Sender<T>, Receiver<T>) {
25 channel_with(T::default())
26}
27
28pub fn channel_with<T: Clone>(value: T) -> (Sender<T>, Receiver<T>) {
30 #[cfg(feature = "debug")]
31 log::error!("Creating watch channel");
32
33 let (tx_shared, rx_shared) = shared(StateExtension::new(value));
34 let sender = Sender { shared: tx_shared };
35
36 let receiver = Receiver {
37 shared: rx_shared,
38 generation: AtomicUsize::new(0),
39 };
40
41 (sender, receiver)
42}
43
44pub fn channel_with_option<T: Clone>() -> (Sender<Option<T>>, Receiver<Option<T>>) {
48 channel::<Option<T>>()
49}
50
51pub struct Sender<T> {
53 pub(in crate::channels::watch) shared: SenderShared<StateExtension<T>>,
54}
55
56assert_impl_all!(Sender<SendSyncMessage>: Send, Sync, fmt::Debug);
57assert_not_impl_all!(Sender<SendSyncMessage>: Clone);
58
59impl<T> Sink for Sender<T> {
60 type Item = T;
61
62 fn poll_send(
63 self: std::pin::Pin<&mut Self>,
64 _cx: &mut crate::Context<'_>,
65 value: Self::Item,
66 ) -> PollSend<Self::Item> {
67 if self.shared.is_closed() {
68 return PollSend::Rejected(value);
69 }
70
71 self.shared.extension().push(value);
72 self.shared.notify_receivers();
73
74 PollSend::Ready
75 }
76}
77
78#[allow(clippy::needless_lifetimes)]
79impl<T> Sender<T> {
80 pub fn borrow_mut<'s>(&'s mut self) -> RefMut<'s, T> {
84 let extension = self.shared.extension();
85 let lock = extension.value.write();
86
87 RefMut {
88 lock,
89 shared: self.shared.clone(),
90 }
91 }
92
93 pub fn subscribe(&mut self) -> Receiver<T> {
95 Receiver {
96 shared: self.shared.clone_receiver(),
97 generation: AtomicUsize::new(0),
98 }
99 }
100
101 pub fn borrow<'s>(&'s mut self) -> Ref<'s, T> {
103 let extension = self.shared.extension();
104 let lock = extension.value.read();
105
106 Ref { lock }
107 }
108}
109
110impl<T> fmt::Debug for Sender<T> {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 f.debug_struct("Sender").finish()
113 }
114}
115
116#[cfg(feature = "futures-traits")]
117mod impl_futures {
118 use std::task::Poll;
119
120 use crate::sink::SendError;
121
122 impl<T> futures::sink::Sink<T> for super::Sender<T> {
123 type Error = SendError<T>;
124
125 fn poll_ready(
126 self: std::pin::Pin<&mut Self>,
127 _cx: &mut std::task::Context<'_>,
128 ) -> Poll<Result<(), Self::Error>> {
129 Poll::Ready(Ok(()))
130 }
131
132 fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
133 if self.shared.is_closed() {
134 return Err(SendError(item));
135 }
136
137 self.shared.extension().push(item);
138 self.shared.notify_receivers();
139
140 Ok(())
141 }
142
143 fn poll_flush(
144 self: std::pin::Pin<&mut Self>,
145 _cx: &mut std::task::Context<'_>,
146 ) -> Poll<Result<(), Self::Error>> {
147 Poll::Ready(Ok(()))
148 }
149
150 fn poll_close(
151 self: std::pin::Pin<&mut Self>,
152 _cx: &mut std::task::Context<'_>,
153 ) -> Poll<Result<(), Self::Error>> {
154 Poll::Ready(Ok(()))
155 }
156 }
157}
158
159pub struct Receiver<T> {
163 pub(in crate::channels::watch) shared: ReceiverShared<StateExtension<T>>,
164 pub(in crate::channels::watch) generation: AtomicUsize,
165}
166
167assert_impl_all!(Receiver<SendSyncMessage>: Clone, Send, Sync, fmt::Debug);
168
169impl<T> Stream for Receiver<T>
170where
171 T: Clone,
172{
173 type Item = T;
174
175 fn poll_recv(
176 self: std::pin::Pin<&mut Self>,
177 cx: &mut crate::Context<'_>,
178 ) -> PollRecv<Self::Item> {
179 loop {
180 let guard = self.shared.send_guard();
181
182 match self.try_recv_internal() {
183 TryRecv::Pending => {
184 if self.shared.is_closed() {
185 return PollRecv::Closed;
186 }
187
188 self.shared.subscribe_send(cx);
189
190 if guard.is_expired() {
191 continue;
192 }
193
194 return PollRecv::Pending;
195 }
196 TryRecv::Ready(v) => return PollRecv::Ready(v),
197 }
198 }
199 }
200}
201
202impl<T> Receiver<T>
203where
204 T: Clone,
205{
206 fn try_recv_internal(&self) -> TryRecv<T> {
207 let state = self.shared.extension();
208 if self.generation.load(std::sync::atomic::Ordering::SeqCst)
209 > state.generation(Ordering::SeqCst)
210 {
211 return TryRecv::Pending;
212 }
213
214 let borrow = self.shared.extension().value.read();
215 let stored_generation = self.shared.extension().generation(Ordering::SeqCst);
216 self.generation
217 .store(stored_generation + 1, Ordering::Release);
218 TryRecv::Ready(borrow.clone())
219 }
220}
221
222enum TryRecv<T> {
223 Pending,
224 Ready(T),
225}
226
227impl<T> Clone for Receiver<T> {
228 fn clone(&self) -> Self {
229 Self {
230 shared: self.shared.clone(),
231 generation: AtomicUsize::new(0),
232 }
233 }
234}
235
236impl<T> fmt::Debug for Receiver<T> {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 f.debug_struct("Receiver").finish()
239 }
240}
241
242pub struct RefMut<'t, T> {
245 lock: RwLockWriteGuard<'t, T>,
246 shared: SenderShared<StateExtension<T>>,
247}
248
249impl<'t, T> DerefMut for RefMut<'t, T> {
250 fn deref_mut(&mut self) -> &mut Self::Target {
251 &mut *self.lock
252 }
253}
254
255impl<'t, T> Deref for RefMut<'t, T> {
256 type Target = T;
257
258 fn deref(&self) -> &Self::Target {
259 &*self.lock
260 }
261}
262
263impl<'t, T> Drop for RefMut<'t, T> {
264 fn drop(&mut self) {
265 self.shared.extension().increment();
266 self.shared.notify_receivers();
267 }
268}
269
270pub struct Ref<'t, T> {
272 lock: RwLockReadGuard<'t, T>,
273}
274
275impl<'t, T> Deref for Ref<'t, T> {
276 type Target = T;
277
278 fn deref(&self) -> &Self::Target {
279 &*self.lock
280 }
281}
282
283impl<T> Receiver<T> {
284 pub fn borrow(&self) -> Ref<'_, T> {
286 let lock = self.shared.extension().value.read();
287 Ref { lock }
288 }
289}
290
291struct StateExtension<T> {
292 generation: AtomicUsize,
293 value: RwLock<T>,
294}
295
296impl<T> StateExtension<T> {
297 pub fn new(value: T) -> Self {
298 Self {
299 generation: AtomicUsize::new(0),
300 value: RwLock::new(value),
301 }
302 }
303
304 pub fn push(&self, value: T) {
305 let mut lock = self.value.write();
306 *lock = value;
307
308 self.generation.fetch_add(1, Ordering::SeqCst);
309 drop(lock);
310 }
311
312 pub fn increment(&self) {
313 self.generation.fetch_add(1, Ordering::SeqCst);
314 }
315
316 pub fn generation(&self, ordering: Ordering) -> usize {
317 self.generation.load(ordering)
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use std::{pin::Pin, task::Context};
324
325 use super::channel;
326 use crate::{
327 sink::{PollSend, Sink},
328 stream::{PollRecv, Stream},
329 test::{noop_context, panic_context},
330 };
331 use futures_test::task::new_count_waker;
332
333 #[derive(Clone, Debug, PartialEq, Eq)]
334 struct State(usize);
335
336 impl Default for State {
337 fn default() -> Self {
338 State(0)
339 }
340 }
341
342 #[test]
343 fn send_accepted() {
344 let mut cx = noop_context();
345 let (mut tx, _rx) = channel();
346
347 assert_eq!(
348 PollSend::Ready,
349 Pin::new(&mut tx).poll_send(&mut cx, State(1))
350 );
351 assert_eq!(
352 PollSend::Ready,
353 Pin::new(&mut tx).poll_send(&mut cx, State(2))
354 );
355 }
356
357 #[test]
358 fn send_recv() {
359 let mut cx = noop_context();
360 let (mut tx, mut rx) = channel();
361
362 assert_eq!(
363 PollSend::Ready,
364 Pin::new(&mut tx).poll_send(&mut cx, State(1))
365 );
366
367 assert_eq!(
368 PollRecv::Ready(State(1)),
369 Pin::new(&mut rx).poll_recv(&mut cx)
370 );
371 assert_eq!(PollRecv::Pending, Pin::new(&mut rx).poll_recv(&mut cx));
372 }
373
374 #[test]
375 fn recv_default() {
376 let mut cx = panic_context();
377 let (_tx, mut rx) = channel();
378
379 assert_eq!(
380 PollRecv::Ready(State(0)),
381 Pin::new(&mut rx).poll_recv(&mut cx)
382 );
383 assert_eq!(
384 PollRecv::Pending,
385 Pin::new(&mut rx).poll_recv(&mut noop_context())
386 );
387 }
388
389 #[test]
390 fn borrow_default() {
391 let (_tx, rx) = channel();
392
393 assert_eq!(&State(0), &*rx.borrow());
394 }
395
396 #[test]
397 fn borrow_sent() {
398 let mut cx = panic_context();
399 let (mut tx, rx) = channel();
400
401 assert_eq!(
402 PollSend::Ready,
403 Pin::new(&mut tx).poll_send(&mut cx, State(1))
404 );
405
406 assert_eq!(&State(1), &*rx.borrow());
407 }
408
409 #[test]
410 fn borrow_mut_notifies() {
411 let mut cx = noop_context();
412 let (mut tx, mut rx) = channel();
413
414 assert_eq!(
415 PollRecv::Ready(State(0)),
416 Pin::new(&mut rx).poll_recv(&mut cx)
417 );
418
419 let (w1, w1_count) = new_count_waker();
420 let w1_context = Context::from_waker(&w1);
421 assert_eq!(
422 PollRecv::Pending,
423 Pin::new(&mut rx).poll_recv(&mut w1_context.into())
424 );
425
426 *tx.borrow_mut() = State(1);
427 assert_eq!(1, w1_count.get());
428 assert_eq!(&State(1), &*rx.borrow());
429
430 assert_eq!(
431 PollRecv::Ready(State(1)),
432 Pin::new(&mut rx).poll_recv(&mut cx)
433 );
434 }
435
436 #[test]
437 fn sender_disconnect() {
438 let mut cx = noop_context();
439 let (mut tx, mut rx) = channel();
440 let mut rx2 = rx.clone();
441
442 assert_eq!(
443 PollSend::Ready,
444 Pin::new(&mut tx).poll_send(&mut cx, State(1))
445 );
446
447 drop(tx);
448
449 assert_eq!(
450 PollRecv::Ready(State(1)),
451 Pin::new(&mut rx).poll_recv(&mut cx)
452 );
453
454 assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
455
456 assert_eq!(
457 PollRecv::Ready(State(1)),
458 Pin::new(&mut rx2).poll_recv(&mut cx)
459 );
460
461 assert_eq!(PollRecv::Closed, Pin::new(&mut rx2).poll_recv(&mut cx));
462 }
463
464 #[test]
465 fn receiver_disconnect() {
466 let mut cx = noop_context();
467 let (mut tx, rx) = channel();
468
469 drop(rx);
470
471 assert_eq!(
472 PollSend::Rejected(State(1)),
473 Pin::new(&mut tx).poll_send(&mut cx, State(1))
474 );
475 }
476
477 #[test]
478 fn send_then_receiver_disconnect() {
479 let mut cx = noop_context();
480 let (mut tx, rx) = channel();
481
482 assert_eq!(
483 PollSend::Ready,
484 Pin::new(&mut tx).poll_send(&mut cx, State(1))
485 );
486
487 drop(rx);
488
489 assert_eq!(
490 PollSend::Rejected(State(2)),
491 Pin::new(&mut tx).poll_send(&mut cx, State(2))
492 );
493 }
494
495 #[test]
496 fn wake_receiver() {
497 let mut cx = panic_context();
498 let (mut tx, mut rx) = channel();
499
500 let (w1, w1_count) = new_count_waker();
501 let w1_context = Context::from_waker(&w1);
502
503 assert_eq!(
504 PollRecv::Ready(State(0)),
505 Pin::new(&mut rx).poll_recv(&mut cx)
506 );
507 assert_eq!(
508 PollRecv::Pending,
509 Pin::new(&mut rx).poll_recv(&mut w1_context.into())
510 );
511
512 assert_eq!(0, w1_count.get());
513
514 assert_eq!(
515 PollSend::Ready,
516 Pin::new(&mut tx).poll_send(&mut cx, State(1))
517 );
518
519 assert_eq!(1, w1_count.get());
520
521 assert_eq!(
522 PollSend::Ready,
523 Pin::new(&mut tx).poll_send(&mut cx, State(2))
524 );
525
526 assert_eq!(1, w1_count.get());
527 }
528
529 #[test]
530 fn wake_receiver_on_disconnect() {
531 let (tx, mut rx) = channel::<State>();
532
533 let (w1, w1_count) = new_count_waker();
534 let w1_context = Context::from_waker(&w1);
535
536 assert_eq!(
537 PollRecv::Ready(State(0)),
538 Pin::new(&mut rx).poll_recv(&mut panic_context())
539 );
540 assert_eq!(
541 PollRecv::Pending,
542 Pin::new(&mut rx).poll_recv(&mut w1_context.into())
543 );
544
545 assert_eq!(0, w1_count.get());
546
547 drop(tx);
548
549 assert_eq!(1, w1_count.get());
550 }
551
552 #[async_std::test]
553 async fn subscribe_default() {
554 let mut cx = panic_context();
555 let (mut tx, _rx) = channel();
556 let mut rx2 = tx.subscribe();
557
558 assert_eq!(
559 PollRecv::Ready(State(0)),
560 Pin::new(&mut rx2).poll_recv(&mut cx)
561 );
562 assert_eq!(
563 PollRecv::Pending,
564 Pin::new(&mut rx2).poll_recv(&mut noop_context())
565 );
566 }
567
568 #[async_std::test]
569 async fn subscribe_both_receive_value() {
570 let mut cx = panic_context();
571 let (mut tx, mut rx) = channel();
572 let mut rx2 = tx.subscribe();
573
574 assert_eq!(
575 PollRecv::Ready(State(0)),
576 Pin::new(&mut rx).poll_recv(&mut cx)
577 );
578 assert_eq!(
579 PollRecv::Pending,
580 Pin::new(&mut rx).poll_recv(&mut noop_context())
581 );
582
583 assert_eq!(
584 PollRecv::Ready(State(0)),
585 Pin::new(&mut rx2).poll_recv(&mut cx)
586 );
587 assert_eq!(
588 PollRecv::Pending,
589 Pin::new(&mut rx2).poll_recv(&mut noop_context())
590 );
591 }
592}
593
594#[cfg(test)]
595mod tokio_tests {
596 use tokio::{spawn, time::timeout};
597
598 use crate::{
599 sink::Sink,
600 stream::Stream,
601 test::{Channel, Channels, Message, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
602 };
603
604 #[tokio::test]
605 async fn simple() {
606 let (mut tx, mut rx) = super::channel();
607
608 tokio::task::spawn(async move {
609 let mut iter = Message::new_iter(0);
610 iter.next();
612 for message in iter {
613 tx.send(message).await.expect("send failed");
614 }
615 });
616
617 timeout(TEST_TIMEOUT, async move {
618 let mut channel = Channel::new(0).allow_skips();
619 while let Some(message) = rx.recv().await {
620 channel.assert_message(&message);
621 }
622 })
623 .await
624 .expect("test timeout");
625 }
626
627 #[tokio::test]
628 async fn send_borrow_mut() {
629 let (mut tx, mut rx) = super::channel();
630
631 tokio::task::spawn(async move {
632 let mut iter = Message::new_iter(0);
633 iter.next();
635 for message in iter {
636 *tx.borrow_mut() = message;
637 }
638 });
639
640 timeout(TEST_TIMEOUT, async move {
641 let mut channel = Channel::new(0).allow_skips();
642 while let Some(message) = rx.recv().await {
643 channel.assert_message(&message);
644 }
645 })
646 .await
647 .expect("test timeout");
648 }
649
650 #[tokio::test]
651 async fn multi_receiver() {
652 let (mut tx, rx) = super::channel();
653
654 tokio::task::spawn(async move {
655 let mut iter = Message::new_iter(0);
656 iter.next();
658 for message in iter {
659 tx.send(message).await.expect("send failed");
660 }
661 });
662
663 let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
664 let mut rx2 = rx.clone();
665 let mut channels = Channels::new(CHANNEL_TEST_RECEIVERS).allow_skips();
666
667 spawn(async move {
668 while let Some(message) = rx2.recv().await {
669 channels.assert_message(&message);
670 }
671 })
672 });
673
674 timeout(TEST_TIMEOUT, async move {
675 for handle in handles {
676 handle.await.expect("join failed");
677 }
678 })
679 .await
680 .expect("test timeout");
681 }
682}
683
684#[cfg(test)]
685mod async_std_tests {
686
687 use async_std::{future::timeout, task::spawn};
688
689 use crate::{
690 sink::Sink,
691 stream::Stream,
692 test::{Channel, Channels, Message, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
693 };
694
695 #[async_std::test]
696 async fn simple() {
697 let (mut tx, mut rx) = super::channel();
698
699 spawn(async move {
700 let mut iter = Message::new_iter(0);
701 iter.next();
703 for message in iter {
704 tx.send(message).await.expect("send failed");
705 }
706 });
707
708 timeout(TEST_TIMEOUT, async move {
709 let mut channel = Channel::new(0).allow_skips();
710 while let Some(message) = rx.recv().await {
711 channel.assert_message(&message);
712 }
713 })
714 .await
715 .expect("test timeout");
716 }
717
718 #[async_std::test]
719 async fn send_borrow_mut() {
720 let (mut tx, mut rx) = super::channel();
721
722 spawn(async move {
723 let mut iter = Message::new_iter(0);
724 iter.next();
726 for message in iter {
727 *tx.borrow_mut() = message;
728 }
729 });
730
731 timeout(TEST_TIMEOUT, async move {
732 let mut channel = Channel::new(0).allow_skips();
733 while let Some(message) = rx.recv().await {
734 channel.assert_message(&message);
735 }
736 })
737 .await
738 .expect("test timeout");
739 }
740
741 #[tokio::test]
742 async fn multi_receiver() {
743 let (mut tx, rx) = super::channel();
744
745 tokio::task::spawn(async move {
746 let mut iter = Message::new_iter(0);
747 iter.next();
749 for message in iter {
750 tx.send(message).await.expect("send failed");
751 }
752 });
753
754 let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
755 let mut rx2 = rx.clone();
756 let mut channels = Channels::new(CHANNEL_TEST_RECEIVERS).allow_skips();
757
758 spawn(async move {
759 while let Some(message) = rx2.recv().await {
760 channels.assert_message(&message);
761 }
762 })
763 });
764
765 timeout(TEST_TIMEOUT, async move {
766 for handle in handles {
767 handle.await;
768 }
769 })
770 .await
771 .expect("test timeout");
772 }
773}