postage/channels/
watch.rs

1//! A state distribution channel.  The internal state can be borrowed or cloned, but receivers do not observe every value.
2//!  
3//! When the channel is created, the receiver will immediately observe `T::default()`.  Cloned receivers will immediately observe the latest stored value.
4//!
5//! Senders can mutably borrow the contained value (which notifies receivers on release).  Receivers can immutably borrow the contained value.
6
7use super::SendSyncMessage;
8use std::{
9    fmt,
10    ops::{Deref, DerefMut},
11    sync::atomic::{AtomicUsize, Ordering},
12};
13
14use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
15use static_assertions::{assert_impl_all, assert_not_impl_all};
16
17use crate::{
18    sink::{PollSend, Sink},
19    stream::{PollRecv, Stream},
20    sync::{shared, ReceiverShared, SenderShared},
21};
22
23/// Constructs a new watch channel pair, filled with `T::default()`.
24pub fn channel<T: Clone + Default>() -> (Sender<T>, Receiver<T>) {
25    channel_with(T::default())
26}
27
28/// Constructs a new watch channel pair, filled with the provided value
29pub fn channel_with<T: Clone>(value: T) -> (Sender<T>, Receiver<T>) {
30    #[cfg(feature = "debug")]
31    log::error!("Creating watch channel");
32
33    let (tx_shared, rx_shared) = shared(StateExtension::new(value));
34    let sender = Sender { shared: tx_shared };
35
36    let receiver = Receiver {
37        shared: rx_shared,
38        generation: AtomicUsize::new(0),
39    };
40
41    (sender, receiver)
42}
43
44/// Constructs a pair of channel endpoints that store Option<T>
45///
46/// This is helpful if T does not implement Default, and you don't have an initial value.
47pub fn channel_with_option<T: Clone>() -> (Sender<Option<T>>, Receiver<Option<T>>) {
48    channel::<Option<T>>()
49}
50
51/// The sender half of a watch channel.  The stored value can be updated with the postage::Sink trait.
52pub struct Sender<T> {
53    pub(in crate::channels::watch) shared: SenderShared<StateExtension<T>>,
54}
55
56assert_impl_all!(Sender<SendSyncMessage>: Send, Sync, fmt::Debug);
57assert_not_impl_all!(Sender<SendSyncMessage>: Clone);
58
59impl<T> Sink for Sender<T> {
60    type Item = T;
61
62    fn poll_send(
63        self: std::pin::Pin<&mut Self>,
64        _cx: &mut crate::Context<'_>,
65        value: Self::Item,
66    ) -> PollSend<Self::Item> {
67        if self.shared.is_closed() {
68            return PollSend::Rejected(value);
69        }
70
71        self.shared.extension().push(value);
72        self.shared.notify_receivers();
73
74        PollSend::Ready
75    }
76}
77
78#[allow(clippy::needless_lifetimes)]
79impl<T> Sender<T> {
80    /// Mutably borrows the contained value, blocking the channel while the borrow is held.
81    ///
82    /// After the borrow is released, receivers will be notified of a new value.
83    pub fn borrow_mut<'s>(&'s mut self) -> RefMut<'s, T> {
84        let extension = self.shared.extension();
85        let lock = extension.value.write();
86
87        RefMut {
88            lock,
89            shared: self.shared.clone(),
90        }
91    }
92
93    /// Creates a new Receiver that listens to this channel.
94    pub fn subscribe(&mut self) -> Receiver<T> {
95        Receiver {
96            shared: self.shared.clone_receiver(),
97            generation: AtomicUsize::new(0),
98        }
99    }
100
101    /// Immutably borrows the contained value, blocking the channel while the borrow is held.
102    pub fn borrow<'s>(&'s mut self) -> Ref<'s, T> {
103        let extension = self.shared.extension();
104        let lock = extension.value.read();
105
106        Ref { lock }
107    }
108}
109
110impl<T> fmt::Debug for Sender<T> {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        f.debug_struct("Sender").finish()
113    }
114}
115
116#[cfg(feature = "futures-traits")]
117mod impl_futures {
118    use std::task::Poll;
119
120    use crate::sink::SendError;
121
122    impl<T> futures::sink::Sink<T> for super::Sender<T> {
123        type Error = SendError<T>;
124
125        fn poll_ready(
126            self: std::pin::Pin<&mut Self>,
127            _cx: &mut std::task::Context<'_>,
128        ) -> Poll<Result<(), Self::Error>> {
129            Poll::Ready(Ok(()))
130        }
131
132        fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
133            if self.shared.is_closed() {
134                return Err(SendError(item));
135            }
136
137            self.shared.extension().push(item);
138            self.shared.notify_receivers();
139
140            Ok(())
141        }
142
143        fn poll_flush(
144            self: std::pin::Pin<&mut Self>,
145            _cx: &mut std::task::Context<'_>,
146        ) -> Poll<Result<(), Self::Error>> {
147            Poll::Ready(Ok(()))
148        }
149
150        fn poll_close(
151            self: std::pin::Pin<&mut Self>,
152            _cx: &mut std::task::Context<'_>,
153        ) -> Poll<Result<(), Self::Error>> {
154            Poll::Ready(Ok(()))
155        }
156    }
157}
158
159/// The receiver half of a watch channel.  Can recieve state updates with the postage::Sink trait.
160///
161/// The reciever will be woken when new values arive, but is not guaranteed to recieve every message.
162pub struct Receiver<T> {
163    pub(in crate::channels::watch) shared: ReceiverShared<StateExtension<T>>,
164    pub(in crate::channels::watch) generation: AtomicUsize,
165}
166
167assert_impl_all!(Receiver<SendSyncMessage>: Clone, Send, Sync, fmt::Debug);
168
169impl<T> Stream for Receiver<T>
170where
171    T: Clone,
172{
173    type Item = T;
174
175    fn poll_recv(
176        self: std::pin::Pin<&mut Self>,
177        cx: &mut crate::Context<'_>,
178    ) -> PollRecv<Self::Item> {
179        loop {
180            let guard = self.shared.send_guard();
181
182            match self.try_recv_internal() {
183                TryRecv::Pending => {
184                    if self.shared.is_closed() {
185                        return PollRecv::Closed;
186                    }
187
188                    self.shared.subscribe_send(cx);
189
190                    if guard.is_expired() {
191                        continue;
192                    }
193
194                    return PollRecv::Pending;
195                }
196                TryRecv::Ready(v) => return PollRecv::Ready(v),
197            }
198        }
199    }
200}
201
202impl<T> Receiver<T>
203where
204    T: Clone,
205{
206    fn try_recv_internal(&self) -> TryRecv<T> {
207        let state = self.shared.extension();
208        if self.generation.load(std::sync::atomic::Ordering::SeqCst)
209            > state.generation(Ordering::SeqCst)
210        {
211            return TryRecv::Pending;
212        }
213
214        let borrow = self.shared.extension().value.read();
215        let stored_generation = self.shared.extension().generation(Ordering::SeqCst);
216        self.generation
217            .store(stored_generation + 1, Ordering::Release);
218        TryRecv::Ready(borrow.clone())
219    }
220}
221
222enum TryRecv<T> {
223    Pending,
224    Ready(T),
225}
226
227impl<T> Clone for Receiver<T> {
228    fn clone(&self) -> Self {
229        Self {
230            shared: self.shared.clone(),
231            generation: AtomicUsize::new(0),
232        }
233    }
234}
235
236impl<T> fmt::Debug for Receiver<T> {
237    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238        f.debug_struct("Receiver").finish()
239    }
240}
241
242/// A mutable reference to the value contained in the channel.
243/// Receivers are notified when the borrow is released.
244pub struct RefMut<'t, T> {
245    lock: RwLockWriteGuard<'t, T>,
246    shared: SenderShared<StateExtension<T>>,
247}
248
249impl<'t, T> DerefMut for RefMut<'t, T> {
250    fn deref_mut(&mut self) -> &mut Self::Target {
251        &mut *self.lock
252    }
253}
254
255impl<'t, T> Deref for RefMut<'t, T> {
256    type Target = T;
257
258    fn deref(&self) -> &Self::Target {
259        &*self.lock
260    }
261}
262
263impl<'t, T> Drop for RefMut<'t, T> {
264    fn drop(&mut self) {
265        self.shared.extension().increment();
266        self.shared.notify_receivers();
267    }
268}
269
270/// An immutable reference to the value contained in the channel.
271pub struct Ref<'t, T> {
272    lock: RwLockReadGuard<'t, T>,
273}
274
275impl<'t, T> Deref for Ref<'t, T> {
276    type Target = T;
277
278    fn deref(&self) -> &Self::Target {
279        &*self.lock
280    }
281}
282
283impl<T> Receiver<T> {
284    /// Borrows the value in the channel, blocking the channel while the value is held.
285    pub fn borrow(&self) -> Ref<'_, T> {
286        let lock = self.shared.extension().value.read();
287        Ref { lock }
288    }
289}
290
291struct StateExtension<T> {
292    generation: AtomicUsize,
293    value: RwLock<T>,
294}
295
296impl<T> StateExtension<T> {
297    pub fn new(value: T) -> Self {
298        Self {
299            generation: AtomicUsize::new(0),
300            value: RwLock::new(value),
301        }
302    }
303
304    pub fn push(&self, value: T) {
305        let mut lock = self.value.write();
306        *lock = value;
307
308        self.generation.fetch_add(1, Ordering::SeqCst);
309        drop(lock);
310    }
311
312    pub fn increment(&self) {
313        self.generation.fetch_add(1, Ordering::SeqCst);
314    }
315
316    pub fn generation(&self, ordering: Ordering) -> usize {
317        self.generation.load(ordering)
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use std::{pin::Pin, task::Context};
324
325    use super::channel;
326    use crate::{
327        sink::{PollSend, Sink},
328        stream::{PollRecv, Stream},
329        test::{noop_context, panic_context},
330    };
331    use futures_test::task::new_count_waker;
332
333    #[derive(Clone, Debug, PartialEq, Eq)]
334    struct State(usize);
335
336    impl Default for State {
337        fn default() -> Self {
338            State(0)
339        }
340    }
341
342    #[test]
343    fn send_accepted() {
344        let mut cx = noop_context();
345        let (mut tx, _rx) = channel();
346
347        assert_eq!(
348            PollSend::Ready,
349            Pin::new(&mut tx).poll_send(&mut cx, State(1))
350        );
351        assert_eq!(
352            PollSend::Ready,
353            Pin::new(&mut tx).poll_send(&mut cx, State(2))
354        );
355    }
356
357    #[test]
358    fn send_recv() {
359        let mut cx = noop_context();
360        let (mut tx, mut rx) = channel();
361
362        assert_eq!(
363            PollSend::Ready,
364            Pin::new(&mut tx).poll_send(&mut cx, State(1))
365        );
366
367        assert_eq!(
368            PollRecv::Ready(State(1)),
369            Pin::new(&mut rx).poll_recv(&mut cx)
370        );
371        assert_eq!(PollRecv::Pending, Pin::new(&mut rx).poll_recv(&mut cx));
372    }
373
374    #[test]
375    fn recv_default() {
376        let mut cx = panic_context();
377        let (_tx, mut rx) = channel();
378
379        assert_eq!(
380            PollRecv::Ready(State(0)),
381            Pin::new(&mut rx).poll_recv(&mut cx)
382        );
383        assert_eq!(
384            PollRecv::Pending,
385            Pin::new(&mut rx).poll_recv(&mut noop_context())
386        );
387    }
388
389    #[test]
390    fn borrow_default() {
391        let (_tx, rx) = channel();
392
393        assert_eq!(&State(0), &*rx.borrow());
394    }
395
396    #[test]
397    fn borrow_sent() {
398        let mut cx = panic_context();
399        let (mut tx, rx) = channel();
400
401        assert_eq!(
402            PollSend::Ready,
403            Pin::new(&mut tx).poll_send(&mut cx, State(1))
404        );
405
406        assert_eq!(&State(1), &*rx.borrow());
407    }
408
409    #[test]
410    fn borrow_mut_notifies() {
411        let mut cx = noop_context();
412        let (mut tx, mut rx) = channel();
413
414        assert_eq!(
415            PollRecv::Ready(State(0)),
416            Pin::new(&mut rx).poll_recv(&mut cx)
417        );
418
419        let (w1, w1_count) = new_count_waker();
420        let w1_context = Context::from_waker(&w1);
421        assert_eq!(
422            PollRecv::Pending,
423            Pin::new(&mut rx).poll_recv(&mut w1_context.into())
424        );
425
426        *tx.borrow_mut() = State(1);
427        assert_eq!(1, w1_count.get());
428        assert_eq!(&State(1), &*rx.borrow());
429
430        assert_eq!(
431            PollRecv::Ready(State(1)),
432            Pin::new(&mut rx).poll_recv(&mut cx)
433        );
434    }
435
436    #[test]
437    fn sender_disconnect() {
438        let mut cx = noop_context();
439        let (mut tx, mut rx) = channel();
440        let mut rx2 = rx.clone();
441
442        assert_eq!(
443            PollSend::Ready,
444            Pin::new(&mut tx).poll_send(&mut cx, State(1))
445        );
446
447        drop(tx);
448
449        assert_eq!(
450            PollRecv::Ready(State(1)),
451            Pin::new(&mut rx).poll_recv(&mut cx)
452        );
453
454        assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
455
456        assert_eq!(
457            PollRecv::Ready(State(1)),
458            Pin::new(&mut rx2).poll_recv(&mut cx)
459        );
460
461        assert_eq!(PollRecv::Closed, Pin::new(&mut rx2).poll_recv(&mut cx));
462    }
463
464    #[test]
465    fn receiver_disconnect() {
466        let mut cx = noop_context();
467        let (mut tx, rx) = channel();
468
469        drop(rx);
470
471        assert_eq!(
472            PollSend::Rejected(State(1)),
473            Pin::new(&mut tx).poll_send(&mut cx, State(1))
474        );
475    }
476
477    #[test]
478    fn send_then_receiver_disconnect() {
479        let mut cx = noop_context();
480        let (mut tx, rx) = channel();
481
482        assert_eq!(
483            PollSend::Ready,
484            Pin::new(&mut tx).poll_send(&mut cx, State(1))
485        );
486
487        drop(rx);
488
489        assert_eq!(
490            PollSend::Rejected(State(2)),
491            Pin::new(&mut tx).poll_send(&mut cx, State(2))
492        );
493    }
494
495    #[test]
496    fn wake_receiver() {
497        let mut cx = panic_context();
498        let (mut tx, mut rx) = channel();
499
500        let (w1, w1_count) = new_count_waker();
501        let w1_context = Context::from_waker(&w1);
502
503        assert_eq!(
504            PollRecv::Ready(State(0)),
505            Pin::new(&mut rx).poll_recv(&mut cx)
506        );
507        assert_eq!(
508            PollRecv::Pending,
509            Pin::new(&mut rx).poll_recv(&mut w1_context.into())
510        );
511
512        assert_eq!(0, w1_count.get());
513
514        assert_eq!(
515            PollSend::Ready,
516            Pin::new(&mut tx).poll_send(&mut cx, State(1))
517        );
518
519        assert_eq!(1, w1_count.get());
520
521        assert_eq!(
522            PollSend::Ready,
523            Pin::new(&mut tx).poll_send(&mut cx, State(2))
524        );
525
526        assert_eq!(1, w1_count.get());
527    }
528
529    #[test]
530    fn wake_receiver_on_disconnect() {
531        let (tx, mut rx) = channel::<State>();
532
533        let (w1, w1_count) = new_count_waker();
534        let w1_context = Context::from_waker(&w1);
535
536        assert_eq!(
537            PollRecv::Ready(State(0)),
538            Pin::new(&mut rx).poll_recv(&mut panic_context())
539        );
540        assert_eq!(
541            PollRecv::Pending,
542            Pin::new(&mut rx).poll_recv(&mut w1_context.into())
543        );
544
545        assert_eq!(0, w1_count.get());
546
547        drop(tx);
548
549        assert_eq!(1, w1_count.get());
550    }
551
552    #[async_std::test]
553    async fn subscribe_default() {
554        let mut cx = panic_context();
555        let (mut tx, _rx) = channel();
556        let mut rx2 = tx.subscribe();
557
558        assert_eq!(
559            PollRecv::Ready(State(0)),
560            Pin::new(&mut rx2).poll_recv(&mut cx)
561        );
562        assert_eq!(
563            PollRecv::Pending,
564            Pin::new(&mut rx2).poll_recv(&mut noop_context())
565        );
566    }
567
568    #[async_std::test]
569    async fn subscribe_both_receive_value() {
570        let mut cx = panic_context();
571        let (mut tx, mut rx) = channel();
572        let mut rx2 = tx.subscribe();
573
574        assert_eq!(
575            PollRecv::Ready(State(0)),
576            Pin::new(&mut rx).poll_recv(&mut cx)
577        );
578        assert_eq!(
579            PollRecv::Pending,
580            Pin::new(&mut rx).poll_recv(&mut noop_context())
581        );
582
583        assert_eq!(
584            PollRecv::Ready(State(0)),
585            Pin::new(&mut rx2).poll_recv(&mut cx)
586        );
587        assert_eq!(
588            PollRecv::Pending,
589            Pin::new(&mut rx2).poll_recv(&mut noop_context())
590        );
591    }
592}
593
594#[cfg(test)]
595mod tokio_tests {
596    use tokio::{spawn, time::timeout};
597
598    use crate::{
599        sink::Sink,
600        stream::Stream,
601        test::{Channel, Channels, Message, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
602    };
603
604    #[tokio::test]
605    async fn simple() {
606        let (mut tx, mut rx) = super::channel();
607
608        tokio::task::spawn(async move {
609            let mut iter = Message::new_iter(0);
610            // skip state 0
611            iter.next();
612            for message in iter {
613                tx.send(message).await.expect("send failed");
614            }
615        });
616
617        timeout(TEST_TIMEOUT, async move {
618            let mut channel = Channel::new(0).allow_skips();
619            while let Some(message) = rx.recv().await {
620                channel.assert_message(&message);
621            }
622        })
623        .await
624        .expect("test timeout");
625    }
626
627    #[tokio::test]
628    async fn send_borrow_mut() {
629        let (mut tx, mut rx) = super::channel();
630
631        tokio::task::spawn(async move {
632            let mut iter = Message::new_iter(0);
633            // skip state 0
634            iter.next();
635            for message in iter {
636                *tx.borrow_mut() = message;
637            }
638        });
639
640        timeout(TEST_TIMEOUT, async move {
641            let mut channel = Channel::new(0).allow_skips();
642            while let Some(message) = rx.recv().await {
643                channel.assert_message(&message);
644            }
645        })
646        .await
647        .expect("test timeout");
648    }
649
650    #[tokio::test]
651    async fn multi_receiver() {
652        let (mut tx, rx) = super::channel();
653
654        tokio::task::spawn(async move {
655            let mut iter = Message::new_iter(0);
656            // skip state 0
657            iter.next();
658            for message in iter {
659                tx.send(message).await.expect("send failed");
660            }
661        });
662
663        let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
664            let mut rx2 = rx.clone();
665            let mut channels = Channels::new(CHANNEL_TEST_RECEIVERS).allow_skips();
666
667            spawn(async move {
668                while let Some(message) = rx2.recv().await {
669                    channels.assert_message(&message);
670                }
671            })
672        });
673
674        timeout(TEST_TIMEOUT, async move {
675            for handle in handles {
676                handle.await.expect("join failed");
677            }
678        })
679        .await
680        .expect("test timeout");
681    }
682}
683
684#[cfg(test)]
685mod async_std_tests {
686
687    use async_std::{future::timeout, task::spawn};
688
689    use crate::{
690        sink::Sink,
691        stream::Stream,
692        test::{Channel, Channels, Message, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
693    };
694
695    #[async_std::test]
696    async fn simple() {
697        let (mut tx, mut rx) = super::channel();
698
699        spawn(async move {
700            let mut iter = Message::new_iter(0);
701            // skip state 0
702            iter.next();
703            for message in iter {
704                tx.send(message).await.expect("send failed");
705            }
706        });
707
708        timeout(TEST_TIMEOUT, async move {
709            let mut channel = Channel::new(0).allow_skips();
710            while let Some(message) = rx.recv().await {
711                channel.assert_message(&message);
712            }
713        })
714        .await
715        .expect("test timeout");
716    }
717
718    #[async_std::test]
719    async fn send_borrow_mut() {
720        let (mut tx, mut rx) = super::channel();
721
722        spawn(async move {
723            let mut iter = Message::new_iter(0);
724            // skip state 0
725            iter.next();
726            for message in iter {
727                *tx.borrow_mut() = message;
728            }
729        });
730
731        timeout(TEST_TIMEOUT, async move {
732            let mut channel = Channel::new(0).allow_skips();
733            while let Some(message) = rx.recv().await {
734                channel.assert_message(&message);
735            }
736        })
737        .await
738        .expect("test timeout");
739    }
740
741    #[tokio::test]
742    async fn multi_receiver() {
743        let (mut tx, rx) = super::channel();
744
745        tokio::task::spawn(async move {
746            let mut iter = Message::new_iter(0);
747            // skip state 0
748            iter.next();
749            for message in iter {
750                tx.send(message).await.expect("send failed");
751            }
752        });
753
754        let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
755            let mut rx2 = rx.clone();
756            let mut channels = Channels::new(CHANNEL_TEST_RECEIVERS).allow_skips();
757
758            spawn(async move {
759                while let Some(message) = rx2.recv().await {
760                    channels.assert_message(&message);
761                }
762            })
763        });
764
765        timeout(TEST_TIMEOUT, async move {
766            for handle in handles {
767                handle.await;
768            }
769        })
770        .await
771        .expect("test timeout");
772    }
773}