tokio/io/util/copy.rs
1use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
2
3use std::future::Future;
4use std::io;
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8#[derive(Debug)]
9pub(super) struct CopyBuffer {
10 read_done: bool,
11 need_flush: bool,
12 pos: usize,
13 cap: usize,
14 amt: u64,
15 buf: Box<[u8]>,
16}
17
18impl CopyBuffer {
19 pub(super) fn new(buf_size: usize) -> Self {
20 Self {
21 read_done: false,
22 need_flush: false,
23 pos: 0,
24 cap: 0,
25 amt: 0,
26 buf: vec![0; buf_size].into_boxed_slice(),
27 }
28 }
29
30 fn poll_fill_buf<R>(
31 &mut self,
32 cx: &mut Context<'_>,
33 reader: Pin<&mut R>,
34 ) -> Poll<io::Result<()>>
35 where
36 R: AsyncRead + ?Sized,
37 {
38 let me = &mut *self;
39 let mut buf = ReadBuf::new(&mut me.buf);
40 buf.set_filled(me.cap);
41
42 let res = reader.poll_read(cx, &mut buf);
43 if let Poll::Ready(Ok(())) = res {
44 let filled_len = buf.filled().len();
45 me.read_done = me.cap == filled_len;
46 me.cap = filled_len;
47 }
48 res
49 }
50
51 fn poll_write_buf<R, W>(
52 &mut self,
53 cx: &mut Context<'_>,
54 mut reader: Pin<&mut R>,
55 mut writer: Pin<&mut W>,
56 ) -> Poll<io::Result<usize>>
57 where
58 R: AsyncRead + ?Sized,
59 W: AsyncWrite + ?Sized,
60 {
61 let me = &mut *self;
62 match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
63 Poll::Pending => {
64 // Top up the buffer towards full if we can read a bit more
65 // data - this should improve the chances of a large write
66 if !me.read_done && me.cap < me.buf.len() {
67 ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
68 }
69 Poll::Pending
70 }
71 res => res,
72 }
73 }
74
75 pub(super) fn poll_copy<R, W>(
76 &mut self,
77 cx: &mut Context<'_>,
78 mut reader: Pin<&mut R>,
79 mut writer: Pin<&mut W>,
80 ) -> Poll<io::Result<u64>>
81 where
82 R: AsyncRead + ?Sized,
83 W: AsyncWrite + ?Sized,
84 {
85 ready!(crate::trace::trace_leaf(cx));
86 #[cfg(any(
87 feature = "fs",
88 feature = "io-std",
89 feature = "net",
90 feature = "process",
91 feature = "rt",
92 feature = "signal",
93 feature = "sync",
94 feature = "time",
95 ))]
96 // Keep track of task budget
97 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
98 loop {
99 // If there is some space left in our buffer, then we try to read some
100 // data to continue, thus maximizing the chances of a large write.
101 if self.cap < self.buf.len() && !self.read_done {
102 match self.poll_fill_buf(cx, reader.as_mut()) {
103 Poll::Ready(Ok(())) => {
104 #[cfg(any(
105 feature = "fs",
106 feature = "io-std",
107 feature = "net",
108 feature = "process",
109 feature = "rt",
110 feature = "signal",
111 feature = "sync",
112 feature = "time",
113 ))]
114 coop.made_progress();
115 }
116 Poll::Ready(Err(err)) => {
117 #[cfg(any(
118 feature = "fs",
119 feature = "io-std",
120 feature = "net",
121 feature = "process",
122 feature = "rt",
123 feature = "signal",
124 feature = "sync",
125 feature = "time",
126 ))]
127 coop.made_progress();
128 return Poll::Ready(Err(err));
129 }
130 Poll::Pending => {
131 // Ignore pending reads when our buffer is not empty, because
132 // we can try to write data immediately.
133 if self.pos == self.cap {
134 // Try flushing when the reader has no progress to avoid deadlock
135 // when the reader depends on buffered writer.
136 if self.need_flush {
137 ready!(writer.as_mut().poll_flush(cx))?;
138 #[cfg(any(
139 feature = "fs",
140 feature = "io-std",
141 feature = "net",
142 feature = "process",
143 feature = "rt",
144 feature = "signal",
145 feature = "sync",
146 feature = "time",
147 ))]
148 coop.made_progress();
149 self.need_flush = false;
150 }
151
152 return Poll::Pending;
153 }
154 }
155 }
156 }
157
158 // If our buffer has some data, let's write it out!
159 while self.pos < self.cap {
160 let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
161 #[cfg(any(
162 feature = "fs",
163 feature = "io-std",
164 feature = "net",
165 feature = "process",
166 feature = "rt",
167 feature = "signal",
168 feature = "sync",
169 feature = "time",
170 ))]
171 coop.made_progress();
172 if i == 0 {
173 return Poll::Ready(Err(io::Error::new(
174 io::ErrorKind::WriteZero,
175 "write zero byte into writer",
176 )));
177 } else {
178 self.pos += i;
179 self.amt += i as u64;
180 self.need_flush = true;
181 }
182 }
183
184 // If pos larger than cap, this loop will never stop.
185 // In particular, user's wrong poll_write implementation returning
186 // incorrect written length may lead to thread blocking.
187 debug_assert!(
188 self.pos <= self.cap,
189 "writer returned length larger than input slice"
190 );
191
192 // All data has been written, the buffer can be considered empty again
193 self.pos = 0;
194 self.cap = 0;
195
196 // If we've written all the data and we've seen EOF, flush out the
197 // data and finish the transfer.
198 if self.read_done {
199 ready!(writer.as_mut().poll_flush(cx))?;
200 #[cfg(any(
201 feature = "fs",
202 feature = "io-std",
203 feature = "net",
204 feature = "process",
205 feature = "rt",
206 feature = "signal",
207 feature = "sync",
208 feature = "time",
209 ))]
210 coop.made_progress();
211 return Poll::Ready(Ok(self.amt));
212 }
213 }
214 }
215}
216
217/// A future that asynchronously copies the entire contents of a reader into a
218/// writer.
219#[derive(Debug)]
220#[must_use = "futures do nothing unless you `.await` or poll them"]
221struct Copy<'a, R: ?Sized, W: ?Sized> {
222 reader: &'a mut R,
223 writer: &'a mut W,
224 buf: CopyBuffer,
225}
226
227cfg_io_util! {
228 /// Asynchronously copies the entire contents of a reader into a writer.
229 ///
230 /// This function returns a future that will continuously read data from
231 /// `reader` and then write it into `writer` in a streaming fashion until
232 /// `reader` returns EOF or fails.
233 ///
234 /// On success, the total number of bytes that were copied from `reader` to
235 /// `writer` is returned.
236 ///
237 /// This is an asynchronous version of [`std::io::copy`][std].
238 ///
239 /// A heap-allocated copy buffer with 8 KB is created to take data from the
240 /// reader to the writer, check [`copy_buf`] if you want an alternative for
241 /// [`AsyncBufRead`]. You can use `copy_buf` with [`BufReader`] to change the
242 /// buffer capacity.
243 ///
244 /// [std]: std::io::copy
245 /// [`copy_buf`]: crate::io::copy_buf
246 /// [`AsyncBufRead`]: crate::io::AsyncBufRead
247 /// [`BufReader`]: crate::io::BufReader
248 ///
249 /// # Errors
250 ///
251 /// The returned future will return an error immediately if any call to
252 /// `poll_read` or `poll_write` returns an error.
253 ///
254 /// # Examples
255 ///
256 /// ```
257 /// use tokio::io;
258 ///
259 /// # async fn dox() -> std::io::Result<()> {
260 /// let mut reader: &[u8] = b"hello";
261 /// let mut writer: Vec<u8> = vec![];
262 ///
263 /// io::copy(&mut reader, &mut writer).await?;
264 ///
265 /// assert_eq!(&b"hello"[..], &writer[..]);
266 /// # Ok(())
267 /// # }
268 /// ```
269 pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
270 where
271 R: AsyncRead + Unpin + ?Sized,
272 W: AsyncWrite + Unpin + ?Sized,
273 {
274 Copy {
275 reader,
276 writer,
277 buf: CopyBuffer::new(super::DEFAULT_BUF_SIZE)
278 }.await
279 }
280}
281
282impl<R, W> Future for Copy<'_, R, W>
283where
284 R: AsyncRead + Unpin + ?Sized,
285 W: AsyncWrite + Unpin + ?Sized,
286{
287 type Output = io::Result<u64>;
288
289 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
290 let me = &mut *self;
291
292 me.buf
293 .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
294 }
295}