1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4
5use crate::crypto::hash;
6use crate::msgs::codec::Codec;
7use crate::msgs::enums::HashAlgorithm;
8use crate::msgs::handshake::HandshakeMessagePayload;
9use crate::msgs::message::{Message, MessagePayload};
10
11#[derive(Clone)]
17pub(crate) struct HandshakeHashBuffer {
18 buffer: Vec<u8>,
19 client_auth_enabled: bool,
20}
21
22impl HandshakeHashBuffer {
23 pub(crate) fn new() -> Self {
24 Self {
25 buffer: Vec::new(),
26 client_auth_enabled: false,
27 }
28 }
29
30 pub(crate) fn set_client_auth_enabled(&mut self) {
33 self.client_auth_enabled = true;
34 }
35
36 pub(crate) fn add_message(&mut self, m: &Message<'_>) {
38 match &m.payload {
39 MessagePayload::Handshake { encoded, .. } => self.add_raw(encoded.bytes()),
40 MessagePayload::HandshakeFlight(payload) => self.add_raw(payload.bytes()),
41 _ => {}
42 };
43 }
44
45 fn add_raw(&mut self, buf: &[u8]) {
47 self.buffer.extend_from_slice(buf);
48 }
49
50 pub(crate) fn hash_given(
52 &self,
53 provider: &'static dyn hash::Hash,
54 extra: &[u8],
55 ) -> hash::Output {
56 let mut ctx = provider.start();
57 ctx.update(&self.buffer);
58 ctx.update(extra);
59 ctx.finish()
60 }
61
62 pub(crate) fn start_hash(self, provider: &'static dyn hash::Hash) -> HandshakeHash {
64 let mut ctx = provider.start();
65 ctx.update(&self.buffer);
66 HandshakeHash {
67 provider,
68 ctx,
69 client_auth: match self.client_auth_enabled {
70 true => Some(self.buffer),
71 false => None,
72 },
73 }
74 }
75}
76
77pub(crate) struct HandshakeHash {
85 provider: &'static dyn hash::Hash,
86 ctx: Box<dyn hash::Context>,
87
88 client_auth: Option<Vec<u8>>,
90}
91
92impl HandshakeHash {
93 pub(crate) fn abandon_client_auth(&mut self) {
96 self.client_auth = None;
97 }
98
99 pub(crate) fn add_message(&mut self, m: &Message<'_>) -> &mut Self {
101 match &m.payload {
102 MessagePayload::Handshake { encoded, .. } => self.add_raw(encoded.bytes()),
103 MessagePayload::HandshakeFlight(payload) => self.add_raw(payload.bytes()),
104 _ => self,
105 }
106 }
107
108 pub(crate) fn add(&mut self, bytes: &[u8]) {
110 self.add_raw(bytes);
111 }
112
113 fn add_raw(&mut self, buf: &[u8]) -> &mut Self {
115 self.ctx.update(buf);
116
117 if let Some(buffer) = &mut self.client_auth {
118 buffer.extend_from_slice(buf);
119 }
120
121 self
122 }
123
124 pub(crate) fn hash_given(&self, extra: &[u8]) -> hash::Output {
127 let mut ctx = self.ctx.fork();
128 ctx.update(extra);
129 ctx.finish()
130 }
131
132 pub(crate) fn into_hrr_buffer(self) -> HandshakeHashBuffer {
133 let old_hash = self.ctx.finish();
134 let old_handshake_hash_msg =
135 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
136
137 HandshakeHashBuffer {
138 client_auth_enabled: self.client_auth.is_some(),
139 buffer: old_handshake_hash_msg.get_encoding(),
140 }
141 }
142
143 pub(crate) fn rollup_for_hrr(&mut self) {
147 let ctx = &mut self.ctx;
148
149 let old_ctx = mem::replace(ctx, self.provider.start());
150 let old_hash = old_ctx.finish();
151 let old_handshake_hash_msg =
152 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
153
154 self.add_raw(&old_handshake_hash_msg.get_encoding());
155 }
156
157 pub(crate) fn current_hash(&self) -> hash::Output {
159 self.ctx.fork_finish()
160 }
161
162 #[cfg(feature = "tls12")]
166 pub(crate) fn take_handshake_buf(&mut self) -> Option<Vec<u8>> {
167 self.client_auth.take()
168 }
169
170 pub(crate) fn algorithm(&self) -> HashAlgorithm {
172 self.provider.algorithm()
173 }
174}
175
176impl Clone for HandshakeHash {
177 fn clone(&self) -> Self {
178 Self {
179 provider: self.provider,
180 ctx: self.ctx.fork(),
181 client_auth: self.client_auth.clone(),
182 }
183 }
184}
185
186#[cfg(test)]
187#[macro_rules_attribute::apply(test_for_each_provider)]
188mod tests {
189 use super::provider::hash::SHA256;
190 use super::*;
191 use crate::crypto::hash::Hash;
192 use crate::enums::{HandshakeType, ProtocolVersion};
193 use crate::msgs::base::Payload;
194 use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
195
196 #[test]
197 fn hashes_correctly() {
198 let mut hhb = HandshakeHashBuffer::new();
199 hhb.add_raw(b"hello");
200 assert_eq!(hhb.buffer.len(), 5);
201 let mut hh = hhb.start_hash(&SHA256);
202 assert!(hh.client_auth.is_none());
203 hh.add_raw(b"world");
204 let h = hh.current_hash();
205 let h = h.as_ref();
206 assert_eq!(h[0], 0x93);
207 assert_eq!(h[1], 0x6a);
208 assert_eq!(h[2], 0x18);
209 assert_eq!(h[3], 0x5c);
210 }
211
212 #[test]
213 fn hashes_message_types() {
214 let server_hello_done_message = Message {
216 version: ProtocolVersion::TLSv1_2,
217 payload: MessagePayload::handshake(HandshakeMessagePayload {
218 typ: HandshakeType::ServerHelloDone,
219 payload: HandshakePayload::ServerHelloDone,
220 }),
221 };
222
223 let app_data_ignored = Message {
224 version: ProtocolVersion::TLSv1_3,
225 payload: MessagePayload::ApplicationData(Payload::Borrowed(b"hello")),
226 };
227
228 let end_of_early_data_flight = Message {
229 version: ProtocolVersion::TLSv1_3,
230 payload: MessagePayload::HandshakeFlight(Payload::Borrowed(b"\x05\x00\x00\x00")),
231 };
232
233 let mut hhb = HandshakeHashBuffer::new();
235 hhb.add_message(&server_hello_done_message);
236 hhb.add_message(&app_data_ignored);
237 hhb.add_message(&end_of_early_data_flight);
238 assert_eq!(
239 hhb.start_hash(&SHA256)
240 .current_hash()
241 .as_ref(),
242 SHA256
243 .hash(b"\x0e\x00\x00\x00\x05\x00\x00\x00")
244 .as_ref()
245 );
246
247 let mut hh = HandshakeHashBuffer::new().start_hash(&SHA256);
249 hh.add_message(&server_hello_done_message);
250 hh.add_message(&app_data_ignored);
251 hh.add_message(&end_of_early_data_flight);
252 assert_eq!(
253 hh.current_hash().as_ref(),
254 SHA256
255 .hash(b"\x0e\x00\x00\x00\x05\x00\x00\x00")
256 .as_ref()
257 );
258 }
259
260 #[cfg(feature = "tls12")]
261 #[test]
262 fn buffers_correctly() {
263 let mut hhb = HandshakeHashBuffer::new();
264 hhb.set_client_auth_enabled();
265 hhb.add_raw(b"hello");
266 assert_eq!(hhb.buffer.len(), 5);
267 let mut hh = hhb.start_hash(&SHA256);
268 assert_eq!(
269 hh.client_auth
270 .as_ref()
271 .map(|buf| buf.len()),
272 Some(5)
273 );
274 hh.add_raw(b"world");
275 assert_eq!(
276 hh.client_auth
277 .as_ref()
278 .map(|buf| buf.len()),
279 Some(10)
280 );
281 let h = hh.current_hash();
282 let h = h.as_ref();
283 assert_eq!(h[0], 0x93);
284 assert_eq!(h[1], 0x6a);
285 assert_eq!(h[2], 0x18);
286 assert_eq!(h[3], 0x5c);
287 let buf = hh.take_handshake_buf();
288 assert_eq!(Some(b"helloworld".to_vec()), buf);
289 }
290
291 #[test]
292 fn abandon() {
293 let mut hhb = HandshakeHashBuffer::new();
294 hhb.set_client_auth_enabled();
295 hhb.add_raw(b"hello");
296 assert_eq!(hhb.buffer.len(), 5);
297 let mut hh = hhb.start_hash(&SHA256);
298 assert_eq!(
299 hh.client_auth
300 .as_ref()
301 .map(|buf| buf.len()),
302 Some(5)
303 );
304 hh.abandon_client_auth();
305 assert_eq!(hh.client_auth, None);
306 hh.add_raw(b"world");
307 assert_eq!(hh.client_auth, None);
308 let h = hh.current_hash();
309 let h = h.as_ref();
310 assert_eq!(h[0], 0x93);
311 assert_eq!(h[1], 0x6a);
312 assert_eq!(h[2], 0x18);
313 assert_eq!(h[3], 0x5c);
314 }
315
316 #[test]
317 fn clones_correctly() {
318 let mut hhb = HandshakeHashBuffer::new();
319 hhb.set_client_auth_enabled();
320 hhb.add_raw(b"hello");
321 assert_eq!(hhb.buffer.len(), 5);
322
323 let mut hhb_prime = hhb.clone();
325 assert_eq!(hhb_prime.buffer, hhb.buffer);
326 assert!(hhb_prime.client_auth_enabled);
327
328 hhb_prime.add_raw(b"world");
330 assert_eq!(hhb_prime.buffer.len(), 10);
331 assert_ne!(hhb.buffer, hhb_prime.buffer);
332
333 let hh = hhb.start_hash(&SHA256);
334 let hh_hash = hh.current_hash();
335 let hh_hash = hh_hash.as_ref();
336
337 let mut hh_prime = hh.clone();
339 let hh_prime_hash = hh_prime.current_hash();
340 let hh_prime_hash = hh_prime_hash.as_ref();
341 assert_eq!(hh_hash, hh_prime_hash);
342
343 hh_prime.add_raw(b"goodbye");
345 assert_eq!(hh.current_hash().as_ref(), hh_hash);
346 assert_ne!(hh_prime.current_hash().as_ref(), hh_hash);
347 }
348}