async_native_tls/
handshake.rs

1use std::future::Future;
2use std::io::{Read, Write};
3use std::marker::Unpin;
4use std::pin::Pin;
5use std::ptr::null_mut;
6use std::task::{Context, Poll};
7
8use native_tls::{Error, HandshakeError, MidHandshakeTlsStream};
9
10use crate::runtime::{AsyncRead, AsyncWrite};
11use crate::std_adapter::StdAdapter;
12use crate::TlsStream;
13
14pub(crate) async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error>
15where
16    F: FnOnce(
17            StdAdapter<S>,
18        )
19            -> Result<native_tls::TlsStream<StdAdapter<S>>, HandshakeError<StdAdapter<S>>>
20        + Unpin,
21    S: AsyncRead + AsyncWrite + Unpin,
22{
23    let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
24
25    match start.await {
26        Err(e) => Err(e),
27        Ok(StartedHandshake::Done(s)) => Ok(s),
28        Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await,
29    }
30}
31
32struct MidHandshake<S>(Option<MidHandshakeTlsStream<StdAdapter<S>>>);
33
34enum StartedHandshake<S> {
35    Done(TlsStream<S>),
36    Mid(MidHandshakeTlsStream<StdAdapter<S>>),
37}
38
39struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
40struct StartedHandshakeFutureInner<F, S> {
41    f: F,
42    stream: S,
43}
44
45impl<F, S> Future for StartedHandshakeFuture<F, S>
46where
47    F: FnOnce(
48            StdAdapter<S>,
49        )
50            -> Result<native_tls::TlsStream<StdAdapter<S>>, HandshakeError<StdAdapter<S>>>
51        + Unpin,
52    S: Unpin,
53    StdAdapter<S>: Read + Write,
54{
55    type Output = Result<StartedHandshake<S>, Error>;
56
57    fn poll(
58        mut self: Pin<&mut Self>,
59        ctx: &mut Context<'_>,
60    ) -> Poll<Result<StartedHandshake<S>, Error>> {
61        let inner = self.0.take().expect("future polled after completion");
62        let stream = StdAdapter {
63            inner: inner.stream,
64            context: ctx as *mut _ as *mut (),
65        };
66
67        match (inner.f)(stream) {
68            Ok(mut s) => {
69                s.get_mut().context = null_mut();
70                Poll::Ready(Ok(StartedHandshake::Done(TlsStream::new(s))))
71            }
72            Err(HandshakeError::WouldBlock(mut s)) => {
73                s.get_mut().context = null_mut();
74                Poll::Ready(Ok(StartedHandshake::Mid(s)))
75            }
76            Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
77        }
78    }
79}
80
81impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> {
82    type Output = Result<TlsStream<S>, Error>;
83
84    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85        let mut_self = self.get_mut();
86        let mut s = mut_self.0.take().expect("future polled after completion");
87
88        s.get_mut().context = cx as *mut _ as *mut ();
89        match s.handshake() {
90            Ok(stream) => Poll::Ready(Ok(TlsStream::new(stream))),
91            Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
92            Err(HandshakeError::WouldBlock(mut s)) => {
93                s.get_mut().context = null_mut();
94                mut_self.0 = Some(s);
95                Poll::Pending
96            }
97        }
98    }
99}