postage/sink/
chain.rs

1use crate::sink::{PollSend, Sink};
2use crate::Context;
3use atomic::{Atomic, Ordering};
4use pin_project::pin_project;
5use std::pin::Pin;
6
7#[derive(Copy, Clone)]
8enum State {
9    WritingLeft,
10    WritingRight,
11    Closed,
12}
13#[pin_project]
14pub struct ChainSink<Left, Right> {
15    state: Atomic<State>,
16
17    #[pin]
18    left: Left,
19    #[pin]
20    right: Right,
21}
22
23impl<Left, Right> ChainSink<Left, Right>
24where
25    Left: Sink,
26    Right: Sink<Item = Left::Item>,
27{
28    pub fn new(left: Left, right: Right) -> Self {
29        Self {
30            state: Atomic::new(State::WritingLeft),
31            left,
32            right,
33        }
34    }
35}
36
37impl<Left, Right> Sink for ChainSink<Left, Right>
38where
39    Left: Sink,
40    Right: Sink<Item = Left::Item>,
41{
42    type Item = Left::Item;
43
44    fn poll_send(
45        self: Pin<&mut Self>,
46        cx: &mut Context<'_>,
47        mut value: Self::Item,
48    ) -> PollSend<Self::Item> {
49        let this = self.project();
50        let mut state = this.state.load(Ordering::Acquire);
51
52        if let State::WritingLeft = state {
53            match this.left.poll_send(cx, value) {
54                PollSend::Ready => return PollSend::Ready,
55                PollSend::Pending(value) => return PollSend::Pending(value),
56                PollSend::Rejected(returned_value) => {
57                    value = returned_value;
58                    this.state.store(State::WritingRight, Ordering::Release);
59                    state = State::WritingRight;
60                }
61            }
62        }
63
64        if let State::WritingRight = state {
65            match this.right.poll_send(cx, value) {
66                PollSend::Ready => return PollSend::Ready,
67                PollSend::Pending(value) => return PollSend::Pending(value),
68                PollSend::Rejected(returned_value) => {
69                    value = returned_value;
70
71                    this.state.store(State::Closed, Ordering::Release);
72                    return PollSend::Rejected(value);
73                }
74            }
75        }
76
77        if let State::Closed = state {
78            return PollSend::Rejected(value);
79        }
80
81        unreachable!();
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use std::pin::Pin;
88
89    use crate::test::sink::*;
90    use crate::{
91        sink::{PollSend, Sink},
92        Context,
93    };
94
95    use super::ChainSink;
96
97    #[test]
98    fn simple() {
99        let mut left = test_sink(vec![PollSend::Ready]);
100        let mut right = test_sink(vec![PollSend::Ready]);
101        let mut chain = ChainSink::new(&mut left, &mut right);
102
103        let mut cx = Context::empty();
104
105        assert_eq!(
106            PollSend::Ready,
107            Pin::new(&mut chain).poll_send(&mut cx, 1usize)
108        );
109        assert_eq!(PollSend::Ready, Pin::new(&mut chain).poll_send(&mut cx, 2));
110        assert_eq!(
111            PollSend::Rejected(3),
112            Pin::new(&mut chain).poll_send(&mut cx, 3)
113        );
114
115        drop(chain);
116
117        assert_eq!(&[1], left.values());
118        assert_eq!(&[2], right.values());
119    }
120
121    #[test]
122    fn waits_for_right() {
123        let mut left = test_sink(vec![PollSend::Pending(1)]);
124        let mut right = test_sink(vec![PollSend::Ready]);
125        let mut chain = ChainSink::new(&mut left, &mut right);
126
127        let mut cx = Context::empty();
128
129        assert_eq!(
130            PollSend::Pending(1),
131            Pin::new(&mut chain).poll_send(&mut cx, 1usize)
132        );
133        assert_eq!(PollSend::Ready, Pin::new(&mut chain).poll_send(&mut cx, 2));
134        assert_eq!(
135            PollSend::Rejected(3),
136            Pin::new(&mut chain).poll_send(&mut cx, 3)
137        );
138
139        drop(chain);
140
141        assert_eq!(Vec::<usize>::new(), left.values());
142        assert_eq!(&[2], right.values());
143    }
144
145    #[test]
146    fn ignores_after_close() {
147        let mut left = test_sink(vec![PollSend::Rejected(1), PollSend::Ready]);
148        let mut right = test_sink(vec![PollSend::Rejected(1), PollSend::Ready]);
149        let mut chain = ChainSink::new(&mut left, &mut right);
150
151        let mut cx = Context::empty();
152
153        assert_eq!(
154            PollSend::Rejected(1),
155            Pin::new(&mut chain).poll_send(&mut cx, 1usize)
156        );
157        assert_eq!(
158            PollSend::Rejected(2),
159            Pin::new(&mut chain).poll_send(&mut cx, 2)
160        );
161        assert_eq!(
162            PollSend::Rejected(3),
163            Pin::new(&mut chain).poll_send(&mut cx, 3)
164        );
165        assert_eq!(
166            PollSend::Rejected(4),
167            Pin::new(&mut chain).poll_send(&mut cx, 4)
168        );
169
170        drop(chain);
171
172        assert_eq!(Vec::<usize>::new(), left.values());
173        assert_eq!(Vec::<usize>::new(), right.values());
174    }
175}