tokio/io/util/
mem.rs

1//! In-process memory IO types.
2
3use crate::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
4use crate::loom::sync::Mutex;
5
6use bytes::{Buf, BytesMut};
7use std::{
8    pin::Pin,
9    sync::Arc,
10    task::{self, ready, Poll, Waker},
11};
12
13/// A bidirectional pipe to read and write bytes in memory.
14///
15/// A pair of `DuplexStream`s are created together, and they act as a "channel"
16/// that can be used as in-memory IO types. Writing to one of the pairs will
17/// allow that data to be read from the other, and vice versa.
18///
19/// # Closing a `DuplexStream`
20///
21/// If one end of the `DuplexStream` channel is dropped, any pending reads on
22/// the other side will continue to read data until the buffer is drained, then
23/// they will signal EOF by returning 0 bytes. Any writes to the other side,
24/// including pending ones (that are waiting for free space in the buffer) will
25/// return `Err(BrokenPipe)` immediately.
26///
27/// # Example
28///
29/// ```
30/// # async fn ex() -> std::io::Result<()> {
31/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32/// let (mut client, mut server) = tokio::io::duplex(64);
33///
34/// client.write_all(b"ping").await?;
35///
36/// let mut buf = [0u8; 4];
37/// server.read_exact(&mut buf).await?;
38/// assert_eq!(&buf, b"ping");
39///
40/// server.write_all(b"pong").await?;
41///
42/// client.read_exact(&mut buf).await?;
43/// assert_eq!(&buf, b"pong");
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug)]
48#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49pub struct DuplexStream {
50    read: Arc<Mutex<SimplexStream>>,
51    write: Arc<Mutex<SimplexStream>>,
52}
53
54/// A unidirectional pipe to read and write bytes in memory.
55///
56/// It can be constructed by [`simplex`] function which will create a pair of
57/// reader and writer or by calling [`SimplexStream::new_unsplit`] that will
58/// create a handle for both reading and writing.
59///
60/// # Example
61///
62/// ```
63/// # async fn ex() -> std::io::Result<()> {
64/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
65/// let (mut receiver, mut sender) = tokio::io::simplex(64);
66///
67/// sender.write_all(b"ping").await?;
68///
69/// let mut buf = [0u8; 4];
70/// receiver.read_exact(&mut buf).await?;
71/// assert_eq!(&buf, b"ping");
72/// # Ok(())
73/// # }
74/// ```
75#[derive(Debug)]
76#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
77pub struct SimplexStream {
78    /// The buffer storing the bytes written, also read from.
79    ///
80    /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
81    /// functionality already. Additionally, it can try to copy data in the
82    /// same buffer if there read index has advanced far enough.
83    buffer: BytesMut,
84    /// Determines if the write side has been closed.
85    is_closed: bool,
86    /// The maximum amount of bytes that can be written before returning
87    /// `Poll::Pending`.
88    max_buf_size: usize,
89    /// If the `read` side has been polled and is pending, this is the waker
90    /// for that parked task.
91    read_waker: Option<Waker>,
92    /// If the `write` side has filled the `max_buf_size` and returned
93    /// `Poll::Pending`, this is the waker for that parked task.
94    write_waker: Option<Waker>,
95}
96
97// ===== impl DuplexStream =====
98
99/// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
100///
101/// The `max_buf_size` argument is the maximum amount of bytes that can be
102/// written to a side before the write returns `Poll::Pending`.
103#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
104pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
105    let one = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
106    let two = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
107
108    (
109        DuplexStream {
110            read: one.clone(),
111            write: two.clone(),
112        },
113        DuplexStream {
114            read: two,
115            write: one,
116        },
117    )
118}
119
120impl AsyncRead for DuplexStream {
121    // Previous rustc required this `self` to be `mut`, even though newer
122    // versions recognize it isn't needed to call `lock()`. So for
123    // compatibility, we include the `mut` and `allow` the lint.
124    //
125    // See https://github.com/rust-lang/rust/issues/73592
126    #[allow(unused_mut)]
127    fn poll_read(
128        mut self: Pin<&mut Self>,
129        cx: &mut task::Context<'_>,
130        buf: &mut ReadBuf<'_>,
131    ) -> Poll<std::io::Result<()>> {
132        Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
133    }
134}
135
136impl AsyncWrite for DuplexStream {
137    #[allow(unused_mut)]
138    fn poll_write(
139        mut self: Pin<&mut Self>,
140        cx: &mut task::Context<'_>,
141        buf: &[u8],
142    ) -> Poll<std::io::Result<usize>> {
143        Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
144    }
145
146    fn poll_write_vectored(
147        self: Pin<&mut Self>,
148        cx: &mut task::Context<'_>,
149        bufs: &[std::io::IoSlice<'_>],
150    ) -> Poll<Result<usize, std::io::Error>> {
151        Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
152    }
153
154    fn is_write_vectored(&self) -> bool {
155        true
156    }
157
158    #[allow(unused_mut)]
159    fn poll_flush(
160        mut self: Pin<&mut Self>,
161        cx: &mut task::Context<'_>,
162    ) -> Poll<std::io::Result<()>> {
163        Pin::new(&mut *self.write.lock()).poll_flush(cx)
164    }
165
166    #[allow(unused_mut)]
167    fn poll_shutdown(
168        mut self: Pin<&mut Self>,
169        cx: &mut task::Context<'_>,
170    ) -> Poll<std::io::Result<()>> {
171        Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
172    }
173}
174
175impl Drop for DuplexStream {
176    fn drop(&mut self) {
177        // notify the other side of the closure
178        self.write.lock().close_write();
179        self.read.lock().close_read();
180    }
181}
182
183// ===== impl SimplexStream =====
184
185/// Creates unidirectional buffer that acts like in memory pipe.
186///
187/// The `max_buf_size` argument is the maximum amount of bytes that can be
188/// written to a buffer before the it returns `Poll::Pending`.
189///
190/// # Unify reader and writer
191///
192/// The reader and writer half can be unified into a single structure
193/// of `SimplexStream` that supports both reading and writing or
194/// the `SimplexStream` can be already created as unified structure
195/// using [`SimplexStream::new_unsplit()`].
196///
197/// ```
198/// # async fn ex() -> std::io::Result<()> {
199/// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
200/// let (writer, reader) = tokio::io::simplex(64);
201/// let mut simplex_stream = writer.unsplit(reader);
202/// simplex_stream.write_all(b"hello").await?;
203///
204/// let mut buf = [0u8; 5];
205/// simplex_stream.read_exact(&mut buf).await?;
206/// assert_eq!(&buf, b"hello");
207/// # Ok(())
208/// # }
209/// ```
210#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
211pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
212    split(SimplexStream::new_unsplit(max_buf_size))
213}
214
215impl SimplexStream {
216    /// Creates unidirectional buffer that acts like in memory pipe. To create split
217    /// version with separate reader and writer you can use [`simplex`] function.
218    ///
219    /// The `max_buf_size` argument is the maximum amount of bytes that can be
220    /// written to a buffer before the it returns `Poll::Pending`.
221    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
222    pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
223        SimplexStream {
224            buffer: BytesMut::new(),
225            is_closed: false,
226            max_buf_size,
227            read_waker: None,
228            write_waker: None,
229        }
230    }
231
232    fn close_write(&mut self) {
233        self.is_closed = true;
234        // needs to notify any readers that no more data will come
235        if let Some(waker) = self.read_waker.take() {
236            waker.wake();
237        }
238    }
239
240    fn close_read(&mut self) {
241        self.is_closed = true;
242        // needs to notify any writers that they have to abort
243        if let Some(waker) = self.write_waker.take() {
244            waker.wake();
245        }
246    }
247
248    fn poll_read_internal(
249        mut self: Pin<&mut Self>,
250        cx: &mut task::Context<'_>,
251        buf: &mut ReadBuf<'_>,
252    ) -> Poll<std::io::Result<()>> {
253        if self.buffer.has_remaining() {
254            let max = self.buffer.remaining().min(buf.remaining());
255            buf.put_slice(&self.buffer[..max]);
256            self.buffer.advance(max);
257            if max > 0 {
258                // The passed `buf` might have been empty, don't wake up if
259                // no bytes have been moved.
260                if let Some(waker) = self.write_waker.take() {
261                    waker.wake();
262                }
263            }
264            Poll::Ready(Ok(()))
265        } else if self.is_closed {
266            Poll::Ready(Ok(()))
267        } else {
268            self.read_waker = Some(cx.waker().clone());
269            Poll::Pending
270        }
271    }
272
273    fn poll_write_internal(
274        mut self: Pin<&mut Self>,
275        cx: &mut task::Context<'_>,
276        buf: &[u8],
277    ) -> Poll<std::io::Result<usize>> {
278        if self.is_closed {
279            return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
280        }
281        let avail = self.max_buf_size - self.buffer.len();
282        if avail == 0 {
283            self.write_waker = Some(cx.waker().clone());
284            return Poll::Pending;
285        }
286
287        let len = buf.len().min(avail);
288        self.buffer.extend_from_slice(&buf[..len]);
289        if let Some(waker) = self.read_waker.take() {
290            waker.wake();
291        }
292        Poll::Ready(Ok(len))
293    }
294
295    fn poll_write_vectored_internal(
296        mut self: Pin<&mut Self>,
297        cx: &mut task::Context<'_>,
298        bufs: &[std::io::IoSlice<'_>],
299    ) -> Poll<Result<usize, std::io::Error>> {
300        if self.is_closed {
301            return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
302        }
303        let avail = self.max_buf_size - self.buffer.len();
304        if avail == 0 {
305            self.write_waker = Some(cx.waker().clone());
306            return Poll::Pending;
307        }
308
309        let mut rem = avail;
310        for buf in bufs {
311            if rem == 0 {
312                break;
313            }
314
315            let len = buf.len().min(rem);
316            self.buffer.extend_from_slice(&buf[..len]);
317            rem -= len;
318        }
319
320        if let Some(waker) = self.read_waker.take() {
321            waker.wake();
322        }
323        Poll::Ready(Ok(avail - rem))
324    }
325}
326
327impl AsyncRead for SimplexStream {
328    cfg_coop! {
329        fn poll_read(
330            self: Pin<&mut Self>,
331            cx: &mut task::Context<'_>,
332            buf: &mut ReadBuf<'_>,
333        ) -> Poll<std::io::Result<()>> {
334            ready!(crate::trace::trace_leaf(cx));
335            let coop = ready!(crate::runtime::coop::poll_proceed(cx));
336
337            let ret = self.poll_read_internal(cx, buf);
338            if ret.is_ready() {
339                coop.made_progress();
340            }
341            ret
342        }
343    }
344
345    cfg_not_coop! {
346        fn poll_read(
347            self: Pin<&mut Self>,
348            cx: &mut task::Context<'_>,
349            buf: &mut ReadBuf<'_>,
350        ) -> Poll<std::io::Result<()>> {
351            ready!(crate::trace::trace_leaf(cx));
352            self.poll_read_internal(cx, buf)
353        }
354    }
355}
356
357impl AsyncWrite for SimplexStream {
358    cfg_coop! {
359        fn poll_write(
360            self: Pin<&mut Self>,
361            cx: &mut task::Context<'_>,
362            buf: &[u8],
363        ) -> Poll<std::io::Result<usize>> {
364            ready!(crate::trace::trace_leaf(cx));
365            let coop = ready!(crate::runtime::coop::poll_proceed(cx));
366
367            let ret = self.poll_write_internal(cx, buf);
368            if ret.is_ready() {
369                coop.made_progress();
370            }
371            ret
372        }
373    }
374
375    cfg_not_coop! {
376        fn poll_write(
377            self: Pin<&mut Self>,
378            cx: &mut task::Context<'_>,
379            buf: &[u8],
380        ) -> Poll<std::io::Result<usize>> {
381            ready!(crate::trace::trace_leaf(cx));
382            self.poll_write_internal(cx, buf)
383        }
384    }
385
386    cfg_coop! {
387        fn poll_write_vectored(
388            self: Pin<&mut Self>,
389            cx: &mut task::Context<'_>,
390            bufs: &[std::io::IoSlice<'_>],
391        ) -> Poll<Result<usize, std::io::Error>> {
392            ready!(crate::trace::trace_leaf(cx));
393            let coop = ready!(crate::runtime::coop::poll_proceed(cx));
394
395            let ret = self.poll_write_vectored_internal(cx, bufs);
396            if ret.is_ready() {
397                coop.made_progress();
398            }
399            ret
400        }
401    }
402
403    cfg_not_coop! {
404        fn poll_write_vectored(
405            self: Pin<&mut Self>,
406            cx: &mut task::Context<'_>,
407            bufs: &[std::io::IoSlice<'_>],
408        ) -> Poll<Result<usize, std::io::Error>> {
409            ready!(crate::trace::trace_leaf(cx));
410            self.poll_write_vectored_internal(cx, bufs)
411        }
412    }
413
414    fn is_write_vectored(&self) -> bool {
415        true
416    }
417
418    fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
419        Poll::Ready(Ok(()))
420    }
421
422    fn poll_shutdown(
423        mut self: Pin<&mut Self>,
424        _: &mut task::Context<'_>,
425    ) -> Poll<std::io::Result<()>> {
426        self.close_write();
427        Poll::Ready(Ok(()))
428    }
429}