1use std::io::{self, IoSlice, Read, Write};
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use rustls::{ConnectionCommon, SideData};
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9mod handshake;
10pub(crate) use handshake::{IoSession, MidHandshake};
11
12#[derive(Debug)]
13pub enum TlsState {
14 #[cfg(feature = "early-data")]
15 EarlyData(usize, Vec<u8>),
16 Stream,
17 ReadShutdown,
18 WriteShutdown,
19 FullyShutdown,
20}
21
22impl TlsState {
23 #[inline]
24 pub fn shutdown_read(&mut self) {
25 match *self {
26 TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
27 _ => *self = TlsState::ReadShutdown,
28 }
29 }
30
31 #[inline]
32 pub fn shutdown_write(&mut self) {
33 match *self {
34 TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
35 _ => *self = TlsState::WriteShutdown,
36 }
37 }
38
39 #[inline]
40 pub fn writeable(&self) -> bool {
41 !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
42 }
43
44 #[inline]
45 pub fn readable(&self) -> bool {
46 !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
47 }
48
49 #[inline]
50 #[cfg(feature = "early-data")]
51 pub fn is_early_data(&self) -> bool {
52 matches!(self, TlsState::EarlyData(..))
53 }
54
55 #[inline]
56 #[cfg(not(feature = "early-data"))]
57 pub const fn is_early_data(&self) -> bool {
58 false
59 }
60}
61
62pub struct Stream<'a, IO, C> {
63 pub io: &'a mut IO,
64 pub session: &'a mut C,
65 pub eof: bool,
66}
67
68impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
69where
70 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
71 SD: SideData,
72{
73 pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
74 Stream {
75 io,
76 session,
77 eof: false,
80 }
81 }
82
83 pub fn set_eof(mut self, eof: bool) -> Self {
84 self.eof = eof;
85 self
86 }
87
88 pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
89 Pin::new(self)
90 }
91
92 pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
93 let mut reader = SyncReadAdapter { io: self.io, cx };
94
95 let n = match self.session.read_tls(&mut reader) {
96 Ok(n) => n,
97 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
98 Err(err) => return Poll::Ready(Err(err)),
99 };
100
101 self.session.process_new_packets().map_err(|err| {
102 let _ = self.write_io(cx);
106
107 io::Error::new(io::ErrorKind::InvalidData, err)
108 })?;
109
110 Poll::Ready(Ok(n))
111 }
112
113 pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
114 let mut writer = SyncWriteAdapter { io: self.io, cx };
115
116 match self.session.write_tls(&mut writer) {
117 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
118 result => Poll::Ready(result),
119 }
120 }
121
122 pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
123 let mut wrlen = 0;
124 let mut rdlen = 0;
125
126 loop {
127 let mut write_would_block = false;
128 let mut read_would_block = false;
129 let mut need_flush = false;
130
131 while self.session.wants_write() {
132 match self.write_io(cx) {
133 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
134 Poll::Ready(Ok(n)) => {
135 wrlen += n;
136 need_flush = true;
137 }
138 Poll::Pending => {
139 write_would_block = true;
140 break;
141 }
142 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
143 }
144 }
145
146 if need_flush {
147 match Pin::new(&mut self.io).poll_flush(cx) {
148 Poll::Ready(Ok(())) => (),
149 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
150 Poll::Pending => write_would_block = true,
151 }
152 }
153
154 while !self.eof && self.session.wants_read() {
155 match self.read_io(cx) {
156 Poll::Ready(Ok(0)) => self.eof = true,
157 Poll::Ready(Ok(n)) => rdlen += n,
158 Poll::Pending => {
159 read_would_block = true;
160 break;
161 }
162 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
163 }
164 }
165
166 return match (self.eof, self.session.is_handshaking()) {
167 (true, true) => {
168 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
169 Poll::Ready(Err(err))
170 }
171 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
172 (_, true) if write_would_block || read_would_block => {
173 if rdlen != 0 || wrlen != 0 {
174 Poll::Ready(Ok((rdlen, wrlen)))
175 } else {
176 Poll::Pending
177 }
178 }
179 (..) => continue,
180 };
181 }
182 }
183}
184
185impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'_, IO, C>
186where
187 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
188 SD: SideData,
189{
190 fn poll_read(
191 mut self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 buf: &mut ReadBuf<'_>,
194 ) -> Poll<io::Result<()>> {
195 let mut io_pending = false;
196
197 while !self.eof && self.session.wants_read() {
199 match self.read_io(cx) {
200 Poll::Ready(Ok(0)) => {
201 break;
202 }
203 Poll::Ready(Ok(_)) => (),
204 Poll::Pending => {
205 io_pending = true;
206 break;
207 }
208 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
209 }
210 }
211
212 match self.session.reader().read(buf.initialize_unfilled()) {
213 Ok(n) => {
222 buf.advance(n);
223 Poll::Ready(Ok(()))
224 }
225
226 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
228 if !io_pending {
229 cx.waker().wake_by_ref();
235 }
236
237 Poll::Pending
238 }
239
240 Err(err) => Poll::Ready(Err(err)),
241 }
242 }
243}
244
245impl<IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'_, IO, C>
246where
247 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
248 SD: SideData,
249{
250 fn poll_write(
251 mut self: Pin<&mut Self>,
252 cx: &mut Context,
253 buf: &[u8],
254 ) -> Poll<io::Result<usize>> {
255 let mut pos = 0;
256
257 while pos != buf.len() {
258 let mut would_block = false;
259
260 match self.session.writer().write(&buf[pos..]) {
261 Ok(n) => pos += n,
262 Err(err) => return Poll::Ready(Err(err)),
263 };
264
265 while self.session.wants_write() {
266 match self.write_io(cx) {
267 Poll::Ready(Ok(0)) | Poll::Pending => {
268 would_block = true;
269 break;
270 }
271 Poll::Ready(Ok(_)) => (),
272 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
273 }
274 }
275
276 return match (pos, would_block) {
277 (0, true) => Poll::Pending,
278 (n, true) => Poll::Ready(Ok(n)),
279 (_, false) => continue,
280 };
281 }
282
283 Poll::Ready(Ok(pos))
284 }
285
286 fn poll_write_vectored(
287 mut self: Pin<&mut Self>,
288 cx: &mut Context<'_>,
289 bufs: &[IoSlice<'_>],
290 ) -> Poll<io::Result<usize>> {
291 if bufs.iter().all(|buf| buf.is_empty()) {
292 return Poll::Ready(Ok(0));
293 }
294
295 loop {
296 let mut would_block = false;
297 let written = self.session.writer().write_vectored(bufs)?;
298
299 while self.session.wants_write() {
300 match self.write_io(cx) {
301 Poll::Ready(Ok(0)) | Poll::Pending => {
302 would_block = true;
303 break;
304 }
305 Poll::Ready(Ok(_)) => (),
306 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
307 }
308 }
309
310 return match (written, would_block) {
311 (0, true) => Poll::Pending,
312 (0, false) => continue,
313 (n, _) => Poll::Ready(Ok(n)),
314 };
315 }
316 }
317
318 #[inline]
319 fn is_write_vectored(&self) -> bool {
320 true
321 }
322
323 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
324 self.session.writer().flush()?;
325 while self.session.wants_write() {
326 if ready!(self.write_io(cx))? == 0 {
327 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
328 }
329 }
330 Pin::new(&mut self.io).poll_flush(cx)
331 }
332
333 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
334 while self.session.wants_write() {
335 if ready!(self.write_io(cx))? == 0 {
336 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
337 }
338 }
339
340 Poll::Ready(match ready!(Pin::new(&mut self.io).poll_shutdown(cx)) {
341 Ok(()) => Ok(()),
342 Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
344 Err(err) => Err(err),
345 })
346 }
347}
348
349pub struct SyncReadAdapter<'a, 'b, T> {
354 pub io: &'a mut T,
355 pub cx: &'a mut Context<'b>,
356}
357
358impl<T: AsyncRead + Unpin> Read for SyncReadAdapter<'_, '_, T> {
359 #[inline]
360 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
361 let mut buf = ReadBuf::new(buf);
362 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
363 Poll::Ready(Ok(())) => Ok(buf.filled().len()),
364 Poll::Ready(Err(err)) => Err(err),
365 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
366 }
367 }
368}
369
370pub struct SyncWriteAdapter<'a, 'b, T> {
375 pub io: &'a mut T,
376 pub cx: &'a mut Context<'b>,
377}
378
379impl<T: Unpin> SyncWriteAdapter<'_, '_, T> {
380 #[inline]
381 fn poll_with<U>(
382 &mut self,
383 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
384 ) -> io::Result<U> {
385 match f(Pin::new(self.io), self.cx) {
386 Poll::Ready(result) => result,
387 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
388 }
389 }
390}
391
392impl<T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'_, '_, T> {
393 #[inline]
394 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
395 self.poll_with(|io, cx| io.poll_write(cx, buf))
396 }
397
398 #[inline]
399 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
400 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
401 }
402
403 fn flush(&mut self) -> io::Result<()> {
404 self.poll_with(|io, cx| io.poll_flush(cx))
405 }
406}
407
408#[cfg(test)]
409mod test_stream;