postage/stream/
find.rs

1use std::pin::Pin;
2
3use crate::Context;
4use atomic::{Atomic, Ordering};
5
6use crate::stream::{PollRecv, Stream};
7
8#[derive(Copy, Clone)]
9enum State {
10    Reading,
11    Closed,
12}
13
14pub struct FindStream<From, Condition> {
15    state: Atomic<State>,
16    from: From,
17    condition: Condition,
18}
19
20impl<From, Condition> FindStream<From, Condition>
21where
22    From: Stream,
23    Condition: Fn(&From::Item) -> bool,
24{
25    pub fn new(from: From, condition: Condition) -> Self {
26        Self {
27            state: Atomic::new(State::Reading),
28            from,
29            condition,
30        }
31    }
32}
33
34impl<From, Condition> Stream for FindStream<From, Condition>
35where
36    From: Stream + Unpin,
37    Condition: Fn(&From::Item) -> bool + Unpin,
38{
39    type Item = From::Item;
40
41    fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollRecv<Self::Item> {
42        let this = self.get_mut();
43
44        if let State::Closed = this.state.load(Ordering::Acquire) {
45            return PollRecv::Closed;
46        }
47
48        loop {
49            let from = Pin::new(&mut this.from);
50            match from.poll_recv(cx) {
51                PollRecv::Ready(value) => {
52                    if (this.condition)(&value) {
53                        this.state.store(State::Closed, Ordering::Release);
54
55                        return PollRecv::Ready(value);
56                    }
57                }
58                PollRecv::Pending => return PollRecv::Pending,
59                PollRecv::Closed => return PollRecv::Closed,
60            }
61        }
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use std::pin::Pin;
68
69    use crate::test::stream::*;
70    use crate::{
71        stream::{PollRecv, Stream},
72        Context,
73    };
74
75    use super::FindStream;
76
77    #[test]
78    fn find() {
79        let source = from_iter(vec![1, 2, 3]);
80        let mut find = FindStream::new(source, |i| *i == 2);
81
82        let mut cx = Context::empty();
83
84        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
85        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
86    }
87
88    #[test]
89    fn find_none() {
90        let source = from_iter(vec![1, 3]);
91        let mut find = FindStream::new(source, |i| *i == 2);
92
93        let mut cx = Context::empty();
94
95        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
96    }
97
98    #[test]
99    fn find_only_once() {
100        let source = from_iter(vec![1, 2, 2]);
101        let mut find = FindStream::new(source, |i| *i == 2);
102
103        let mut cx = Context::empty();
104
105        assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
106        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
107    }
108
109    #[test]
110    fn forward_pending() {
111        let source = pending::<usize>();
112        let mut find = FindStream::new(source, |i| *i == 2);
113
114        let mut cx = Context::empty();
115
116        assert_eq!(PollRecv::Pending, Pin::new(&mut find).poll_recv(&mut cx));
117    }
118
119    #[test]
120    fn forward_closed() {
121        let source = closed::<usize>();
122        let mut find = FindStream::new(source, |i| *i == 2);
123
124        let mut cx = Context::empty();
125
126        assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
127    }
128}