rustls/msgs/
persist.rs

1use alloc::sync::Arc;
2use alloc::vec::Vec;
3use core::cmp;
4
5use pki_types::{DnsName, UnixTime};
6use zeroize::Zeroizing;
7
8use crate::enums::{CipherSuite, ProtocolVersion};
9use crate::error::InvalidMessage;
10use crate::msgs::base::{PayloadU16, PayloadU8};
11use crate::msgs::codec::{Codec, Reader};
12use crate::msgs::handshake::CertificateChain;
13#[cfg(feature = "tls12")]
14use crate::msgs::handshake::SessionId;
15#[cfg(feature = "tls12")]
16use crate::tls12::Tls12CipherSuite;
17use crate::tls13::Tls13CipherSuite;
18
19pub(crate) struct Retrieved<T> {
20    pub(crate) value: T,
21    retrieved_at: UnixTime,
22}
23
24impl<T> Retrieved<T> {
25    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
26        Self {
27            value,
28            retrieved_at,
29        }
30    }
31
32    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
33        Some(Retrieved {
34            value: f(&self.value)?,
35            retrieved_at: self.retrieved_at,
36        })
37    }
38}
39
40impl Retrieved<&Tls13ClientSessionValue> {
41    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
42        let age_secs = self
43            .retrieved_at
44            .as_secs()
45            .saturating_sub(self.value.common.epoch);
46        let age_millis = age_secs as u32 * 1000;
47        age_millis.wrapping_add(self.value.age_add)
48    }
49}
50
51impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
52    pub(crate) fn has_expired(&self) -> bool {
53        let common = &*self.value;
54        common.lifetime_secs != 0
55            && common
56                .epoch
57                .saturating_add(u64::from(common.lifetime_secs))
58                < self.retrieved_at.as_secs()
59    }
60}
61
62impl<T> core::ops::Deref for Retrieved<T> {
63    type Target = T;
64
65    fn deref(&self) -> &Self::Target {
66        &self.value
67    }
68}
69
70#[derive(Debug)]
71pub struct Tls13ClientSessionValue {
72    suite: &'static Tls13CipherSuite,
73    age_add: u32,
74    max_early_data_size: u32,
75    pub(crate) common: ClientSessionCommon,
76    quic_params: PayloadU16,
77}
78
79impl Tls13ClientSessionValue {
80    pub(crate) fn new(
81        suite: &'static Tls13CipherSuite,
82        ticket: Arc<PayloadU16>,
83        secret: &[u8],
84        server_cert_chain: CertificateChain<'static>,
85        time_now: UnixTime,
86        lifetime_secs: u32,
87        age_add: u32,
88        max_early_data_size: u32,
89    ) -> Self {
90        Self {
91            suite,
92            age_add,
93            max_early_data_size,
94            common: ClientSessionCommon::new(
95                ticket,
96                secret,
97                time_now,
98                lifetime_secs,
99                server_cert_chain,
100            ),
101            quic_params: PayloadU16(Vec::new()),
102        }
103    }
104
105    pub fn max_early_data_size(&self) -> u32 {
106        self.max_early_data_size
107    }
108
109    pub fn suite(&self) -> &'static Tls13CipherSuite {
110        self.suite
111    }
112
113    #[doc(hidden)]
114    /// Test only: rewind epoch by `delta` seconds.
115    pub fn rewind_epoch(&mut self, delta: u32) {
116        self.common.epoch -= delta as u64;
117    }
118
119    #[doc(hidden)]
120    /// Test only: replace `max_early_data_size` with `new`
121    pub fn _private_set_max_early_data_size(&mut self, new: u32) {
122        self.max_early_data_size = new;
123    }
124
125    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
126        self.quic_params = PayloadU16(quic_params.to_vec());
127    }
128
129    pub fn quic_params(&self) -> Vec<u8> {
130        self.quic_params.0.clone()
131    }
132}
133
134impl core::ops::Deref for Tls13ClientSessionValue {
135    type Target = ClientSessionCommon;
136
137    fn deref(&self) -> &Self::Target {
138        &self.common
139    }
140}
141
142#[derive(Debug, Clone)]
143pub struct Tls12ClientSessionValue {
144    #[cfg(feature = "tls12")]
145    suite: &'static Tls12CipherSuite,
146    #[cfg(feature = "tls12")]
147    pub(crate) session_id: SessionId,
148    #[cfg(feature = "tls12")]
149    extended_ms: bool,
150    #[doc(hidden)]
151    #[cfg(feature = "tls12")]
152    pub(crate) common: ClientSessionCommon,
153}
154
155#[cfg(feature = "tls12")]
156impl Tls12ClientSessionValue {
157    pub(crate) fn new(
158        suite: &'static Tls12CipherSuite,
159        session_id: SessionId,
160        ticket: Arc<PayloadU16>,
161        master_secret: &[u8],
162        server_cert_chain: CertificateChain<'static>,
163        time_now: UnixTime,
164        lifetime_secs: u32,
165        extended_ms: bool,
166    ) -> Self {
167        Self {
168            suite,
169            session_id,
170            extended_ms,
171            common: ClientSessionCommon::new(
172                ticket,
173                master_secret,
174                time_now,
175                lifetime_secs,
176                server_cert_chain,
177            ),
178        }
179    }
180
181    pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
182        Arc::clone(&self.common.ticket)
183    }
184
185    pub(crate) fn extended_ms(&self) -> bool {
186        self.extended_ms
187    }
188
189    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
190        self.suite
191    }
192
193    #[doc(hidden)]
194    /// Test only: rewind epoch by `delta` seconds.
195    pub fn rewind_epoch(&mut self, delta: u32) {
196        self.common.epoch -= delta as u64;
197    }
198}
199
200#[cfg(feature = "tls12")]
201impl core::ops::Deref for Tls12ClientSessionValue {
202    type Target = ClientSessionCommon;
203
204    fn deref(&self) -> &Self::Target {
205        &self.common
206    }
207}
208
209#[derive(Debug, Clone)]
210pub struct ClientSessionCommon {
211    ticket: Arc<PayloadU16>,
212    secret: Zeroizing<PayloadU8>,
213    epoch: u64,
214    lifetime_secs: u32,
215    server_cert_chain: Arc<CertificateChain<'static>>,
216}
217
218impl ClientSessionCommon {
219    fn new(
220        ticket: Arc<PayloadU16>,
221        secret: &[u8],
222        time_now: UnixTime,
223        lifetime_secs: u32,
224        server_cert_chain: CertificateChain<'static>,
225    ) -> Self {
226        Self {
227            ticket,
228            secret: Zeroizing::new(PayloadU8(secret.to_vec())),
229            epoch: time_now.as_secs(),
230            lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
231            server_cert_chain: Arc::new(server_cert_chain),
232        }
233    }
234
235    pub(crate) fn server_cert_chain(&self) -> &CertificateChain<'static> {
236        &self.server_cert_chain
237    }
238
239    pub(crate) fn secret(&self) -> &[u8] {
240        self.secret.0.as_ref()
241    }
242
243    pub(crate) fn ticket(&self) -> &[u8] {
244        self.ticket.0.as_ref()
245    }
246}
247
248static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
249
250/// This is the maximum allowed skew between server and client clocks, over
251/// the maximum ticket lifetime period.  This encompasses TCP retransmission
252/// times in case packet loss occurs when the client sends the ClientHello
253/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
254static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
255
256// --- Server types ---
257#[derive(Debug)]
258pub struct ServerSessionValue {
259    pub(crate) sni: Option<DnsName<'static>>,
260    pub(crate) version: ProtocolVersion,
261    pub(crate) cipher_suite: CipherSuite,
262    pub(crate) master_secret: Zeroizing<PayloadU8>,
263    pub(crate) extended_ms: bool,
264    pub(crate) client_cert_chain: Option<CertificateChain<'static>>,
265    pub(crate) alpn: Option<PayloadU8>,
266    pub(crate) application_data: PayloadU16,
267    pub creation_time_sec: u64,
268    pub(crate) age_obfuscation_offset: u32,
269    freshness: Option<bool>,
270}
271
272impl Codec<'_> for ServerSessionValue {
273    fn encode(&self, bytes: &mut Vec<u8>) {
274        if let Some(ref sni) = self.sni {
275            1u8.encode(bytes);
276            let sni_bytes: &str = sni.as_ref();
277            PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
278        } else {
279            0u8.encode(bytes);
280        }
281        self.version.encode(bytes);
282        self.cipher_suite.encode(bytes);
283        self.master_secret.encode(bytes);
284        (u8::from(self.extended_ms)).encode(bytes);
285        if let Some(ref chain) = self.client_cert_chain {
286            1u8.encode(bytes);
287            chain.encode(bytes);
288        } else {
289            0u8.encode(bytes);
290        }
291        if let Some(ref alpn) = self.alpn {
292            1u8.encode(bytes);
293            alpn.encode(bytes);
294        } else {
295            0u8.encode(bytes);
296        }
297        self.application_data.encode(bytes);
298        self.creation_time_sec.encode(bytes);
299        self.age_obfuscation_offset
300            .encode(bytes);
301    }
302
303    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
304        let has_sni = u8::read(r)?;
305        let sni = if has_sni == 1 {
306            let dns_name = PayloadU8::read(r)?;
307            let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
308                Ok(dns_name) => dns_name.to_owned(),
309                Err(_) => return Err(InvalidMessage::InvalidServerName),
310            };
311
312            Some(dns_name)
313        } else {
314            None
315        };
316
317        let v = ProtocolVersion::read(r)?;
318        let cs = CipherSuite::read(r)?;
319        let ms = Zeroizing::new(PayloadU8::read(r)?);
320        let ems = u8::read(r)?;
321        let has_ccert = u8::read(r)? == 1;
322        let ccert = if has_ccert {
323            Some(CertificateChain::read(r)?.into_owned())
324        } else {
325            None
326        };
327        let has_alpn = u8::read(r)? == 1;
328        let alpn = if has_alpn {
329            Some(PayloadU8::read(r)?)
330        } else {
331            None
332        };
333        let application_data = PayloadU16::read(r)?;
334        let creation_time_sec = u64::read(r)?;
335        let age_obfuscation_offset = u32::read(r)?;
336
337        Ok(Self {
338            sni,
339            version: v,
340            cipher_suite: cs,
341            master_secret: ms,
342            extended_ms: ems == 1u8,
343            client_cert_chain: ccert,
344            alpn,
345            application_data,
346            creation_time_sec,
347            age_obfuscation_offset,
348            freshness: None,
349        })
350    }
351}
352
353impl ServerSessionValue {
354    pub(crate) fn new(
355        sni: Option<&DnsName<'_>>,
356        v: ProtocolVersion,
357        cs: CipherSuite,
358        ms: &[u8],
359        client_cert_chain: Option<CertificateChain<'static>>,
360        alpn: Option<Vec<u8>>,
361        application_data: Vec<u8>,
362        creation_time: UnixTime,
363        age_obfuscation_offset: u32,
364    ) -> Self {
365        Self {
366            sni: sni.map(|dns| dns.to_owned()),
367            version: v,
368            cipher_suite: cs,
369            master_secret: Zeroizing::new(PayloadU8::new(ms.to_vec())),
370            extended_ms: false,
371            client_cert_chain,
372            alpn: alpn.map(PayloadU8::new),
373            application_data: PayloadU16::new(application_data),
374            creation_time_sec: creation_time.as_secs(),
375            age_obfuscation_offset,
376            freshness: None,
377        }
378    }
379
380    #[cfg(feature = "tls12")]
381    pub(crate) fn set_extended_ms_used(&mut self) {
382        self.extended_ms = true;
383    }
384
385    pub(crate) fn set_freshness(
386        mut self,
387        obfuscated_client_age_ms: u32,
388        time_now: UnixTime,
389    ) -> Self {
390        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
391        let server_age_ms = (time_now
392            .as_secs()
393            .saturating_sub(self.creation_time_sec) as u32)
394            .saturating_mul(1000);
395
396        let age_difference = if client_age_ms < server_age_ms {
397            server_age_ms - client_age_ms
398        } else {
399            client_age_ms - server_age_ms
400        };
401
402        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
403        self
404    }
405
406    pub(crate) fn is_fresh(&self) -> bool {
407        self.freshness.unwrap_or_default()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[cfg(feature = "std")] // for UnixTime::now
416    #[test]
417    fn serversessionvalue_is_debug() {
418        use std::{println, vec};
419        let ssv = ServerSessionValue::new(
420            None,
421            ProtocolVersion::TLSv1_3,
422            CipherSuite::TLS13_AES_128_GCM_SHA256,
423            &[1, 2, 3],
424            None,
425            None,
426            vec![4, 5, 6],
427            UnixTime::now(),
428            0x12345678,
429        );
430        println!("{:?}", ssv);
431    }
432
433    #[test]
434    fn serversessionvalue_no_sni() {
435        let bytes = [
436            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
437            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
438        ];
439        let mut rd = Reader::init(&bytes);
440        let ssv = ServerSessionValue::read(&mut rd).unwrap();
441        assert_eq!(ssv.get_encoding(), bytes);
442    }
443
444    #[test]
445    fn serversessionvalue_with_cert() {
446        let bytes = [
447            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
448            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
449        ];
450        let mut rd = Reader::init(&bytes);
451        let ssv = ServerSessionValue::read(&mut rd).unwrap();
452        assert_eq!(ssv.get_encoding(), bytes);
453    }
454}