1use crate::io::AsyncWrite;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5#[derive(Debug)]
12pub(crate) struct SplitByUtf8BoundaryIfWindows<W> {
13 inner: W,
14}
15
16impl<W> SplitByUtf8BoundaryIfWindows<W> {
17 pub(crate) fn new(inner: W) -> Self {
18 Self { inner }
19 }
20}
21
22const MAX_BYTES_PER_CHAR: usize = 4;
24
25const MAGIC_CONST: usize = 8;
27
28impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
29where
30 W: AsyncWrite + Unpin,
31{
32 fn poll_write(
33 mut self: Pin<&mut Self>,
34 cx: &mut Context<'_>,
35 mut buf: &[u8],
36 ) -> Poll<Result<usize, std::io::Error>> {
37 let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
39
40 if cfg!(not(any(target_os = "windows", test)))
49 || buf.len() <= crate::io::blocking::DEFAULT_MAX_BUF_SIZE
50 {
51 return call_inner(buf);
52 }
53
54 buf = &buf[..crate::io::blocking::DEFAULT_MAX_BUF_SIZE];
55
56 let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
69 Ok(_) => true,
70 Err(err) => {
71 let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
72 incomplete_bytes < MAX_BYTES_PER_CHAR
73 }
74 };
75
76 if have_to_fix_up {
77 let trailing_incomplete_char_size = buf
82 .iter()
83 .rev()
84 .take(MAX_BYTES_PER_CHAR)
85 .position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
86 .unwrap_or(0)
87 + 1;
88 buf = &buf[..buf.len() - trailing_incomplete_char_size];
89 }
90
91 call_inner(buf)
92 }
93
94 fn poll_flush(
95 mut self: Pin<&mut Self>,
96 cx: &mut Context<'_>,
97 ) -> Poll<Result<(), std::io::Error>> {
98 Pin::new(&mut self.inner).poll_flush(cx)
99 }
100
101 fn poll_shutdown(
102 mut self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 ) -> Poll<Result<(), std::io::Error>> {
105 Pin::new(&mut self.inner).poll_shutdown(cx)
106 }
107}
108
109#[cfg(test)]
110#[cfg(not(loom))]
111mod tests {
112 use crate::io::blocking::DEFAULT_MAX_BUF_SIZE;
113 use crate::io::AsyncWriteExt;
114 use std::io;
115 use std::pin::Pin;
116 use std::task::Context;
117 use std::task::Poll;
118
119 struct TextMockWriter;
120
121 impl crate::io::AsyncWrite for TextMockWriter {
122 fn poll_write(
123 self: Pin<&mut Self>,
124 _cx: &mut Context<'_>,
125 buf: &[u8],
126 ) -> Poll<Result<usize, io::Error>> {
127 assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
128 assert!(std::str::from_utf8(buf).is_ok());
129 Poll::Ready(Ok(buf.len()))
130 }
131
132 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
133 Poll::Ready(Ok(()))
134 }
135
136 fn poll_shutdown(
137 self: Pin<&mut Self>,
138 _cx: &mut Context<'_>,
139 ) -> Poll<Result<(), io::Error>> {
140 Poll::Ready(Ok(()))
141 }
142 }
143
144 struct LoggingMockWriter {
145 write_history: Vec<usize>,
146 }
147
148 impl LoggingMockWriter {
149 fn new() -> Self {
150 LoggingMockWriter {
151 write_history: Vec::new(),
152 }
153 }
154 }
155
156 impl crate::io::AsyncWrite for LoggingMockWriter {
157 fn poll_write(
158 mut self: Pin<&mut Self>,
159 _cx: &mut Context<'_>,
160 buf: &[u8],
161 ) -> Poll<Result<usize, io::Error>> {
162 assert!(buf.len() <= DEFAULT_MAX_BUF_SIZE);
163 self.write_history.push(buf.len());
164 Poll::Ready(Ok(buf.len()))
165 }
166
167 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
168 Poll::Ready(Ok(()))
169 }
170
171 fn poll_shutdown(
172 self: Pin<&mut Self>,
173 _cx: &mut Context<'_>,
174 ) -> Poll<Result<(), io::Error>> {
175 Poll::Ready(Ok(()))
176 }
177 }
178
179 #[test]
180 #[cfg_attr(miri, ignore)]
181 fn test_splitter() {
182 let data = str::repeat("█", DEFAULT_MAX_BUF_SIZE);
183 let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
184 let fut = async move {
185 wr.write_all(data.as_bytes()).await.unwrap();
186 };
187 crate::runtime::Builder::new_current_thread()
188 .build()
189 .unwrap()
190 .block_on(fut);
191 }
192
193 #[test]
194 #[cfg_attr(miri, ignore)]
195 fn test_pseudo_text() {
196 let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
200 let mut data: Vec<u8> = str::repeat("a", checked_count).into();
201 data.extend(std::iter::repeat(0b1010_1010).take(DEFAULT_MAX_BUF_SIZE - checked_count + 1));
202 let mut writer = LoggingMockWriter::new();
203 let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
204 crate::runtime::Builder::new_current_thread()
205 .build()
206 .unwrap()
207 .block_on(async {
208 splitter.write_all(&data).await.unwrap();
209 });
210 assert!(writer.write_history.len() <= 2);
212 assert_eq!(
214 writer.write_history.iter().copied().sum::<usize>(),
215 data.len()
216 );
217 assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
221 }
222}