1use crate::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
4use crate::loom::sync::Mutex;
5
6use bytes::{Buf, BytesMut};
7use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, ready, Poll, Waker},
11};
12
13#[derive(Debug)]
48#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49pub struct DuplexStream {
50 read: Arc<Mutex<SimplexStream>>,
51 write: Arc<Mutex<SimplexStream>>,
52}
53
54#[derive(Debug)]
76#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
77pub struct SimplexStream {
78 buffer: BytesMut,
84 is_closed: bool,
86 max_buf_size: usize,
89 read_waker: Option<Waker>,
92 write_waker: Option<Waker>,
95}
96
97#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
104pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
105 let one = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
106 let two = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
107
108 (
109 DuplexStream {
110 read: one.clone(),
111 write: two.clone(),
112 },
113 DuplexStream {
114 read: two,
115 write: one,
116 },
117 )
118}
119
120impl AsyncRead for DuplexStream {
121 #[allow(unused_mut)]
127 fn poll_read(
128 mut self: Pin<&mut Self>,
129 cx: &mut task::Context<'_>,
130 buf: &mut ReadBuf<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
133 }
134}
135
136impl AsyncWrite for DuplexStream {
137 #[allow(unused_mut)]
138 fn poll_write(
139 mut self: Pin<&mut Self>,
140 cx: &mut task::Context<'_>,
141 buf: &[u8],
142 ) -> Poll<std::io::Result<usize>> {
143 Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
144 }
145
146 fn poll_write_vectored(
147 self: Pin<&mut Self>,
148 cx: &mut task::Context<'_>,
149 bufs: &[std::io::IoSlice<'_>],
150 ) -> Poll<Result<usize, std::io::Error>> {
151 Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
152 }
153
154 fn is_write_vectored(&self) -> bool {
155 true
156 }
157
158 #[allow(unused_mut)]
159 fn poll_flush(
160 mut self: Pin<&mut Self>,
161 cx: &mut task::Context<'_>,
162 ) -> Poll<std::io::Result<()>> {
163 Pin::new(&mut *self.write.lock()).poll_flush(cx)
164 }
165
166 #[allow(unused_mut)]
167 fn poll_shutdown(
168 mut self: Pin<&mut Self>,
169 cx: &mut task::Context<'_>,
170 ) -> Poll<std::io::Result<()>> {
171 Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
172 }
173}
174
175impl Drop for DuplexStream {
176 fn drop(&mut self) {
177 self.write.lock().close_write();
179 self.read.lock().close_read();
180 }
181}
182
183#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
211pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
212 split(SimplexStream::new_unsplit(max_buf_size))
213}
214
215impl SimplexStream {
216 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
222 pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
223 SimplexStream {
224 buffer: BytesMut::new(),
225 is_closed: false,
226 max_buf_size,
227 read_waker: None,
228 write_waker: None,
229 }
230 }
231
232 fn close_write(&mut self) {
233 self.is_closed = true;
234 if let Some(waker) = self.read_waker.take() {
236 waker.wake();
237 }
238 }
239
240 fn close_read(&mut self) {
241 self.is_closed = true;
242 if let Some(waker) = self.write_waker.take() {
244 waker.wake();
245 }
246 }
247
248 fn poll_read_internal(
249 mut self: Pin<&mut Self>,
250 cx: &mut task::Context<'_>,
251 buf: &mut ReadBuf<'_>,
252 ) -> Poll<std::io::Result<()>> {
253 if self.buffer.has_remaining() {
254 let max = self.buffer.remaining().min(buf.remaining());
255 buf.put_slice(&self.buffer[..max]);
256 self.buffer.advance(max);
257 if max > 0 {
258 if let Some(waker) = self.write_waker.take() {
261 waker.wake();
262 }
263 }
264 Poll::Ready(Ok(()))
265 } else if self.is_closed {
266 Poll::Ready(Ok(()))
267 } else {
268 self.read_waker = Some(cx.waker().clone());
269 Poll::Pending
270 }
271 }
272
273 fn poll_write_internal(
274 mut self: Pin<&mut Self>,
275 cx: &mut task::Context<'_>,
276 buf: &[u8],
277 ) -> Poll<std::io::Result<usize>> {
278 if self.is_closed {
279 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
280 }
281 let avail = self.max_buf_size - self.buffer.len();
282 if avail == 0 {
283 self.write_waker = Some(cx.waker().clone());
284 return Poll::Pending;
285 }
286
287 let len = buf.len().min(avail);
288 self.buffer.extend_from_slice(&buf[..len]);
289 if let Some(waker) = self.read_waker.take() {
290 waker.wake();
291 }
292 Poll::Ready(Ok(len))
293 }
294
295 fn poll_write_vectored_internal(
296 mut self: Pin<&mut Self>,
297 cx: &mut task::Context<'_>,
298 bufs: &[std::io::IoSlice<'_>],
299 ) -> Poll<Result<usize, std::io::Error>> {
300 if self.is_closed {
301 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
302 }
303 let avail = self.max_buf_size - self.buffer.len();
304 if avail == 0 {
305 self.write_waker = Some(cx.waker().clone());
306 return Poll::Pending;
307 }
308
309 let mut rem = avail;
310 for buf in bufs {
311 if rem == 0 {
312 break;
313 }
314
315 let len = buf.len().min(rem);
316 self.buffer.extend_from_slice(&buf[..len]);
317 rem -= len;
318 }
319
320 if let Some(waker) = self.read_waker.take() {
321 waker.wake();
322 }
323 Poll::Ready(Ok(avail - rem))
324 }
325}
326
327impl AsyncRead for SimplexStream {
328 cfg_coop! {
329 fn poll_read(
330 self: Pin<&mut Self>,
331 cx: &mut task::Context<'_>,
332 buf: &mut ReadBuf<'_>,
333 ) -> Poll<std::io::Result<()>> {
334 ready!(crate::trace::trace_leaf(cx));
335 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
336
337 let ret = self.poll_read_internal(cx, buf);
338 if ret.is_ready() {
339 coop.made_progress();
340 }
341 ret
342 }
343 }
344
345 cfg_not_coop! {
346 fn poll_read(
347 self: Pin<&mut Self>,
348 cx: &mut task::Context<'_>,
349 buf: &mut ReadBuf<'_>,
350 ) -> Poll<std::io::Result<()>> {
351 ready!(crate::trace::trace_leaf(cx));
352 self.poll_read_internal(cx, buf)
353 }
354 }
355}
356
357impl AsyncWrite for SimplexStream {
358 cfg_coop! {
359 fn poll_write(
360 self: Pin<&mut Self>,
361 cx: &mut task::Context<'_>,
362 buf: &[u8],
363 ) -> Poll<std::io::Result<usize>> {
364 ready!(crate::trace::trace_leaf(cx));
365 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
366
367 let ret = self.poll_write_internal(cx, buf);
368 if ret.is_ready() {
369 coop.made_progress();
370 }
371 ret
372 }
373 }
374
375 cfg_not_coop! {
376 fn poll_write(
377 self: Pin<&mut Self>,
378 cx: &mut task::Context<'_>,
379 buf: &[u8],
380 ) -> Poll<std::io::Result<usize>> {
381 ready!(crate::trace::trace_leaf(cx));
382 self.poll_write_internal(cx, buf)
383 }
384 }
385
386 cfg_coop! {
387 fn poll_write_vectored(
388 self: Pin<&mut Self>,
389 cx: &mut task::Context<'_>,
390 bufs: &[std::io::IoSlice<'_>],
391 ) -> Poll<Result<usize, std::io::Error>> {
392 ready!(crate::trace::trace_leaf(cx));
393 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
394
395 let ret = self.poll_write_vectored_internal(cx, bufs);
396 if ret.is_ready() {
397 coop.made_progress();
398 }
399 ret
400 }
401 }
402
403 cfg_not_coop! {
404 fn poll_write_vectored(
405 self: Pin<&mut Self>,
406 cx: &mut task::Context<'_>,
407 bufs: &[std::io::IoSlice<'_>],
408 ) -> Poll<Result<usize, std::io::Error>> {
409 ready!(crate::trace::trace_leaf(cx));
410 self.poll_write_vectored_internal(cx, bufs)
411 }
412 }
413
414 fn is_write_vectored(&self) -> bool {
415 true
416 }
417
418 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
419 Poll::Ready(Ok(()))
420 }
421
422 fn poll_shutdown(
423 mut self: Pin<&mut Self>,
424 _: &mut task::Context<'_>,
425 ) -> Poll<std::io::Result<()>> {
426 self.close_write();
427 Poll::Ready(Ok(()))
428 }
429}