1use alloc::vec::Vec;
2use core::mem;
3use core::ops::Range;
4
5use super::buffers::{BufferProgress, Coalescer, Delocator, Locator};
6use crate::error::InvalidMessage;
7use crate::msgs::codec::{u24, Codec};
8use crate::msgs::message::InboundPlainMessage;
9use crate::{ContentType, ProtocolVersion};
10
11#[derive(Debug)]
12pub(crate) struct HandshakeDeframer {
13 spans: Vec<FragmentSpan>,
15
16 outer_discard: usize,
19}
20
21impl HandshakeDeframer {
22 pub(crate) fn input_message(
37 &mut self,
38 msg: InboundPlainMessage<'_>,
39 containing_buffer: &Locator,
40 outer_discard: usize,
41 ) {
42 debug_assert_eq!(msg.typ, ContentType::Handshake);
43 debug_assert!(containing_buffer.fully_contains(msg.payload));
44 debug_assert!(self.outer_discard <= outer_discard);
45
46 self.outer_discard = outer_discard;
47
48 if let Some(_last_incomplete) = self
57 .spans
58 .last()
59 .filter(|span| !span.is_complete())
60 {
61 self.spans.push(FragmentSpan {
62 version: msg.version,
63 size: None,
64 bounds: containing_buffer.locate(msg.payload),
65 });
66 return;
67 }
68
69 for span in DissectHandshakeIter::new(msg, containing_buffer) {
72 self.spans.push(span);
73 }
74 }
75
76 pub(crate) fn progress(&self) -> BufferProgress {
78 BufferProgress::new(self.outer_discard)
79 }
80
81 pub(crate) fn has_message_ready(&self) -> bool {
83 match self.spans.first() {
84 Some(span) => span.is_complete(),
85 None => false,
86 }
87 }
88
89 pub(crate) fn is_active(&self) -> bool {
91 !self.spans.is_empty()
92 }
93
94 pub(crate) fn is_aligned(&self) -> bool {
97 self.spans
98 .iter()
99 .all(|span| span.is_complete())
100 }
101
102 pub(crate) fn iter<'a, 'b>(&'a mut self, containing_buffer: &'b [u8]) -> HandshakeIter<'a, 'b> {
104 HandshakeIter {
105 deframer: self,
106 containing_buffer: Delocator::new(containing_buffer),
107 index: 0,
108 }
109 }
110
111 pub(crate) fn coalesce(&mut self, containing_buffer: &mut [u8]) -> Result<(), InvalidMessage> {
165 while let Some(i) = self.requires_coalesce() {
169 self.coalesce_one(i, Coalescer::new(containing_buffer));
170 }
171
172 match self
174 .spans
175 .iter()
176 .any(|span| span.size.unwrap_or_default() > MAX_HANDSHAKE_SIZE)
177 {
178 true => Err(InvalidMessage::HandshakePayloadTooLarge),
179 false => Ok(()),
180 }
181 }
182
183 fn coalesce_one(&mut self, index: usize, mut containing_buffer: Coalescer<'_>) {
186 let second = self.spans.remove(index + 1);
187 let mut first = self.spans.remove(index);
188
189 let len = second.bounds.len();
191 let target = Range {
192 start: first.bounds.end,
193 end: first.bounds.end + len,
194 };
195
196 containing_buffer.copy_within(second.bounds, target);
197 let delocator = containing_buffer.delocator();
198
199 first.bounds.end += len;
201
202 let msg = InboundPlainMessage {
204 typ: ContentType::Handshake,
205 version: first.version,
206 payload: delocator.slice_from_range(&first.bounds),
207 };
208
209 for (i, span) in DissectHandshakeIter::new(msg, &delocator.locator()).enumerate() {
210 self.spans.insert(index + i, span);
211 }
212 }
213
214 fn requires_coalesce(&self) -> Option<usize> {
219 self.spans
220 .split_last()
221 .and_then(|(_last, elements)| {
222 elements
223 .iter()
224 .enumerate()
225 .find_map(|(i, span)| (!span.is_complete()).then_some(i))
226 })
227 }
228}
229
230impl Default for HandshakeDeframer {
231 fn default() -> Self {
232 Self {
233 spans: Vec::with_capacity(16),
236 outer_discard: 0,
237 }
238 }
239}
240
241struct DissectHandshakeIter<'a, 'b> {
242 version: ProtocolVersion,
243 payload: &'b [u8],
244 containing_buffer: &'a Locator,
245}
246
247impl<'a, 'b> DissectHandshakeIter<'a, 'b> {
248 fn new(msg: InboundPlainMessage<'b>, containing_buffer: &'a Locator) -> Self {
249 Self {
250 version: msg.version,
251 payload: msg.payload,
252 containing_buffer,
253 }
254 }
255}
256
257impl Iterator for DissectHandshakeIter<'_, '_> {
258 type Item = FragmentSpan;
259
260 fn next(&mut self) -> Option<Self::Item> {
261 if self.payload.is_empty() {
262 return None;
263 }
264
265 if self.payload.len() < HANDSHAKE_HEADER_LEN {
267 let buf = mem::take(&mut self.payload);
268 let bounds = self.containing_buffer.locate(buf);
269 return Some(FragmentSpan {
270 version: self.version,
271 size: None,
272 bounds: bounds.clone(),
273 });
274 }
275
276 let (header, rest) = mem::take(&mut self.payload).split_at(HANDSHAKE_HEADER_LEN);
277
278 let size = u24::read_bytes(&header[1..])
280 .unwrap()
281 .into();
282
283 let available = if size < rest.len() {
284 self.payload = &rest[size..];
285 size
286 } else {
287 rest.len()
288 };
289
290 let mut bounds = self.containing_buffer.locate(header);
291 bounds.end += available;
292 Some(FragmentSpan {
293 version: self.version,
294 size: Some(size),
295 bounds: bounds.clone(),
296 })
297 }
298}
299
300pub(crate) struct HandshakeIter<'a, 'b> {
301 deframer: &'a mut HandshakeDeframer,
302 containing_buffer: Delocator<'b>,
303 index: usize,
304}
305
306impl<'b> Iterator for HandshakeIter<'_, 'b> {
307 type Item = (InboundPlainMessage<'b>, usize);
308
309 fn next(&mut self) -> Option<Self::Item> {
310 let next_span = self.deframer.spans.get(self.index)?;
311
312 if !next_span.is_complete() {
313 return None;
314 }
315
316 let discard = if self.deframer.spans.len() - 1 == self.index {
320 mem::take(&mut self.deframer.outer_discard)
321 } else {
322 0
323 };
324
325 self.index += 1;
326 Some((
327 InboundPlainMessage {
328 typ: ContentType::Handshake,
329 version: next_span.version,
330 payload: self
331 .containing_buffer
332 .slice_from_range(&next_span.bounds),
333 },
334 discard,
335 ))
336 }
337}
338
339impl Drop for HandshakeIter<'_, '_> {
340 fn drop(&mut self) {
341 self.deframer.spans.drain(..self.index);
342 }
343}
344
345#[derive(Debug)]
346struct FragmentSpan {
347 version: ProtocolVersion,
349
350 size: Option<usize>,
355
356 bounds: Range<usize>,
358}
359
360impl FragmentSpan {
361 fn is_complete(&self) -> bool {
364 match self.size {
365 Some(sz) => sz + HANDSHAKE_HEADER_LEN == self.bounds.len(),
366 None => false,
367 }
368 }
369}
370
371const HANDSHAKE_HEADER_LEN: usize = 1 + 3;
372
373const MAX_HANDSHAKE_SIZE: usize = 0xffff;
377
378#[cfg(test)]
379mod tests {
380 use std::vec;
381
382 use super::*;
383 use crate::msgs::deframer::DeframerIter;
384
385 fn add_bytes(hs: &mut HandshakeDeframer, slice: &[u8], within: &[u8]) {
386 let msg = InboundPlainMessage {
387 typ: ContentType::Handshake,
388 version: ProtocolVersion::TLSv1_3,
389 payload: slice,
390 };
391 let locator = Locator::new(within);
392 let discard = locator.locate(slice).end;
393 hs.input_message(msg, &locator, discard);
394 }
395
396 #[test]
397 fn coalesce() {
398 let mut input = vec![0, 0, 0, 0x21, 0, 0, 0, 0, 0x01, 0xff, 0x00, 0x01];
399 let mut hs = HandshakeDeframer::default();
400
401 add_bytes(&mut hs, &input[3..4], &input);
402 assert_eq!(hs.requires_coalesce(), None);
403 add_bytes(&mut hs, &input[4..6], &input);
404 assert_eq!(hs.requires_coalesce(), Some(0));
405 add_bytes(&mut hs, &input[8..10], &input);
406 assert_eq!(hs.requires_coalesce(), Some(0));
407
408 std::println!("before: {hs:?}");
409 hs.coalesce(&mut input).unwrap();
410 std::println!("after: {hs:?}");
411
412 let (msg, discard) = hs.iter(&input).next().unwrap();
413 std::println!("msg {msg:?} discard {discard:?}");
414 assert_eq!(msg.typ, ContentType::Handshake);
415 assert_eq!(msg.version, ProtocolVersion::TLSv1_3);
416 assert_eq!(msg.payload, &[0x21, 0x00, 0x00, 0x01, 0xff]);
417
418 input.drain(..discard);
419 assert_eq!(input, &[0, 1]);
420 }
421
422 #[test]
423 fn append() {
424 let mut input = vec![0, 0, 0, 0x21, 0, 0, 5, 0, 0, 1, 2, 3, 4, 5, 0];
425 let mut hs = HandshakeDeframer::default();
426
427 add_bytes(&mut hs, &input[3..7], &input);
428 add_bytes(&mut hs, &input[9..14], &input);
429 assert_eq!(hs.spans.len(), 2);
430
431 hs.coalesce(&mut input).unwrap();
432 assert_eq!(hs.spans.len(), 1);
433
434 let (msg, discard) = std::dbg!(hs.iter(&input).next().unwrap());
435 assert_eq!(msg.typ, ContentType::Handshake);
436 assert_eq!(msg.version, ProtocolVersion::TLSv1_3);
437 assert_eq!(msg.payload, &[0x21, 0x00, 0x00, 0x05, 1, 2, 3, 4, 5]);
438
439 input.drain(..discard);
440 assert_eq!(input, &[0]);
441 }
442
443 #[test]
444 fn coalesce_rejects_excess_size_message() {
445 const X: u8 = 0xff;
446 let mut input = vec![0x21, 0x01, 0x00, X, 0x00, 0xab, X];
447 let mut hs = HandshakeDeframer::default();
448
449 add_bytes(&mut hs, &input[0..3], &input);
452 add_bytes(&mut hs, &input[4..6], &input);
453
454 assert_eq!(
455 hs.coalesce(&mut input),
456 Err(InvalidMessage::HandshakePayloadTooLarge)
457 );
458 }
459
460 #[test]
461 fn iter_only_returns_full_messages() {
462 let input = [0, 0, 0, 0x21, 0, 0, 1, 0xab, 0x21, 0, 0, 1];
463
464 let mut hs = HandshakeDeframer::default();
465
466 add_bytes(&mut hs, &input[3..8], &input);
467 add_bytes(&mut hs, &input[8..12], &input);
468
469 let mut iter = hs.iter(&input);
470 let (msg, discard) = iter.next().unwrap();
471 assert!(iter.next().is_none());
472
473 assert_eq!(msg.typ, ContentType::Handshake);
474 assert_eq!(msg.version, ProtocolVersion::TLSv1_3);
475 assert_eq!(msg.payload, &[0x21, 0x00, 0x00, 0x01, 0xab]);
476 assert_eq!(discard, 0);
477 }
478
479 #[test]
480 fn handshake_flight() {
481 let mut input = include_bytes!("../../testdata/handshake-test.1.bin").to_vec();
483 let locator = Locator::new(&input);
484
485 let mut hs = HandshakeDeframer::default();
486
487 let mut iter = DeframerIter::new(&mut input[..]);
488
489 while let Some(message) = iter.next() {
490 let plain = message.unwrap().into_plain_message();
491 std::println!("message {plain:?}");
492
493 hs.input_message(plain, &locator, iter.bytes_consumed());
494 }
495
496 hs.coalesce(&mut input[..]).unwrap();
497
498 let mut iter = hs.iter(&input[..]);
499 for _ in 0..4 {
500 let (msg, discard) = iter.next().unwrap();
501 assert!(matches!(
502 msg,
503 InboundPlainMessage {
504 typ: ContentType::Handshake,
505 ..
506 }
507 ));
508 assert_eq!(discard, 0);
509 }
510
511 let (msg, discard) = iter.next().unwrap();
512 assert!(matches!(
513 msg,
514 InboundPlainMessage {
515 typ: ContentType::Handshake,
516 ..
517 }
518 ));
519 assert_eq!(discard, 4280);
520 drop(iter);
521
522 input.drain(0..discard);
523 assert!(input.is_empty());
524 }
525}