1use std::fmt;
5use std::sync::Arc;
6
7use super::SendMessage;
8use crate::{
9 sink::{PollSend, Sink},
10 stream::{PollRecv, Stream},
11 sync::transfer::Transfer,
12};
13use static_assertions::{assert_impl_all, assert_not_impl_all};
14
15pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
17 #[cfg(feature = "debug")]
18 log::error!("Creating oneshot channel");
19
20 let shared = Arc::new(Transfer::new());
21 let sender = Sender {
22 shared: shared.clone(),
23 };
24
25 let receiver = Receiver { shared };
26
27 (sender, receiver)
28}
29
30pub struct Sender<T> {
32 pub(in crate::channels::oneshot) shared: Arc<Transfer<T>>,
33}
34
35assert_impl_all!(Sender<SendMessage>: Send, Sync, fmt::Debug);
36assert_not_impl_all!(Sender<SendMessage>: Clone);
37
38impl<T> Sink for Sender<T> {
39 type Item = T;
40
41 fn poll_send(
42 self: std::pin::Pin<&mut Self>,
43 _cx: &mut crate::Context<'_>,
44 value: Self::Item,
45 ) -> PollSend<Self::Item> {
46 match self.shared.send(value) {
47 Ok(_) => PollSend::Ready,
48 Err(v) => PollSend::Rejected(v),
49 }
50 }
51}
52
53impl<T> fmt::Debug for Sender<T> {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.debug_struct("Sender").finish()
56 }
57}
58
59#[cfg(feature = "futures-traits")]
60mod impl_futures {
61 use crate::sink::SendError;
62 use std::task::Poll;
63
64 impl<T> futures::sink::Sink<T> for super::Sender<T> {
65 type Error = crate::sink::SendError<T>;
66
67 fn poll_ready(
68 self: std::pin::Pin<&mut Self>,
69 _cx: &mut std::task::Context<'_>,
70 ) -> Poll<Result<(), Self::Error>> {
71 Poll::Ready(Ok(()))
72 }
73
74 fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
75 self.shared.send(item).map_err(|t| SendError(t))
76 }
77
78 fn poll_flush(
79 self: std::pin::Pin<&mut Self>,
80 _cx: &mut std::task::Context<'_>,
81 ) -> Poll<Result<(), Self::Error>> {
82 Poll::Ready(Ok(()))
83 }
84
85 fn poll_close(
86 self: std::pin::Pin<&mut Self>,
87 _cx: &mut std::task::Context<'_>,
88 ) -> Poll<Result<(), Self::Error>> {
89 Poll::Ready(Ok(()))
90 }
91 }
92}
93
94impl<T> Drop for Sender<T> {
95 fn drop(&mut self) {
96 self.shared.sender_disconnect();
97 }
98}
99
100pub struct Receiver<T> {
102 pub(in crate::channels::oneshot) shared: Arc<Transfer<T>>,
103}
104
105assert_impl_all!(Sender<SendMessage>: Send, Sync, fmt::Debug);
106assert_not_impl_all!(Sender<SendMessage>: Clone);
107
108impl<T> Stream for Receiver<T> {
109 type Item = T;
110
111 fn poll_recv(
112 self: std::pin::Pin<&mut Self>,
113 cx: &mut crate::Context<'_>,
114 ) -> PollRecv<Self::Item> {
115 self.shared.recv(cx)
116 }
117}
118
119impl<T> Drop for Receiver<T> {
120 fn drop(&mut self) {
121 self.shared.receiver_disconnect();
122 }
123}
124
125impl<T> fmt::Debug for Receiver<T> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Receiver").finish()
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::pin::Pin;
134
135 use crate::{
136 sink::{PollSend, Sink},
137 stream::{PollRecv, Stream},
138 test::{noop_context, panic_context},
139 Context,
140 };
141 use futures_test::task::new_count_waker;
142
143 use super::channel;
144
145 #[derive(Clone, Debug, PartialEq, Eq)]
146 struct Message(usize);
147
148 #[test]
149 fn send_accepted() {
150 let mut cx = noop_context();
151 let (mut tx, _rx) = channel();
152
153 assert_eq!(
154 PollSend::Ready,
155 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
156 );
157 assert_eq!(
158 PollSend::Rejected(Message(2)),
159 Pin::new(&mut tx).poll_send(&mut cx, Message(2))
160 );
161 }
162
163 #[test]
164 fn send_recv() {
165 let mut cx = noop_context();
166 let (mut tx, mut rx) = channel();
167
168 assert_eq!(
169 PollSend::Ready,
170 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
171 );
172
173 assert_eq!(
174 PollRecv::Ready(Message(1)),
175 Pin::new(&mut rx).poll_recv(&mut cx)
176 );
177 assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
178 }
179
180 #[test]
181 fn sender_disconnect() {
182 let mut cx = noop_context();
183 let (tx, mut rx) = channel::<Message>();
184
185 drop(tx);
186
187 assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
188 }
189
190 #[test]
191 fn sender_disconnect_after_poll() {
192 let mut cx = noop_context();
193 let (tx, mut rx) = channel::<Message>();
194
195 assert_eq!(PollRecv::Pending, Pin::new(&mut rx).poll_recv(&mut cx));
196
197 drop(tx);
198 assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
199 }
200
201 #[test]
202 fn send_then_disconnect() {
203 let mut cx = noop_context();
204 let (mut tx, mut rx) = channel();
205
206 assert_eq!(
207 PollSend::Ready,
208 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
209 );
210
211 drop(tx);
212
213 assert_eq!(
214 PollRecv::Ready(Message(1)),
215 Pin::new(&mut rx).poll_recv(&mut cx)
216 );
217
218 assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
219 }
220
221 #[test]
222 fn receiver_disconnect() {
223 let mut cx = noop_context();
224 let (mut tx, rx) = channel();
225
226 drop(rx);
227
228 assert_eq!(
229 PollSend::Rejected(Message(1)),
230 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
231 );
232 }
233
234 #[test]
235 fn wake_receiver() {
236 let mut cx = panic_context();
237 let (mut tx, mut rx) = channel();
238
239 let (w1, w1_count) = new_count_waker();
240 let mut w1_context = Context::from_waker(&w1);
241
242 assert_eq!(
243 PollRecv::Pending,
244 Pin::new(&mut rx).poll_recv(&mut w1_context)
245 );
246
247 assert_eq!(0, w1_count.get());
248
249 assert_eq!(
250 PollSend::Ready,
251 Pin::new(&mut tx).poll_send(&mut cx, Message(1))
252 );
253
254 assert_eq!(1, w1_count.get());
255
256 assert_eq!(
257 PollSend::Rejected(Message(2)),
258 Pin::new(&mut tx).poll_send(&mut cx, Message(2))
259 );
260
261 assert_eq!(1, w1_count.get());
262 }
263
264 #[test]
265 fn sender_disconnect_wakes_receiver() {
266 let (tx, mut rx) = channel::<usize>();
267
268 let (w1, w1_count) = new_count_waker();
269 let mut w1_context = Context::from_waker(&w1);
270
271 assert_eq!(
272 PollRecv::Pending,
273 Pin::new(&mut rx).poll_recv(&mut w1_context)
274 );
275
276 assert_eq!(0, w1_count.get());
277
278 drop(tx);
279
280 assert_eq!(1, w1_count.get());
281
282 assert_eq!(
283 PollRecv::Closed,
284 Pin::new(&mut rx).poll_recv(&mut w1_context)
285 );
286 }
287}
288
289#[cfg(test)]
290mod tokio_tests {
291 use std::time::Duration;
292
293 use tokio::{task::spawn, time::timeout};
294
295 use crate::{
296 sink::Sink,
297 stream::Stream,
298 test::{CHANNEL_TEST_ITERATIONS, TEST_TIMEOUT},
299 };
300
301 use super::channel;
302
303 #[tokio::test]
304 async fn simple() {
305 for _ in 0..CHANNEL_TEST_ITERATIONS {
306 let (mut tx, mut rx) = channel();
307
308 spawn(async move { tx.send(100usize).await });
309
310 let msg = timeout(TEST_TIMEOUT, async move { rx.recv().await })
311 .await
312 .expect("test timeout");
313
314 assert_eq!(Some(100usize), msg);
315 }
316 }
317
318 #[tokio::test]
319 async fn sender_disconnect() {
320 for _ in 0..CHANNEL_TEST_ITERATIONS {
321 let (tx, mut rx) = channel::<usize>();
322
323 spawn(async move { drop(tx) });
324
325 let msg = timeout(Duration::from_millis(100), async move { rx.recv().await })
326 .await
327 .expect("test timeout");
328
329 assert_eq!(None, msg);
330 }
331 }
332}
333
334#[cfg(test)]
335mod async_std_tests {
336 use std::time::Duration;
337
338 use async_std::{future::timeout, task::spawn};
339
340 use crate::{
341 sink::Sink,
342 stream::Stream,
343 test::{CHANNEL_TEST_ITERATIONS, TEST_TIMEOUT},
344 };
345
346 use super::channel;
347
348 #[async_std::test]
349 async fn simple() {
350 for i in 0..CHANNEL_TEST_ITERATIONS {
351 let (mut tx, mut rx) = channel();
352
353 spawn(async move { tx.send(i).await });
354
355 let msg = timeout(TEST_TIMEOUT, async move { rx.recv().await })
356 .await
357 .expect("test timeout");
358
359 assert_eq!(Some(i), msg);
360 }
361 }
362
363 #[async_std::test]
364 async fn sender_disconnect() {
365 for _ in 0..CHANNEL_TEST_ITERATIONS {
366 let (tx, mut rx) = channel::<usize>();
367
368 spawn(async move { drop(tx) });
369
370 let msg = timeout(Duration::from_millis(100), async move { rx.recv().await })
371 .await
372 .expect("test timeout");
373
374 assert_eq!(None, msg);
375 }
376 }
377}