1use alloc::sync::Arc;
2use alloc::vec::Vec;
3use core::cmp;
4
5use pki_types::{DnsName, UnixTime};
6use zeroize::Zeroizing;
7
8use crate::enums::{CipherSuite, ProtocolVersion};
9use crate::error::InvalidMessage;
10use crate::msgs::base::{PayloadU16, PayloadU8};
11use crate::msgs::codec::{Codec, Reader};
12use crate::msgs::handshake::CertificateChain;
13#[cfg(feature = "tls12")]
14use crate::msgs::handshake::SessionId;
15#[cfg(feature = "tls12")]
16use crate::tls12::Tls12CipherSuite;
17use crate::tls13::Tls13CipherSuite;
18
19pub(crate) struct Retrieved<T> {
20 pub(crate) value: T,
21 retrieved_at: UnixTime,
22}
23
24impl<T> Retrieved<T> {
25 pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
26 Self {
27 value,
28 retrieved_at,
29 }
30 }
31
32 pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
33 Some(Retrieved {
34 value: f(&self.value)?,
35 retrieved_at: self.retrieved_at,
36 })
37 }
38}
39
40impl Retrieved<&Tls13ClientSessionValue> {
41 pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
42 let age_secs = self
43 .retrieved_at
44 .as_secs()
45 .saturating_sub(self.value.common.epoch);
46 let age_millis = age_secs as u32 * 1000;
47 age_millis.wrapping_add(self.value.age_add)
48 }
49}
50
51impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
52 pub(crate) fn has_expired(&self) -> bool {
53 let common = &*self.value;
54 common.lifetime_secs != 0
55 && common
56 .epoch
57 .saturating_add(u64::from(common.lifetime_secs))
58 < self.retrieved_at.as_secs()
59 }
60}
61
62impl<T> core::ops::Deref for Retrieved<T> {
63 type Target = T;
64
65 fn deref(&self) -> &Self::Target {
66 &self.value
67 }
68}
69
70#[derive(Debug)]
71pub struct Tls13ClientSessionValue {
72 suite: &'static Tls13CipherSuite,
73 age_add: u32,
74 max_early_data_size: u32,
75 pub(crate) common: ClientSessionCommon,
76 quic_params: PayloadU16,
77}
78
79impl Tls13ClientSessionValue {
80 pub(crate) fn new(
81 suite: &'static Tls13CipherSuite,
82 ticket: Arc<PayloadU16>,
83 secret: &[u8],
84 server_cert_chain: CertificateChain<'static>,
85 time_now: UnixTime,
86 lifetime_secs: u32,
87 age_add: u32,
88 max_early_data_size: u32,
89 ) -> Self {
90 Self {
91 suite,
92 age_add,
93 max_early_data_size,
94 common: ClientSessionCommon::new(
95 ticket,
96 secret,
97 time_now,
98 lifetime_secs,
99 server_cert_chain,
100 ),
101 quic_params: PayloadU16(Vec::new()),
102 }
103 }
104
105 pub fn max_early_data_size(&self) -> u32 {
106 self.max_early_data_size
107 }
108
109 pub fn suite(&self) -> &'static Tls13CipherSuite {
110 self.suite
111 }
112
113 #[doc(hidden)]
114 pub fn rewind_epoch(&mut self, delta: u32) {
116 self.common.epoch -= delta as u64;
117 }
118
119 #[doc(hidden)]
120 pub fn _private_set_max_early_data_size(&mut self, new: u32) {
122 self.max_early_data_size = new;
123 }
124
125 pub fn set_quic_params(&mut self, quic_params: &[u8]) {
126 self.quic_params = PayloadU16(quic_params.to_vec());
127 }
128
129 pub fn quic_params(&self) -> Vec<u8> {
130 self.quic_params.0.clone()
131 }
132}
133
134impl core::ops::Deref for Tls13ClientSessionValue {
135 type Target = ClientSessionCommon;
136
137 fn deref(&self) -> &Self::Target {
138 &self.common
139 }
140}
141
142#[derive(Debug, Clone)]
143pub struct Tls12ClientSessionValue {
144 #[cfg(feature = "tls12")]
145 suite: &'static Tls12CipherSuite,
146 #[cfg(feature = "tls12")]
147 pub(crate) session_id: SessionId,
148 #[cfg(feature = "tls12")]
149 extended_ms: bool,
150 #[doc(hidden)]
151 #[cfg(feature = "tls12")]
152 pub(crate) common: ClientSessionCommon,
153}
154
155#[cfg(feature = "tls12")]
156impl Tls12ClientSessionValue {
157 pub(crate) fn new(
158 suite: &'static Tls12CipherSuite,
159 session_id: SessionId,
160 ticket: Arc<PayloadU16>,
161 master_secret: &[u8],
162 server_cert_chain: CertificateChain<'static>,
163 time_now: UnixTime,
164 lifetime_secs: u32,
165 extended_ms: bool,
166 ) -> Self {
167 Self {
168 suite,
169 session_id,
170 extended_ms,
171 common: ClientSessionCommon::new(
172 ticket,
173 master_secret,
174 time_now,
175 lifetime_secs,
176 server_cert_chain,
177 ),
178 }
179 }
180
181 pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
182 Arc::clone(&self.common.ticket)
183 }
184
185 pub(crate) fn extended_ms(&self) -> bool {
186 self.extended_ms
187 }
188
189 pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
190 self.suite
191 }
192
193 #[doc(hidden)]
194 pub fn rewind_epoch(&mut self, delta: u32) {
196 self.common.epoch -= delta as u64;
197 }
198}
199
200#[cfg(feature = "tls12")]
201impl core::ops::Deref for Tls12ClientSessionValue {
202 type Target = ClientSessionCommon;
203
204 fn deref(&self) -> &Self::Target {
205 &self.common
206 }
207}
208
209#[derive(Debug, Clone)]
210pub struct ClientSessionCommon {
211 ticket: Arc<PayloadU16>,
212 secret: Zeroizing<PayloadU8>,
213 epoch: u64,
214 lifetime_secs: u32,
215 server_cert_chain: Arc<CertificateChain<'static>>,
216}
217
218impl ClientSessionCommon {
219 fn new(
220 ticket: Arc<PayloadU16>,
221 secret: &[u8],
222 time_now: UnixTime,
223 lifetime_secs: u32,
224 server_cert_chain: CertificateChain<'static>,
225 ) -> Self {
226 Self {
227 ticket,
228 secret: Zeroizing::new(PayloadU8(secret.to_vec())),
229 epoch: time_now.as_secs(),
230 lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
231 server_cert_chain: Arc::new(server_cert_chain),
232 }
233 }
234
235 pub(crate) fn server_cert_chain(&self) -> &CertificateChain<'static> {
236 &self.server_cert_chain
237 }
238
239 pub(crate) fn secret(&self) -> &[u8] {
240 self.secret.0.as_ref()
241 }
242
243 pub(crate) fn ticket(&self) -> &[u8] {
244 self.ticket.0.as_ref()
245 }
246}
247
248static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
249
250static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
255
256#[derive(Debug)]
258pub struct ServerSessionValue {
259 pub(crate) sni: Option<DnsName<'static>>,
260 pub(crate) version: ProtocolVersion,
261 pub(crate) cipher_suite: CipherSuite,
262 pub(crate) master_secret: Zeroizing<PayloadU8>,
263 pub(crate) extended_ms: bool,
264 pub(crate) client_cert_chain: Option<CertificateChain<'static>>,
265 pub(crate) alpn: Option<PayloadU8>,
266 pub(crate) application_data: PayloadU16,
267 pub creation_time_sec: u64,
268 pub(crate) age_obfuscation_offset: u32,
269 freshness: Option<bool>,
270}
271
272impl Codec<'_> for ServerSessionValue {
273 fn encode(&self, bytes: &mut Vec<u8>) {
274 if let Some(ref sni) = self.sni {
275 1u8.encode(bytes);
276 let sni_bytes: &str = sni.as_ref();
277 PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
278 } else {
279 0u8.encode(bytes);
280 }
281 self.version.encode(bytes);
282 self.cipher_suite.encode(bytes);
283 self.master_secret.encode(bytes);
284 (u8::from(self.extended_ms)).encode(bytes);
285 if let Some(ref chain) = self.client_cert_chain {
286 1u8.encode(bytes);
287 chain.encode(bytes);
288 } else {
289 0u8.encode(bytes);
290 }
291 if let Some(ref alpn) = self.alpn {
292 1u8.encode(bytes);
293 alpn.encode(bytes);
294 } else {
295 0u8.encode(bytes);
296 }
297 self.application_data.encode(bytes);
298 self.creation_time_sec.encode(bytes);
299 self.age_obfuscation_offset
300 .encode(bytes);
301 }
302
303 fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
304 let has_sni = u8::read(r)?;
305 let sni = if has_sni == 1 {
306 let dns_name = PayloadU8::read(r)?;
307 let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
308 Ok(dns_name) => dns_name.to_owned(),
309 Err(_) => return Err(InvalidMessage::InvalidServerName),
310 };
311
312 Some(dns_name)
313 } else {
314 None
315 };
316
317 let v = ProtocolVersion::read(r)?;
318 let cs = CipherSuite::read(r)?;
319 let ms = Zeroizing::new(PayloadU8::read(r)?);
320 let ems = u8::read(r)?;
321 let has_ccert = u8::read(r)? == 1;
322 let ccert = if has_ccert {
323 Some(CertificateChain::read(r)?.into_owned())
324 } else {
325 None
326 };
327 let has_alpn = u8::read(r)? == 1;
328 let alpn = if has_alpn {
329 Some(PayloadU8::read(r)?)
330 } else {
331 None
332 };
333 let application_data = PayloadU16::read(r)?;
334 let creation_time_sec = u64::read(r)?;
335 let age_obfuscation_offset = u32::read(r)?;
336
337 Ok(Self {
338 sni,
339 version: v,
340 cipher_suite: cs,
341 master_secret: ms,
342 extended_ms: ems == 1u8,
343 client_cert_chain: ccert,
344 alpn,
345 application_data,
346 creation_time_sec,
347 age_obfuscation_offset,
348 freshness: None,
349 })
350 }
351}
352
353impl ServerSessionValue {
354 pub(crate) fn new(
355 sni: Option<&DnsName<'_>>,
356 v: ProtocolVersion,
357 cs: CipherSuite,
358 ms: &[u8],
359 client_cert_chain: Option<CertificateChain<'static>>,
360 alpn: Option<Vec<u8>>,
361 application_data: Vec<u8>,
362 creation_time: UnixTime,
363 age_obfuscation_offset: u32,
364 ) -> Self {
365 Self {
366 sni: sni.map(|dns| dns.to_owned()),
367 version: v,
368 cipher_suite: cs,
369 master_secret: Zeroizing::new(PayloadU8::new(ms.to_vec())),
370 extended_ms: false,
371 client_cert_chain,
372 alpn: alpn.map(PayloadU8::new),
373 application_data: PayloadU16::new(application_data),
374 creation_time_sec: creation_time.as_secs(),
375 age_obfuscation_offset,
376 freshness: None,
377 }
378 }
379
380 #[cfg(feature = "tls12")]
381 pub(crate) fn set_extended_ms_used(&mut self) {
382 self.extended_ms = true;
383 }
384
385 pub(crate) fn set_freshness(
386 mut self,
387 obfuscated_client_age_ms: u32,
388 time_now: UnixTime,
389 ) -> Self {
390 let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
391 let server_age_ms = (time_now
392 .as_secs()
393 .saturating_sub(self.creation_time_sec) as u32)
394 .saturating_mul(1000);
395
396 let age_difference = if client_age_ms < server_age_ms {
397 server_age_ms - client_age_ms
398 } else {
399 client_age_ms - server_age_ms
400 };
401
402 self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
403 self
404 }
405
406 pub(crate) fn is_fresh(&self) -> bool {
407 self.freshness.unwrap_or_default()
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[cfg(feature = "std")] #[test]
417 fn serversessionvalue_is_debug() {
418 use std::{println, vec};
419 let ssv = ServerSessionValue::new(
420 None,
421 ProtocolVersion::TLSv1_3,
422 CipherSuite::TLS13_AES_128_GCM_SHA256,
423 &[1, 2, 3],
424 None,
425 None,
426 vec![4, 5, 6],
427 UnixTime::now(),
428 0x12345678,
429 );
430 println!("{:?}", ssv);
431 }
432
433 #[test]
434 fn serversessionvalue_no_sni() {
435 let bytes = [
436 0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
437 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
438 ];
439 let mut rd = Reader::init(&bytes);
440 let ssv = ServerSessionValue::read(&mut rd).unwrap();
441 assert_eq!(ssv.get_encoding(), bytes);
442 }
443
444 #[test]
445 fn serversessionvalue_with_cert() {
446 let bytes = [
447 0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
448 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
449 ];
450 let mut rd = Reader::init(&bytes);
451 let ssv = ServerSessionValue::read(&mut rd).unwrap();
452 assert_eq!(ssv.get_encoding(), bytes);
453 }
454}