tokio/net/tcp/
split_owned.rs

1//! `TcpStream` owned split support.
2//!
3//! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
4//! with the `TcpStream::into_split` method.  `OwnedReadHalf` implements
5//! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
6//!
7//! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
8//! split has no associated overhead and enforces all invariants at the type
9//! level.
10
11use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready};
12use crate::net::TcpStream;
13
14use std::error::Error;
15use std::future::poll_fn;
16use std::net::{Shutdown, SocketAddr};
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::{fmt, io};
21
22cfg_io_util! {
23    use bytes::BufMut;
24}
25
26/// Owned read half of a [`TcpStream`], created by [`into_split`].
27///
28/// Reading from an `OwnedReadHalf` is usually done using the convenience methods found
29/// on the [`AsyncReadExt`] trait.
30///
31/// [`TcpStream`]: TcpStream
32/// [`into_split`]: TcpStream::into_split()
33/// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
34#[derive(Debug)]
35pub struct OwnedReadHalf {
36    inner: Arc<TcpStream>,
37}
38
39/// Owned write half of a [`TcpStream`], created by [`into_split`].
40///
41/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
42/// shut down the TCP stream in the write direction.  Dropping the write half
43/// will also shut down the write half of the TCP stream.
44///
45/// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
46/// on the [`AsyncWriteExt`] trait.
47///
48/// [`TcpStream`]: TcpStream
49/// [`into_split`]: TcpStream::into_split()
50/// [`AsyncWrite`]: trait@crate::io::AsyncWrite
51/// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
52/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
53#[derive(Debug)]
54pub struct OwnedWriteHalf {
55    inner: Arc<TcpStream>,
56    shutdown_on_drop: bool,
57}
58
59pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
60    let arc = Arc::new(stream);
61    let read = OwnedReadHalf {
62        inner: Arc::clone(&arc),
63    };
64    let write = OwnedWriteHalf {
65        inner: arc,
66        shutdown_on_drop: true,
67    };
68    (read, write)
69}
70
71pub(crate) fn reunite(
72    read: OwnedReadHalf,
73    write: OwnedWriteHalf,
74) -> Result<TcpStream, ReuniteError> {
75    if Arc::ptr_eq(&read.inner, &write.inner) {
76        write.forget();
77        // This unwrap cannot fail as the api does not allow creating more than two Arcs,
78        // and we just dropped the other half.
79        Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite"))
80    } else {
81        Err(ReuniteError(read, write))
82    }
83}
84
85/// Error indicating that two halves were not from the same socket, and thus could
86/// not be reunited.
87#[derive(Debug)]
88pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
89
90impl fmt::Display for ReuniteError {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        write!(
93            f,
94            "tried to reunite halves that are not from the same socket"
95        )
96    }
97}
98
99impl Error for ReuniteError {}
100
101impl OwnedReadHalf {
102    /// Attempts to put the two halves of a `TcpStream` back together and
103    /// recover the original socket. Succeeds only if the two halves
104    /// originated from the same call to [`into_split`].
105    ///
106    /// [`into_split`]: TcpStream::into_split()
107    pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
108        reunite(self, other)
109    }
110
111    /// Attempt to receive data on the socket, without removing that data from
112    /// the queue, registering the current task for wakeup if data is not yet
113    /// available.
114    ///
115    /// Note that on multiple calls to `poll_peek` or `poll_read`, only the
116    /// `Waker` from the `Context` passed to the most recent call is scheduled
117    /// to receive a wakeup.
118    ///
119    /// See the [`TcpStream::poll_peek`] level documentation for more details.
120    ///
121    /// # Examples
122    ///
123    /// ```no_run
124    /// use tokio::io::{self, ReadBuf};
125    /// use tokio::net::TcpStream;
126    ///
127    /// use std::future::poll_fn;
128    ///
129    /// #[tokio::main]
130    /// async fn main() -> io::Result<()> {
131    ///     let stream = TcpStream::connect("127.0.0.1:8000").await?;
132    ///     let (mut read_half, _) = stream.into_split();
133    ///     let mut buf = [0; 10];
134    ///     let mut buf = ReadBuf::new(&mut buf);
135    ///
136    ///     poll_fn(|cx| {
137    ///         read_half.poll_peek(cx, &mut buf)
138    ///     }).await?;
139    ///
140    ///     Ok(())
141    /// }
142    /// ```
143    ///
144    /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
145    pub fn poll_peek(
146        &mut self,
147        cx: &mut Context<'_>,
148        buf: &mut ReadBuf<'_>,
149    ) -> Poll<io::Result<usize>> {
150        self.inner.poll_peek(cx, buf)
151    }
152
153    /// Receives data on the socket from the remote address to which it is
154    /// connected, without removing that data from the queue. On success,
155    /// returns the number of bytes peeked.
156    ///
157    /// See the [`TcpStream::peek`] level documentation for more details.
158    ///
159    /// [`TcpStream::peek`]: TcpStream::peek
160    ///
161    /// # Examples
162    ///
163    /// ```no_run
164    /// use tokio::net::TcpStream;
165    /// use tokio::io::AsyncReadExt;
166    /// use std::error::Error;
167    ///
168    /// #[tokio::main]
169    /// async fn main() -> Result<(), Box<dyn Error>> {
170    ///     // Connect to a peer
171    ///     let stream = TcpStream::connect("127.0.0.1:8080").await?;
172    ///     let (mut read_half, _) = stream.into_split();
173    ///
174    ///     let mut b1 = [0; 10];
175    ///     let mut b2 = [0; 10];
176    ///
177    ///     // Peek at the data
178    ///     let n = read_half.peek(&mut b1).await?;
179    ///
180    ///     // Read the data
181    ///     assert_eq!(n, read_half.read(&mut b2[..n]).await?);
182    ///     assert_eq!(&b1[..n], &b2[..n]);
183    ///
184    ///     Ok(())
185    /// }
186    /// ```
187    ///
188    /// The [`read`] method is defined on the [`AsyncReadExt`] trait.
189    ///
190    /// [`read`]: fn@crate::io::AsyncReadExt::read
191    /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
192    pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
193        let mut buf = ReadBuf::new(buf);
194        poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
195    }
196
197    /// Waits for any of the requested ready states.
198    ///
199    /// This function is usually paired with [`try_read()`]. It can be used instead
200    /// of [`readable()`] to check the returned ready set for [`Ready::READABLE`]
201    /// and [`Ready::READ_CLOSED`] events.
202    ///
203    /// The function may complete without the socket being ready. This is a
204    /// false-positive and attempting an operation will return with
205    /// `io::ErrorKind::WouldBlock`. The function can also return with an empty
206    /// [`Ready`] set, so you should always check the returned value and possibly
207    /// wait again if the requested states are not set.
208    ///
209    /// This function is equivalent to [`TcpStream::ready`].
210    ///
211    /// [`try_read()`]: Self::try_read
212    /// [`readable()`]: Self::readable
213    ///
214    /// # Cancel safety
215    ///
216    /// This method is cancel safe. Once a readiness event occurs, the method
217    /// will continue to return immediately until the readiness event is
218    /// consumed by an attempt to read or write that fails with `WouldBlock` or
219    /// `Poll::Pending`.
220    pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
221        self.inner.ready(interest).await
222    }
223
224    /// Waits for the socket to become readable.
225    ///
226    /// This function is equivalent to `ready(Interest::READABLE)` and is usually
227    /// paired with `try_read()`.
228    ///
229    /// This function is also equivalent to [`TcpStream::ready`].
230    ///
231    /// # Cancel safety
232    ///
233    /// This method is cancel safe. Once a readiness event occurs, the method
234    /// will continue to return immediately until the readiness event is
235    /// consumed by an attempt to read that fails with `WouldBlock` or
236    /// `Poll::Pending`.
237    pub async fn readable(&self) -> io::Result<()> {
238        self.inner.readable().await
239    }
240
241    /// Tries to read data from the stream into the provided buffer, returning how
242    /// many bytes were read.
243    ///
244    /// Receives any pending data from the socket but does not wait for new data
245    /// to arrive. On success, returns the number of bytes read. Because
246    /// `try_read()` is non-blocking, the buffer does not have to be stored by
247    /// the async task and can exist entirely on the stack.
248    ///
249    /// Usually, [`readable()`] or [`ready()`] is used with this function.
250    ///
251    /// [`readable()`]: Self::readable()
252    /// [`ready()`]: Self::ready()
253    ///
254    /// # Return
255    ///
256    /// If data is successfully read, `Ok(n)` is returned, where `n` is the
257    /// number of bytes read. If `n` is `0`, then it can indicate one of two scenarios:
258    ///
259    /// 1. The stream's read half is closed and will no longer yield data.
260    /// 2. The specified buffer was 0 bytes in length.
261    ///
262    /// If the stream is not ready to read data,
263    /// `Err(io::ErrorKind::WouldBlock)` is returned.
264    pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
265        self.inner.try_read(buf)
266    }
267
268    /// Tries to read data from the stream into the provided buffers, returning
269    /// how many bytes were read.
270    ///
271    /// Data is copied to fill each buffer in order, with the final buffer
272    /// written to possibly being only partially filled. This method behaves
273    /// equivalently to a single call to [`try_read()`] with concatenated
274    /// buffers.
275    ///
276    /// Receives any pending data from the socket but does not wait for new data
277    /// to arrive. On success, returns the number of bytes read. Because
278    /// `try_read_vectored()` is non-blocking, the buffer does not have to be
279    /// stored by the async task and can exist entirely on the stack.
280    ///
281    /// Usually, [`readable()`] or [`ready()`] is used with this function.
282    ///
283    /// [`try_read()`]: Self::try_read()
284    /// [`readable()`]: Self::readable()
285    /// [`ready()`]: Self::ready()
286    ///
287    /// # Return
288    ///
289    /// If data is successfully read, `Ok(n)` is returned, where `n` is the
290    /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
291    /// and will no longer yield data. If the stream is not ready to read data
292    /// `Err(io::ErrorKind::WouldBlock)` is returned.
293    pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
294        self.inner.try_read_vectored(bufs)
295    }
296
297    cfg_io_util! {
298        /// Tries to read data from the stream into the provided buffer, advancing the
299        /// buffer's internal cursor, returning how many bytes were read.
300        ///
301        /// Receives any pending data from the socket but does not wait for new data
302        /// to arrive. On success, returns the number of bytes read. Because
303        /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by
304        /// the async task and can exist entirely on the stack.
305        ///
306        /// Usually, [`readable()`] or [`ready()`] is used with this function.
307        ///
308        /// [`readable()`]: Self::readable()
309        /// [`ready()`]: Self::ready()
310        ///
311        /// # Return
312        ///
313        /// If data is successfully read, `Ok(n)` is returned, where `n` is the
314        /// number of bytes read. `Ok(0)` indicates the stream's read half is closed
315        /// and will no longer yield data. If the stream is not ready to read data
316        /// `Err(io::ErrorKind::WouldBlock)` is returned.
317        pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
318            self.inner.try_read_buf(buf)
319        }
320    }
321
322    /// Returns the remote address that this stream is connected to.
323    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
324        self.inner.peer_addr()
325    }
326
327    /// Returns the local address that this stream is bound to.
328    pub fn local_addr(&self) -> io::Result<SocketAddr> {
329        self.inner.local_addr()
330    }
331}
332
333impl AsyncRead for OwnedReadHalf {
334    fn poll_read(
335        self: Pin<&mut Self>,
336        cx: &mut Context<'_>,
337        buf: &mut ReadBuf<'_>,
338    ) -> Poll<io::Result<()>> {
339        self.inner.poll_read_priv(cx, buf)
340    }
341}
342
343impl OwnedWriteHalf {
344    /// Attempts to put the two halves of a `TcpStream` back together and
345    /// recover the original socket. Succeeds only if the two halves
346    /// originated from the same call to [`into_split`].
347    ///
348    /// [`into_split`]: TcpStream::into_split()
349    pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
350        reunite(other, self)
351    }
352
353    /// Destroys the write half, but don't close the write half of the stream
354    /// until the read half is dropped. If the read half has already been
355    /// dropped, this closes the stream.
356    pub fn forget(mut self) {
357        self.shutdown_on_drop = false;
358        drop(self);
359    }
360
361    /// Waits for any of the requested ready states.
362    ///
363    /// This function is usually paired with [`try_write()`]. It can be used instead
364    /// of [`writable()`] to check the returned ready set for [`Ready::WRITABLE`]
365    /// and [`Ready::WRITE_CLOSED`] events.
366    ///
367    /// The function may complete without the socket being ready. This is a
368    /// false-positive and attempting an operation will return with
369    /// `io::ErrorKind::WouldBlock`. The function can also return with an empty
370    /// [`Ready`] set, so you should always check the returned value and possibly
371    /// wait again if the requested states are not set.
372    ///
373    /// This function is equivalent to [`TcpStream::ready`].
374    ///
375    /// [`try_write()`]: Self::try_write
376    /// [`writable()`]: Self::writable
377    ///
378    /// # Cancel safety
379    ///
380    /// This method is cancel safe. Once a readiness event occurs, the method
381    /// will continue to return immediately until the readiness event is
382    /// consumed by an attempt to read or write that fails with `WouldBlock` or
383    /// `Poll::Pending`.
384    pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
385        self.inner.ready(interest).await
386    }
387
388    /// Waits for the socket to become writable.
389    ///
390    /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
391    /// paired with `try_write()`.
392    ///
393    /// # Cancel safety
394    ///
395    /// This method is cancel safe. Once a readiness event occurs, the method
396    /// will continue to return immediately until the readiness event is
397    /// consumed by an attempt to write that fails with `WouldBlock` or
398    /// `Poll::Pending`.
399    pub async fn writable(&self) -> io::Result<()> {
400        self.inner.writable().await
401    }
402
403    /// Tries to write a buffer to the stream, returning how many bytes were
404    /// written.
405    ///
406    /// The function will attempt to write the entire contents of `buf`, but
407    /// only part of the buffer may be written.
408    ///
409    /// This function is usually paired with `writable()`.
410    ///
411    /// # Return
412    ///
413    /// If data is successfully written, `Ok(n)` is returned, where `n` is the
414    /// number of bytes written. If the stream is not ready to write data,
415    /// `Err(io::ErrorKind::WouldBlock)` is returned.
416    pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
417        self.inner.try_write(buf)
418    }
419
420    /// Tries to write several buffers to the stream, returning how many bytes
421    /// were written.
422    ///
423    /// Data is written from each buffer in order, with the final buffer read
424    /// from possible being only partially consumed. This method behaves
425    /// equivalently to a single call to [`try_write()`] with concatenated
426    /// buffers.
427    ///
428    /// This function is usually paired with `writable()`.
429    ///
430    /// [`try_write()`]: Self::try_write()
431    ///
432    /// # Return
433    ///
434    /// If data is successfully written, `Ok(n)` is returned, where `n` is the
435    /// number of bytes written. If the stream is not ready to write data,
436    /// `Err(io::ErrorKind::WouldBlock)` is returned.
437    pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
438        self.inner.try_write_vectored(bufs)
439    }
440
441    /// Returns the remote address that this stream is connected to.
442    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
443        self.inner.peer_addr()
444    }
445
446    /// Returns the local address that this stream is bound to.
447    pub fn local_addr(&self) -> io::Result<SocketAddr> {
448        self.inner.local_addr()
449    }
450}
451
452impl Drop for OwnedWriteHalf {
453    fn drop(&mut self) {
454        if self.shutdown_on_drop {
455            let _ = self.inner.shutdown_std(Shutdown::Write);
456        }
457    }
458}
459
460impl AsyncWrite for OwnedWriteHalf {
461    fn poll_write(
462        self: Pin<&mut Self>,
463        cx: &mut Context<'_>,
464        buf: &[u8],
465    ) -> Poll<io::Result<usize>> {
466        self.inner.poll_write_priv(cx, buf)
467    }
468
469    fn poll_write_vectored(
470        self: Pin<&mut Self>,
471        cx: &mut Context<'_>,
472        bufs: &[io::IoSlice<'_>],
473    ) -> Poll<io::Result<usize>> {
474        self.inner.poll_write_vectored_priv(cx, bufs)
475    }
476
477    fn is_write_vectored(&self) -> bool {
478        self.inner.is_write_vectored()
479    }
480
481    #[inline]
482    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
483        // tcp flush is a no-op
484        Poll::Ready(Ok(()))
485    }
486
487    // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
488    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
489        let res = self.inner.shutdown_std(Shutdown::Write);
490        if res.is_ok() {
491            Pin::into_inner(self).shutdown_on_drop = false;
492        }
493        res.into()
494    }
495}
496
497impl AsRef<TcpStream> for OwnedReadHalf {
498    fn as_ref(&self) -> &TcpStream {
499        &self.inner
500    }
501}
502
503impl AsRef<TcpStream> for OwnedWriteHalf {
504    fn as_ref(&self) -> &TcpStream {
505        &self.inner
506    }
507}