1use std::io::{self, BufRead as _, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub enum TlsState {
14 #[cfg(feature = "early-data")]
15 EarlyData(usize, Vec<u8>),
16 Stream,
17 ReadShutdown,
18 WriteShutdown,
19 FullyShutdown,
20}
21
22impl TlsState {
23 #[inline]
24 pub fn shutdown_read(&mut self) {
25 match *self {
26 TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
27 _ => *self = TlsState::ReadShutdown,
28 }
29 }
30
31 #[inline]
32 pub fn shutdown_write(&mut self) {
33 match *self {
34 TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
35 _ => *self = TlsState::WriteShutdown,
36 }
37 }
38
39 #[inline]
40 pub fn writeable(&self) -> bool {
41 !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
42 }
43
44 #[inline]
45 pub fn readable(&self) -> bool {
46 !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
47 }
48
49 #[inline]
50 #[cfg(feature = "early-data")]
51 pub fn is_early_data(&self) -> bool {
52 matches!(self, TlsState::EarlyData(..))
53 }
54
55 #[inline]
56 #[cfg(not(feature = "early-data"))]
57 pub const fn is_early_data(&self) -> bool {
58 false
59 }
60}
61
62pub struct Stream<'a, IO, C> {
63 pub io: &'a mut IO,
64 pub session: &'a mut C,
65 pub eof: bool,
66}
67
68impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
69where
70 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
71 SD: SideData,
72{
73 pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
74 Stream {
75 io,
76 session,
77 eof: false,
80 }
81 }
82
83 pub fn set_eof(mut self, eof: bool) -> Self {
84 self.eof = eof;
85 self
86 }
87
88 pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
89 Pin::new(self)
90 }
91
92 pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
93 let mut reader = SyncReadAdapter { io: self.io, cx };
94
95 let n = match self.session.read_tls(&mut reader) {
96 Ok(n) => n,
97 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
98 Err(err) => return Poll::Ready(Err(err)),
99 };
100
101 self.session.process_new_packets().map_err(|err| {
102 let _ = self.write_io(cx);
106
107 io::Error::new(io::ErrorKind::InvalidData, err)
108 })?;
109
110 Poll::Ready(Ok(n))
111 }
112
113 pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
114 let mut writer = SyncWriteAdapter { io: self.io, cx };
115
116 match self.session.write_tls(&mut writer) {
117 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
118 result => Poll::Ready(result),
119 }
120 }
121
122 pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
123 let mut wrlen = 0;
124 let mut rdlen = 0;
125
126 loop {
127 let mut write_would_block = false;
128 let mut read_would_block = false;
129 let mut need_flush = false;
130
131 while self.session.wants_write() {
132 match self.write_io(cx) {
133 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
134 Poll::Ready(Ok(n)) => {
135 wrlen += n;
136 need_flush = true;
137 }
138 Poll::Pending => {
139 write_would_block = true;
140 break;
141 }
142 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
143 }
144 }
145
146 if need_flush {
147 match Pin::new(&mut self.io).poll_flush(cx) {
148 Poll::Ready(Ok(())) => (),
149 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
150 Poll::Pending => write_would_block = true,
151 }
152 }
153
154 while !self.eof && self.session.wants_read() {
155 match self.read_io(cx) {
156 Poll::Ready(Ok(0)) => self.eof = true,
157 Poll::Ready(Ok(n)) => rdlen += n,
158 Poll::Pending => {
159 read_would_block = true;
160 break;
161 }
162 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
163 }
164 }
165
166 return match (self.eof, self.session.is_handshaking()) {
167 (true, true) => {
168 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
169 Poll::Ready(Err(err))
170 }
171 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
172 (_, true) if write_would_block || read_would_block => {
173 if rdlen != 0 || wrlen != 0 {
174 Poll::Ready(Ok((rdlen, wrlen)))
175 } else {
176 Poll::Pending
177 }
178 }
179 (..) => continue,
180 };
181 }
182 }
183
184 pub(crate) fn poll_fill_buf(mut self, cx: &mut Context<'_>) -> Poll<io::Result<&'a [u8]>>
185 where
186 SD: 'a,
187 {
188 let mut io_pending = false;
189
190 while !self.eof && self.session.wants_read() {
192 match self.read_io(cx) {
193 Poll::Ready(Ok(0)) => {
194 break;
195 }
196 Poll::Ready(Ok(_)) => (),
197 Poll::Pending => {
198 io_pending = true;
199 break;
200 }
201 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
202 }
203 }
204
205 match self.session.reader().into_first_chunk() {
206 Ok(buf) => {
207 Poll::Ready(Ok(buf))
210 }
211 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
212 if !io_pending {
213 cx.waker().wake_by_ref();
219 }
220
221 Poll::Pending
222 }
223 Err(e) => Poll::Ready(Err(e)),
224 }
225 }
226}
227
228impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
229where
230 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
231 SD: SideData + 'a,
232{
233 fn poll_read(
234 mut self: Pin<&mut Self>,
235 cx: &mut Context<'_>,
236 buf: &mut ReadBuf<'_>,
237 ) -> Poll<io::Result<()>> {
238 let data = ready!(self.as_mut().poll_fill_buf(cx))?;
239 let amount = buf.remaining().min(data.len());
240 buf.put_slice(&data[..amount]);
241 self.session.reader().consume(amount);
242 Poll::Ready(Ok(()))
243 }
244}
245
246impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncBufRead for Stream<'a, IO, C>
247where
248 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
249 SD: SideData + 'a,
250{
251 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
252 let this = self.get_mut();
253 Stream {
254 io: this.io,
256 session: this.session,
257 ..*this
258 }
259 .poll_fill_buf(cx)
260 }
261
262 fn consume(mut self: Pin<&mut Self>, amt: usize) {
263 self.session.reader().consume(amt);
264 }
265}
266
267impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'_, IO, C>
268where
269 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
270 SD: SideData,
271{
272 fn poll_write(
273 mut self: Pin<&mut Self>,
274 cx: &mut Context,
275 buf: &[u8],
276 ) -> Poll<io::Result<usize>> {
277 let mut pos = 0;
278
279 while pos != buf.len() {
280 let mut would_block = false;
281
282 match self.session.writer().write(&buf[pos..]) {
283 Ok(n) => pos += n,
284 Err(err) => return Poll::Ready(Err(err)),
285 };
286
287 while self.session.wants_write() {
288 match self.write_io(cx) {
289 Poll::Ready(Ok(0)) | Poll::Pending => {
290 would_block = true;
291 break;
292 }
293 Poll::Ready(Ok(_)) => (),
294 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
295 }
296 }
297
298 return match (pos, would_block) {
299 (0, true) => Poll::Pending,
300 (n, true) => Poll::Ready(Ok(n)),
301 (_, false) => continue,
302 };
303 }
304
305 Poll::Ready(Ok(pos))
306 }
307
308 fn poll_write_vectored(
309 mut self: Pin<&mut Self>,
310 cx: &mut Context<'_>,
311 bufs: &[IoSlice<'_>],
312 ) -> Poll<io::Result<usize>> {
313 if bufs.iter().all(|buf| buf.is_empty()) {
314 return Poll::Ready(Ok(0));
315 }
316
317 loop {
318 let mut would_block = false;
319 let written = self.session.writer().write_vectored(bufs)?;
320
321 while self.session.wants_write() {
322 match self.write_io(cx) {
323 Poll::Ready(Ok(0)) | Poll::Pending => {
324 would_block = true;
325 break;
326 }
327 Poll::Ready(Ok(_)) => (),
328 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
329 }
330 }
331
332 return match (written, would_block) {
333 (0, true) => Poll::Pending,
334 (0, false) => continue,
335 (n, _) => Poll::Ready(Ok(n)),
336 };
337 }
338 }
339
340 #[inline]
341 fn is_write_vectored(&self) -> bool {
342 true
343 }
344
345 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
346 self.session.writer().flush()?;
347 while self.session.wants_write() {
348 if ready!(self.write_io(cx))? == 0 {
349 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
350 }
351 }
352 Pin::new(&mut self.io).poll_flush(cx)
353 }
354
355 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
356 while self.session.wants_write() {
357 if ready!(self.write_io(cx))? == 0 {
358 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
359 }
360 }
361
362 Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
363 Ok(()) => Ok(()),
364 Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
366 Err(err) => Err(err),
367 })
368 }
369}
370
371pub struct SyncReadAdapter<'a, 'b, T> {
376 pub io: &'a mut T,
377 pub cx: &'a mut Context<'b>,
378}
379
380impl<T: AsyncRead + Unpin> Read for SyncReadAdapter<'_, '_, T> {
381 #[inline]
382 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
383 let mut buf = ReadBuf::new(buf);
384 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
385 Poll::Ready(Ok(())) => Ok(buf.filled().len()),
386 Poll::Ready(Err(err)) => Err(err),
387 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
388 }
389 }
390}
391
392pub struct SyncWriteAdapter<'a, 'b, T> {
397 pub io: &'a mut T,
398 pub cx: &'a mut Context<'b>,
399}
400
401impl<T: Unpin> SyncWriteAdapter<'_, '_, T> {
402 #[inline]
403 fn poll_with<U>(
404 &mut self,
405 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
406 ) -> io::Result<U> {
407 match f(Pin::new(self.io), self.cx) {
408 Poll::Ready(result) => result,
409 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
410 }
411 }
412}
413
414impl<T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'_, '_, T> {
415 #[inline]
416 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
417 self.poll_with(|io, cx| io.poll_write(cx, buf))
418 }
419
420 #[inline]
421 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
422 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
423 }
424
425 fn flush(&mut self) -> io::Result<()> {
426 self.poll_with(|io, cx| io.poll_flush(cx))
427 }
428}
429
430#[cfg(test)]
431mod test_stream;