postage/stream/
merge.rs

1use crate::stream::{PollRecv, Stream};
2use pin_project::pin_project;
3use std::pin::Pin;
4
5use crate::Context;
6#[derive(Copy, Clone)]
7enum State {
8    Left,
9    Right,
10}
11
12impl State {
13    pub fn swap(&self) -> Self {
14        match self {
15            Self::Left => Self::Right,
16            Self::Right => Self::Left,
17        }
18    }
19}
20
21#[pin_project]
22pub struct MergeStream<Left, Right> {
23    state: State,
24    #[pin]
25    left: Left,
26    #[pin]
27    right: Right,
28}
29
30impl<Left, Right> MergeStream<Left, Right>
31where
32    Left: Stream,
33    Right: Stream<Item = Left::Item>,
34{
35    pub fn new(left: Left, right: Right) -> Self {
36        Self {
37            state: State::Left,
38            left,
39            right,
40        }
41    }
42}
43
44impl<Left, Right> Stream for MergeStream<Left, Right>
45where
46    Left: Stream,
47    Right: Stream<Item = Left::Item>,
48{
49    type Item = Left::Item;
50
51    fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollRecv<Self::Item> {
52        let this = self.project();
53
54        let poll = match this.state {
55            State::Left => poll(this.left, this.right, cx),
56            State::Right => poll(this.right, this.left, cx),
57        };
58
59        if poll.swap() {
60            *this.state = this.state.swap();
61        }
62
63        poll.into_recv()
64    }
65}
66
67enum MergePoll<T> {
68    First(PollRecv<T>),
69    Second(PollRecv<T>),
70}
71
72impl<T> MergePoll<T> {
73    pub fn into_recv(self) -> PollRecv<T> {
74        match self {
75            MergePoll::First(p) => p,
76            MergePoll::Second(p) => p,
77        }
78    }
79
80    pub fn swap(&self) -> bool {
81        match self {
82            MergePoll::First(_) => true,
83            MergePoll::Second(PollRecv::Ready(_)) => true,
84            MergePoll::Second(PollRecv::Pending) => true,
85            MergePoll::Second(PollRecv::Closed) => false,
86        }
87    }
88}
89
90fn poll<A, B>(first: Pin<&mut A>, second: Pin<&mut B>, cx: &mut Context<'_>) -> MergePoll<A::Item>
91where
92    A: Stream,
93    B: Stream<Item = A::Item>,
94{
95    match first.poll_recv(cx) {
96        PollRecv::Ready(v) => MergePoll::First(PollRecv::Ready(v)),
97        PollRecv::Pending => MergePoll::Second(second.poll_recv(cx)),
98        PollRecv::Closed => MergePoll::Second(second.poll_recv(cx)),
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use std::pin::Pin;
105
106    use crate::test::stream::*;
107    use crate::{
108        stream::{PollRecv, Stream},
109        Context,
110    };
111
112    use super::MergeStream;
113
114    #[test]
115    fn simple_merge() {
116        let left = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(3)]);
117        let right = from_poll_iter(vec![PollRecv::Ready(2), PollRecv::Ready(4)]);
118        let mut find = MergeStream::new(left, right);
119
120        let mut cx = Context::empty();
121
122        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
123        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
124        assert_eq!(PollRecv::Ready(3), Pin::new(&mut find).poll_recv(&mut cx));
125        assert_eq!(PollRecv::Ready(4), Pin::new(&mut find).poll_recv(&mut cx));
126        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
127        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
128    }
129
130    #[test]
131    fn swap_ready() {
132        let left = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(3)]);
133        let right = from_poll_iter(vec![PollRecv::Ready(2)]);
134        let mut find = MergeStream::new(left, right);
135
136        let mut cx = Context::empty();
137
138        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
139        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
140        assert_eq!(PollRecv::Ready(3), Pin::new(&mut find).poll_recv(&mut cx));
141        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
142    }
143
144    #[test]
145    fn swap_pending() {
146        let left = from_poll_iter(vec![PollRecv::Pending, PollRecv::Ready(2)]);
147        let right = from_poll_iter(vec![PollRecv::Ready(1)]);
148        let mut find = MergeStream::new(left, right);
149
150        let mut cx = Context::empty();
151
152        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
153        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
154        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
155    }
156
157    #[test]
158    fn swap_closed() {
159        let left = from_poll_iter(vec![PollRecv::Closed, PollRecv::Closed]);
160        let right = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(2)]);
161        let mut find = MergeStream::new(left, right);
162
163        let mut cx = Context::empty();
164
165        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
166        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
167        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
168    }
169
170    #[test]
171    fn pending_uses_right() {
172        let left = from_poll_iter(vec![PollRecv::Pending]);
173        let right = from_poll_iter(vec![PollRecv::Ready(1)]);
174        let mut find = MergeStream::new(left, right);
175
176        let mut cx = Context::empty();
177
178        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
179        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
180    }
181
182    #[test]
183    fn pending_uses_left() {
184        let left = from_poll_iter(vec![PollRecv::Ready(1)]);
185        let right = from_poll_iter(vec![PollRecv::Pending]);
186        let mut find = MergeStream::new(left, right);
187
188        let mut cx = Context::empty();
189
190        assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
191        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
192    }
193}