rustls/msgs/message/
outbound.rs

1use alloc::vec::Vec;
2
3use super::{MessageError, PlainMessage, HEADER_SIZE, MAX_PAYLOAD};
4use crate::enums::{ContentType, ProtocolVersion};
5use crate::msgs::base::Payload;
6use crate::msgs::codec::{Codec, Reader};
7use crate::record_layer::RecordLayer;
8
9/// A TLS frame, named `TLSPlaintext` in the standard.
10///
11/// This outbound type borrows its "to be encrypted" payload from the "user".
12/// It is used for fragmenting and is consumed by encryption.
13#[derive(Debug)]
14pub struct OutboundPlainMessage<'a> {
15    pub typ: ContentType,
16    pub version: ProtocolVersion,
17    pub payload: OutboundChunks<'a>,
18}
19
20impl OutboundPlainMessage<'_> {
21    pub(crate) fn encoded_len(&self, record_layer: &RecordLayer) -> usize {
22        HEADER_SIZE + record_layer.encrypted_len(self.payload.len())
23    }
24
25    pub(crate) fn to_unencrypted_opaque(&self) -> OutboundOpaqueMessage {
26        let mut payload = PrefixedPayload::with_capacity(self.payload.len());
27        payload.extend_from_chunks(&self.payload);
28        OutboundOpaqueMessage {
29            version: self.version,
30            typ: self.typ,
31            payload,
32        }
33    }
34}
35
36/// A collection of borrowed plaintext slices.
37///
38/// Warning: OutboundChunks does not guarantee that the simplest variant is used.
39/// Multiple can hold non fragmented or empty payloads.
40#[derive(Debug, Clone)]
41pub enum OutboundChunks<'a> {
42    /// A single byte slice. Contrary to `Multiple`, this uses a single pointer indirection
43    Single(&'a [u8]),
44    /// A collection of chunks (byte slices)
45    /// and cursors to single out a fragmented range of bytes.
46    /// OutboundChunks assumes that start <= end
47    Multiple {
48        chunks: &'a [&'a [u8]],
49        start: usize,
50        end: usize,
51    },
52}
53
54impl<'a> OutboundChunks<'a> {
55    /// Create a payload from a slice of byte slices.
56    /// If fragmented the cursors are added by default: start = 0, end = length
57    pub fn new(chunks: &'a [&'a [u8]]) -> Self {
58        if chunks.len() == 1 {
59            Self::Single(chunks[0])
60        } else {
61            Self::Multiple {
62                chunks,
63                start: 0,
64                end: chunks
65                    .iter()
66                    .map(|chunk| chunk.len())
67                    .sum(),
68            }
69        }
70    }
71
72    /// Create a payload with a single empty slice
73    pub fn new_empty() -> Self {
74        Self::Single(&[])
75    }
76
77    /// Flatten the slice of byte slices to an owned vector of bytes
78    pub fn to_vec(&self) -> Vec<u8> {
79        let mut vec = Vec::with_capacity(self.len());
80        self.copy_to_vec(&mut vec);
81        vec
82    }
83
84    /// Append all bytes to a vector
85    pub fn copy_to_vec(&self, vec: &mut Vec<u8>) {
86        match *self {
87            Self::Single(chunk) => vec.extend_from_slice(chunk),
88            Self::Multiple { chunks, start, end } => {
89                let mut size = 0;
90                for chunk in chunks.iter() {
91                    let psize = size;
92                    let len = chunk.len();
93                    size += len;
94                    if size <= start || psize >= end {
95                        continue;
96                    }
97                    let start = start.saturating_sub(psize);
98                    let end = if end - psize < len { end - psize } else { len };
99                    vec.extend_from_slice(&chunk[start..end]);
100                }
101            }
102        }
103    }
104
105    /// Split self in two, around an index
106    /// Works similarly to `split_at` in the core library, except it doesn't panic if out of bound
107    pub fn split_at(&self, mid: usize) -> (Self, Self) {
108        match *self {
109            Self::Single(chunk) => {
110                let mid = Ord::min(mid, chunk.len());
111                (Self::Single(&chunk[..mid]), Self::Single(&chunk[mid..]))
112            }
113            Self::Multiple { chunks, start, end } => {
114                let mid = Ord::min(start + mid, end);
115                (
116                    Self::Multiple {
117                        chunks,
118                        start,
119                        end: mid,
120                    },
121                    Self::Multiple {
122                        chunks,
123                        start: mid,
124                        end,
125                    },
126                )
127            }
128        }
129    }
130
131    /// Returns true if the payload is empty
132    pub fn is_empty(&self) -> bool {
133        self.len() == 0
134    }
135
136    /// Returns the cumulative length of all chunks
137    pub fn len(&self) -> usize {
138        match self {
139            Self::Single(chunk) => chunk.len(),
140            Self::Multiple { start, end, .. } => end - start,
141        }
142    }
143}
144
145impl<'a> From<&'a [u8]> for OutboundChunks<'a> {
146    fn from(payload: &'a [u8]) -> Self {
147        Self::Single(payload)
148    }
149}
150
151/// A TLS frame, named `TLSPlaintext` in the standard.
152///
153/// This outbound type owns all memory for its interior parts.
154/// It results from encryption and is used for io write.
155#[derive(Clone, Debug)]
156pub struct OutboundOpaqueMessage {
157    pub typ: ContentType,
158    pub version: ProtocolVersion,
159    pub payload: PrefixedPayload,
160}
161
162impl OutboundOpaqueMessage {
163    /// Construct a new `OpaqueMessage` from constituent fields.
164    ///
165    /// `body` is moved into the `payload` field.
166    pub fn new(typ: ContentType, version: ProtocolVersion, payload: PrefixedPayload) -> Self {
167        Self {
168            typ,
169            version,
170            payload,
171        }
172    }
173
174    /// Construct by decoding from a [`Reader`].
175    ///
176    /// `MessageError` allows callers to distinguish between valid prefixes (might
177    /// become valid if we read more data) and invalid data.
178    pub fn read(r: &mut Reader<'_>) -> Result<Self, MessageError> {
179        let (typ, version, len) = read_opaque_message_header(r)?;
180
181        let content = r
182            .take(len as usize)
183            .ok_or(MessageError::TooShortForLength)?;
184
185        Ok(Self {
186            typ,
187            version,
188            payload: PrefixedPayload::from(content),
189        })
190    }
191
192    pub fn encode(self) -> Vec<u8> {
193        let length = self.payload.len() as u16;
194        let mut encoded_payload = self.payload.0;
195        encoded_payload[0] = self.typ.into();
196        encoded_payload[1..3].copy_from_slice(&self.version.to_array());
197        encoded_payload[3..5].copy_from_slice(&(length).to_be_bytes());
198        encoded_payload
199    }
200
201    /// Force conversion into a plaintext message.
202    ///
203    /// This should only be used for messages that are known to be in plaintext. Otherwise, the
204    /// `OutboundOpaqueMessage` should be decrypted into a `PlainMessage` using a `MessageDecrypter`.
205    pub fn into_plain_message(self) -> PlainMessage {
206        PlainMessage {
207            version: self.version,
208            typ: self.typ,
209            payload: Payload::Owned(self.payload.as_ref().to_vec()),
210        }
211    }
212}
213
214#[derive(Clone, Debug)]
215pub struct PrefixedPayload(Vec<u8>);
216
217impl PrefixedPayload {
218    pub fn with_capacity(capacity: usize) -> Self {
219        let mut prefixed_payload = Vec::with_capacity(HEADER_SIZE + capacity);
220        prefixed_payload.resize(HEADER_SIZE, 0);
221        Self(prefixed_payload)
222    }
223
224    pub fn extend_from_slice(&mut self, slice: &[u8]) {
225        self.0.extend_from_slice(slice)
226    }
227
228    pub fn extend_from_chunks(&mut self, chunks: &OutboundChunks<'_>) {
229        chunks.copy_to_vec(&mut self.0)
230    }
231
232    pub fn truncate(&mut self, len: usize) {
233        self.0.truncate(len + HEADER_SIZE)
234    }
235
236    fn len(&self) -> usize {
237        self.0.len() - HEADER_SIZE
238    }
239}
240
241impl AsRef<[u8]> for PrefixedPayload {
242    fn as_ref(&self) -> &[u8] {
243        &self.0[HEADER_SIZE..]
244    }
245}
246
247impl AsMut<[u8]> for PrefixedPayload {
248    fn as_mut(&mut self) -> &mut [u8] {
249        &mut self.0[HEADER_SIZE..]
250    }
251}
252
253impl<'a> Extend<&'a u8> for PrefixedPayload {
254    fn extend<T: IntoIterator<Item = &'a u8>>(&mut self, iter: T) {
255        self.0.extend(iter)
256    }
257}
258
259impl From<&[u8]> for PrefixedPayload {
260    fn from(content: &[u8]) -> Self {
261        let mut payload = Vec::with_capacity(HEADER_SIZE + content.len());
262        payload.extend(&[0u8; HEADER_SIZE]);
263        payload.extend(content);
264        Self(payload)
265    }
266}
267
268impl<const N: usize> From<&[u8; N]> for PrefixedPayload {
269    fn from(content: &[u8; N]) -> Self {
270        Self::from(&content[..])
271    }
272}
273
274pub(crate) fn read_opaque_message_header(
275    r: &mut Reader<'_>,
276) -> Result<(ContentType, ProtocolVersion, u16), MessageError> {
277    let typ = ContentType::read(r).map_err(|_| MessageError::TooShortForHeader)?;
278    // Don't accept any new content-types.
279    if let ContentType::Unknown(_) = typ {
280        return Err(MessageError::InvalidContentType);
281    }
282
283    let version = ProtocolVersion::read(r).map_err(|_| MessageError::TooShortForHeader)?;
284    // Accept only versions 0x03XX for any XX.
285    match version {
286        ProtocolVersion::Unknown(ref v) if (v & 0xff00) != 0x0300 => {
287            return Err(MessageError::UnknownProtocolVersion);
288        }
289        _ => {}
290    };
291
292    let len = u16::read(r).map_err(|_| MessageError::TooShortForHeader)?;
293
294    // Reject undersize messages
295    //  implemented per section 5.1 of RFC8446 (TLSv1.3)
296    //              per section 6.2.1 of RFC5246 (TLSv1.2)
297    if typ != ContentType::ApplicationData && len == 0 {
298        return Err(MessageError::InvalidEmptyPayload);
299    }
300
301    // Reject oversize messages
302    if len >= MAX_PAYLOAD {
303        return Err(MessageError::MessageTooLarge);
304    }
305
306    Ok((typ, version, len))
307}
308
309#[cfg(test)]
310mod tests {
311    use std::{println, vec};
312
313    use super::*;
314
315    #[test]
316    fn split_at_with_single_slice() {
317        let owner: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7];
318        let borrowed_payload = OutboundChunks::Single(owner);
319
320        let (before, after) = borrowed_payload.split_at(6);
321        println!("before:{:?}\nafter:{:?}", before, after);
322        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5]);
323        assert_eq!(after.to_vec(), &[6, 7]);
324    }
325
326    #[test]
327    fn split_at_with_multiple_slices() {
328        let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]];
329        let borrowed_payload = OutboundChunks::new(&owner);
330
331        let (before, after) = borrowed_payload.split_at(3);
332        println!("before:{:?}\nafter:{:?}", before, after);
333        assert_eq!(before.to_vec(), &[0, 1, 2]);
334        assert_eq!(after.to_vec(), &[3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
335
336        let (before, after) = borrowed_payload.split_at(8);
337        println!("before:{:?}\nafter:{:?}", before, after);
338        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7]);
339        assert_eq!(after.to_vec(), &[8, 9, 10, 11, 12]);
340
341        let (before, after) = borrowed_payload.split_at(11);
342        println!("before:{:?}\nafter:{:?}", before, after);
343        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
344        assert_eq!(after.to_vec(), &[11, 12]);
345    }
346
347    #[test]
348    fn split_out_of_bounds() {
349        let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]];
350
351        let single_payload = OutboundChunks::Single(owner[0]);
352        let (before, after) = single_payload.split_at(17);
353        println!("before:{:?}\nafter:{:?}", before, after);
354        assert_eq!(before.to_vec(), &[0, 1, 2, 3]);
355        assert!(after.is_empty());
356
357        let multiple_payload = OutboundChunks::new(&owner);
358        let (before, after) = multiple_payload.split_at(17);
359        println!("before:{:?}\nafter:{:?}", before, after);
360        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
361        assert!(after.is_empty());
362
363        let empty_payload = OutboundChunks::new_empty();
364        let (before, after) = empty_payload.split_at(17);
365        println!("before:{:?}\nafter:{:?}", before, after);
366        assert!(before.is_empty());
367        assert!(after.is_empty());
368    }
369
370    #[test]
371    fn empty_slices_mixed() {
372        let owner: Vec<&[u8]> = vec![&[], &[], &[0], &[], &[1, 2], &[], &[3], &[4], &[], &[]];
373        let mut borrowed_payload = OutboundChunks::new(&owner);
374        let mut fragment_count = 0;
375        let mut fragment;
376        let expected_fragments: &[&[u8]] = &[&[0, 1], &[2, 3], &[4]];
377
378        while !borrowed_payload.is_empty() {
379            (fragment, borrowed_payload) = borrowed_payload.split_at(2);
380            println!("{fragment:?}");
381            assert_eq!(&expected_fragments[fragment_count], &fragment.to_vec());
382            fragment_count += 1;
383        }
384        assert_eq!(fragment_count, expected_fragments.len());
385    }
386
387    #[test]
388    fn exhaustive_splitting() {
389        let owner: Vec<u8> = (0..127).collect();
390        let slices = (0..7)
391            .map(|i| &owner[((1 << i) - 1)..((1 << (i + 1)) - 1)])
392            .collect::<Vec<_>>();
393        let payload = OutboundChunks::new(&slices);
394
395        assert_eq!(payload.to_vec(), owner);
396        println!("{:#?}", payload);
397
398        for start in 0..128 {
399            for end in start..128 {
400                for mid in 0..(end - start) {
401                    let witness = owner[start..end].split_at(mid);
402                    let split_payload = payload
403                        .split_at(end)
404                        .0
405                        .split_at(start)
406                        .1
407                        .split_at(mid);
408                    assert_eq!(
409                        witness.0,
410                        split_payload.0.to_vec(),
411                        "start: {start}, mid:{mid}, end:{end}"
412                    );
413                    assert_eq!(
414                        witness.1,
415                        split_payload.1.to_vec(),
416                        "start: {start}, mid:{mid}, end:{end}"
417                    );
418                }
419            }
420        }
421    }
422}