tokio/io/
blocking.rs

1use crate::io::sys;
2use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::cmp;
5use std::future::Future;
6use std::io;
7use std::io::prelude::*;
8use std::pin::Pin;
9use std::task::{ready, Context, Poll};
10
11/// `T` should not implement _both_ Read and Write.
12#[derive(Debug)]
13pub(crate) struct Blocking<T> {
14    inner: Option<T>,
15    state: State<T>,
16    /// `true` if the lower IO layer needs flushing.
17    need_flush: bool,
18}
19
20#[derive(Debug)]
21pub(crate) struct Buf {
22    buf: Vec<u8>,
23    pos: usize,
24}
25
26pub(crate) const DEFAULT_MAX_BUF_SIZE: usize = 2 * 1024 * 1024;
27
28#[derive(Debug)]
29enum State<T> {
30    Idle(Option<Buf>),
31    Busy(sys::Blocking<(io::Result<usize>, Buf, T)>),
32}
33
34cfg_io_blocking! {
35    impl<T> Blocking<T> {
36        #[cfg_attr(feature = "fs", allow(dead_code))]
37        pub(crate) fn new(inner: T) -> Blocking<T> {
38            Blocking {
39                inner: Some(inner),
40                state: State::Idle(Some(Buf::with_capacity(0))),
41                need_flush: false,
42            }
43        }
44    }
45}
46
47impl<T> AsyncRead for Blocking<T>
48where
49    T: Read + Unpin + Send + 'static,
50{
51    fn poll_read(
52        mut self: Pin<&mut Self>,
53        cx: &mut Context<'_>,
54        dst: &mut ReadBuf<'_>,
55    ) -> Poll<io::Result<()>> {
56        loop {
57            match self.state {
58                State::Idle(ref mut buf_cell) => {
59                    let mut buf = buf_cell.take().unwrap();
60
61                    if !buf.is_empty() {
62                        buf.copy_to(dst);
63                        *buf_cell = Some(buf);
64                        return Poll::Ready(Ok(()));
65                    }
66
67                    buf.ensure_capacity_for(dst, DEFAULT_MAX_BUF_SIZE);
68                    let mut inner = self.inner.take().unwrap();
69
70                    self.state = State::Busy(sys::run(move || {
71                        let res = buf.read_from(&mut inner);
72                        (res, buf, inner)
73                    }));
74                }
75                State::Busy(ref mut rx) => {
76                    let (res, mut buf, inner) = ready!(Pin::new(rx).poll(cx))?;
77                    self.inner = Some(inner);
78
79                    match res {
80                        Ok(_) => {
81                            buf.copy_to(dst);
82                            self.state = State::Idle(Some(buf));
83                            return Poll::Ready(Ok(()));
84                        }
85                        Err(e) => {
86                            assert!(buf.is_empty());
87
88                            self.state = State::Idle(Some(buf));
89                            return Poll::Ready(Err(e));
90                        }
91                    }
92                }
93            }
94        }
95    }
96}
97
98impl<T> AsyncWrite for Blocking<T>
99where
100    T: Write + Unpin + Send + 'static,
101{
102    fn poll_write(
103        mut self: Pin<&mut Self>,
104        cx: &mut Context<'_>,
105        src: &[u8],
106    ) -> Poll<io::Result<usize>> {
107        loop {
108            match self.state {
109                State::Idle(ref mut buf_cell) => {
110                    let mut buf = buf_cell.take().unwrap();
111
112                    assert!(buf.is_empty());
113
114                    let n = buf.copy_from(src, DEFAULT_MAX_BUF_SIZE);
115                    let mut inner = self.inner.take().unwrap();
116
117                    self.state = State::Busy(sys::run(move || {
118                        let n = buf.len();
119                        let res = buf.write_to(&mut inner).map(|()| n);
120
121                        (res, buf, inner)
122                    }));
123                    self.need_flush = true;
124
125                    return Poll::Ready(Ok(n));
126                }
127                State::Busy(ref mut rx) => {
128                    let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?;
129                    self.state = State::Idle(Some(buf));
130                    self.inner = Some(inner);
131
132                    // If error, return
133                    res?;
134                }
135            }
136        }
137    }
138
139    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
140        loop {
141            let need_flush = self.need_flush;
142            match self.state {
143                // The buffer is not used here
144                State::Idle(ref mut buf_cell) => {
145                    if need_flush {
146                        let buf = buf_cell.take().unwrap();
147                        let mut inner = self.inner.take().unwrap();
148
149                        self.state = State::Busy(sys::run(move || {
150                            let res = inner.flush().map(|()| 0);
151                            (res, buf, inner)
152                        }));
153
154                        self.need_flush = false;
155                    } else {
156                        return Poll::Ready(Ok(()));
157                    }
158                }
159                State::Busy(ref mut rx) => {
160                    let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?;
161                    self.state = State::Idle(Some(buf));
162                    self.inner = Some(inner);
163
164                    // If error, return
165                    res?;
166                }
167            }
168        }
169    }
170
171    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
172        Poll::Ready(Ok(()))
173    }
174}
175
176/// Repeats operations that are interrupted.
177macro_rules! uninterruptibly {
178    ($e:expr) => {{
179        loop {
180            match $e {
181                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
182                res => break res,
183            }
184        }
185    }};
186}
187
188impl Buf {
189    pub(crate) fn with_capacity(n: usize) -> Buf {
190        Buf {
191            buf: Vec::with_capacity(n),
192            pos: 0,
193        }
194    }
195
196    pub(crate) fn is_empty(&self) -> bool {
197        self.len() == 0
198    }
199
200    pub(crate) fn len(&self) -> usize {
201        self.buf.len() - self.pos
202    }
203
204    pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
205        let n = cmp::min(self.len(), dst.remaining());
206        dst.put_slice(&self.bytes()[..n]);
207        self.pos += n;
208
209        if self.pos == self.buf.len() {
210            self.buf.truncate(0);
211            self.pos = 0;
212        }
213
214        n
215    }
216
217    pub(crate) fn copy_from(&mut self, src: &[u8], max_buf_size: usize) -> usize {
218        assert!(self.is_empty());
219
220        let n = cmp::min(src.len(), max_buf_size);
221
222        self.buf.extend_from_slice(&src[..n]);
223        n
224    }
225
226    pub(crate) fn bytes(&self) -> &[u8] {
227        &self.buf[self.pos..]
228    }
229
230    pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>, max_buf_size: usize) {
231        assert!(self.is_empty());
232
233        let len = cmp::min(bytes.remaining(), max_buf_size);
234
235        if self.buf.len() < len {
236            self.buf.reserve(len - self.buf.len());
237        }
238
239        unsafe {
240            self.buf.set_len(len);
241        }
242    }
243
244    pub(crate) fn read_from<T: Read>(&mut self, rd: &mut T) -> io::Result<usize> {
245        let res = uninterruptibly!(rd.read(&mut self.buf));
246
247        if let Ok(n) = res {
248            self.buf.truncate(n);
249        } else {
250            self.buf.clear();
251        }
252
253        assert_eq!(self.pos, 0);
254
255        res
256    }
257
258    pub(crate) fn write_to<T: Write>(&mut self, wr: &mut T) -> io::Result<()> {
259        assert_eq!(self.pos, 0);
260
261        // `write_all` already ignores interrupts
262        let res = wr.write_all(&self.buf);
263        self.buf.clear();
264        res
265    }
266}
267
268cfg_fs! {
269    impl Buf {
270        pub(crate) fn discard_read(&mut self) -> i64 {
271            let ret = -(self.bytes().len() as i64);
272            self.pos = 0;
273            self.buf.truncate(0);
274            ret
275        }
276
277        pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>], max_buf_size: usize) -> usize {
278            assert!(self.is_empty());
279
280            let mut rem = max_buf_size;
281            for buf in bufs {
282                if rem == 0 {
283                    break
284                }
285
286                let len = buf.len().min(rem);
287                self.buf.extend_from_slice(&buf[..len]);
288                rem -= len;
289            }
290
291            max_buf_size - rem
292        }
293    }
294}