tokio_rustls/common/
mod.rs

1use std::io::{self, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub enum TlsState {
14    #[cfg(feature = "early-data")]
15    EarlyData(usize, Vec<u8>),
16    Stream,
17    ReadShutdown,
18    WriteShutdown,
19    FullyShutdown,
20}
21
22impl TlsState {
23    #[inline]
24    pub fn shutdown_read(&mut self) {
25        match *self {
26            TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
27            _ => *self = TlsState::ReadShutdown,
28        }
29    }
30
31    #[inline]
32    pub fn shutdown_write(&mut self) {
33        match *self {
34            TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
35            _ => *self = TlsState::WriteShutdown,
36        }
37    }
38
39    #[inline]
40    pub fn writeable(&self) -> bool {
41        !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
42    }
43
44    #[inline]
45    pub fn readable(&self) -> bool {
46        !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
47    }
48
49    #[inline]
50    #[cfg(feature = "early-data")]
51    pub fn is_early_data(&self) -> bool {
52        matches!(self, TlsState::EarlyData(..))
53    }
54
55    #[inline]
56    #[cfg(not(feature = "early-data"))]
57    pub const fn is_early_data(&self) -> bool {
58        false
59    }
60}
61
62pub struct Stream<'a, IO, C> {
63    pub io: &'a mut IO,
64    pub session: &'a mut C,
65    pub eof: bool,
66}
67
68impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
69where
70    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
71    SD: SideData,
72{
73    pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
74        Stream {
75            io,
76            session,
77            // The state so far is only used to detect EOF, so either Stream
78            // or EarlyData state should both be all right.
79            eof: false,
80        }
81    }
82
83    pub fn set_eof(mut self, eof: bool) -> Self {
84        self.eof = eof;
85        self
86    }
87
88    pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
89        Pin::new(self)
90    }
91
92    pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
93        let mut reader = SyncReadAdapter { io: self.io, cx };
94
95        let n = match self.session.read_tls(&mut reader) {
96            Ok(n) => n,
97            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
98            Err(err) => return Poll::Ready(Err(err)),
99        };
100
101        self.session.process_new_packets().map_err(|err| {
102            // In case we have an alert to send describing this error,
103            // try a last-gasp write -- but don't predate the primary
104            // error.
105            let _ = self.write_io(cx);
106
107            io::Error::new(io::ErrorKind::InvalidData, err)
108        })?;
109
110        Poll::Ready(Ok(n))
111    }
112
113    pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
114        let mut writer = SyncWriteAdapter { io: self.io, cx };
115
116        match self.session.write_tls(&mut writer) {
117            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
118            result => Poll::Ready(result),
119        }
120    }
121
122    pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
123        let mut wrlen = 0;
124        let mut rdlen = 0;
125
126        loop {
127            let mut write_would_block = false;
128            let mut read_would_block = false;
129            let mut need_flush = false;
130
131            while self.session.wants_write() {
132                match self.write_io(cx) {
133                    Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
134                    Poll::Ready(Ok(n)) => {
135                        wrlen += n;
136                        need_flush = true;
137                    }
138                    Poll::Pending => {
139                        write_would_block = true;
140                        break;
141                    }
142                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
143                }
144            }
145
146            if need_flush {
147                match Pin::new(&mut self.io).poll_flush(cx) {
148                    Poll::Ready(Ok(())) => (),
149                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
150                    Poll::Pending => write_would_block = true,
151                }
152            }
153
154            while !self.eof && self.session.wants_read() {
155                match self.read_io(cx) {
156                    Poll::Ready(Ok(0)) => self.eof = true,
157                    Poll::Ready(Ok(n)) => rdlen += n,
158                    Poll::Pending => {
159                        read_would_block = true;
160                        break;
161                    }
162                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
163                }
164            }
165
166            return match (self.eof, self.session.is_handshaking()) {
167                (true, true) => {
168                    let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
169                    Poll::Ready(Err(err))
170                }
171                (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
172                (_, true) if write_would_block || read_would_block => {
173                    if rdlen != 0 || wrlen != 0 {
174                        Poll::Ready(Ok((rdlen, wrlen)))
175                    } else {
176                        Poll::Pending
177                    }
178                }
179                (..) => continue,
180            };
181        }
182    }
183}
184
185impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'_, IO, C>
186where
187    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
188    SD: SideData,
189{
190    fn poll_read(
191        mut self: Pin<&mut Self>,
192        cx: &mut Context<'_>,
193        buf: &mut ReadBuf<'_>,
194    ) -> Poll<io::Result<()>> {
195        let mut io_pending = false;
196
197        // read a packet
198        while !self.eof && self.session.wants_read() {
199            match self.read_io(cx) {
200                Poll::Ready(Ok(0)) => {
201                    break;
202                }
203                Poll::Ready(Ok(_)) => (),
204                Poll::Pending => {
205                    io_pending = true;
206                    break;
207                }
208                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
209            }
210        }
211
212        match self.session.reader().read(buf.initialize_unfilled()) {
213            // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
214            // connection with a `CloseNotify` message and no more data will be forthcoming.
215            //
216            // Rustls yielded more data: advance the buffer, then see if more data is coming.
217            //
218            // We don't need to modify `self.eof` here, because it is only a temporary mark.
219            // rustls will only return 0 if is has received `CloseNotify`,
220            // in which case no additional processing is required.
221            Ok(n) => {
222                buf.advance(n);
223                Poll::Ready(Ok(()))
224            }
225
226            // Rustls doesn't have more data to yield, but it believes the connection is open.
227            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
228                if !io_pending {
229                    // If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
230                    // but if it does, we can try again.
231                    //
232                    // If the rustls state is abnormal, it may cause a cyclic wakeup.
233                    // but tokio's cooperative budget will prevent infinite wakeup.
234                    cx.waker().wake_by_ref();
235                }
236
237                Poll::Pending
238            }
239
240            Err(err) => Poll::Ready(Err(err)),
241        }
242    }
243}
244
245impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'_, IO, C>
246where
247    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
248    SD: SideData,
249{
250    fn poll_write(
251        mut self: Pin<&mut Self>,
252        cx: &mut Context,
253        buf: &[u8],
254    ) -> Poll<io::Result<usize>> {
255        let mut pos = 0;
256
257        while pos != buf.len() {
258            let mut would_block = false;
259
260            match self.session.writer().write(&buf[pos..]) {
261                Ok(n) => pos += n,
262                Err(err) => return Poll::Ready(Err(err)),
263            };
264
265            while self.session.wants_write() {
266                match self.write_io(cx) {
267                    Poll::Ready(Ok(0)) | Poll::Pending => {
268                        would_block = true;
269                        break;
270                    }
271                    Poll::Ready(Ok(_)) => (),
272                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
273                }
274            }
275
276            return match (pos, would_block) {
277                (0, true) => Poll::Pending,
278                (n, true) => Poll::Ready(Ok(n)),
279                (_, false) => continue,
280            };
281        }
282
283        Poll::Ready(Ok(pos))
284    }
285
286    fn poll_write_vectored(
287        mut self: Pin<&mut Self>,
288        cx: &mut Context<'_>,
289        bufs: &[IoSlice<'_>],
290    ) -> Poll<io::Result<usize>> {
291        if bufs.iter().all(|buf| buf.is_empty()) {
292            return Poll::Ready(Ok(0));
293        }
294
295        loop {
296            let mut would_block = false;
297            let written = self.session.writer().write_vectored(bufs)?;
298
299            while self.session.wants_write() {
300                match self.write_io(cx) {
301                    Poll::Ready(Ok(0)) | Poll::Pending => {
302                        would_block = true;
303                        break;
304                    }
305                    Poll::Ready(Ok(_)) => (),
306                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
307                }
308            }
309
310            return match (written, would_block) {
311                (0, true) => Poll::Pending,
312                (0, false) => continue,
313                (n, _) => Poll::Ready(Ok(n)),
314            };
315        }
316    }
317
318    #[inline]
319    fn is_write_vectored(&self) -> bool {
320        true
321    }
322
323    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
324        self.session.writer().flush()?;
325        while self.session.wants_write() {
326            if ready!(self.write_io(cx))? == 0 {
327                return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
328            }
329        }
330        Pin::new(&mut self.io).poll_flush(cx)
331    }
332
333    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
334        while self.session.wants_write() {
335            if ready!(self.write_io(cx))? == 0 {
336                return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
337            }
338        }
339
340        Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
341            Ok(()) => Ok(()),
342            // When trying to shutdown, not being connected seems fine
343            Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
344            Err(err) => Err(err),
345        })
346    }
347}
348
349/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an
350/// associated [`Context`].
351///
352/// Turns `Poll::Pending` into `WouldBlock`.
353pub struct SyncReadAdapter<'a, 'b, T> {
354    pub io: &'a mut T,
355    pub cx: &'a mut Context<'b>,
356}
357
358impl<T: AsyncRead + Unpin> Read for SyncReadAdapter<'_, '_, T> {
359    #[inline]
360    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
361        let mut buf = ReadBuf::new(buf);
362        match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
363            Poll::Ready(Ok(())) => Ok(buf.filled().len()),
364            Poll::Ready(Err(err)) => Err(err),
365            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
366        }
367    }
368}
369
370/// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an
371/// associated [`Context`].
372///
373/// Turns `Poll::Pending` into `WouldBlock`.
374pub struct SyncWriteAdapter<'a, 'b, T> {
375    pub io: &'a mut T,
376    pub cx: &'a mut Context<'b>,
377}
378
379impl<T: Unpin> SyncWriteAdapter<'_, '_, T> {
380    #[inline]
381    fn poll_with<U>(
382        &mut self,
383        f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
384    ) -> io::Result<U> {
385        match f(Pin::new(self.io), self.cx) {
386            Poll::Ready(result) => result,
387            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
388        }
389    }
390}
391
392impl<T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'_, '_, T> {
393    #[inline]
394    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
395        self.poll_with(|io, cx| io.poll_write(cx, buf))
396    }
397
398    #[inline]
399    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
400        self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
401    }
402
403    fn flush(&mut self) -> io::Result<()> {
404        self.poll_with(|io, cx| io.poll_flush(cx))
405    }
406}
407
408#[cfg(test)]
409mod test_stream;