tokio/io/util/
copy.rs

1use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
2
3use std::future::Future;
4use std::io;
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8#[derive(Debug)]
9pub(super) struct CopyBuffer {
10    read_done: bool,
11    need_flush: bool,
12    pos: usize,
13    cap: usize,
14    amt: u64,
15    buf: Box<[u8]>,
16}
17
18impl CopyBuffer {
19    pub(super) fn new(buf_size: usize) -> Self {
20        Self {
21            read_done: false,
22            need_flush: false,
23            pos: 0,
24            cap: 0,
25            amt: 0,
26            buf: vec![0; buf_size].into_boxed_slice(),
27        }
28    }
29
30    fn poll_fill_buf<R>(
31        &mut self,
32        cx: &mut Context<'_>,
33        reader: Pin<&mut R>,
34    ) -> Poll<io::Result<()>>
35    where
36        R: AsyncRead + ?Sized,
37    {
38        let me = &mut *self;
39        let mut buf = ReadBuf::new(&mut me.buf);
40        buf.set_filled(me.cap);
41
42        let res = reader.poll_read(cx, &mut buf);
43        if let Poll::Ready(Ok(())) = res {
44            let filled_len = buf.filled().len();
45            me.read_done = me.cap == filled_len;
46            me.cap = filled_len;
47        }
48        res
49    }
50
51    fn poll_write_buf<R, W>(
52        &mut self,
53        cx: &mut Context<'_>,
54        mut reader: Pin<&mut R>,
55        mut writer: Pin<&mut W>,
56    ) -> Poll<io::Result<usize>>
57    where
58        R: AsyncRead + ?Sized,
59        W: AsyncWrite + ?Sized,
60    {
61        let me = &mut *self;
62        match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
63            Poll::Pending => {
64                // Top up the buffer towards full if we can read a bit more
65                // data - this should improve the chances of a large write
66                if !me.read_done && me.cap < me.buf.len() {
67                    ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
68                }
69                Poll::Pending
70            }
71            res => res,
72        }
73    }
74
75    pub(super) fn poll_copy<R, W>(
76        &mut self,
77        cx: &mut Context<'_>,
78        mut reader: Pin<&mut R>,
79        mut writer: Pin<&mut W>,
80    ) -> Poll<io::Result<u64>>
81    where
82        R: AsyncRead + ?Sized,
83        W: AsyncWrite + ?Sized,
84    {
85        ready!(crate::trace::trace_leaf(cx));
86        #[cfg(any(
87            feature = "fs",
88            feature = "io-std",
89            feature = "net",
90            feature = "process",
91            feature = "rt",
92            feature = "signal",
93            feature = "sync",
94            feature = "time",
95        ))]
96        // Keep track of task budget
97        let coop = ready!(crate::runtime::coop::poll_proceed(cx));
98        loop {
99            // If there is some space left in our buffer, then we try to read some
100            // data to continue, thus maximizing the chances of a large write.
101            if self.cap < self.buf.len() && !self.read_done {
102                match self.poll_fill_buf(cx, reader.as_mut()) {
103                    Poll::Ready(Ok(())) => {
104                        #[cfg(any(
105                            feature = "fs",
106                            feature = "io-std",
107                            feature = "net",
108                            feature = "process",
109                            feature = "rt",
110                            feature = "signal",
111                            feature = "sync",
112                            feature = "time",
113                        ))]
114                        coop.made_progress();
115                    }
116                    Poll::Ready(Err(err)) => {
117                        #[cfg(any(
118                            feature = "fs",
119                            feature = "io-std",
120                            feature = "net",
121                            feature = "process",
122                            feature = "rt",
123                            feature = "signal",
124                            feature = "sync",
125                            feature = "time",
126                        ))]
127                        coop.made_progress();
128                        return Poll::Ready(Err(err));
129                    }
130                    Poll::Pending => {
131                        // Ignore pending reads when our buffer is not empty, because
132                        // we can try to write data immediately.
133                        if self.pos == self.cap {
134                            // Try flushing when the reader has no progress to avoid deadlock
135                            // when the reader depends on buffered writer.
136                            if self.need_flush {
137                                ready!(writer.as_mut().poll_flush(cx))?;
138                                #[cfg(any(
139                                    feature = "fs",
140                                    feature = "io-std",
141                                    feature = "net",
142                                    feature = "process",
143                                    feature = "rt",
144                                    feature = "signal",
145                                    feature = "sync",
146                                    feature = "time",
147                                ))]
148                                coop.made_progress();
149                                self.need_flush = false;
150                            }
151
152                            return Poll::Pending;
153                        }
154                    }
155                }
156            }
157
158            // If our buffer has some data, let's write it out!
159            while self.pos < self.cap {
160                let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
161                #[cfg(any(
162                    feature = "fs",
163                    feature = "io-std",
164                    feature = "net",
165                    feature = "process",
166                    feature = "rt",
167                    feature = "signal",
168                    feature = "sync",
169                    feature = "time",
170                ))]
171                coop.made_progress();
172                if i == 0 {
173                    return Poll::Ready(Err(io::Error::new(
174                        io::ErrorKind::WriteZero,
175                        "write zero byte into writer",
176                    )));
177                } else {
178                    self.pos += i;
179                    self.amt += i as u64;
180                    self.need_flush = true;
181                }
182            }
183
184            // If pos larger than cap, this loop will never stop.
185            // In particular, user's wrong poll_write implementation returning
186            // incorrect written length may lead to thread blocking.
187            debug_assert!(
188                self.pos <= self.cap,
189                "writer returned length larger than input slice"
190            );
191
192            // All data has been written, the buffer can be considered empty again
193            self.pos = 0;
194            self.cap = 0;
195
196            // If we've written all the data and we've seen EOF, flush out the
197            // data and finish the transfer.
198            if self.read_done {
199                ready!(writer.as_mut().poll_flush(cx))?;
200                #[cfg(any(
201                    feature = "fs",
202                    feature = "io-std",
203                    feature = "net",
204                    feature = "process",
205                    feature = "rt",
206                    feature = "signal",
207                    feature = "sync",
208                    feature = "time",
209                ))]
210                coop.made_progress();
211                return Poll::Ready(Ok(self.amt));
212            }
213        }
214    }
215}
216
217/// A future that asynchronously copies the entire contents of a reader into a
218/// writer.
219#[derive(Debug)]
220#[must_use = "futures do nothing unless you `.await` or poll them"]
221struct Copy<'a, R: ?Sized, W: ?Sized> {
222    reader: &'a mut R,
223    writer: &'a mut W,
224    buf: CopyBuffer,
225}
226
227cfg_io_util! {
228    /// Asynchronously copies the entire contents of a reader into a writer.
229    ///
230    /// This function returns a future that will continuously read data from
231    /// `reader` and then write it into `writer` in a streaming fashion until
232    /// `reader` returns EOF or fails.
233    ///
234    /// On success, the total number of bytes that were copied from `reader` to
235    /// `writer` is returned.
236    ///
237    /// This is an asynchronous version of [`std::io::copy`][std].
238    ///
239    /// A heap-allocated copy buffer with 8 KB is created to take data from the
240    /// reader to the writer, check [`copy_buf`] if you want an alternative for
241    /// [`AsyncBufRead`]. You can use `copy_buf` with [`BufReader`] to change the
242    /// buffer capacity.
243    ///
244    /// [std]: std::io::copy
245    /// [`copy_buf`]: crate::io::copy_buf
246    /// [`AsyncBufRead`]: crate::io::AsyncBufRead
247    /// [`BufReader`]: crate::io::BufReader
248    ///
249    /// # Errors
250    ///
251    /// The returned future will return an error immediately if any call to
252    /// `poll_read` or `poll_write` returns an error.
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// use tokio::io;
258    ///
259    /// # async fn dox() -> std::io::Result<()> {
260    /// let mut reader: &[u8] = b"hello";
261    /// let mut writer: Vec<u8> = vec![];
262    ///
263    /// io::copy(&mut reader, &mut writer).await?;
264    ///
265    /// assert_eq!(&b"hello"[..], &writer[..]);
266    /// # Ok(())
267    /// # }
268    /// ```
269    pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
270    where
271        R: AsyncRead + Unpin + ?Sized,
272        W: AsyncWrite + Unpin + ?Sized,
273    {
274        Copy {
275            reader,
276            writer,
277            buf: CopyBuffer::new(super::DEFAULT_BUF_SIZE)
278        }.await
279    }
280}
281
282impl<R, W> Future for Copy<'_, R, W>
283where
284    R: AsyncRead + Unpin + ?Sized,
285    W: AsyncWrite + Unpin + ?Sized,
286{
287    type Output = io::Result<u64>;
288
289    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
290        let me = &mut *self;
291
292        me.buf
293            .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
294    }
295}