tokio/io/
stdio_common.rs

1//! Contains utilities for stdout and stderr.
2use crate::io::AsyncWrite;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5/// # Windows
6/// [`AsyncWrite`] adapter that finds last char boundary in given buffer and does not write the rest,
7/// if buffer contents seems to be `utf8`. Otherwise it only trims buffer down to `DEFAULT_MAX_BUF_SIZE`.
8/// That's why, wrapped writer will always receive well-formed utf-8 bytes.
9/// # Other platforms
10/// Passes data to `inner` as is.
11#[derive(Debug)]
12pub(crate) struct SplitByUtf8BoundaryIfWindows<W> {
13    inner: W,
14}
15
16impl<W> SplitByUtf8BoundaryIfWindows<W> {
17    pub(crate) fn new(inner: W) -> Self {
18        Self { inner }
19    }
20}
21
22// this constant is defined by Unicode standard.
23const MAX_BYTES_PER_CHAR: usize = 4;
24
25// Subject for tweaking here
26const MAGIC_CONST: usize = 8;
27
28impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
29where
30    W: AsyncWrite + Unpin,
31{
32    fn poll_write(
33        mut self: Pin<&mut Self>,
34        cx: &mut Context<'_>,
35        mut buf: &[u8],
36    ) -> Poll<Result<usize, std::io::Error>> {
37        // just a closure to avoid repetitive code
38        let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
39
40        // 1. Only windows stdio can suffer from non-utf8.
41        // We also check for `test` so that we can write some tests
42        // for further code. Since `AsyncWrite` can always shrink
43        // buffer at its discretion, excessive (i.e. in tests) shrinking
44        // does not break correctness.
45        // 2. If buffer is small, it will not be shrunk.
46        // That's why, it's "textness" will not change, so we don't have
47        // to fixup it.
48        if cfg!(not(any(target_os = "windows", test)))
49            || buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE
50        {
51            return call_inner(buf);
52        }
53
54        buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE];
55
56        // Now there are two possibilities.
57        // If caller gave is binary buffer, we **should not** shrink it
58        // anymore, because excessive shrinking hits performance.
59        // If caller gave as binary buffer, we  **must** additionally
60        // shrink it to strip incomplete char at the end of buffer.
61        // that's why check we will perform now is allowed to have
62        // false-positive.
63
64        // Now let's look at the first MAX_BYTES_PER_CHAR * MAGIC_CONST bytes.
65        // if they are (possibly incomplete) utf8, then we can be quite sure
66        // that input buffer was utf8.
67
68        let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
69            Ok(_) => true,
70            Err(err) => {
71                let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
72                incomplete_bytes < MAX_BYTES_PER_CHAR
73            }
74        };
75
76        if have_to_fix_up {
77            // We must pop several bytes at the end which form incomplete
78            // character. To achieve it, we exploit UTF8 encoding:
79            // for any code point, all bytes except first start with 0b10 prefix.
80            // see https://en.wikipedia.org/wiki/UTF-8#Encoding for details
81            let trailing_incomplete_char_size = buf
82                .iter()
83                .rev()
84                .take(MAX_BYTES_PER_CHAR)
85                .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
86                .unwrap_or(0)
87                + 1;
88            buf = &buf[..buf.len() - trailing_incomplete_char_size];
89        }
90
91        call_inner(buf)
92    }
93
94    fn poll_flush(
95        mut self: Pin<&mut Self>,
96        cx: &mut Context<'_>,
97    ) -> Poll<Result<(), std::io::Error>> {
98        Pin::new(&mut self.inner).poll_flush(cx)
99    }
100
101    fn poll_shutdown(
102        mut self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104    ) -> Poll<Result<(), std::io::Error>> {
105        Pin::new(&mut self.inner).poll_shutdown(cx)
106    }
107}
108
109#[cfg(test)]
110#[cfg(not(loom))]
111mod tests {
112    use crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
113    use crate::io::AsyncWriteExt;
114    use std::io;
115    use std::pin::Pin;
116    use std::task::Context;
117    use std::task::Poll;
118
119    struct TextMockWriter;
120
121    impl crate::io::AsyncWrite for TextMockWriter {
122        fn poll_write(
123            self: Pin<&mut Self>,
124            _cx: &mut Context<'_>,
125            buf: &[u8],
126        ) -> Poll<Result<usize, io::Error>> {
127            assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
128            assert!(std::str::from_utf8(buf).is_ok());
129            Poll::Ready(Ok(buf.len()))
130        }
131
132        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
133            Poll::Ready(Ok(()))
134        }
135
136        fn poll_shutdown(
137            self: Pin<&mut Self>,
138            _cx: &mut Context<'_>,
139        ) -> Poll<Result<(), io::Error>> {
140            Poll::Ready(Ok(()))
141        }
142    }
143
144    struct LoggingMockWriter {
145        write_history: Vec<usize>,
146    }
147
148    impl LoggingMockWriter {
149        fn new() -> Self {
150            LoggingMockWriter {
151                write_history: Vec::new(),
152            }
153        }
154    }
155
156    impl crate::io::AsyncWrite for LoggingMockWriter {
157        fn poll_write(
158            mut self: Pin<&mut Self>,
159            _cx: &mut Context<'_>,
160            buf: &[u8],
161        ) -> Poll<Result<usize, io::Error>> {
162            assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
163            self.write_history.push(buf.len());
164            Poll::Ready(Ok(buf.len()))
165        }
166
167        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
168            Poll::Ready(Ok(()))
169        }
170
171        fn poll_shutdown(
172            self: Pin<&mut Self>,
173            _cx: &mut Context<'_>,
174        ) -> Poll<Result<(), io::Error>> {
175            Poll::Ready(Ok(()))
176        }
177    }
178
179    #[test]
180    #[cfg_attr(miri, ignore)]
181    fn test_splitter() {
182        let data = str::repeat("█", DEFAULT_MAX_BUF_SIZE);
183        let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
184        let fut = async move {
185            wr.write_all(data.as_bytes()).await.unwrap();
186        };
187        crate::runtime::Builder::new_current_thread()
188            .build()
189            .unwrap()
190            .block_on(fut);
191    }
192
193    #[test]
194    #[cfg_attr(miri, ignore)]
195    fn test_pseudo_text() {
196        // In this test we write a piece of binary data, whose beginning is
197        // text though. We then validate that even in this corner case buffer
198        // was not shrunk too much.
199        let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
200        let mut data: Vec<u8> = str::repeat("a", checked_count).into();
201        data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1));
202        let mut writer = LoggingMockWriter::new();
203        let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
204        crate::runtime::Builder::new_current_thread()
205            .build()
206            .unwrap()
207            .block_on(async {
208                splitter.write_all(&data).await.unwrap();
209            });
210        // Check that at most two writes were performed
211        assert!(writer.write_history.len() <= 2);
212        // Check that all has been written
213        assert_eq!(
214            writer.write_history.iter().copied().sum::<usize>(),
215            data.len()
216        );
217        // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk
218        // from the buffer: one because it was outside of DEFAULT_MAX_BUF_SIZE boundary, and
219        // up to one "utf8 code point".
220        assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
221    }
222}