1use std::io;
2#[cfg(unix)]
3use std::os::unix::io::{AsRawFd, RawFd};
4#[cfg(windows)]
5use std::os::windows::io::{AsRawSocket, RawSocket};
6use std::pin::Pin;
7#[cfg(feature = "early-data")]
8use std::task::Waker;
9use std::task::{Context, Poll};
10
11use rustls::ClientConnection;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13
14use crate::common::{IoSession, Stream, TlsState};
15
16#[derive(Debug)]
19pub struct TlsStream<IO> {
20 pub(crate) io: IO,
21 pub(crate) session: ClientConnection,
22 pub(crate) state: TlsState,
23
24 #[cfg(feature = "early-data")]
25 pub(crate) early_waker: Option<Waker>,
26}
27
28impl<IO> TlsStream<IO> {
29 #[inline]
30 pub fn get_ref(&self) -> (&IO, &ClientConnection) {
31 (&self.io, &self.session)
32 }
33
34 #[inline]
35 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
36 (&mut self.io, &mut self.session)
37 }
38
39 #[inline]
40 pub fn into_inner(self) -> (IO, ClientConnection) {
41 (self.io, self.session)
42 }
43}
44
45#[cfg(unix)]
46impl<S> AsRawFd for TlsStream<S>
47where
48 S: AsRawFd,
49{
50 fn as_raw_fd(&self) -> RawFd {
51 self.get_ref().0.as_raw_fd()
52 }
53}
54
55#[cfg(windows)]
56impl<S> AsRawSocket for TlsStream<S>
57where
58 S: AsRawSocket,
59{
60 fn as_raw_socket(&self) -> RawSocket {
61 self.get_ref().0.as_raw_socket()
62 }
63}
64
65impl<IO> IoSession for TlsStream<IO> {
66 type Io = IO;
67 type Session = ClientConnection;
68
69 #[inline]
70 fn skip_handshake(&self) -> bool {
71 self.state.is_early_data()
72 }
73
74 #[inline]
75 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
76 (&mut self.state, &mut self.io, &mut self.session)
77 }
78
79 #[inline]
80 fn into_io(self) -> Self::Io {
81 self.io
82 }
83}
84
85impl<IO> AsyncRead for TlsStream<IO>
86where
87 IO: AsyncRead + AsyncWrite + Unpin,
88{
89 fn poll_read(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 buf: &mut ReadBuf<'_>,
93 ) -> Poll<io::Result<()>> {
94 match self.state {
95 #[cfg(feature = "early-data")]
96 TlsState::EarlyData(..) => {
97 let this = self.get_mut();
98
99 if this
106 .early_waker
107 .as_ref()
108 .filter(|waker| cx.waker().will_wake(waker))
109 .is_none()
110 {
111 this.early_waker = Some(cx.waker().clone());
112 }
113
114 Poll::Pending
115 }
116 TlsState::Stream | TlsState::WriteShutdown => {
117 let this = self.get_mut();
118 let mut stream =
119 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
120 let prev = buf.remaining();
121
122 match stream.as_mut_pin().poll_read(cx, buf) {
123 Poll::Ready(Ok(())) => {
124 if prev == buf.remaining() || stream.eof {
125 this.state.shutdown_read();
126 }
127
128 Poll::Ready(Ok(()))
129 }
130 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
131 this.state.shutdown_read();
132 Poll::Ready(Err(err))
133 }
134 output => output,
135 }
136 }
137 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
138 }
139 }
140}
141
142impl<IO> AsyncWrite for TlsStream<IO>
143where
144 IO: AsyncRead + AsyncWrite + Unpin,
145{
146 fn poll_write(
149 self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 buf: &[u8],
152 ) -> Poll<io::Result<usize>> {
153 let this = self.get_mut();
154 let mut stream =
155 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
156
157 #[cfg(feature = "early-data")]
158 {
159 let bufs = [io::IoSlice::new(buf)];
160 let written = ready!(poll_handle_early_data(
161 &mut this.state,
162 &mut stream,
163 &mut this.early_waker,
164 cx,
165 &bufs
166 ))?;
167 if written != 0 {
168 return Poll::Ready(Ok(written));
169 }
170 }
171
172 stream.as_mut_pin().poll_write(cx, buf)
173 }
174
175 fn poll_write_vectored(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 bufs: &[io::IoSlice<'_>],
181 ) -> Poll<io::Result<usize>> {
182 let this = self.get_mut();
183 let mut stream =
184 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
185
186 #[cfg(feature = "early-data")]
187 {
188 let written = ready!(poll_handle_early_data(
189 &mut this.state,
190 &mut stream,
191 &mut this.early_waker,
192 cx,
193 bufs
194 ))?;
195 if written != 0 {
196 return Poll::Ready(Ok(written));
197 }
198 }
199
200 stream.as_mut_pin().poll_write_vectored(cx, bufs)
201 }
202
203 #[inline]
204 fn is_write_vectored(&self) -> bool {
205 true
206 }
207
208 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
209 let this = self.get_mut();
210 let mut stream =
211 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
212
213 #[cfg(feature = "early-data")]
214 ready!(poll_handle_early_data(
215 &mut this.state,
216 &mut stream,
217 &mut this.early_waker,
218 cx,
219 &[]
220 ))?;
221
222 stream.as_mut_pin().poll_flush(cx)
223 }
224
225 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
226 #[cfg(feature = "early-data")]
227 {
228 if matches!(self.state, TlsState::EarlyData(..)) {
230 ready!(self.as_mut().poll_flush(cx))?;
231 }
232 }
233
234 if self.state.writeable() {
235 self.session.send_close_notify();
236 self.state.shutdown_write();
237 }
238
239 let this = self.get_mut();
240 let mut stream =
241 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
242 stream.as_mut_pin().poll_shutdown(cx)
243 }
244}
245
246#[cfg(feature = "early-data")]
247fn poll_handle_early_data<IO>(
248 state: &mut TlsState,
249 stream: &mut Stream<IO, ClientConnection>,
250 early_waker: &mut Option<Waker>,
251 cx: &mut Context<'_>,
252 bufs: &[io::IoSlice<'_>],
253) -> Poll<io::Result<usize>>
254where
255 IO: AsyncRead + AsyncWrite + Unpin,
256{
257 if let TlsState::EarlyData(pos, data) = state {
258 use std::io::Write;
259
260 if let Some(mut early_data) = stream.session.early_data() {
262 let mut written = 0;
263
264 for buf in bufs {
265 if buf.is_empty() {
266 continue;
267 }
268
269 let len = match early_data.write(buf) {
270 Ok(0) => break,
271 Ok(n) => n,
272 Err(err) => return Poll::Ready(Err(err)),
273 };
274
275 written += len;
276 data.extend_from_slice(&buf[..len]);
277
278 if len < buf.len() {
279 break;
280 }
281 }
282
283 if written != 0 {
284 return Poll::Ready(Ok(written));
285 }
286 }
287
288 while stream.session.is_handshaking() {
290 ready!(stream.handshake(cx))?;
291 }
292
293 if !stream.session.is_early_data_accepted() {
295 while *pos < data.len() {
296 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
297 *pos += len;
298 }
299 }
300
301 *state = TlsState::Stream;
303
304 if let Some(waker) = early_waker.take() {
305 waker.wake();
306 }
307 }
308
309 Poll::Ready(Ok(0))
310}