postage/stream/
chain.rs

1use std::pin::Pin;
2
3use atomic::{Atomic, Ordering};
4
5use crate::stream::{PollRecv, Stream};
6use crate::Context;
7use pin_project::pin_project;
8
9#[derive(Copy, Clone)]
10enum State {
11    Left,
12    Right,
13    Closed,
14}
15
16#[pin_project]
17pub struct ChainStream<Left, Right> {
18    state: Atomic<State>,
19    #[pin]
20    left: Left,
21    #[pin]
22    right: Right,
23}
24
25impl<Left, Right> ChainStream<Left, Right>
26where
27    Left: Stream,
28    Right: Stream<Item = Left::Item>,
29{
30    pub fn new(left: Left, right: Right) -> Self {
31        Self {
32            state: Atomic::new(State::Left),
33            left,
34            right,
35        }
36    }
37}
38
39impl<Left, Right> Stream for ChainStream<Left, Right>
40where
41    Left: Stream,
42    Right: Stream<Item = Left::Item>,
43{
44    type Item = Left::Item;
45
46    fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollRecv<Self::Item> {
47        let this = self.project();
48        let mut state = this.state.load(Ordering::Acquire);
49
50        if let State::Left = state {
51            match this.left.poll_recv(cx) {
52                PollRecv::Ready(v) => return PollRecv::Ready(v),
53                PollRecv::Pending => return PollRecv::Pending,
54                PollRecv::Closed => {
55                    this.state.store(State::Right, Ordering::Release);
56                    state = State::Right;
57                }
58            }
59        }
60
61        if let State::Right = state {
62            match this.right.poll_recv(cx) {
63                PollRecv::Ready(v) => return PollRecv::Ready(v),
64                PollRecv::Pending => return PollRecv::Pending,
65                PollRecv::Closed => {
66                    this.state.store(State::Closed, Ordering::Release);
67                    return PollRecv::Closed;
68                }
69            }
70        }
71
72        if let State::Closed = state {
73            return PollRecv::Closed;
74        }
75
76        unreachable!();
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use std::pin::Pin;
83
84    use crate::test::stream::*;
85    use crate::{
86        stream::{PollRecv, Stream},
87        Context,
88    };
89
90    use super::ChainStream;
91
92    #[test]
93    fn chain() {
94        let left = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(2)]);
95        let right = from_poll_iter(vec![PollRecv::Ready(3)]);
96        let mut find = ChainStream::new(left, right);
97
98        let mut cx = Context::empty();
99
100        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
101        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
102        assert_eq!(PollRecv::Ready(3), Pin::new(&mut find).poll_recv(&mut cx));
103        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
104    }
105
106    #[test]
107    fn waits_for_right() {
108        let left = from_poll_iter(vec![PollRecv::Pending]);
109        let right = from_poll_iter(vec![PollRecv::Ready(1)]);
110        let mut find = ChainStream::new(left, right);
111
112        let mut cx = Context::empty();
113
114        assert_eq!(PollRecv::Pending, Pin::new(&mut find).poll_recv(&mut cx));
115        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
116        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
117    }
118
119    #[test]
120    fn ignores_after_close() {
121        let left = from_poll_iter(vec![PollRecv::Closed, PollRecv::Ready(1)]);
122        let right = from_poll_iter(vec![PollRecv::Closed, PollRecv::Ready(2)]);
123        let mut find = ChainStream::new(left, right);
124
125        let mut cx = Context::empty();
126
127        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
128        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
129        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
130    }
131}