rustls/
hash_hs.rs

1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4
5use crate::crypto::hash;
6use crate::msgs::codec::Codec;
7use crate::msgs::enums::HashAlgorithm;
8use crate::msgs::handshake::HandshakeMessagePayload;
9use crate::msgs::message::{Message, MessagePayload};
10
11/// Early stage buffering of handshake payloads.
12///
13/// Before we know the hash algorithm to use to verify the handshake, we just buffer the messages.
14/// During the handshake, we may restart the transcript due to a HelloRetryRequest, reverting
15/// from the `HandshakeHash` to a `HandshakeHashBuffer` again.
16#[derive(Clone)]
17pub(crate) struct HandshakeHashBuffer {
18    buffer: Vec<u8>,
19    client_auth_enabled: bool,
20}
21
22impl HandshakeHashBuffer {
23    pub(crate) fn new() -> Self {
24        Self {
25            buffer: Vec::new(),
26            client_auth_enabled: false,
27        }
28    }
29
30    /// We might be doing client auth, so need to keep a full
31    /// log of the handshake.
32    pub(crate) fn set_client_auth_enabled(&mut self) {
33        self.client_auth_enabled = true;
34    }
35
36    /// Hash/buffer a handshake message.
37    pub(crate) fn add_message(&mut self, m: &Message<'_>) {
38        match &m.payload {
39            MessagePayload::Handshake { encoded, .. } => self.add_raw(encoded.bytes()),
40            MessagePayload::HandshakeFlight(payload) => self.add_raw(payload.bytes()),
41            _ => {}
42        };
43    }
44
45    /// Hash or buffer a byte slice.
46    fn add_raw(&mut self, buf: &[u8]) {
47        self.buffer.extend_from_slice(buf);
48    }
49
50    /// Get the hash value if we were to hash `extra` too.
51    pub(crate) fn hash_given(
52        &self,
53        provider: &'static dyn hash::Hash,
54        extra: &[u8],
55    ) -> hash::Output {
56        let mut ctx = provider.start();
57        ctx.update(&self.buffer);
58        ctx.update(extra);
59        ctx.finish()
60    }
61
62    /// We now know what hash function the verify_data will use.
63    pub(crate) fn start_hash(self, provider: &'static dyn hash::Hash) -> HandshakeHash {
64        let mut ctx = provider.start();
65        ctx.update(&self.buffer);
66        HandshakeHash {
67            provider,
68            ctx,
69            client_auth: match self.client_auth_enabled {
70                true => Some(self.buffer),
71                false => None,
72            },
73        }
74    }
75}
76
77/// This deals with keeping a running hash of the handshake
78/// payloads.  This is computed by buffering initially.  Once
79/// we know what hash function we need to use we switch to
80/// incremental hashing.
81///
82/// For client auth, we also need to buffer all the messages.
83/// This is disabled in cases where client auth is not possible.
84pub(crate) struct HandshakeHash {
85    provider: &'static dyn hash::Hash,
86    ctx: Box<dyn hash::Context>,
87
88    /// buffer for client-auth.
89    client_auth: Option<Vec<u8>>,
90}
91
92impl HandshakeHash {
93    /// We decided not to do client auth after all, so discard
94    /// the transcript.
95    pub(crate) fn abandon_client_auth(&mut self) {
96        self.client_auth = None;
97    }
98
99    /// Hash/buffer a handshake message.
100    pub(crate) fn add_message(&mut self, m: &Message<'_>) -> &mut Self {
101        match &m.payload {
102            MessagePayload::Handshake { encoded, .. } => self.add_raw(encoded.bytes()),
103            MessagePayload::HandshakeFlight(payload) => self.add_raw(payload.bytes()),
104            _ => self,
105        }
106    }
107
108    /// Hash/buffer an encoded handshake message.
109    pub(crate) fn add(&mut self, bytes: &[u8]) {
110        self.add_raw(bytes);
111    }
112
113    /// Hash or buffer a byte slice.
114    fn add_raw(&mut self, buf: &[u8]) -> &mut Self {
115        self.ctx.update(buf);
116
117        if let Some(buffer) = &mut self.client_auth {
118            buffer.extend_from_slice(buf);
119        }
120
121        self
122    }
123
124    /// Get the hash value if we were to hash `extra` too,
125    /// using hash function `hash`.
126    pub(crate) fn hash_given(&self, extra: &[u8]) -> hash::Output {
127        let mut ctx = self.ctx.fork();
128        ctx.update(extra);
129        ctx.finish()
130    }
131
132    pub(crate) fn into_hrr_buffer(self) -> HandshakeHashBuffer {
133        let old_hash = self.ctx.finish();
134        let old_handshake_hash_msg =
135            HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
136
137        HandshakeHashBuffer {
138            client_auth_enabled: self.client_auth.is_some(),
139            buffer: old_handshake_hash_msg.get_encoding(),
140        }
141    }
142
143    /// Take the current hash value, and encapsulate it in a
144    /// 'handshake_hash' handshake message.  Start this hash
145    /// again, with that message at the front.
146    pub(crate) fn rollup_for_hrr(&mut self) {
147        let ctx = &mut self.ctx;
148
149        let old_ctx = mem::replace(ctx, self.provider.start());
150        let old_hash = old_ctx.finish();
151        let old_handshake_hash_msg =
152            HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
153
154        self.add_raw(&old_handshake_hash_msg.get_encoding());
155    }
156
157    /// Get the current hash value.
158    pub(crate) fn current_hash(&self) -> hash::Output {
159        self.ctx.fork_finish()
160    }
161
162    /// Takes this object's buffer containing all handshake messages
163    /// so far.  This method only works once; it resets the buffer
164    /// to empty.
165    #[cfg(feature = "tls12")]
166    pub(crate) fn take_handshake_buf(&mut self) -> Option<Vec<u8>> {
167        self.client_auth.take()
168    }
169
170    /// The hashing algorithm
171    pub(crate) fn algorithm(&self) -> HashAlgorithm {
172        self.provider.algorithm()
173    }
174}
175
176impl Clone for HandshakeHash {
177    fn clone(&self) -> Self {
178        Self {
179            provider: self.provider,
180            ctx: self.ctx.fork(),
181            client_auth: self.client_auth.clone(),
182        }
183    }
184}
185
186#[cfg(test)]
187#[macro_rules_attribute::apply(test_for_each_provider)]
188mod tests {
189    use super::provider::hash::SHA256;
190    use super::*;
191    use crate::crypto::hash::Hash;
192    use crate::enums::{HandshakeType, ProtocolVersion};
193    use crate::msgs::base::Payload;
194    use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
195
196    #[test]
197    fn hashes_correctly() {
198        let mut hhb = HandshakeHashBuffer::new();
199        hhb.add_raw(b"hello");
200        assert_eq!(hhb.buffer.len(), 5);
201        let mut hh = hhb.start_hash(&SHA256);
202        assert!(hh.client_auth.is_none());
203        hh.add_raw(b"world");
204        let h = hh.current_hash();
205        let h = h.as_ref();
206        assert_eq!(h[0], 0x93);
207        assert_eq!(h[1], 0x6a);
208        assert_eq!(h[2], 0x18);
209        assert_eq!(h[3], 0x5c);
210    }
211
212    #[test]
213    fn hashes_message_types() {
214        // handshake protocol encoding of 0x0e 00 00 00
215        let server_hello_done_message = Message {
216            version: ProtocolVersion::TLSv1_2,
217            payload: MessagePayload::handshake(HandshakeMessagePayload {
218                typ: HandshakeType::ServerHelloDone,
219                payload: HandshakePayload::ServerHelloDone,
220            }),
221        };
222
223        let app_data_ignored = Message {
224            version: ProtocolVersion::TLSv1_3,
225            payload: MessagePayload::ApplicationData(Payload::Borrowed(b"hello")),
226        };
227
228        let end_of_early_data_flight = Message {
229            version: ProtocolVersion::TLSv1_3,
230            payload: MessagePayload::HandshakeFlight(Payload::Borrowed(b"\x05\x00\x00\x00")),
231        };
232
233        // buffered mode
234        let mut hhb = HandshakeHashBuffer::new();
235        hhb.add_message(&server_hello_done_message);
236        hhb.add_message(&app_data_ignored);
237        hhb.add_message(&end_of_early_data_flight);
238        assert_eq!(
239            hhb.start_hash(&SHA256)
240                .current_hash()
241                .as_ref(),
242            SHA256
243                .hash(b"\x0e\x00\x00\x00\x05\x00\x00\x00")
244                .as_ref()
245        );
246
247        // non-buffered mode
248        let mut hh = HandshakeHashBuffer::new().start_hash(&SHA256);
249        hh.add_message(&server_hello_done_message);
250        hh.add_message(&app_data_ignored);
251        hh.add_message(&end_of_early_data_flight);
252        assert_eq!(
253            hh.current_hash().as_ref(),
254            SHA256
255                .hash(b"\x0e\x00\x00\x00\x05\x00\x00\x00")
256                .as_ref()
257        );
258    }
259
260    #[cfg(feature = "tls12")]
261    #[test]
262    fn buffers_correctly() {
263        let mut hhb = HandshakeHashBuffer::new();
264        hhb.set_client_auth_enabled();
265        hhb.add_raw(b"hello");
266        assert_eq!(hhb.buffer.len(), 5);
267        let mut hh = hhb.start_hash(&SHA256);
268        assert_eq!(
269            hh.client_auth
270                .as_ref()
271                .map(|buf| buf.len()),
272            Some(5)
273        );
274        hh.add_raw(b"world");
275        assert_eq!(
276            hh.client_auth
277                .as_ref()
278                .map(|buf| buf.len()),
279            Some(10)
280        );
281        let h = hh.current_hash();
282        let h = h.as_ref();
283        assert_eq!(h[0], 0x93);
284        assert_eq!(h[1], 0x6a);
285        assert_eq!(h[2], 0x18);
286        assert_eq!(h[3], 0x5c);
287        let buf = hh.take_handshake_buf();
288        assert_eq!(Some(b"helloworld".to_vec()), buf);
289    }
290
291    #[test]
292    fn abandon() {
293        let mut hhb = HandshakeHashBuffer::new();
294        hhb.set_client_auth_enabled();
295        hhb.add_raw(b"hello");
296        assert_eq!(hhb.buffer.len(), 5);
297        let mut hh = hhb.start_hash(&SHA256);
298        assert_eq!(
299            hh.client_auth
300                .as_ref()
301                .map(|buf| buf.len()),
302            Some(5)
303        );
304        hh.abandon_client_auth();
305        assert_eq!(hh.client_auth, None);
306        hh.add_raw(b"world");
307        assert_eq!(hh.client_auth, None);
308        let h = hh.current_hash();
309        let h = h.as_ref();
310        assert_eq!(h[0], 0x93);
311        assert_eq!(h[1], 0x6a);
312        assert_eq!(h[2], 0x18);
313        assert_eq!(h[3], 0x5c);
314    }
315
316    #[test]
317    fn clones_correctly() {
318        let mut hhb = HandshakeHashBuffer::new();
319        hhb.set_client_auth_enabled();
320        hhb.add_raw(b"hello");
321        assert_eq!(hhb.buffer.len(), 5);
322
323        // Cloning the HHB should result in the same buffer and client auth state.
324        let mut hhb_prime = hhb.clone();
325        assert_eq!(hhb_prime.buffer, hhb.buffer);
326        assert!(hhb_prime.client_auth_enabled);
327
328        // Updating the HHB clone shouldn't affect the original.
329        hhb_prime.add_raw(b"world");
330        assert_eq!(hhb_prime.buffer.len(), 10);
331        assert_ne!(hhb.buffer, hhb_prime.buffer);
332
333        let hh = hhb.start_hash(&SHA256);
334        let hh_hash = hh.current_hash();
335        let hh_hash = hh_hash.as_ref();
336
337        // Cloning the HH should result in the same current hash.
338        let mut hh_prime = hh.clone();
339        let hh_prime_hash = hh_prime.current_hash();
340        let hh_prime_hash = hh_prime_hash.as_ref();
341        assert_eq!(hh_hash, hh_prime_hash);
342
343        // Updating the HH clone shouldn't affect the original.
344        hh_prime.add_raw(b"goodbye");
345        assert_eq!(hh.current_hash().as_ref(), hh_hash);
346        assert_ne!(hh_prime.current_hash().as_ref(), hh_hash);
347    }
348}