1use std::pin::Pin;
2
3use atomic::{Atomic, Ordering};
4
5use crate::stream::{PollRecv, Stream};
6use crate::Context;
7use pin_project::pin_project;
8
9#[derive(Copy, Clone)]
10enum State {
11 Left,
12 Right,
13 Closed,
14}
15
16#[pin_project]
17pub struct ChainStream<Left, Right> {
18 state: Atomic<State>,
19 #[pin]
20 left: Left,
21 #[pin]
22 right: Right,
23}
24
25impl<Left, Right> ChainStream<Left, Right>
26where
27 Left: Stream,
28 Right: Stream<Item = Left::Item>,
29{
30 pub fn new(left: Left, right: Right) -> Self {
31 Self {
32 state: Atomic::new(State::Left),
33 left,
34 right,
35 }
36 }
37}
38
39impl<Left, Right> Stream for ChainStream<Left, Right>
40where
41 Left: Stream,
42 Right: Stream<Item = Left::Item>,
43{
44 type Item = Left::Item;
45
46 fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollRecv<Self::Item> {
47 let this = self.project();
48 let mut state = this.state.load(Ordering::Acquire);
49
50 if let State::Left = state {
51 match this.left.poll_recv(cx) {
52 PollRecv::Ready(v) => return PollRecv::Ready(v),
53 PollRecv::Pending => return PollRecv::Pending,
54 PollRecv::Closed => {
55 this.state.store(State::Right, Ordering::Release);
56 state = State::Right;
57 }
58 }
59 }
60
61 if let State::Right = state {
62 match this.right.poll_recv(cx) {
63 PollRecv::Ready(v) => return PollRecv::Ready(v),
64 PollRecv::Pending => return PollRecv::Pending,
65 PollRecv::Closed => {
66 this.state.store(State::Closed, Ordering::Release);
67 return PollRecv::Closed;
68 }
69 }
70 }
71
72 if let State::Closed = state {
73 return PollRecv::Closed;
74 }
75
76 unreachable!();
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use std::pin::Pin;
83
84 use crate::test::stream::*;
85 use crate::{
86 stream::{PollRecv, Stream},
87 Context,
88 };
89
90 use super::ChainStream;
91
92 #[test]
93 fn chain() {
94 let left = from_poll_iter(vec![PollRecv::Ready(1), PollRecv::Ready(2)]);
95 let right = from_poll_iter(vec![PollRecv::Ready(3)]);
96 let mut find = ChainStream::new(left, right);
97
98 let mut cx = Context::empty();
99
100 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
101 assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
102 assert_eq!(PollRecv::Ready(3), Pin::new(&mut find).poll_recv(&mut cx));
103 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
104 }
105
106 #[test]
107 fn waits_for_right() {
108 let left = from_poll_iter(vec![PollRecv::Pending]);
109 let right = from_poll_iter(vec![PollRecv::Ready(1)]);
110 let mut find = ChainStream::new(left, right);
111
112 let mut cx = Context::empty();
113
114 assert_eq!(PollRecv::Pending, Pin::new(&mut find).poll_recv(&mut cx));
115 assert_eq!(PollRecv::Ready(1), Pin::new(&mut find).poll_recv(&mut cx));
116 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
117 }
118
119 #[test]
120 fn ignores_after_close() {
121 let left = from_poll_iter(vec![PollRecv::Closed, PollRecv::Ready(1)]);
122 let right = from_poll_iter(vec![PollRecv::Closed, PollRecv::Ready(2)]);
123 let mut find = ChainStream::new(left, right);
124
125 let mut cx = Context::empty();
126
127 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
128 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
129 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
130 }
131}