1use crate::stream::{PollRecv, Stream};
2use pin_project::pin_project;
3use std::pin::Pin;
4
5use crate::Context;
6#[derive(Copy, Clone)]
7enum State {
8 Left,
9 Right,
10}
11
12impl State {
13 pub fn swap(&self) -> Self {
14 match self {
15 Self::Left => Self::Right,
16 Self::Right => Self::Left,
17 }
18 }
19}
20
21#[pin_project]
22pub struct MergeStream<Left, Right> {
23 state: State,
24 #[pin]
25 left: Left,
26 #[pin]
27 right: Right,
28}
29
30impl<Left, Right> MergeStream<Left, Right>
31where
32 Left: Stream,
33 Right: Stream<Item = Left::Item>,
34{
35 pub fn new(left: Left, right: Right) -> Self {
36 Self {
37 state: State::Left,
38 left,
39 right,
40 }
41 }
42}
43
44impl<Left, Right> Stream for MergeStream<Left, Right>
45where
46 Left: Stream,
47 Right: Stream<Item = Left::Item>,
48{
49 type Item = Left::Item;
50
51 fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollRecv<Self::Item> {
52 let this = self.project();
53
54 let poll = match this.state {
55 State::Left => poll(this.left, this.right, cx),
56 State::Right => poll(this.right, this.left, cx),
57 };
58
59 if poll.swap() {
60 *this.state = this.state.swap();
61 }
62
63 poll.into_recv()
64 }
65}
66
67enum MergePoll<T> {
68 First(PollRecv<T>),
69 Second(PollRecv<T>),
70}
71
72impl<T> MergePoll<T> {
73 pub fn into_recv(self) -> PollRecv<T> {
74 match self {
75 MergePoll::First(p) => p,
76 MergePoll::Second(p) => p,
77 }
78 }
79
80 pub fn swap(&self) -> bool {
81 match self {
82 MergePoll::First(_) => true,
83 MergePoll::Second(PollRecv::Ready(_)) => true,
84 MergePoll::Second(PollRecv::Pending) => true,
85 MergePoll::Second(PollRecv::Closed) => false,
86 }
87 }
88}
89
90fn poll<A, B>(first: Pin<&mut A>, second: Pin<&mut B>, cx: &mut Context<'_>) -> MergePoll<A::Item>
91where
92 A: Stream,
93 B: Stream<Item = A::Item>,
94{
95 match first.poll_recv(cx) {
96 PollRecv::Ready(v) => MergePoll::First(PollRecv::Ready(v)),
97 PollRecv::Pending => MergePoll::Second(second.poll_recv(cx)),
98 PollRecv::Closed => MergePoll::Second(second.poll_recv(cx)),
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use std::pin::Pin;
105
106 use crate::test::stream::*;
107 use crate::{
108 stream::{PollRecv, Stream},
109 Context,
110 };
111
112 use super::MergeStream;
113
114 #[test]
115 fn simple_merge() {
116 let left = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(3)]);
117 let right = from_poll_iter(vec![PollRecv::Ready(2), PollRecv::Ready(4)]);
118 let mut find = MergeStream::new(left, right);
119
120 let mut cx = Context::empty();
121
122 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
123 assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
124 assert_eq!(PollRecv::Ready(3), Pin::new(&mut find).poll_recv(&mut cx));
125 assert_eq!(PollRecv::Ready(4), Pin::new(&mut find).poll_recv(&mut cx));
126 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
127 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
128 }
129
130 #[test]
131 fn swap_ready() {
132 let left = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(3)]);
133 let right = from_poll_iter(vec![PollRecv::Ready(2)]);
134 let mut find = MergeStream::new(left, right);
135
136 let mut cx = Context::empty();
137
138 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
139 assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
140 assert_eq!(PollRecv::Ready(3), Pin::new(&mut find).poll_recv(&mut cx));
141 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
142 }
143
144 #[test]
145 fn swap_pending() {
146 let left = from_poll_iter(vec![PollRecv::Pending, PollRecv::Ready(2)]);
147 let right = from_poll_iter(vec![PollRecv::Ready(1)]);
148 let mut find = MergeStream::new(left, right);
149
150 let mut cx = Context::empty();
151
152 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
153 assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
154 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
155 }
156
157 #[test]
158 fn swap_closed() {
159 let left = from_poll_iter(vec![PollRecv::Closed, PollRecv::Closed]);
160 let right = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(2)]);
161 let mut find = MergeStream::new(left, right);
162
163 let mut cx = Context::empty();
164
165 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
166 assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
167 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
168 }
169
170 #[test]
171 fn pending_uses_right() {
172 let left = from_poll_iter(vec![PollRecv::Pending]);
173 let right = from_poll_iter(vec![PollRecv::Ready(1)]);
174 let mut find = MergeStream::new(left, right);
175
176 let mut cx = Context::empty();
177
178 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
179 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
180 }
181
182 #[test]
183 fn pending_uses_left() {
184 let left = from_poll_iter(vec![PollRecv::Ready(1)]);
185 let right = from_poll_iter(vec![PollRecv::Pending]);
186 let mut find = MergeStream::new(left, right);
187
188 let mut cx = Context::empty();
189
190 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
191 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
192 }
193}