1use crate::enums::{ContentType, ProtocolVersion};
2use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage};
3use crate::Error;
4pub(crate) const MAX_FRAGMENT_LEN: usize = 16384;
5pub(crate) const PACKET_OVERHEAD: usize = 1 + 2 + 2;
6pub(crate) const MAX_FRAGMENT_SIZE: usize = MAX_FRAGMENT_LEN + PACKET_OVERHEAD;
7
8pub struct MessageFragmenter {
9 max_frag: usize,
10}
11
12impl Default for MessageFragmenter {
13 fn default() -> Self {
14 Self {
15 max_frag: MAX_FRAGMENT_LEN,
16 }
17 }
18}
19
20impl MessageFragmenter {
21 pub fn fragment_message<'a>(
29 &self,
30 msg: &'a PlainMessage,
31 ) -> impl Iterator<Item = OutboundPlainMessage<'a>> + 'a {
32 self.fragment_payload(msg.typ, msg.version, msg.payload.bytes().into())
33 }
34
35 pub(crate) fn fragment_payload<'a>(
43 &self,
44 typ: ContentType,
45 version: ProtocolVersion,
46 payload: OutboundChunks<'a>,
47 ) -> impl ExactSizeIterator<Item = OutboundPlainMessage<'a>> {
48 Chunker::new(payload, self.max_frag).map(move |payload| OutboundPlainMessage {
49 typ,
50 version,
51 payload,
52 })
53 }
54
55 pub fn set_max_fragment_size(&mut self, max_fragment_size: Option<usize>) -> Result<(), Error> {
64 self.max_frag = match max_fragment_size {
65 Some(sz @ 32..=MAX_FRAGMENT_SIZE) => sz - PACKET_OVERHEAD,
66 None => MAX_FRAGMENT_LEN,
67 _ => return Err(Error::BadMaxFragmentSize),
68 };
69 Ok(())
70 }
71}
72
73struct Chunker<'a> {
75 payload: OutboundChunks<'a>,
76 limit: usize,
77}
78
79impl<'a> Chunker<'a> {
80 fn new(payload: OutboundChunks<'a>, limit: usize) -> Self {
81 Self { payload, limit }
82 }
83}
84
85impl<'a> Iterator for Chunker<'a> {
86 type Item = OutboundChunks<'a>;
87
88 fn next(&mut self) -> Option<Self::Item> {
89 if self.payload.is_empty() {
90 return None;
91 }
92
93 let (before, after) = self.payload.split_at(self.limit);
94 self.payload = after;
95 Some(before)
96 }
97}
98
99impl ExactSizeIterator for Chunker<'_> {
100 fn len(&self) -> usize {
101 (self.payload.len() + self.limit - 1) / self.limit
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use std::prelude::v1::*;
108 use std::vec;
109
110 use super::{MessageFragmenter, PACKET_OVERHEAD};
111 use crate::enums::{ContentType, ProtocolVersion};
112 use crate::msgs::base::Payload;
113 use crate::msgs::message::{OutboundChunks, OutboundPlainMessage, PlainMessage};
114
115 fn msg_eq(
116 m: &OutboundPlainMessage<'_>,
117 total_len: usize,
118 typ: &ContentType,
119 version: &ProtocolVersion,
120 bytes: &[u8],
121 ) {
122 assert_eq!(&m.typ, typ);
123 assert_eq!(&m.version, version);
124 assert_eq!(m.payload.to_vec(), bytes);
125
126 let buf = m.to_unencrypted_opaque().encode();
127
128 assert_eq!(total_len, buf.len());
129 }
130
131 #[test]
132 fn smoke() {
133 let typ = ContentType::Handshake;
134 let version = ProtocolVersion::TLSv1_2;
135 let data: Vec<u8> = (1..70u8).collect();
136 let m = PlainMessage {
137 typ,
138 version,
139 payload: Payload::new(data),
140 };
141
142 let mut frag = MessageFragmenter::default();
143 frag.set_max_fragment_size(Some(32))
144 .unwrap();
145 let q = frag
146 .fragment_message(&m)
147 .collect::<Vec<_>>();
148 assert_eq!(q.len(), 3);
149 msg_eq(
150 &q[0],
151 32,
152 &typ,
153 &version,
154 &[
155 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
156 24, 25, 26, 27,
157 ],
158 );
159 msg_eq(
160 &q[1],
161 32,
162 &typ,
163 &version,
164 &[
165 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
166 49, 50, 51, 52, 53, 54,
167 ],
168 );
169 msg_eq(
170 &q[2],
171 20,
172 &typ,
173 &version,
174 &[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
175 );
176 }
177
178 #[test]
179 fn non_fragment() {
180 let m = PlainMessage {
181 typ: ContentType::Handshake,
182 version: ProtocolVersion::TLSv1_2,
183 payload: Payload::new(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
184 };
185
186 let mut frag = MessageFragmenter::default();
187 frag.set_max_fragment_size(Some(32))
188 .unwrap();
189 let q = frag
190 .fragment_message(&m)
191 .collect::<Vec<_>>();
192 assert_eq!(q.len(), 1);
193 msg_eq(
194 &q[0],
195 PACKET_OVERHEAD + 8,
196 &ContentType::Handshake,
197 &ProtocolVersion::TLSv1_2,
198 b"\x01\x02\x03\x04\x05\x06\x07\x08",
199 );
200 }
201
202 #[test]
203 fn fragment_multiple_slices() {
204 let typ = ContentType::Handshake;
205 let version = ProtocolVersion::TLSv1_2;
206 let payload_owner: Vec<&[u8]> = vec![&[b'a'; 8], &[b'b'; 12], &[b'c'; 32], &[b'd'; 20]];
207 let borrowed_payload = OutboundChunks::new(&payload_owner);
208 let mut frag = MessageFragmenter::default();
209 frag.set_max_fragment_size(Some(37)) .unwrap();
211
212 let fragments = frag
213 .fragment_payload(typ, version, borrowed_payload)
214 .collect::<Vec<_>>();
215 assert_eq!(fragments.len(), 3);
216 msg_eq(
217 &fragments[0],
218 37,
219 &typ,
220 &version,
221 b"aaaaaaaabbbbbbbbbbbbcccccccccccc",
222 );
223 msg_eq(
224 &fragments[1],
225 37,
226 &typ,
227 &version,
228 b"ccccccccccccccccccccdddddddddddd",
229 );
230 msg_eq(&fragments[2], 13, &typ, &version, b"dddddddd");
231 }
232}