1use std::pin::Pin;
2
3use crate::Context;
4use atomic::{Atomic, Ordering};
5
6use crate::stream::{PollRecv, Stream};
7
8#[derive(Copy, Clone)]
9enum State {
10 Reading,
11 Closed,
12}
13
14pub struct FindStream<From, Condition> {
15 state: Atomic<State>,
16 from: From,
17 condition: Condition,
18}
19
20impl<From, Condition> FindStream<From, Condition>
21where
22 From: Stream,
23 Condition: Fn(&From::Item) -> bool,
24{
25 pub fn new(from: From, condition: Condition) -> Self {
26 Self {
27 state: Atomic::new(State::Reading),
28 from,
29 condition,
30 }
31 }
32}
33
34impl<From, Condition> Stream for FindStream<From, Condition>
35where
36 From: Stream + Unpin,
37 Condition: Fn(&From::Item) -> bool + Unpin,
38{
39 type Item = From::Item;
40
41 fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollRecv<Self::Item> {
42 let this = self.get_mut();
43
44 if let State::Closed = this.state.load(Ordering::Acquire) {
45 return PollRecv::Closed;
46 }
47
48 loop {
49 let from = Pin::new(&mut this.from);
50 match from.poll_recv(cx) {
51 PollRecv::Ready(value) => {
52 if (this.condition)(&value) {
53 this.state.store(State::Closed, Ordering::Release);
54
55 return PollRecv::Ready(value);
56 }
57 }
58 PollRecv::Pending => return PollRecv::Pending,
59 PollRecv::Closed => return PollRecv::Closed,
60 }
61 }
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use std::pin::Pin;
68
69 use crate::test::stream::*;
70 use crate::{
71 stream::{PollRecv, Stream},
72 Context,
73 };
74
75 use super::FindStream;
76
77 #[test]
78 fn find() {
79 let source = from_iter(vec![1, 2, 3]);
80 let mut find = FindStream::new(source, |i| *i == 2);
81
82 let mut cx = Context::empty();
83
84 assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
85 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
86 }
87
88 #[test]
89 fn find_none() {
90 let source = from_iter(vec![1, 3]);
91 let mut find = FindStream::new(source, |i| *i == 2);
92
93 let mut cx = Context::empty();
94
95 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
96 }
97
98 #[test]
99 fn find_only_once() {
100 let source = from_iter(vec![1, 2, 2]);
101 let mut find = FindStream::new(source, |i| *i == 2);
102
103 let mut cx = Context::empty();
104
105 assert_eq!(PollRecv::Ready(2), Pin::new(&mut find).poll_recv(&mut cx));
106 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
107 }
108
109 #[test]
110 fn forward_pending() {
111 let source = pending::<usize>();
112 let mut find = FindStream::new(source, |i| *i == 2);
113
114 let mut cx = Context::empty();
115
116 assert_eq!(PollRecv::Pending, Pin::new(&mut find).poll_recv(&mut cx));
117 }
118
119 #[test]
120 fn forward_closed() {
121 let source = closed::<usize>();
122 let mut find = FindStream::new(source, |i| *i == 2);
123
124 let mut cx = Context::empty();
125
126 assert_eq!(PollRecv::Closed, Pin::new(&mut find).poll_recv(&mut cx));
127 }
128}