rustls/
record_layer.rs

1use alloc::boxed::Box;
2use core::cmp::min;
3
4use crate::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter};
5use crate::error::Error;
6use crate::log::trace;
7use crate::msgs::message::{InboundPlainMessage, OutboundOpaqueMessage, OutboundPlainMessage};
8
9#[derive(PartialEq)]
10enum DirectionState {
11    /// No keying material.
12    Invalid,
13
14    /// Keying material present, but not yet in use.
15    Prepared,
16
17    /// Keying material in use.
18    Active,
19}
20
21/// Record layer that tracks decryption and encryption keys.
22pub(crate) struct RecordLayer {
23    message_encrypter: Box<dyn MessageEncrypter>,
24    message_decrypter: Box<dyn MessageDecrypter>,
25    write_seq_max: u64,
26    write_seq: u64,
27    read_seq: u64,
28    has_decrypted: bool,
29    encrypt_state: DirectionState,
30    decrypt_state: DirectionState,
31
32    // Message encrypted with other keys may be encountered, so failures
33    // should be swallowed by the caller.  This struct tracks the amount
34    // of message size this is allowed for.
35    trial_decryption_len: Option<usize>,
36}
37
38impl RecordLayer {
39    /// Create new record layer with no keys.
40    pub(crate) fn new() -> Self {
41        Self {
42            message_encrypter: <dyn MessageEncrypter>::invalid(),
43            message_decrypter: <dyn MessageDecrypter>::invalid(),
44            write_seq_max: 0,
45            write_seq: 0,
46            read_seq: 0,
47            has_decrypted: false,
48            encrypt_state: DirectionState::Invalid,
49            decrypt_state: DirectionState::Invalid,
50            trial_decryption_len: None,
51        }
52    }
53
54    /// Decrypt a TLS message.
55    ///
56    /// `encr` is a decoded message allegedly received from the peer.
57    /// If it can be decrypted, its decryption is returned.  Otherwise,
58    /// an error is returned.
59    pub(crate) fn decrypt_incoming<'a>(
60        &mut self,
61        encr: InboundOpaqueMessage<'a>,
62    ) -> Result<Option<Decrypted<'a>>, Error> {
63        if self.decrypt_state != DirectionState::Active {
64            return Ok(Some(Decrypted {
65                want_close_before_decrypt: false,
66                plaintext: encr.into_plain_message(),
67            }));
68        }
69
70        // Set to `true` if the peer appears to getting close to encrypting
71        // too many messages with this key.
72        //
73        // Perhaps if we send an alert well before their counter wraps, a
74        // buggy peer won't make a terrible mistake here?
75        //
76        // Note that there's no reason to refuse to decrypt: the security
77        // failure has already happened.
78        let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT;
79
80        let encrypted_len = encr.payload.len();
81        match self
82            .message_decrypter
83            .decrypt(encr, self.read_seq)
84        {
85            Ok(plaintext) => {
86                self.read_seq += 1;
87                if !self.has_decrypted {
88                    self.has_decrypted = true;
89                }
90                Ok(Some(Decrypted {
91                    want_close_before_decrypt,
92                    plaintext,
93                }))
94            }
95            Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => {
96                trace!("Dropping undecryptable message after aborted early_data");
97                Ok(None)
98            }
99            Err(err) => Err(err),
100        }
101    }
102
103    /// Encrypt a TLS message.
104    ///
105    /// `plain` is a TLS message we'd like to send.  This function
106    /// panics if the requisite keying material hasn't been established yet.
107    pub(crate) fn encrypt_outgoing(
108        &mut self,
109        plain: OutboundPlainMessage<'_>,
110    ) -> OutboundOpaqueMessage {
111        debug_assert!(self.encrypt_state == DirectionState::Active);
112        assert!(self.next_pre_encrypt_action() != PreEncryptAction::Refuse);
113        let seq = self.write_seq;
114        self.write_seq += 1;
115        self.message_encrypter
116            .encrypt(plain, seq)
117            .unwrap()
118    }
119
120    /// Prepare to use the given `MessageEncrypter` for future message encryption.
121    /// It is not used until you call `start_encrypting`.
122    pub(crate) fn prepare_message_encrypter(
123        &mut self,
124        cipher: Box<dyn MessageEncrypter>,
125        max_messages: u64,
126    ) {
127        self.message_encrypter = cipher;
128        self.write_seq = 0;
129        self.write_seq_max = min(SEQ_SOFT_LIMIT, max_messages);
130        self.encrypt_state = DirectionState::Prepared;
131    }
132
133    /// Prepare to use the given `MessageDecrypter` for future message decryption.
134    /// It is not used until you call `start_decrypting`.
135    pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
136        self.message_decrypter = cipher;
137        self.read_seq = 0;
138        self.decrypt_state = DirectionState::Prepared;
139    }
140
141    /// Start using the `MessageEncrypter` previously provided to the previous
142    /// call to `prepare_message_encrypter`.
143    pub(crate) fn start_encrypting(&mut self) {
144        debug_assert!(self.encrypt_state == DirectionState::Prepared);
145        self.encrypt_state = DirectionState::Active;
146    }
147
148    /// Start using the `MessageDecrypter` previously provided to the previous
149    /// call to `prepare_message_decrypter`.
150    pub(crate) fn start_decrypting(&mut self) {
151        debug_assert!(self.decrypt_state == DirectionState::Prepared);
152        self.decrypt_state = DirectionState::Active;
153    }
154
155    /// Set and start using the given `MessageEncrypter` for future outgoing
156    /// message encryption.
157    pub(crate) fn set_message_encrypter(
158        &mut self,
159        cipher: Box<dyn MessageEncrypter>,
160        max_messages: u64,
161    ) {
162        self.prepare_message_encrypter(cipher, max_messages);
163        self.start_encrypting();
164    }
165
166    /// Set and start using the given `MessageDecrypter` for future incoming
167    /// message decryption.
168    pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) {
169        self.prepare_message_decrypter(cipher);
170        self.start_decrypting();
171        self.trial_decryption_len = None;
172    }
173
174    /// Set and start using the given `MessageDecrypter` for future incoming
175    /// message decryption, and enable "trial decryption" mode for when TLS1.3
176    /// 0-RTT is attempted but rejected by the server.
177    pub(crate) fn set_message_decrypter_with_trial_decryption(
178        &mut self,
179        cipher: Box<dyn MessageDecrypter>,
180        max_length: usize,
181    ) {
182        self.prepare_message_decrypter(cipher);
183        self.start_decrypting();
184        self.trial_decryption_len = Some(max_length);
185    }
186
187    pub(crate) fn finish_trial_decryption(&mut self) {
188        self.trial_decryption_len = None;
189    }
190
191    pub(crate) fn next_pre_encrypt_action(&self) -> PreEncryptAction {
192        self.pre_encrypt_action(0)
193    }
194
195    /// Return a remedial action when we are near to encrypting too many messages.
196    ///
197    /// `add` is added to the current sequence number.  `add` as `0` means
198    /// "the next message processed by `encrypt_outgoing`"
199    pub(crate) fn pre_encrypt_action(&self, add: u64) -> PreEncryptAction {
200        match self.write_seq.saturating_add(add) {
201            v if v == self.write_seq_max => PreEncryptAction::RefreshOrClose,
202            SEQ_HARD_LIMIT.. => PreEncryptAction::Refuse,
203            _ => PreEncryptAction::Nothing,
204        }
205    }
206
207    pub(crate) fn is_encrypting(&self) -> bool {
208        self.encrypt_state == DirectionState::Active
209    }
210
211    /// Return true if we have ever decrypted a message. This is used in place
212    /// of checking the read_seq since that will be reset on key updates.
213    pub(crate) fn has_decrypted(&self) -> bool {
214        self.has_decrypted
215    }
216
217    pub(crate) fn write_seq(&self) -> u64 {
218        self.write_seq
219    }
220
221    pub(crate) fn read_seq(&self) -> u64 {
222        self.read_seq
223    }
224
225    pub(crate) fn encrypted_len(&self, payload_len: usize) -> usize {
226        self.message_encrypter
227            .encrypted_payload_len(payload_len)
228    }
229
230    fn doing_trial_decryption(&mut self, requested: usize) -> bool {
231        match self
232            .trial_decryption_len
233            .and_then(|value| value.checked_sub(requested))
234        {
235            Some(remaining) => {
236                self.trial_decryption_len = Some(remaining);
237                true
238            }
239            _ => false,
240        }
241    }
242}
243
244/// Result of decryption.
245#[derive(Debug)]
246pub(crate) struct Decrypted<'a> {
247    /// Whether the peer appears to be getting close to encrypting too many messages with this key.
248    pub(crate) want_close_before_decrypt: bool,
249    /// The decrypted message.
250    pub(crate) plaintext: InboundPlainMessage<'a>,
251}
252
253#[derive(Debug, Eq, PartialEq)]
254pub(crate) enum PreEncryptAction {
255    /// No action is needed before calling `encrypt_outgoing`
256    Nothing,
257
258    /// A `key_update` request should be sent ASAP.
259    ///
260    /// If that is not possible (for example, the connection is TLS1.2), a `close_notify`
261    /// alert should be sent instead.
262    RefreshOrClose,
263
264    /// Do not call `encrypt_outgoing` further, it will panic rather than
265    /// over-use the key.
266    Refuse,
267}
268
269const SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64;
270const SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64;
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_has_decrypted() {
278        use crate::{ContentType, ProtocolVersion};
279
280        struct PassThroughDecrypter;
281        impl MessageDecrypter for PassThroughDecrypter {
282            fn decrypt<'a>(
283                &mut self,
284                m: InboundOpaqueMessage<'a>,
285                _: u64,
286            ) -> Result<InboundPlainMessage<'a>, Error> {
287                Ok(m.into_plain_message())
288            }
289        }
290
291        // A record layer starts out invalid, having never decrypted.
292        let mut record_layer = RecordLayer::new();
293        assert!(matches!(
294            record_layer.decrypt_state,
295            DirectionState::Invalid
296        ));
297        assert_eq!(record_layer.read_seq, 0);
298        assert!(!record_layer.has_decrypted());
299
300        // Preparing the record layer should update the decrypt state, but shouldn't affect whether it
301        // has decrypted.
302        record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter));
303        assert!(matches!(
304            record_layer.decrypt_state,
305            DirectionState::Prepared
306        ));
307        assert_eq!(record_layer.read_seq, 0);
308        assert!(!record_layer.has_decrypted());
309
310        // Starting decryption should update the decrypt state, but not affect whether it has decrypted.
311        record_layer.start_decrypting();
312        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
313        assert_eq!(record_layer.read_seq, 0);
314        assert!(!record_layer.has_decrypted());
315
316        // Decrypting a message should update the read_seq and track that we have now performed
317        // a decryption.
318        record_layer
319            .decrypt_incoming(InboundOpaqueMessage::new(
320                ContentType::Handshake,
321                ProtocolVersion::TLSv1_2,
322                &mut [0xC0, 0xFF, 0xEE],
323            ))
324            .unwrap();
325        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
326        assert_eq!(record_layer.read_seq, 1);
327        assert!(record_layer.has_decrypted());
328
329        // Resetting the record layer message decrypter (as if a key update occurred) should reset
330        // the read_seq number, but not our knowledge of whether we have decrypted previously.
331        record_layer.set_message_decrypter(Box::new(PassThroughDecrypter));
332        assert!(matches!(record_layer.decrypt_state, DirectionState::Active));
333        assert_eq!(record_layer.read_seq, 0);
334        assert!(record_layer.has_decrypted());
335    }
336}