rustls/
stream.rs

1use core::ops::{Deref, DerefMut};
2use std::io::{IoSlice, Read, Result, Write};
3
4use crate::conn::{ConnectionCommon, SideData};
5
6/// This type implements `io::Read` and `io::Write`, encapsulating
7/// a Connection `C` and an underlying transport `T`, such as a socket.
8///
9/// This allows you to use a rustls Connection like a normal stream.
10#[derive(Debug)]
11pub struct Stream<'a, C: 'a + ?Sized, T: 'a + Read + Write + ?Sized> {
12    /// Our TLS connection
13    pub conn: &'a mut C,
14
15    /// The underlying transport, like a socket
16    pub sock: &'a mut T,
17}
18
19impl<'a, C, T, S> Stream<'a, C, T>
20where
21    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
22    T: 'a + Read + Write,
23    S: SideData,
24{
25    /// Make a new Stream using the Connection `conn` and socket-like object
26    /// `sock`.  This does not fail and does no IO.
27    pub fn new(conn: &'a mut C, sock: &'a mut T) -> Self {
28        Self { conn, sock }
29    }
30
31    /// If we're handshaking, complete all the IO for that.
32    /// If we have data to write, write it all.
33    fn complete_prior_io(&mut self) -> Result<()> {
34        if self.conn.is_handshaking() {
35            self.conn.complete_io(self.sock)?;
36        }
37
38        if self.conn.wants_write() {
39            self.conn.complete_io(self.sock)?;
40        }
41
42        Ok(())
43    }
44}
45
46impl<'a, C, T, S> Read for Stream<'a, C, T>
47where
48    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
49    T: 'a + Read + Write,
50    S: SideData,
51{
52    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
53        self.complete_prior_io()?;
54
55        // We call complete_io() in a loop since a single call may read only
56        // a partial packet from the underlying transport. A full packet is
57        // needed to get more plaintext, which we must do if EOF has not been
58        // hit.
59        while self.conn.wants_read() {
60            if self.conn.complete_io(self.sock)?.0 == 0 {
61                break;
62            }
63        }
64
65        self.conn.reader().read(buf)
66    }
67
68    #[cfg(read_buf)]
69    fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> {
70        self.complete_prior_io()?;
71
72        // We call complete_io() in a loop since a single call may read only
73        // a partial packet from the underlying transport. A full packet is
74        // needed to get more plaintext, which we must do if EOF has not been
75        // hit.
76        while self.conn.wants_read() {
77            if self.conn.complete_io(self.sock)?.0 == 0 {
78                break;
79            }
80        }
81
82        self.conn.reader().read_buf(cursor)
83    }
84}
85
86impl<'a, C, T, S> Write for Stream<'a, C, T>
87where
88    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
89    T: 'a + Read + Write,
90    S: SideData,
91{
92    fn write(&mut self, buf: &[u8]) -> Result<usize> {
93        self.complete_prior_io()?;
94
95        let len = self.conn.writer().write(buf)?;
96
97        // Try to write the underlying transport here, but don't let
98        // any errors mask the fact we've consumed `len` bytes.
99        // Callers will learn of permanent errors on the next call.
100        let _ = self.conn.complete_io(self.sock);
101
102        Ok(len)
103    }
104
105    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
106        self.complete_prior_io()?;
107
108        let len = self
109            .conn
110            .writer()
111            .write_vectored(bufs)?;
112
113        // Try to write the underlying transport here, but don't let
114        // any errors mask the fact we've consumed `len` bytes.
115        // Callers will learn of permanent errors on the next call.
116        let _ = self.conn.complete_io(self.sock);
117
118        Ok(len)
119    }
120
121    fn flush(&mut self) -> Result<()> {
122        self.complete_prior_io()?;
123
124        self.conn.writer().flush()?;
125        if self.conn.wants_write() {
126            self.conn.complete_io(self.sock)?;
127        }
128        Ok(())
129    }
130}
131
132/// This type implements `io::Read` and `io::Write`, encapsulating
133/// and owning a Connection `C` and an underlying blocking transport
134/// `T`, such as a socket.
135///
136/// This allows you to use a rustls Connection like a normal stream.
137#[derive(Debug)]
138pub struct StreamOwned<C: Sized, T: Read + Write + Sized> {
139    /// Our connection
140    pub conn: C,
141
142    /// The underlying transport, like a socket
143    pub sock: T,
144}
145
146impl<C, T, S> StreamOwned<C, T>
147where
148    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
149    T: Read + Write,
150    S: SideData,
151{
152    /// Make a new StreamOwned taking the Connection `conn` and socket-like
153    /// object `sock`.  This does not fail and does no IO.
154    ///
155    /// This is the same as `Stream::new` except `conn` and `sock` are
156    /// moved into the StreamOwned.
157    pub fn new(conn: C, sock: T) -> Self {
158        Self { conn, sock }
159    }
160
161    /// Get a reference to the underlying socket
162    pub fn get_ref(&self) -> &T {
163        &self.sock
164    }
165
166    /// Get a mutable reference to the underlying socket
167    pub fn get_mut(&mut self) -> &mut T {
168        &mut self.sock
169    }
170
171    /// Extract the `conn` and `sock` parts from the `StreamOwned`
172    pub fn into_parts(self) -> (C, T) {
173        (self.conn, self.sock)
174    }
175}
176
177impl<'a, C, T, S> StreamOwned<C, T>
178where
179    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
180    T: Read + Write,
181    S: SideData,
182{
183    fn as_stream(&'a mut self) -> Stream<'a, C, T> {
184        Stream {
185            conn: &mut self.conn,
186            sock: &mut self.sock,
187        }
188    }
189}
190
191impl<C, T, S> Read for StreamOwned<C, T>
192where
193    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
194    T: Read + Write,
195    S: SideData,
196{
197    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
198        self.as_stream().read(buf)
199    }
200
201    #[cfg(read_buf)]
202    fn read_buf(&mut self, cursor: core::io::BorrowedCursor<'_>) -> Result<()> {
203        self.as_stream().read_buf(cursor)
204    }
205}
206
207impl<C, T, S> Write for StreamOwned<C, T>
208where
209    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
210    T: Read + Write,
211    S: SideData,
212{
213    fn write(&mut self, buf: &[u8]) -> Result<usize> {
214        self.as_stream().write(buf)
215    }
216
217    fn flush(&mut self) -> Result<()> {
218        self.as_stream().flush()
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use std::net::TcpStream;
225
226    use super::{Stream, StreamOwned};
227    use crate::client::ClientConnection;
228    use crate::server::ServerConnection;
229
230    #[test]
231    fn stream_can_be_created_for_connection_and_tcpstream() {
232        type _Test<'a> = Stream<'a, ClientConnection, TcpStream>;
233    }
234
235    #[test]
236    fn streamowned_can_be_created_for_client_and_tcpstream() {
237        type _Test = StreamOwned<ClientConnection, TcpStream>;
238    }
239
240    #[test]
241    fn streamowned_can_be_created_for_server_and_tcpstream() {
242        type _Test = StreamOwned<ServerConnection, TcpStream>;
243    }
244}