rustls/msgs/
fragmenter.rs

1use crate::enums::{ContentType, ProtocolVersion};
2use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage};
3use crate::Error;
4pub(crate) const MAX_FRAGMENT_LEN: usize = 16384;
5pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2;
6pub(crate) const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD;
7
8pub struct MessageFragmenter {
9    max_frag: usize,
10}
11
12impl Default for MessageFragmenter {
13    fn default() -> Self {
14        Self {
15            max_frag: MAX_FRAGMENT_LEN,
16        }
17    }
18}
19
20impl MessageFragmenter {
21    /// Take `msg` and fragment it into new messages with the same type and version.
22    ///
23    /// Each returned message size is no more than `max_frag`.
24    ///
25    /// Return an iterator across those messages.
26    ///
27    /// Payloads are borrowed from `msg`.
28    pub fn fragment_message<'a>(
29        &self,
30        msg: &'a PlainMessage,
31    ) -> impl Iterator<Item = OutboundPlainMessage<'a>> + 'a {
32        self.fragment_payload(msg.typ, msg.version, msg.payload.bytes().into())
33    }
34
35    /// Take `payload` and fragment it into new messages with given type and version.
36    ///
37    /// Each returned message size is no more than `max_frag`.
38    ///
39    /// Return an iterator across those messages.
40    ///
41    /// Payloads are borrowed from `payload`.
42    pub(crate) fn fragment_payload<'a>(
43        &self,
44        typ: ContentType,
45        version: ProtocolVersion,
46        payload: OutboundChunks<'a>,
47    ) -> impl ExactSizeIterator<Item = OutboundPlainMessage<'a>> {
48        Chunker::new(payload, self.max_frag).map(move |payload| OutboundPlainMessage {
49            typ,
50            version,
51            payload,
52        })
53    }
54
55    /// Set the maximum fragment size that will be produced.
56    ///
57    /// This includes overhead. A `max_fragment_size` of 10 will produce TLS fragments
58    /// up to 10 bytes long.
59    ///
60    /// A `max_fragment_size` of `None` sets the highest allowable fragment size.
61    ///
62    /// Returns BadMaxFragmentSize if the size is smaller than 32 or larger than 16389.
63    pub fn set_max_fragment_size(&mut self, max_fragment_size: Option<usize>) -> Result<(), Error> {
64        self.max_frag = match max_fragment_size {
65            Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD,
66            None => MAX_FRAGMENT_LEN,
67            _ => return Err(Error::BadMaxFragmentSize),
68        };
69        Ok(())
70    }
71}
72
73/// An iterator over borrowed fragments of a payload
74struct Chunker<'a> {
75    payload: OutboundChunks<'a>,
76    limit: usize,
77}
78
79impl<'a> Chunker<'a> {
80    fn new(payload: OutboundChunks<'a>, limit: usize) -> Self {
81        Self { payload, limit }
82    }
83}
84
85impl<'a> Iterator for Chunker<'a> {
86    type Item = OutboundChunks<'a>;
87
88    fn next(&mut self) -> Option<Self::Item> {
89        if self.payload.is_empty() {
90            return None;
91        }
92
93        let (before, after) = self.payload.split_at(self.limit);
94        self.payload = after;
95        Some(before)
96    }
97}
98
99impl ExactSizeIterator for Chunker<'_> {
100    fn len(&self) -> usize {
101        (self.payload.len() + self.limit - 1) / self.limit
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use std::prelude::v1::*;
108    use std::vec;
109
110    use super::{MessageFragmenter, PACKET_OVERHEAD};
111    use crate::enums::{ContentType, ProtocolVersion};
112    use crate::msgs::base::Payload;
113    use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage};
114
115    fn msg_eq(
116        m: &OutboundPlainMessage<'_>,
117        total_len: usize,
118        typ: &ContentType,
119        version: &ProtocolVersion,
120        bytes: &[u8],
121    ) {
122        assert_eq!(&m.typ, typ);
123        assert_eq!(&m.version, version);
124        assert_eq!(m.payload.to_vec(), bytes);
125
126        let buf = m.to_unencrypted_opaque().encode();
127
128        assert_eq!(total_len, buf.len());
129    }
130
131    #[test]
132    fn smoke() {
133        let typ = ContentType::Handshake;
134        let version = ProtocolVersion::TLSv1_2;
135        let data: Vec<u8> = (1..70u8).collect();
136        let m = PlainMessage {
137            typ,
138            version,
139            payload: Payload::new(data),
140        };
141
142        let mut frag = MessageFragmenter::default();
143        frag.set_max_fragment_size(Some(32))
144            .unwrap();
145        let q = frag
146            .fragment_message(&m)
147            .collect::<Vec<_>>();
148        assert_eq!(q.len(), 3);
149        msg_eq(
150            &q[0],
151            32,
152            &typ,
153            &version,
154            &[
155                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
156                24, 25, 26, 27,
157            ],
158        );
159        msg_eq(
160            &q[1],
161            32,
162            &typ,
163            &version,
164            &[
165                28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
166                49, 50, 51, 52, 53, 54,
167            ],
168        );
169        msg_eq(
170            &q[2],
171            20,
172            &typ,
173            &version,
174            &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
175        );
176    }
177
178    #[test]
179    fn non_fragment() {
180        let m = PlainMessage {
181            typ: ContentType::Handshake,
182            version: ProtocolVersion::TLSv1_2,
183            payload: Payload::new(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
184        };
185
186        let mut frag = MessageFragmenter::default();
187        frag.set_max_fragment_size(Some(32))
188            .unwrap();
189        let q = frag
190            .fragment_message(&m)
191            .collect::<Vec<_>>();
192        assert_eq!(q.len(), 1);
193        msg_eq(
194            &q[0],
195            PACKET_OVERHEAD + 8,
196            &ContentType::Handshake,
197            &ProtocolVersion::TLSv1_2,
198            b"\x01\x02\x03\x04\x05\x06\x07\x08",
199        );
200    }
201
202    #[test]
203    fn fragment_multiple_slices() {
204        let typ = ContentType::Handshake;
205        let version = ProtocolVersion::TLSv1_2;
206        let payload_owner: Vec<&[u8]> = vec![&[b'a'; 8], &[b'b'; 12], &[b'c'; 32], &[b'd'; 20]];
207        let borrowed_payload = OutboundChunks::new(&payload_owner);
208        let mut frag = MessageFragmenter::default();
209        frag.set_max_fragment_size(Some(37)) // 32 + packet overhead
210            .unwrap();
211
212        let fragments = frag
213            .fragment_payload(typ, version, borrowed_payload)
214            .collect::<Vec<_>>();
215        assert_eq!(fragments.len(), 3);
216        msg_eq(
217            &fragments[0],
218            37,
219            &typ,
220            &version,
221            b"aaaaaaaabbbbbbbbbbbbcccccccccccc",
222        );
223        msg_eq(
224            &fragments[1],
225            37,
226            &typ,
227            &version,
228            b"ccccccccccccccccccccdddddddddddd",
229        );
230        msg_eq(&fragments[2], 13, &typ, &version, b"dddddddd");
231    }
232}