tokio_rustls/common/
mod.rs

1use std::io::{self, BufRead as _, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub enum TlsState {
14    #[cfg(feature = "early-data")]
15    EarlyData(usize, Vec<u8>),
16    Stream,
17    ReadShutdown,
18    WriteShutdown,
19    FullyShutdown,
20}
21
22impl TlsState {
23    #[inline]
24    pub fn shutdown_read(&mut self) {
25        match *self {
26            TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
27            _ => *self = TlsState::ReadShutdown,
28        }
29    }
30
31    #[inline]
32    pub fn shutdown_write(&mut self) {
33        match *self {
34            TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
35            _ => *self = TlsState::WriteShutdown,
36        }
37    }
38
39    #[inline]
40    pub fn writeable(&self) -> bool {
41        !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
42    }
43
44    #[inline]
45    pub fn readable(&self) -> bool {
46        !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
47    }
48
49    #[inline]
50    #[cfg(feature = "early-data")]
51    pub fn is_early_data(&self) -> bool {
52        matches!(self, TlsState::EarlyData(..))
53    }
54
55    #[inline]
56    #[cfg(not(feature = "early-data"))]
57    pub const fn is_early_data(&self) -> bool {
58        false
59    }
60}
61
62pub struct Stream<'a, IO, C> {
63    pub io: &'a mut IO,
64    pub session: &'a mut C,
65    pub eof: bool,
66}
67
68impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
69where
70    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
71    SD: SideData,
72{
73    pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
74        Stream {
75            io,
76            session,
77            // The state so far is only used to detect EOF, so either Stream
78            // or EarlyData state should both be all right.
79            eof: false,
80        }
81    }
82
83    pub fn set_eof(mut self, eof: bool) -> Self {
84        self.eof = eof;
85        self
86    }
87
88    pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
89        Pin::new(self)
90    }
91
92    pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
93        let mut reader = SyncReadAdapter { io: self.io, cx };
94
95        let n = match self.session.read_tls(&mut reader) {
96            Ok(n) => n,
97            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
98            Err(err) => return Poll::Ready(Err(err)),
99        };
100
101        self.session.process_new_packets().map_err(|err| {
102            // In case we have an alert to send describing this error,
103            // try a last-gasp write -- but don't predate the primary
104            // error.
105            let _ = self.write_io(cx);
106
107            io::Error::new(io::ErrorKind::InvalidData, err)
108        })?;
109
110        Poll::Ready(Ok(n))
111    }
112
113    pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
114        let mut writer = SyncWriteAdapter { io: self.io, cx };
115
116        match self.session.write_tls(&mut writer) {
117            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
118            result => Poll::Ready(result),
119        }
120    }
121
122    pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
123        let mut wrlen = 0;
124        let mut rdlen = 0;
125
126        loop {
127            let mut write_would_block = false;
128            let mut read_would_block = false;
129            let mut need_flush = false;
130
131            while self.session.wants_write() {
132                match self.write_io(cx) {
133                    Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
134                    Poll::Ready(Ok(n)) => {
135                        wrlen += n;
136                        need_flush = true;
137                    }
138                    Poll::Pending => {
139                        write_would_block = true;
140                        break;
141                    }
142                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
143                }
144            }
145
146            if need_flush {
147                match Pin::new(&mut self.io).poll_flush(cx) {
148                    Poll::Ready(Ok(())) => (),
149                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
150                    Poll::Pending => write_would_block = true,
151                }
152            }
153
154            while !self.eof && self.session.wants_read() {
155                match self.read_io(cx) {
156                    Poll::Ready(Ok(0)) => self.eof = true,
157                    Poll::Ready(Ok(n)) => rdlen += n,
158                    Poll::Pending => {
159                        read_would_block = true;
160                        break;
161                    }
162                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
163                }
164            }
165
166            return match (self.eof, self.session.is_handshaking()) {
167                (true, true) => {
168                    let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
169                    Poll::Ready(Err(err))
170                }
171                (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
172                (_, true) if write_would_block || read_would_block => {
173                    if rdlen != 0 || wrlen != 0 {
174                        Poll::Ready(Ok((rdlen, wrlen)))
175                    } else {
176                        Poll::Pending
177                    }
178                }
179                (..) => continue,
180            };
181        }
182    }
183
184    pub(crate) fn poll_fill_buf(mut self, cx: &mut Context<'_>) -> Poll<io::Result<&'a [u8]>>
185    where
186        SD: 'a,
187    {
188        let mut io_pending = false;
189
190        // read a packet
191        while !self.eof && self.session.wants_read() {
192            match self.read_io(cx) {
193                Poll::Ready(Ok(0)) => {
194                    break;
195                }
196                Poll::Ready(Ok(_)) => (),
197                Poll::Pending => {
198                    io_pending = true;
199                    break;
200                }
201                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
202            }
203        }
204
205        match self.session.reader().into_first_chunk() {
206            Ok(buf) => {
207                // Note that this could be empty (i.e. EOF) if a `CloseNotify` has been
208                // received and there is no more buffered data.
209                Poll::Ready(Ok(buf))
210            }
211            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
212                if !io_pending {
213                    // If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
214                    // but if it does, we can try again.
215                    //
216                    // If the rustls state is abnormal, it may cause a cyclic wakeup.
217                    // but tokio's cooperative budget will prevent infinite wakeup.
218                    cx.waker().wake_by_ref();
219                }
220
221                Poll::Pending
222            }
223            Err(e) => Poll::Ready(Err(e)),
224        }
225    }
226}
227
228impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
229where
230    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
231    SD: SideData + 'a,
232{
233    fn poll_read(
234        mut self: Pin<&mut Self>,
235        cx: &mut Context<'_>,
236        buf: &mut ReadBuf<'_>,
237    ) -> Poll<io::Result<()>> {
238        let data = ready!(self.as_mut().poll_fill_buf(cx))?;
239        let amount = buf.remaining().min(data.len());
240        buf.put_slice(&data[..amount]);
241        self.session.reader().consume(amount);
242        Poll::Ready(Ok(()))
243    }
244}
245
246impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncBufRead for Stream<'a, IO, C>
247where
248    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
249    SD: SideData + 'a,
250{
251    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
252        let this = self.get_mut();
253        Stream {
254            // reborrow
255            io: this.io,
256            session: this.session,
257            ..*this
258        }
259        .poll_fill_buf(cx)
260    }
261
262    fn consume(mut self: Pin<&mut Self>, amt: usize) {
263        self.session.reader().consume(amt);
264    }
265}
266
267impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'_, IO, C>
268where
269    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
270    SD: SideData,
271{
272    fn poll_write(
273        mut self: Pin<&mut Self>,
274        cx: &mut Context,
275        buf: &[u8],
276    ) -> Poll<io::Result<usize>> {
277        let mut pos = 0;
278
279        while pos != buf.len() {
280            let mut would_block = false;
281
282            match self.session.writer().write(&buf[pos..]) {
283                Ok(n) => pos += n,
284                Err(err) => return Poll::Ready(Err(err)),
285            };
286
287            while self.session.wants_write() {
288                match self.write_io(cx) {
289                    Poll::Ready(Ok(0)) | Poll::Pending => {
290                        would_block = true;
291                        break;
292                    }
293                    Poll::Ready(Ok(_)) => (),
294                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
295                }
296            }
297
298            return match (pos, would_block) {
299                (0, true) => Poll::Pending,
300                (n, true) => Poll::Ready(Ok(n)),
301                (_, false) => continue,
302            };
303        }
304
305        Poll::Ready(Ok(pos))
306    }
307
308    fn poll_write_vectored(
309        mut self: Pin<&mut Self>,
310        cx: &mut Context<'_>,
311        bufs: &[IoSlice<'_>],
312    ) -> Poll<io::Result<usize>> {
313        if bufs.iter().all(|buf| buf.is_empty()) {
314            return Poll::Ready(Ok(0));
315        }
316
317        loop {
318            let mut would_block = false;
319            let written = self.session.writer().write_vectored(bufs)?;
320
321            while self.session.wants_write() {
322                match self.write_io(cx) {
323                    Poll::Ready(Ok(0)) | Poll::Pending => {
324                        would_block = true;
325                        break;
326                    }
327                    Poll::Ready(Ok(_)) => (),
328                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
329                }
330            }
331
332            return match (written, would_block) {
333                (0, true) => Poll::Pending,
334                (0, false) => continue,
335                (n, _) => Poll::Ready(Ok(n)),
336            };
337        }
338    }
339
340    #[inline]
341    fn is_write_vectored(&self) -> bool {
342        true
343    }
344
345    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
346        self.session.writer().flush()?;
347        while self.session.wants_write() {
348            if ready!(self.write_io(cx))? == 0 {
349                return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
350            }
351        }
352        Pin::new(&mut self.io).poll_flush(cx)
353    }
354
355    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
356        while self.session.wants_write() {
357            if ready!(self.write_io(cx))? == 0 {
358                return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
359            }
360        }
361
362        Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
363            Ok(()) => Ok(()),
364            // When trying to shutdown, not being connected seems fine
365            Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
366            Err(err) => Err(err),
367        })
368    }
369}
370
371/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an
372/// associated [`Context`].
373///
374/// Turns `Poll::Pending` into `WouldBlock`.
375pub struct SyncReadAdapter<'a, 'b, T> {
376    pub io: &'a mut T,
377    pub cx: &'a mut Context<'b>,
378}
379
380impl<T: AsyncRead + Unpin> Read for SyncReadAdapter<'_, '_, T> {
381    #[inline]
382    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
383        let mut buf = ReadBuf::new(buf);
384        match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
385            Poll::Ready(Ok(())) => Ok(buf.filled().len()),
386            Poll::Ready(Err(err)) => Err(err),
387            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
388        }
389    }
390}
391
392/// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an
393/// associated [`Context`].
394///
395/// Turns `Poll::Pending` into `WouldBlock`.
396pub struct SyncWriteAdapter<'a, 'b, T> {
397    pub io: &'a mut T,
398    pub cx: &'a mut Context<'b>,
399}
400
401impl<T: Unpin> SyncWriteAdapter<'_, '_, T> {
402    #[inline]
403    fn poll_with<U>(
404        &mut self,
405        f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
406    ) -> io::Result<U> {
407        match f(Pin::new(self.io), self.cx) {
408            Poll::Ready(result) => result,
409            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
410        }
411    }
412}
413
414impl<T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'_, '_, T> {
415    #[inline]
416    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
417        self.poll_with(|io, cx| io.poll_write(cx, buf))
418    }
419
420    #[inline]
421    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
422        self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
423    }
424
425    fn flush(&mut self) -> io::Result<()> {
426        self.poll_with(|io, cx| io.poll_flush(cx))
427    }
428}
429
430#[cfg(test)]
431mod test_stream;