1use crate::sink::{PollSend, Sink};
2use crate::Context;
3use atomic::{Atomic, Ordering};
4use pin_project::pin_project;
5use std::pin::Pin;
6
7#[derive(Copy, Clone)]
8enum State {
9 WritingLeft,
10 WritingRight,
11 Closed,
12}
13#[pin_project]
14pub struct ChainSink<Left, Right> {
15 state: Atomic<State>,
16
17 #[pin]
18 left: Left,
19 #[pin]
20 right: Right,
21}
22
23impl<Left, Right> ChainSink<Left, Right>
24where
25 Left: Sink,
26 Right: Sink<Item = Left::Item>,
27{
28 pub fn new(left: Left, right: Right) -> Self {
29 Self {
30 state: Atomic::new(State::WritingLeft),
31 left,
32 right,
33 }
34 }
35}
36
37impl<Left, Right> Sink for ChainSink<Left, Right>
38where
39 Left: Sink,
40 Right: Sink<Item = Left::Item>,
41{
42 type Item = Left::Item;
43
44 fn poll_send(
45 self: Pin<&mut Self>,
46 cx: &mut Context<'_>,
47 mut value: Self::Item,
48 ) -> PollSend<Self::Item> {
49 let this = self.project();
50 let mut state = this.state.load(Ordering::Acquire);
51
52 if let State::WritingLeft = state {
53 match this.left.poll_send(cx, value) {
54 PollSend::Ready => return PollSend::Ready,
55 PollSend::Pending(value) => return PollSend::Pending(value),
56 PollSend::Rejected(returned_value) => {
57 value = returned_value;
58 this.state.store(State::WritingRight, Ordering::Release);
59 state = State::WritingRight;
60 }
61 }
62 }
63
64 if let State::WritingRight = state {
65 match this.right.poll_send(cx, value) {
66 PollSend::Ready => return PollSend::Ready,
67 PollSend::Pending(value) => return PollSend::Pending(value),
68 PollSend::Rejected(returned_value) => {
69 value = returned_value;
70
71 this.state.store(State::Closed, Ordering::Release);
72 return PollSend::Rejected(value);
73 }
74 }
75 }
76
77 if let State::Closed = state {
78 return PollSend::Rejected(value);
79 }
80
81 unreachable!();
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use std::pin::Pin;
88
89 use crate::test::sink::*;
90 use crate::{
91 sink::{PollSend, Sink},
92 Context,
93 };
94
95 use super::ChainSink;
96
97 #[test]
98 fn simple() {
99 let mut left = test_sink(vec![PollSend::Ready]);
100 let mut right = test_sink(vec![PollSend::Ready]);
101 let mut chain = ChainSink::new(&mut left, &mut right);
102
103 let mut cx = Context::empty();
104
105 assert_eq!(
106 PollSend::Ready,
107 Pin::new(&mut chain).poll_send(&mut cx, 1usize)
108 );
109 assert_eq!(PollSend::Ready, Pin::new(&mut chain).poll_send(&mut cx, 2));
110 assert_eq!(
111 PollSend::Rejected(3),
112 Pin::new(&mut chain).poll_send(&mut cx, 3)
113 );
114
115 drop(chain);
116
117 assert_eq!(&[1], left.values());
118 assert_eq!(&[2], right.values());
119 }
120
121 #[test]
122 fn waits_for_right() {
123 let mut left = test_sink(vec![PollSend::Pending(1)]);
124 let mut right = test_sink(vec![PollSend::Ready]);
125 let mut chain = ChainSink::new(&mut left, &mut right);
126
127 let mut cx = Context::empty();
128
129 assert_eq!(
130 PollSend::Pending(1),
131 Pin::new(&mut chain).poll_send(&mut cx, 1usize)
132 );
133 assert_eq!(PollSend::Ready, Pin::new(&mut chain).poll_send(&mut cx, 2));
134 assert_eq!(
135 PollSend::Rejected(3),
136 Pin::new(&mut chain).poll_send(&mut cx, 3)
137 );
138
139 drop(chain);
140
141 assert_eq!(Vec::<usize>::new(), left.values());
142 assert_eq!(&[2], right.values());
143 }
144
145 #[test]
146 fn ignores_after_close() {
147 let mut left = test_sink(vec![PollSend::Rejected(1), PollSend::Ready]);
148 let mut right = test_sink(vec![PollSend::Rejected(1), PollSend::Ready]);
149 let mut chain = ChainSink::new(&mut left, &mut right);
150
151 let mut cx = Context::empty();
152
153 assert_eq!(
154 PollSend::Rejected(1),
155 Pin::new(&mut chain).poll_send(&mut cx, 1usize)
156 );
157 assert_eq!(
158 PollSend::Rejected(2),
159 Pin::new(&mut chain).poll_send(&mut cx, 2)
160 );
161 assert_eq!(
162 PollSend::Rejected(3),
163 Pin::new(&mut chain).poll_send(&mut cx, 3)
164 );
165 assert_eq!(
166 PollSend::Rejected(4),
167 Pin::new(&mut chain).poll_send(&mut cx, 4)
168 );
169
170 drop(chain);
171
172 assert_eq!(Vec::<usize>::new(), left.values());
173 assert_eq!(Vec::<usize>::new(), right.values());
174 }
175}