rustls/msgs/
handshake.rs

1use alloc::collections::BTreeSet;
2#[cfg(feature = "logging")]
3use alloc::string::String;
4use alloc::sync::Arc;
5use alloc::vec;
6use alloc::vec::Vec;
7use core::ops::Deref;
8use core::{fmt, iter};
9
10use pki_types::{CertificateDer, DnsName};
11
12#[cfg(feature = "tls12")]
13use crate::crypto::ActiveKeyExchange;
14use crate::crypto::SecureRandom;
15use crate::enums::{
16    CertificateCompressionAlgorithm, CipherSuite, EchClientHelloType, HandshakeType,
17    ProtocolVersion, SignatureScheme,
18};
19use crate::error::InvalidMessage;
20#[cfg(feature = "tls12")]
21use crate::ffdhe_groups::FfdheGroup;
22use crate::log::warn;
23use crate::msgs::base::{Payload, PayloadU16, PayloadU24, PayloadU8};
24use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
25use crate::msgs::enums::{
26    CertificateStatusType, CertificateType, ClientCertificateType, Compression, ECCurveType,
27    ECPointFormat, EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest,
28    NamedGroup, PSKKeyExchangeMode, ServerNameType,
29};
30use crate::rand;
31use crate::verify::DigitallySignedStruct;
32use crate::x509::wrap_in_sequence;
33
34/// Create a newtype wrapper around a given type.
35///
36/// This is used to create newtypes for the various TLS message types which is used to wrap
37/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
38/// anything other than access to the underlying bytes.
39macro_rules! wrapped_payload(
40  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident,) => {
41    $(#[$comment])*
42    #[derive(Clone, Debug)]
43    $vis struct $name($inner);
44
45    impl From<Vec<u8>> for $name {
46        fn from(v: Vec<u8>) -> Self {
47            Self($inner::new(v))
48        }
49    }
50
51    impl AsRef<[u8]> for $name {
52        fn as_ref(&self) -> &[u8] {
53            self.0.0.as_slice()
54        }
55    }
56
57    impl Codec<'_> for $name {
58        fn encode(&self, bytes: &mut Vec<u8>) {
59            self.0.encode(bytes);
60        }
61
62        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
63            Ok(Self($inner::read(r)?))
64        }
65    }
66  }
67);
68
69#[derive(Clone, Copy, Eq, PartialEq)]
70pub struct Random(pub(crate) [u8; 32]);
71
72impl fmt::Debug for Random {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        super::base::hex(f, &self.0)
75    }
76}
77
78static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
79    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
80    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
81]);
82
83static ZERO_RANDOM: Random = Random([0u8; 32]);
84
85impl Codec<'_> for Random {
86    fn encode(&self, bytes: &mut Vec<u8>) {
87        bytes.extend_from_slice(&self.0);
88    }
89
90    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
91        let Some(bytes) = r.take(32) else {
92            return Err(InvalidMessage::MissingData("Random"));
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(self.as_ref());
148    }
149
150    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let Some(bytes) = r.take(len) else {
157            return Err(InvalidMessage::MissingData("SessionID"));
158        };
159
160        let mut out = [0u8; 32];
161        out[..len].clone_from_slice(&bytes[..len]);
162        Ok(Self { data: out, len })
163    }
164}
165
166impl SessionId {
167    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
168        let mut data = [0u8; 32];
169        secure_random.fill(&mut data)?;
170        Ok(Self { data, len: 32 })
171    }
172
173    pub(crate) fn empty() -> Self {
174        Self {
175            data: [0u8; 32],
176            len: 0,
177        }
178    }
179
180    #[cfg(feature = "tls12")]
181    pub(crate) fn is_empty(&self) -> bool {
182        self.len == 0
183    }
184}
185
186impl AsRef<[u8]> for SessionId {
187    fn as_ref(&self) -> &[u8] {
188        &self.data[..self.len]
189    }
190}
191
192#[derive(Clone, Debug, PartialEq)]
193pub struct UnknownExtension {
194    pub(crate) typ: ExtensionType,
195    pub(crate) payload: Payload<'static>,
196}
197
198impl UnknownExtension {
199    fn encode(&self, bytes: &mut Vec<u8>) {
200        self.payload.encode(bytes);
201    }
202
203    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
204        let payload = Payload::read(r).into_owned();
205        Self { typ, payload }
206    }
207}
208
209impl TlsListElement for ECPointFormat {
210    const SIZE_LEN: ListLength = ListLength::U8;
211}
212
213impl TlsListElement for NamedGroup {
214    const SIZE_LEN: ListLength = ListLength::U16;
215}
216
217impl TlsListElement for SignatureScheme {
218    const SIZE_LEN: ListLength = ListLength::U16;
219}
220
221#[derive(Clone, Debug)]
222pub(crate) enum ServerNamePayload {
223    HostName(DnsName<'static>),
224    IpAddress(PayloadU16),
225    Unknown(Payload<'static>),
226}
227
228impl ServerNamePayload {
229    pub(crate) fn new_hostname(hostname: DnsName<'static>) -> Self {
230        Self::HostName(hostname)
231    }
232
233    fn read_hostname(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
234        use pki_types::ServerName;
235        let raw = PayloadU16::read(r)?;
236
237        match ServerName::try_from(raw.0.as_slice()) {
238            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
239            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
240            Ok(_) | Err(_) => {
241                warn!(
242                    "Illegal SNI hostname received {:?}",
243                    String::from_utf8_lossy(&raw.0)
244                );
245                Err(InvalidMessage::InvalidServerName)
246            }
247        }
248    }
249
250    fn encode(&self, bytes: &mut Vec<u8>) {
251        match *self {
252            Self::HostName(ref name) => {
253                (name.as_ref().len() as u16).encode(bytes);
254                bytes.extend_from_slice(name.as_ref().as_bytes());
255            }
256            Self::IpAddress(ref r) => r.encode(bytes),
257            Self::Unknown(ref r) => r.encode(bytes),
258        }
259    }
260}
261
262#[derive(Clone, Debug)]
263pub struct ServerName {
264    pub(crate) typ: ServerNameType,
265    pub(crate) payload: ServerNamePayload,
266}
267
268impl Codec<'_> for ServerName {
269    fn encode(&self, bytes: &mut Vec<u8>) {
270        self.typ.encode(bytes);
271        self.payload.encode(bytes);
272    }
273
274    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
275        let typ = ServerNameType::read(r)?;
276
277        let payload = match typ {
278            ServerNameType::HostName => ServerNamePayload::read_hostname(r)?,
279            _ => ServerNamePayload::Unknown(Payload::read(r).into_owned()),
280        };
281
282        Ok(Self { typ, payload })
283    }
284}
285
286impl TlsListElement for ServerName {
287    const SIZE_LEN: ListLength = ListLength::U16;
288}
289
290pub(crate) trait ConvertServerNameList {
291    fn has_duplicate_names_for_type(&self) -> bool;
292    fn single_hostname(&self) -> Option<DnsName<'_>>;
293}
294
295impl ConvertServerNameList for [ServerName] {
296    /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
297    fn has_duplicate_names_for_type(&self) -> bool {
298        has_duplicates::<_, _, u8>(self.iter().map(|name| name.typ))
299    }
300
301    fn single_hostname(&self) -> Option<DnsName<'_>> {
302        fn only_dns_hostnames(name: &ServerName) -> Option<DnsName<'_>> {
303            if let ServerNamePayload::HostName(ref dns) = name.payload {
304                Some(dns.borrow())
305            } else {
306                None
307            }
308        }
309
310        self.iter()
311            .filter_map(only_dns_hostnames)
312            .next()
313    }
314}
315
316wrapped_payload!(pub struct ProtocolName, PayloadU8,);
317
318impl TlsListElement for ProtocolName {
319    const SIZE_LEN: ListLength = ListLength::U16;
320}
321
322pub(crate) trait ConvertProtocolNameList {
323    fn from_slices(names: &[&[u8]]) -> Self;
324    fn to_slices(&self) -> Vec<&[u8]>;
325    fn as_single_slice(&self) -> Option<&[u8]>;
326}
327
328impl ConvertProtocolNameList for Vec<ProtocolName> {
329    fn from_slices(names: &[&[u8]]) -> Self {
330        let mut ret = Self::new();
331
332        for name in names {
333            ret.push(ProtocolName::from(name.to_vec()));
334        }
335
336        ret
337    }
338
339    fn to_slices(&self) -> Vec<&[u8]> {
340        self.iter()
341            .map(|proto| proto.as_ref())
342            .collect::<Vec<&[u8]>>()
343    }
344
345    fn as_single_slice(&self) -> Option<&[u8]> {
346        if self.len() == 1 {
347            Some(self[0].as_ref())
348        } else {
349            None
350        }
351    }
352}
353
354// --- TLS 1.3 Key shares ---
355#[derive(Clone, Debug)]
356pub struct KeyShareEntry {
357    pub(crate) group: NamedGroup,
358    pub(crate) payload: PayloadU16,
359}
360
361impl KeyShareEntry {
362    pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
363        Self {
364            group,
365            payload: PayloadU16::new(payload.into()),
366        }
367    }
368
369    pub fn group(&self) -> NamedGroup {
370        self.group
371    }
372}
373
374impl Codec<'_> for KeyShareEntry {
375    fn encode(&self, bytes: &mut Vec<u8>) {
376        self.group.encode(bytes);
377        self.payload.encode(bytes);
378    }
379
380    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
381        let group = NamedGroup::read(r)?;
382        let payload = PayloadU16::read(r)?;
383
384        Ok(Self { group, payload })
385    }
386}
387
388// --- TLS 1.3 PresharedKey offers ---
389#[derive(Clone, Debug)]
390pub(crate) struct PresharedKeyIdentity {
391    pub(crate) identity: PayloadU16,
392    pub(crate) obfuscated_ticket_age: u32,
393}
394
395impl PresharedKeyIdentity {
396    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
397        Self {
398            identity: PayloadU16::new(id),
399            obfuscated_ticket_age: age,
400        }
401    }
402}
403
404impl Codec<'_> for PresharedKeyIdentity {
405    fn encode(&self, bytes: &mut Vec<u8>) {
406        self.identity.encode(bytes);
407        self.obfuscated_ticket_age.encode(bytes);
408    }
409
410    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
411        Ok(Self {
412            identity: PayloadU16::read(r)?,
413            obfuscated_ticket_age: u32::read(r)?,
414        })
415    }
416}
417
418impl TlsListElement for PresharedKeyIdentity {
419    const SIZE_LEN: ListLength = ListLength::U16;
420}
421
422wrapped_payload!(pub(crate) struct PresharedKeyBinder, PayloadU8,);
423
424impl TlsListElement for PresharedKeyBinder {
425    const SIZE_LEN: ListLength = ListLength::U16;
426}
427
428#[derive(Clone, Debug)]
429pub struct PresharedKeyOffer {
430    pub(crate) identities: Vec<PresharedKeyIdentity>,
431    pub(crate) binders: Vec<PresharedKeyBinder>,
432}
433
434impl PresharedKeyOffer {
435    /// Make a new one with one entry.
436    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
437        Self {
438            identities: vec![id],
439            binders: vec![PresharedKeyBinder::from(binder)],
440        }
441    }
442}
443
444impl Codec<'_> for PresharedKeyOffer {
445    fn encode(&self, bytes: &mut Vec<u8>) {
446        self.identities.encode(bytes);
447        self.binders.encode(bytes);
448    }
449
450    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
451        Ok(Self {
452            identities: Vec::read(r)?,
453            binders: Vec::read(r)?,
454        })
455    }
456}
457
458// --- RFC6066 certificate status request ---
459wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
460
461impl TlsListElement for ResponderId {
462    const SIZE_LEN: ListLength = ListLength::U16;
463}
464
465#[derive(Clone, Debug)]
466pub struct OcspCertificateStatusRequest {
467    pub(crate) responder_ids: Vec<ResponderId>,
468    pub(crate) extensions: PayloadU16,
469}
470
471impl Codec<'_> for OcspCertificateStatusRequest {
472    fn encode(&self, bytes: &mut Vec<u8>) {
473        CertificateStatusType::OCSP.encode(bytes);
474        self.responder_ids.encode(bytes);
475        self.extensions.encode(bytes);
476    }
477
478    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
479        Ok(Self {
480            responder_ids: Vec::read(r)?,
481            extensions: PayloadU16::read(r)?,
482        })
483    }
484}
485
486#[derive(Clone, Debug)]
487pub enum CertificateStatusRequest {
488    Ocsp(OcspCertificateStatusRequest),
489    Unknown((CertificateStatusType, Payload<'static>)),
490}
491
492impl Codec<'_> for CertificateStatusRequest {
493    fn encode(&self, bytes: &mut Vec<u8>) {
494        match self {
495            Self::Ocsp(ref r) => r.encode(bytes),
496            Self::Unknown((typ, payload)) => {
497                typ.encode(bytes);
498                payload.encode(bytes);
499            }
500        }
501    }
502
503    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
504        let typ = CertificateStatusType::read(r)?;
505
506        match typ {
507            CertificateStatusType::OCSP => {
508                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
509                Ok(Self::Ocsp(ocsp_req))
510            }
511            _ => {
512                let data = Payload::read(r).into_owned();
513                Ok(Self::Unknown((typ, data)))
514            }
515        }
516    }
517}
518
519impl CertificateStatusRequest {
520    pub(crate) fn build_ocsp() -> Self {
521        let ocsp = OcspCertificateStatusRequest {
522            responder_ids: Vec::new(),
523            extensions: PayloadU16::empty(),
524        };
525        Self::Ocsp(ocsp)
526    }
527}
528
529// ---
530
531impl TlsListElement for PSKKeyExchangeMode {
532    const SIZE_LEN: ListLength = ListLength::U8;
533}
534
535impl TlsListElement for KeyShareEntry {
536    const SIZE_LEN: ListLength = ListLength::U16;
537}
538
539impl TlsListElement for ProtocolVersion {
540    const SIZE_LEN: ListLength = ListLength::U8;
541}
542
543impl TlsListElement for CertificateType {
544    const SIZE_LEN: ListLength = ListLength::U8;
545}
546
547impl TlsListElement for CertificateCompressionAlgorithm {
548    const SIZE_LEN: ListLength = ListLength::U8;
549}
550
551#[derive(Clone, Debug)]
552pub enum ClientExtension {
553    EcPointFormats(Vec<ECPointFormat>),
554    NamedGroups(Vec<NamedGroup>),
555    SignatureAlgorithms(Vec<SignatureScheme>),
556    ServerName(Vec<ServerName>),
557    SessionTicket(ClientSessionTicket),
558    Protocols(Vec<ProtocolName>),
559    SupportedVersions(Vec<ProtocolVersion>),
560    KeyShare(Vec<KeyShareEntry>),
561    PresharedKeyModes(Vec<PSKKeyExchangeMode>),
562    PresharedKey(PresharedKeyOffer),
563    Cookie(PayloadU16),
564    ExtendedMasterSecretRequest,
565    CertificateStatusRequest(CertificateStatusRequest),
566    ServerCertTypes(Vec<CertificateType>),
567    ClientCertTypes(Vec<CertificateType>),
568    TransportParameters(Vec<u8>),
569    TransportParametersDraft(Vec<u8>),
570    EarlyData,
571    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
572    EncryptedClientHello(EncryptedClientHello),
573    EncryptedClientHelloOuterExtensions(Vec<ExtensionType>),
574    Unknown(UnknownExtension),
575}
576
577impl ClientExtension {
578    pub(crate) fn ext_type(&self) -> ExtensionType {
579        match *self {
580            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
581            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
582            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
583            Self::ServerName(_) => ExtensionType::ServerName,
584            Self::SessionTicket(_) => ExtensionType::SessionTicket,
585            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
586            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
587            Self::KeyShare(_) => ExtensionType::KeyShare,
588            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
589            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
590            Self::Cookie(_) => ExtensionType::Cookie,
591            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
592            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
593            Self::ClientCertTypes(_) => ExtensionType::ClientCertificateType,
594            Self::ServerCertTypes(_) => ExtensionType::ServerCertificateType,
595            Self::TransportParameters(_) => ExtensionType::TransportParameters,
596            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
597            Self::EarlyData => ExtensionType::EarlyData,
598            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
599            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
600            Self::EncryptedClientHelloOuterExtensions(_) => {
601                ExtensionType::EncryptedClientHelloOuterExtensions
602            }
603            Self::Unknown(ref r) => r.typ,
604        }
605    }
606}
607
608impl Codec<'_> for ClientExtension {
609    fn encode(&self, bytes: &mut Vec<u8>) {
610        self.ext_type().encode(bytes);
611
612        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
613        match *self {
614            Self::EcPointFormats(ref r) => r.encode(nested.buf),
615            Self::NamedGroups(ref r) => r.encode(nested.buf),
616            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
617            Self::ServerName(ref r) => r.encode(nested.buf),
618            Self::SessionTicket(ClientSessionTicket::Request)
619            | Self::ExtendedMasterSecretRequest
620            | Self::EarlyData => {}
621            Self::SessionTicket(ClientSessionTicket::Offer(ref r)) => r.encode(nested.buf),
622            Self::Protocols(ref r) => r.encode(nested.buf),
623            Self::SupportedVersions(ref r) => r.encode(nested.buf),
624            Self::KeyShare(ref r) => r.encode(nested.buf),
625            Self::PresharedKeyModes(ref r) => r.encode(nested.buf),
626            Self::PresharedKey(ref r) => r.encode(nested.buf),
627            Self::Cookie(ref r) => r.encode(nested.buf),
628            Self::CertificateStatusRequest(ref r) => r.encode(nested.buf),
629            Self::ClientCertTypes(ref r) => r.encode(nested.buf),
630            Self::ServerCertTypes(ref r) => r.encode(nested.buf),
631            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
632                nested.buf.extend_from_slice(r);
633            }
634            Self::CertificateCompressionAlgorithms(ref r) => r.encode(nested.buf),
635            Self::EncryptedClientHello(ref r) => r.encode(nested.buf),
636            Self::EncryptedClientHelloOuterExtensions(ref r) => r.encode(nested.buf),
637            Self::Unknown(ref r) => r.encode(nested.buf),
638        }
639    }
640
641    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
642        let typ = ExtensionType::read(r)?;
643        let len = u16::read(r)? as usize;
644        let mut sub = r.sub(len)?;
645
646        let ext = match typ {
647            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
648            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
649            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
650            ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?),
651            ExtensionType::SessionTicket => {
652                if sub.any_left() {
653                    let contents = Payload::read(&mut sub).into_owned();
654                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
655                } else {
656                    Self::SessionTicket(ClientSessionTicket::Request)
657                }
658            }
659            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
660            ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?),
661            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
662            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
663            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
664            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
665            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
666                Self::ExtendedMasterSecretRequest
667            }
668            ExtensionType::ClientCertificateType => Self::ClientCertTypes(Vec::read(&mut sub)?),
669            ExtensionType::ServerCertificateType => Self::ServerCertTypes(Vec::read(&mut sub)?),
670            ExtensionType::StatusRequest => {
671                let csr = CertificateStatusRequest::read(&mut sub)?;
672                Self::CertificateStatusRequest(csr)
673            }
674            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
675            ExtensionType::TransportParametersDraft => {
676                Self::TransportParametersDraft(sub.rest().to_vec())
677            }
678            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
679            ExtensionType::CompressCertificate => {
680                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
681            }
682            ExtensionType::EncryptedClientHelloOuterExtensions => {
683                Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?)
684            }
685            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
686        };
687
688        sub.expect_empty("ClientExtension")
689            .map(|_| ext)
690    }
691}
692
693fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
694    let dns_name_str = dns_name.as_ref();
695
696    // RFC6066: "The hostname is represented as a byte string using
697    // ASCII encoding without a trailing dot"
698    if dns_name_str.ends_with('.') {
699        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
700        DnsName::try_from(trimmed)
701            .unwrap()
702            .to_owned()
703    } else {
704        dns_name.to_owned()
705    }
706}
707
708impl ClientExtension {
709    /// Make a basic SNI ServerNameRequest quoting `hostname`.
710    pub(crate) fn make_sni(dns_name: &DnsName<'_>) -> Self {
711        let name = ServerName {
712            typ: ServerNameType::HostName,
713            payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),
714        };
715
716        Self::ServerName(vec![name])
717    }
718}
719
720#[derive(Clone, Debug)]
721pub enum ClientSessionTicket {
722    Request,
723    Offer(Payload<'static>),
724}
725
726#[derive(Clone, Debug)]
727pub enum ServerExtension {
728    EcPointFormats(Vec<ECPointFormat>),
729    ServerNameAck,
730    SessionTicketAck,
731    RenegotiationInfo(PayloadU8),
732    Protocols(Vec<ProtocolName>),
733    KeyShare(KeyShareEntry),
734    PresharedKey(u16),
735    ExtendedMasterSecretAck,
736    CertificateStatusAck,
737    ServerCertType(CertificateType),
738    ClientCertType(CertificateType),
739    SupportedVersions(ProtocolVersion),
740    TransportParameters(Vec<u8>),
741    TransportParametersDraft(Vec<u8>),
742    EarlyData,
743    EncryptedClientHello(ServerEncryptedClientHello),
744    Unknown(UnknownExtension),
745}
746
747impl ServerExtension {
748    pub(crate) fn ext_type(&self) -> ExtensionType {
749        match *self {
750            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
751            Self::ServerNameAck => ExtensionType::ServerName,
752            Self::SessionTicketAck => ExtensionType::SessionTicket,
753            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
754            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
755            Self::KeyShare(_) => ExtensionType::KeyShare,
756            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
757            Self::ClientCertType(_) => ExtensionType::ClientCertificateType,
758            Self::ServerCertType(_) => ExtensionType::ServerCertificateType,
759            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
760            Self::CertificateStatusAck => ExtensionType::StatusRequest,
761            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
762            Self::TransportParameters(_) => ExtensionType::TransportParameters,
763            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
764            Self::EarlyData => ExtensionType::EarlyData,
765            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
766            Self::Unknown(ref r) => r.typ,
767        }
768    }
769}
770
771impl Codec<'_> for ServerExtension {
772    fn encode(&self, bytes: &mut Vec<u8>) {
773        self.ext_type().encode(bytes);
774
775        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
776        match *self {
777            Self::EcPointFormats(ref r) => r.encode(nested.buf),
778            Self::ServerNameAck
779            | Self::SessionTicketAck
780            | Self::ExtendedMasterSecretAck
781            | Self::CertificateStatusAck
782            | Self::EarlyData => {}
783            Self::RenegotiationInfo(ref r) => r.encode(nested.buf),
784            Self::Protocols(ref r) => r.encode(nested.buf),
785            Self::KeyShare(ref r) => r.encode(nested.buf),
786            Self::PresharedKey(r) => r.encode(nested.buf),
787            Self::ClientCertType(r) => r.encode(nested.buf),
788            Self::ServerCertType(r) => r.encode(nested.buf),
789            Self::SupportedVersions(ref r) => r.encode(nested.buf),
790            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
791                nested.buf.extend_from_slice(r);
792            }
793            Self::EncryptedClientHello(ref r) => r.encode(nested.buf),
794            Self::Unknown(ref r) => r.encode(nested.buf),
795        }
796    }
797
798    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
799        let typ = ExtensionType::read(r)?;
800        let len = u16::read(r)? as usize;
801        let mut sub = r.sub(len)?;
802
803        let ext = match typ {
804            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
805            ExtensionType::ServerName => Self::ServerNameAck,
806            ExtensionType::SessionTicket => Self::SessionTicketAck,
807            ExtensionType::StatusRequest => Self::CertificateStatusAck,
808            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
809            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
810            ExtensionType::ClientCertificateType => {
811                Self::ClientCertType(CertificateType::read(&mut sub)?)
812            }
813            ExtensionType::ServerCertificateType => {
814                Self::ServerCertType(CertificateType::read(&mut sub)?)
815            }
816            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
817            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
818            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
819            ExtensionType::SupportedVersions => {
820                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
821            }
822            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
823            ExtensionType::TransportParametersDraft => {
824                Self::TransportParametersDraft(sub.rest().to_vec())
825            }
826            ExtensionType::EarlyData => Self::EarlyData,
827            ExtensionType::EncryptedClientHello => {
828                Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?)
829            }
830            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
831        };
832
833        sub.expect_empty("ServerExtension")
834            .map(|_| ext)
835    }
836}
837
838impl ServerExtension {
839    pub(crate) fn make_alpn(proto: &[&[u8]]) -> Self {
840        Self::Protocols(Vec::from_slices(proto))
841    }
842
843    #[cfg(feature = "tls12")]
844    pub(crate) fn make_empty_renegotiation_info() -> Self {
845        let empty = Vec::new();
846        Self::RenegotiationInfo(PayloadU8::new(empty))
847    }
848}
849
850#[derive(Clone, Debug)]
851pub struct ClientHelloPayload {
852    pub client_version: ProtocolVersion,
853    pub random: Random,
854    pub session_id: SessionId,
855    pub cipher_suites: Vec<CipherSuite>,
856    pub compression_methods: Vec<Compression>,
857    pub extensions: Vec<ClientExtension>,
858}
859
860impl Codec<'_> for ClientHelloPayload {
861    fn encode(&self, bytes: &mut Vec<u8>) {
862        self.payload_encode(bytes, Encoding::Standard)
863    }
864
865    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
866        let mut ret = Self {
867            client_version: ProtocolVersion::read(r)?,
868            random: Random::read(r)?,
869            session_id: SessionId::read(r)?,
870            cipher_suites: Vec::read(r)?,
871            compression_methods: Vec::read(r)?,
872            extensions: Vec::new(),
873        };
874
875        if r.any_left() {
876            ret.extensions = Vec::read(r)?;
877        }
878
879        match (r.any_left(), ret.extensions.is_empty()) {
880            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
881            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
882            _ => Ok(ret),
883        }
884    }
885}
886
887impl TlsListElement for CipherSuite {
888    const SIZE_LEN: ListLength = ListLength::U16;
889}
890
891impl TlsListElement for Compression {
892    const SIZE_LEN: ListLength = ListLength::U8;
893}
894
895impl TlsListElement for ClientExtension {
896    const SIZE_LEN: ListLength = ListLength::U16;
897}
898
899impl TlsListElement for ExtensionType {
900    const SIZE_LEN: ListLength = ListLength::U8;
901}
902
903impl ClientHelloPayload {
904    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
905        let mut bytes = Vec::new();
906        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
907        bytes
908    }
909
910    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
911        self.client_version.encode(bytes);
912        self.random.encode(bytes);
913
914        match purpose {
915            // SessionID is required to be empty in the encoded inner client hello.
916            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
917            _ => self.session_id.encode(bytes),
918        }
919
920        self.cipher_suites.encode(bytes);
921        self.compression_methods.encode(bytes);
922
923        let to_compress = match purpose {
924            // Compressed extensions must be replaced in the encoded inner client hello.
925            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
926            _ => {
927                if !self.extensions.is_empty() {
928                    self.extensions.encode(bytes);
929                }
930                return;
931            }
932        };
933
934        // Safety: not empty check in match guard.
935        let first_compressed_type = *to_compress.first().unwrap();
936
937        // Compressed extensions are in a contiguous range and must be replaced
938        // with a marker extension.
939        let compressed_start_idx = self
940            .extensions
941            .iter()
942            .position(|ext| ext.ext_type() == first_compressed_type);
943        let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len());
944        let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress);
945
946        let exts = self
947            .extensions
948            .iter()
949            .enumerate()
950            .filter_map(|(i, ext)| {
951                if Some(i) == compressed_start_idx {
952                    Some(&marker_ext)
953                } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx {
954                    None
955                } else {
956                    Some(ext)
957                }
958            });
959
960        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
961        for ext in exts {
962            ext.encode(nested.buf);
963        }
964    }
965
966    /// Returns true if there is more than one extension of a given
967    /// type.
968    pub(crate) fn has_duplicate_extension(&self) -> bool {
969        has_duplicates::<_, _, u16>(
970            self.extensions
971                .iter()
972                .map(|ext| ext.ext_type()),
973        )
974    }
975
976    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
977        self.extensions
978            .iter()
979            .find(|x| x.ext_type() == ext)
980    }
981
982    pub(crate) fn sni_extension(&self) -> Option<&[ServerName]> {
983        let ext = self.find_extension(ExtensionType::ServerName)?;
984        match *ext {
985            // Does this comply with RFC6066?
986            //
987            // [RFC6066][] specifies that literal IP addresses are illegal in
988            // `ServerName`s with a `name_type` of `host_name`.
989            //
990            // Some clients incorrectly send such extensions: we choose to
991            // successfully parse these (into `ServerNamePayload::IpAddress`)
992            // but then act like the client sent no `server_name` extension.
993            //
994            // [RFC6066]: https://datatracker.ietf.org/doc/html/rfc6066#section-3
995            ClientExtension::ServerName(ref req)
996                if !req
997                    .iter()
998                    .any(|name| matches!(name.payload, ServerNamePayload::IpAddress(_))) =>
999            {
1000                Some(req)
1001            }
1002            _ => None,
1003        }
1004    }
1005
1006    pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1007        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1008        match *ext {
1009            ClientExtension::SignatureAlgorithms(ref req) => Some(req),
1010            _ => None,
1011        }
1012    }
1013
1014    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
1015        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
1016        match *ext {
1017            ClientExtension::NamedGroups(ref req) => Some(req),
1018            _ => None,
1019        }
1020    }
1021
1022    #[cfg(feature = "tls12")]
1023    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1024        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1025        match *ext {
1026            ClientExtension::EcPointFormats(ref req) => Some(req),
1027            _ => None,
1028        }
1029    }
1030
1031    pub(crate) fn server_certificate_extension(&self) -> Option<&[CertificateType]> {
1032        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
1033        match ext {
1034            ClientExtension::ServerCertTypes(req) => Some(req),
1035            _ => None,
1036        }
1037    }
1038
1039    pub(crate) fn client_certificate_extension(&self) -> Option<&[CertificateType]> {
1040        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
1041        match ext {
1042            ClientExtension::ClientCertTypes(req) => Some(req),
1043            _ => None,
1044        }
1045    }
1046
1047    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
1048        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1049        match *ext {
1050            ClientExtension::Protocols(ref req) => Some(req),
1051            _ => None,
1052        }
1053    }
1054
1055    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
1056        let ext = self
1057            .find_extension(ExtensionType::TransportParameters)
1058            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1059        match *ext {
1060            ClientExtension::TransportParameters(ref bytes)
1061            | ClientExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
1062            _ => None,
1063        }
1064    }
1065
1066    #[cfg(feature = "tls12")]
1067    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
1068        self.find_extension(ExtensionType::SessionTicket)
1069    }
1070
1071    pub(crate) fn versions_extension(&self) -> Option<&[ProtocolVersion]> {
1072        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1073        match *ext {
1074            ClientExtension::SupportedVersions(ref vers) => Some(vers),
1075            _ => None,
1076        }
1077    }
1078
1079    pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
1080        let ext = self.find_extension(ExtensionType::KeyShare)?;
1081        match *ext {
1082            ClientExtension::KeyShare(ref shares) => Some(shares),
1083            _ => None,
1084        }
1085    }
1086
1087    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1088        self.keyshare_extension()
1089            .map(|entries| {
1090                has_duplicates::<_, _, u16>(
1091                    entries
1092                        .iter()
1093                        .map(|kse| u16::from(kse.group)),
1094                )
1095            })
1096            .unwrap_or_default()
1097    }
1098
1099    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
1100        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1101        match *ext {
1102            ClientExtension::PresharedKey(ref psk) => Some(psk),
1103            _ => None,
1104        }
1105    }
1106
1107    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
1108        self.extensions
1109            .last()
1110            .is_some_and(|ext| ext.ext_type() == ExtensionType::PreSharedKey)
1111    }
1112
1113    pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
1114        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1115        match *ext {
1116            ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes),
1117            _ => None,
1118        }
1119    }
1120
1121    pub(crate) fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool {
1122        self.psk_modes()
1123            .map(|modes| modes.contains(&mode))
1124            .unwrap_or(false)
1125    }
1126
1127    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1128        let last_extension = self.extensions.last_mut();
1129        if let Some(ClientExtension::PresharedKey(ref mut offer)) = last_extension {
1130            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1131        }
1132    }
1133
1134    #[cfg(feature = "tls12")]
1135    pub(crate) fn ems_support_offered(&self) -> bool {
1136        self.find_extension(ExtensionType::ExtendedMasterSecret)
1137            .is_some()
1138    }
1139
1140    pub(crate) fn early_data_extension_offered(&self) -> bool {
1141        self.find_extension(ExtensionType::EarlyData)
1142            .is_some()
1143    }
1144
1145    pub(crate) fn certificate_compression_extension(
1146        &self,
1147    ) -> Option<&[CertificateCompressionAlgorithm]> {
1148        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
1149        match *ext {
1150            ClientExtension::CertificateCompressionAlgorithms(ref algs) => Some(algs),
1151            _ => None,
1152        }
1153    }
1154
1155    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1156        if let Some(algs) = self.certificate_compression_extension() {
1157            has_duplicates::<_, _, u16>(algs.iter().cloned())
1158        } else {
1159            false
1160        }
1161    }
1162}
1163
1164#[derive(Clone, Debug)]
1165pub(crate) enum HelloRetryExtension {
1166    KeyShare(NamedGroup),
1167    Cookie(PayloadU16),
1168    SupportedVersions(ProtocolVersion),
1169    EchHelloRetryRequest(Vec<u8>),
1170    Unknown(UnknownExtension),
1171}
1172
1173impl HelloRetryExtension {
1174    pub(crate) fn ext_type(&self) -> ExtensionType {
1175        match *self {
1176            Self::KeyShare(_) => ExtensionType::KeyShare,
1177            Self::Cookie(_) => ExtensionType::Cookie,
1178            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1179            Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello,
1180            Self::Unknown(ref r) => r.typ,
1181        }
1182    }
1183}
1184
1185impl Codec<'_> for HelloRetryExtension {
1186    fn encode(&self, bytes: &mut Vec<u8>) {
1187        self.ext_type().encode(bytes);
1188
1189        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1190        match *self {
1191            Self::KeyShare(ref r) => r.encode(nested.buf),
1192            Self::Cookie(ref r) => r.encode(nested.buf),
1193            Self::SupportedVersions(ref r) => r.encode(nested.buf),
1194            Self::EchHelloRetryRequest(ref r) => {
1195                nested.buf.extend_from_slice(r);
1196            }
1197            Self::Unknown(ref r) => r.encode(nested.buf),
1198        }
1199    }
1200
1201    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1202        let typ = ExtensionType::read(r)?;
1203        let len = u16::read(r)? as usize;
1204        let mut sub = r.sub(len)?;
1205
1206        let ext = match typ {
1207            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1208            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1209            ExtensionType::SupportedVersions => {
1210                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1211            }
1212            ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()),
1213            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1214        };
1215
1216        sub.expect_empty("HelloRetryExtension")
1217            .map(|_| ext)
1218    }
1219}
1220
1221impl TlsListElement for HelloRetryExtension {
1222    const SIZE_LEN: ListLength = ListLength::U16;
1223}
1224
1225#[derive(Clone, Debug)]
1226pub struct HelloRetryRequest {
1227    pub(crate) legacy_version: ProtocolVersion,
1228    pub session_id: SessionId,
1229    pub(crate) cipher_suite: CipherSuite,
1230    pub(crate) extensions: Vec<HelloRetryExtension>,
1231}
1232
1233impl Codec<'_> for HelloRetryRequest {
1234    fn encode(&self, bytes: &mut Vec<u8>) {
1235        self.payload_encode(bytes, Encoding::Standard)
1236    }
1237
1238    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1239        let session_id = SessionId::read(r)?;
1240        let cipher_suite = CipherSuite::read(r)?;
1241        let compression = Compression::read(r)?;
1242
1243        if compression != Compression::Null {
1244            return Err(InvalidMessage::UnsupportedCompression);
1245        }
1246
1247        Ok(Self {
1248            legacy_version: ProtocolVersion::Unknown(0),
1249            session_id,
1250            cipher_suite,
1251            extensions: Vec::read(r)?,
1252        })
1253    }
1254}
1255
1256impl HelloRetryRequest {
1257    /// Returns true if there is more than one extension of a given
1258    /// type.
1259    pub(crate) fn has_duplicate_extension(&self) -> bool {
1260        has_duplicates::<_, _, u16>(
1261            self.extensions
1262                .iter()
1263                .map(|ext| ext.ext_type()),
1264        )
1265    }
1266
1267    pub(crate) fn has_unknown_extension(&self) -> bool {
1268        self.extensions.iter().any(|ext| {
1269            ext.ext_type() != ExtensionType::KeyShare
1270                && ext.ext_type() != ExtensionType::SupportedVersions
1271                && ext.ext_type() != ExtensionType::Cookie
1272                && ext.ext_type() != ExtensionType::EncryptedClientHello
1273        })
1274    }
1275
1276    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1277        self.extensions
1278            .iter()
1279            .find(|x| x.ext_type() == ext)
1280    }
1281
1282    pub fn requested_key_share_group(&self) -> Option<NamedGroup> {
1283        let ext = self.find_extension(ExtensionType::KeyShare)?;
1284        match *ext {
1285            HelloRetryExtension::KeyShare(grp) => Some(grp),
1286            _ => None,
1287        }
1288    }
1289
1290    pub(crate) fn cookie(&self) -> Option<&PayloadU16> {
1291        let ext = self.find_extension(ExtensionType::Cookie)?;
1292        match *ext {
1293            HelloRetryExtension::Cookie(ref ck) => Some(ck),
1294            _ => None,
1295        }
1296    }
1297
1298    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1299        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1300        match *ext {
1301            HelloRetryExtension::SupportedVersions(ver) => Some(ver),
1302            _ => None,
1303        }
1304    }
1305
1306    pub(crate) fn ech(&self) -> Option<&Vec<u8>> {
1307        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
1308        match *ext {
1309            HelloRetryExtension::EchHelloRetryRequest(ref ech) => Some(ech),
1310            _ => None,
1311        }
1312    }
1313
1314    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1315        self.legacy_version.encode(bytes);
1316        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1317        self.session_id.encode(bytes);
1318        self.cipher_suite.encode(bytes);
1319        Compression::Null.encode(bytes);
1320
1321        match purpose {
1322            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1323            // must have its payload replaced by 8 zero bytes.
1324            //
1325            // See draft-ietf-tls-esni-18 7.2.1:
1326            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1327            Encoding::EchConfirmation => {
1328                let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1329                for ext in &self.extensions {
1330                    match ext.ext_type() {
1331                        ExtensionType::EncryptedClientHello => {
1332                            HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8])
1333                                .encode(extensions.buf);
1334                        }
1335                        _ => {
1336                            ext.encode(extensions.buf);
1337                        }
1338                    }
1339                }
1340            }
1341            _ => {
1342                self.extensions.encode(bytes);
1343            }
1344        }
1345    }
1346}
1347
1348#[derive(Clone, Debug)]
1349pub struct ServerHelloPayload {
1350    pub extensions: Vec<ServerExtension>,
1351    pub(crate) legacy_version: ProtocolVersion,
1352    pub(crate) random: Random,
1353    pub(crate) session_id: SessionId,
1354    pub(crate) cipher_suite: CipherSuite,
1355    pub(crate) compression_method: Compression,
1356}
1357
1358impl Codec<'_> for ServerHelloPayload {
1359    fn encode(&self, bytes: &mut Vec<u8>) {
1360        self.payload_encode(bytes, Encoding::Standard)
1361    }
1362
1363    // minus version and random, which have already been read.
1364    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1365        let session_id = SessionId::read(r)?;
1366        let suite = CipherSuite::read(r)?;
1367        let compression = Compression::read(r)?;
1368
1369        // RFC5246:
1370        // "The presence of extensions can be detected by determining whether
1371        //  there are bytes following the compression_method field at the end of
1372        //  the ServerHello."
1373        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1374
1375        let ret = Self {
1376            legacy_version: ProtocolVersion::Unknown(0),
1377            random: ZERO_RANDOM,
1378            session_id,
1379            cipher_suite: suite,
1380            compression_method: compression,
1381            extensions,
1382        };
1383
1384        r.expect_empty("ServerHelloPayload")
1385            .map(|_| ret)
1386    }
1387}
1388
1389impl HasServerExtensions for ServerHelloPayload {
1390    fn extensions(&self) -> &[ServerExtension] {
1391        &self.extensions
1392    }
1393}
1394
1395impl ServerHelloPayload {
1396    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1397        let ext = self.find_extension(ExtensionType::KeyShare)?;
1398        match *ext {
1399            ServerExtension::KeyShare(ref share) => Some(share),
1400            _ => None,
1401        }
1402    }
1403
1404    pub(crate) fn psk_index(&self) -> Option<u16> {
1405        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1406        match *ext {
1407            ServerExtension::PresharedKey(ref index) => Some(*index),
1408            _ => None,
1409        }
1410    }
1411
1412    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1413        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1414        match *ext {
1415            ServerExtension::EcPointFormats(ref fmts) => Some(fmts),
1416            _ => None,
1417        }
1418    }
1419
1420    #[cfg(feature = "tls12")]
1421    pub(crate) fn ems_support_acked(&self) -> bool {
1422        self.find_extension(ExtensionType::ExtendedMasterSecret)
1423            .is_some()
1424    }
1425
1426    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1427        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1428        match *ext {
1429            ServerExtension::SupportedVersions(vers) => Some(vers),
1430            _ => None,
1431        }
1432    }
1433
1434    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1435        self.legacy_version.encode(bytes);
1436
1437        match encoding {
1438            // When encoding a ServerHello for ECH confirmation, the random value
1439            // has the last 8 bytes zeroed out.
1440            Encoding::EchConfirmation => {
1441                // Indexing safety: self.random is 32 bytes long by definition.
1442                let rand_vec = self.random.get_encoding();
1443                bytes.extend_from_slice(&rand_vec.as_slice()[..24]);
1444                bytes.extend_from_slice(&[0u8; 8]);
1445            }
1446            _ => self.random.encode(bytes),
1447        }
1448
1449        self.session_id.encode(bytes);
1450        self.cipher_suite.encode(bytes);
1451        self.compression_method.encode(bytes);
1452
1453        if !self.extensions.is_empty() {
1454            self.extensions.encode(bytes);
1455        }
1456    }
1457}
1458
1459#[derive(Clone, Default, Debug)]
1460pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>);
1461
1462impl CertificateChain<'_> {
1463    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1464        CertificateChain(
1465            self.0
1466                .into_iter()
1467                .map(|c| c.into_owned())
1468                .collect(),
1469        )
1470    }
1471}
1472
1473impl<'a> Codec<'a> for CertificateChain<'a> {
1474    fn encode(&self, bytes: &mut Vec<u8>) {
1475        Vec::encode(&self.0, bytes)
1476    }
1477
1478    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1479        Vec::read(r).map(Self)
1480    }
1481}
1482
1483impl<'a> Deref for CertificateChain<'a> {
1484    type Target = [CertificateDer<'a>];
1485
1486    fn deref(&self) -> &[CertificateDer<'a>] {
1487        &self.0
1488    }
1489}
1490
1491impl TlsListElement for CertificateDer<'_> {
1492    const SIZE_LEN: ListLength = ListLength::U24 {
1493        max: CERTIFICATE_MAX_SIZE_LIMIT,
1494        error: InvalidMessage::CertificatePayloadTooLarge,
1495    };
1496}
1497
1498/// TLS has a 16MB size limit on any handshake message,
1499/// plus a 16MB limit on any given certificate.
1500///
1501/// We contract that to 64KB to limit the amount of memory allocation
1502/// that is directly controllable by the peer.
1503pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1504
1505#[derive(Debug)]
1506pub(crate) enum CertificateExtension<'a> {
1507    CertificateStatus(CertificateStatus<'a>),
1508    Unknown(UnknownExtension),
1509}
1510
1511impl CertificateExtension<'_> {
1512    pub(crate) fn ext_type(&self) -> ExtensionType {
1513        match *self {
1514            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1515            Self::Unknown(ref r) => r.typ,
1516        }
1517    }
1518
1519    pub(crate) fn cert_status(&self) -> Option<&[u8]> {
1520        match *self {
1521            Self::CertificateStatus(ref cs) => Some(cs.ocsp_response.0.bytes()),
1522            _ => None,
1523        }
1524    }
1525
1526    pub(crate) fn into_owned(self) -> CertificateExtension<'static> {
1527        match self {
1528            Self::CertificateStatus(st) => CertificateExtension::CertificateStatus(st.into_owned()),
1529            Self::Unknown(unk) => CertificateExtension::Unknown(unk),
1530        }
1531    }
1532}
1533
1534impl<'a> Codec<'a> for CertificateExtension<'a> {
1535    fn encode(&self, bytes: &mut Vec<u8>) {
1536        self.ext_type().encode(bytes);
1537
1538        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1539        match *self {
1540            Self::CertificateStatus(ref r) => r.encode(nested.buf),
1541            Self::Unknown(ref r) => r.encode(nested.buf),
1542        }
1543    }
1544
1545    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1546        let typ = ExtensionType::read(r)?;
1547        let len = u16::read(r)? as usize;
1548        let mut sub = r.sub(len)?;
1549
1550        let ext = match typ {
1551            ExtensionType::StatusRequest => {
1552                let st = CertificateStatus::read(&mut sub)?;
1553                Self::CertificateStatus(st)
1554            }
1555            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1556        };
1557
1558        sub.expect_empty("CertificateExtension")
1559            .map(|_| ext)
1560    }
1561}
1562
1563impl TlsListElement for CertificateExtension<'_> {
1564    const SIZE_LEN: ListLength = ListLength::U16;
1565}
1566
1567#[derive(Debug)]
1568pub(crate) struct CertificateEntry<'a> {
1569    pub(crate) cert: CertificateDer<'a>,
1570    pub(crate) exts: Vec<CertificateExtension<'a>>,
1571}
1572
1573impl<'a> Codec<'a> for CertificateEntry<'a> {
1574    fn encode(&self, bytes: &mut Vec<u8>) {
1575        self.cert.encode(bytes);
1576        self.exts.encode(bytes);
1577    }
1578
1579    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1580        Ok(Self {
1581            cert: CertificateDer::read(r)?,
1582            exts: Vec::read(r)?,
1583        })
1584    }
1585}
1586
1587impl<'a> CertificateEntry<'a> {
1588    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1589        Self {
1590            cert,
1591            exts: Vec::new(),
1592        }
1593    }
1594
1595    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1596        CertificateEntry {
1597            cert: self.cert.into_owned(),
1598            exts: self
1599                .exts
1600                .into_iter()
1601                .map(CertificateExtension::into_owned)
1602                .collect(),
1603        }
1604    }
1605
1606    pub(crate) fn has_duplicate_extension(&self) -> bool {
1607        has_duplicates::<_, _, u16>(
1608            self.exts
1609                .iter()
1610                .map(|ext| ext.ext_type()),
1611        )
1612    }
1613
1614    pub(crate) fn has_unknown_extension(&self) -> bool {
1615        self.exts
1616            .iter()
1617            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1618    }
1619
1620    pub(crate) fn ocsp_response(&self) -> Option<&[u8]> {
1621        self.exts
1622            .iter()
1623            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1624            .and_then(CertificateExtension::cert_status)
1625    }
1626}
1627
1628impl TlsListElement for CertificateEntry<'_> {
1629    const SIZE_LEN: ListLength = ListLength::U24 {
1630        max: CERTIFICATE_MAX_SIZE_LIMIT,
1631        error: InvalidMessage::CertificatePayloadTooLarge,
1632    };
1633}
1634
1635#[derive(Debug)]
1636pub struct CertificatePayloadTls13<'a> {
1637    pub(crate) context: PayloadU8,
1638    pub(crate) entries: Vec<CertificateEntry<'a>>,
1639}
1640
1641impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1642    fn encode(&self, bytes: &mut Vec<u8>) {
1643        self.context.encode(bytes);
1644        self.entries.encode(bytes);
1645    }
1646
1647    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1648        Ok(Self {
1649            context: PayloadU8::read(r)?,
1650            entries: Vec::read(r)?,
1651        })
1652    }
1653}
1654
1655impl<'a> CertificatePayloadTls13<'a> {
1656    pub(crate) fn new(
1657        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1658        ocsp_response: Option<&'a [u8]>,
1659    ) -> Self {
1660        Self {
1661            context: PayloadU8::empty(),
1662            entries: certs
1663                // zip certificate iterator with `ocsp_response` followed by
1664                // an infinite-length iterator of `None`.
1665                .zip(
1666                    ocsp_response
1667                        .into_iter()
1668                        .map(Some)
1669                        .chain(iter::repeat(None)),
1670                )
1671                .map(|(cert, ocsp)| {
1672                    let mut e = CertificateEntry::new(cert.clone());
1673                    if let Some(ocsp) = ocsp {
1674                        e.exts
1675                            .push(CertificateExtension::CertificateStatus(
1676                                CertificateStatus::new(ocsp),
1677                            ));
1678                    }
1679                    e
1680                })
1681                .collect(),
1682        }
1683    }
1684
1685    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1686        CertificatePayloadTls13 {
1687            context: self.context,
1688            entries: self
1689                .entries
1690                .into_iter()
1691                .map(CertificateEntry::into_owned)
1692                .collect(),
1693        }
1694    }
1695
1696    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1697        for entry in &self.entries {
1698            if entry.has_duplicate_extension() {
1699                return true;
1700            }
1701        }
1702
1703        false
1704    }
1705
1706    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1707        for entry in &self.entries {
1708            if entry.has_unknown_extension() {
1709                return true;
1710            }
1711        }
1712
1713        false
1714    }
1715
1716    pub(crate) fn any_entry_has_extension(&self) -> bool {
1717        for entry in &self.entries {
1718            if !entry.exts.is_empty() {
1719                return true;
1720            }
1721        }
1722
1723        false
1724    }
1725
1726    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1727        self.entries
1728            .first()
1729            .and_then(CertificateEntry::ocsp_response)
1730            .map(|resp| resp.to_vec())
1731            .unwrap_or_default()
1732    }
1733
1734    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1735        CertificateChain(
1736            self.entries
1737                .into_iter()
1738                .map(|e| e.cert)
1739                .collect(),
1740        )
1741    }
1742}
1743
1744/// Describes supported key exchange mechanisms.
1745#[derive(Clone, Copy, Debug, PartialEq)]
1746#[non_exhaustive]
1747pub enum KeyExchangeAlgorithm {
1748    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1749    ///
1750    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1751    DHE,
1752    /// Key exchange performed via elliptic curve Diffie-Hellman.
1753    ECDHE,
1754}
1755
1756pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1757    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1758
1759// We don't support arbitrary curves.  It's a terrible
1760// idea and unnecessary attack surface.  Please,
1761// get a grip.
1762#[derive(Debug)]
1763pub(crate) struct EcParameters {
1764    pub(crate) curve_type: ECCurveType,
1765    pub(crate) named_group: NamedGroup,
1766}
1767
1768impl Codec<'_> for EcParameters {
1769    fn encode(&self, bytes: &mut Vec<u8>) {
1770        self.curve_type.encode(bytes);
1771        self.named_group.encode(bytes);
1772    }
1773
1774    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1775        let ct = ECCurveType::read(r)?;
1776        if ct != ECCurveType::NamedCurve {
1777            return Err(InvalidMessage::UnsupportedCurveType);
1778        }
1779
1780        let grp = NamedGroup::read(r)?;
1781
1782        Ok(Self {
1783            curve_type: ct,
1784            named_group: grp,
1785        })
1786    }
1787}
1788
1789#[cfg(feature = "tls12")]
1790pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1791    /// Decode a key exchange message given the key_exchange `algo`
1792    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1793}
1794
1795#[cfg(feature = "tls12")]
1796#[derive(Debug)]
1797pub(crate) enum ClientKeyExchangeParams {
1798    Ecdh(ClientEcdhParams),
1799    Dh(ClientDhParams),
1800}
1801
1802#[cfg(feature = "tls12")]
1803impl ClientKeyExchangeParams {
1804    pub(crate) fn pub_key(&self) -> &[u8] {
1805        match self {
1806            Self::Ecdh(ecdh) => &ecdh.public.0,
1807            Self::Dh(dh) => &dh.public.0,
1808        }
1809    }
1810
1811    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1812        match self {
1813            Self::Ecdh(ecdh) => ecdh.encode(buf),
1814            Self::Dh(dh) => dh.encode(buf),
1815        }
1816    }
1817}
1818
1819#[cfg(feature = "tls12")]
1820impl KxDecode<'_> for ClientKeyExchangeParams {
1821    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1822        use KeyExchangeAlgorithm::*;
1823        Ok(match algo {
1824            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1825            DHE => Self::Dh(ClientDhParams::read(r)?),
1826        })
1827    }
1828}
1829
1830#[cfg(feature = "tls12")]
1831#[derive(Debug)]
1832pub(crate) struct ClientEcdhParams {
1833    pub(crate) public: PayloadU8,
1834}
1835
1836#[cfg(feature = "tls12")]
1837impl Codec<'_> for ClientEcdhParams {
1838    fn encode(&self, bytes: &mut Vec<u8>) {
1839        self.public.encode(bytes);
1840    }
1841
1842    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1843        let pb = PayloadU8::read(r)?;
1844        Ok(Self { public: pb })
1845    }
1846}
1847
1848#[cfg(feature = "tls12")]
1849#[derive(Debug)]
1850pub(crate) struct ClientDhParams {
1851    pub(crate) public: PayloadU16,
1852}
1853
1854#[cfg(feature = "tls12")]
1855impl Codec<'_> for ClientDhParams {
1856    fn encode(&self, bytes: &mut Vec<u8>) {
1857        self.public.encode(bytes);
1858    }
1859
1860    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1861        Ok(Self {
1862            public: PayloadU16::read(r)?,
1863        })
1864    }
1865}
1866
1867#[derive(Debug)]
1868pub(crate) struct ServerEcdhParams {
1869    pub(crate) curve_params: EcParameters,
1870    pub(crate) public: PayloadU8,
1871}
1872
1873impl ServerEcdhParams {
1874    #[cfg(feature = "tls12")]
1875    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1876        Self {
1877            curve_params: EcParameters {
1878                curve_type: ECCurveType::NamedCurve,
1879                named_group: kx.group(),
1880            },
1881            public: PayloadU8::new(kx.pub_key().to_vec()),
1882        }
1883    }
1884}
1885
1886impl Codec<'_> for ServerEcdhParams {
1887    fn encode(&self, bytes: &mut Vec<u8>) {
1888        self.curve_params.encode(bytes);
1889        self.public.encode(bytes);
1890    }
1891
1892    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1893        let cp = EcParameters::read(r)?;
1894        let pb = PayloadU8::read(r)?;
1895
1896        Ok(Self {
1897            curve_params: cp,
1898            public: pb,
1899        })
1900    }
1901}
1902
1903#[derive(Debug)]
1904#[allow(non_snake_case)]
1905pub(crate) struct ServerDhParams {
1906    pub(crate) dh_p: PayloadU16,
1907    pub(crate) dh_g: PayloadU16,
1908    pub(crate) dh_Ys: PayloadU16,
1909}
1910
1911impl ServerDhParams {
1912    #[cfg(feature = "tls12")]
1913    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1914        let Some(params) = kx.ffdhe_group() else {
1915            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
1916        };
1917
1918        Self {
1919            dh_p: PayloadU16::new(params.p.to_vec()),
1920            dh_g: PayloadU16::new(params.g.to_vec()),
1921            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
1922        }
1923    }
1924
1925    #[cfg(feature = "tls12")]
1926    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
1927        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
1928    }
1929}
1930
1931impl Codec<'_> for ServerDhParams {
1932    fn encode(&self, bytes: &mut Vec<u8>) {
1933        self.dh_p.encode(bytes);
1934        self.dh_g.encode(bytes);
1935        self.dh_Ys.encode(bytes);
1936    }
1937
1938    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1939        Ok(Self {
1940            dh_p: PayloadU16::read(r)?,
1941            dh_g: PayloadU16::read(r)?,
1942            dh_Ys: PayloadU16::read(r)?,
1943        })
1944    }
1945}
1946
1947#[allow(dead_code)]
1948#[derive(Debug)]
1949pub(crate) enum ServerKeyExchangeParams {
1950    Ecdh(ServerEcdhParams),
1951    Dh(ServerDhParams),
1952}
1953
1954impl ServerKeyExchangeParams {
1955    #[cfg(feature = "tls12")]
1956    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1957        match kx.group().key_exchange_algorithm() {
1958            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
1959            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
1960        }
1961    }
1962
1963    #[cfg(feature = "tls12")]
1964    pub(crate) fn pub_key(&self) -> &[u8] {
1965        match self {
1966            Self::Ecdh(ecdh) => &ecdh.public.0,
1967            Self::Dh(dh) => &dh.dh_Ys.0,
1968        }
1969    }
1970
1971    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1972        match self {
1973            Self::Ecdh(ecdh) => ecdh.encode(buf),
1974            Self::Dh(dh) => dh.encode(buf),
1975        }
1976    }
1977}
1978
1979#[cfg(feature = "tls12")]
1980impl KxDecode<'_> for ServerKeyExchangeParams {
1981    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1982        use KeyExchangeAlgorithm::*;
1983        Ok(match algo {
1984            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
1985            DHE => Self::Dh(ServerDhParams::read(r)?),
1986        })
1987    }
1988}
1989
1990#[derive(Debug)]
1991pub struct ServerKeyExchange {
1992    pub(crate) params: ServerKeyExchangeParams,
1993    pub(crate) dss: DigitallySignedStruct,
1994}
1995
1996impl ServerKeyExchange {
1997    pub fn encode(&self, buf: &mut Vec<u8>) {
1998        self.params.encode(buf);
1999        self.dss.encode(buf);
2000    }
2001}
2002
2003#[derive(Debug)]
2004pub enum ServerKeyExchangePayload {
2005    Known(ServerKeyExchange),
2006    Unknown(Payload<'static>),
2007}
2008
2009impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2010    fn from(value: ServerKeyExchange) -> Self {
2011        Self::Known(value)
2012    }
2013}
2014
2015impl Codec<'_> for ServerKeyExchangePayload {
2016    fn encode(&self, bytes: &mut Vec<u8>) {
2017        match *self {
2018            Self::Known(ref x) => x.encode(bytes),
2019            Self::Unknown(ref x) => x.encode(bytes),
2020        }
2021    }
2022
2023    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2024        // read as Unknown, fully parse when we know the
2025        // KeyExchangeAlgorithm
2026        Ok(Self::Unknown(Payload::read(r).into_owned()))
2027    }
2028}
2029
2030impl ServerKeyExchangePayload {
2031    #[cfg(feature = "tls12")]
2032    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2033        if let Self::Unknown(ref unk) = *self {
2034            let mut rd = Reader::init(unk.bytes());
2035
2036            let result = ServerKeyExchange {
2037                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2038                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2039            };
2040
2041            if !rd.any_left() {
2042                return Some(result);
2043            };
2044        }
2045
2046        None
2047    }
2048}
2049
2050// -- EncryptedExtensions (TLS1.3 only) --
2051
2052impl TlsListElement for ServerExtension {
2053    const SIZE_LEN: ListLength = ListLength::U16;
2054}
2055
2056pub(crate) trait HasServerExtensions {
2057    fn extensions(&self) -> &[ServerExtension];
2058
2059    /// Returns true if there is more than one extension of a given
2060    /// type.
2061    fn has_duplicate_extension(&self) -> bool {
2062        has_duplicates::<_, _, u16>(
2063            self.extensions()
2064                .iter()
2065                .map(|ext| ext.ext_type()),
2066        )
2067    }
2068
2069    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
2070        self.extensions()
2071            .iter()
2072            .find(|x| x.ext_type() == ext)
2073    }
2074
2075    fn alpn_protocol(&self) -> Option<&[u8]> {
2076        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
2077        match *ext {
2078            ServerExtension::Protocols(ref protos) => protos.as_single_slice(),
2079            _ => None,
2080        }
2081    }
2082
2083    fn server_cert_type(&self) -> Option<&CertificateType> {
2084        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
2085        match ext {
2086            ServerExtension::ServerCertType(req) => Some(req),
2087            _ => None,
2088        }
2089    }
2090
2091    fn client_cert_type(&self) -> Option<&CertificateType> {
2092        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
2093        match ext {
2094            ServerExtension::ClientCertType(req) => Some(req),
2095            _ => None,
2096        }
2097    }
2098
2099    fn quic_params_extension(&self) -> Option<Vec<u8>> {
2100        let ext = self
2101            .find_extension(ExtensionType::TransportParameters)
2102            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
2103        match *ext {
2104            ServerExtension::TransportParameters(ref bytes)
2105            | ServerExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
2106            _ => None,
2107        }
2108    }
2109
2110    fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> {
2111        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
2112        match ext {
2113            ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()),
2114            _ => None,
2115        }
2116    }
2117
2118    fn early_data_extension_offered(&self) -> bool {
2119        self.find_extension(ExtensionType::EarlyData)
2120            .is_some()
2121    }
2122}
2123
2124impl HasServerExtensions for Vec<ServerExtension> {
2125    fn extensions(&self) -> &[ServerExtension] {
2126        self
2127    }
2128}
2129
2130impl TlsListElement for ClientCertificateType {
2131    const SIZE_LEN: ListLength = ListLength::U8;
2132}
2133
2134wrapped_payload!(
2135    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2136    ///
2137    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2138    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2139    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2140    ///
2141    /// ```ignore
2142    /// for name in distinguished_names {
2143    ///     use x509_parser::prelude::FromDer;
2144    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2145    /// }
2146    /// ```
2147    pub struct DistinguishedName,
2148    PayloadU16,
2149);
2150
2151impl DistinguishedName {
2152    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2153    ///
2154    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2155    ///
2156    /// ```ignore
2157    /// use x509_parser::prelude::FromDer;
2158    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2159    /// ```
2160    pub fn in_sequence(bytes: &[u8]) -> Self {
2161        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2162    }
2163}
2164
2165impl TlsListElement for DistinguishedName {
2166    const SIZE_LEN: ListLength = ListLength::U16;
2167}
2168
2169#[derive(Debug)]
2170pub struct CertificateRequestPayload {
2171    pub(crate) certtypes: Vec<ClientCertificateType>,
2172    pub(crate) sigschemes: Vec<SignatureScheme>,
2173    pub(crate) canames: Vec<DistinguishedName>,
2174}
2175
2176impl Codec<'_> for CertificateRequestPayload {
2177    fn encode(&self, bytes: &mut Vec<u8>) {
2178        self.certtypes.encode(bytes);
2179        self.sigschemes.encode(bytes);
2180        self.canames.encode(bytes);
2181    }
2182
2183    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2184        let certtypes = Vec::read(r)?;
2185        let sigschemes = Vec::read(r)?;
2186        let canames = Vec::read(r)?;
2187
2188        if sigschemes.is_empty() {
2189            warn!("meaningless CertificateRequest message");
2190            Err(InvalidMessage::NoSignatureSchemes)
2191        } else {
2192            Ok(Self {
2193                certtypes,
2194                sigschemes,
2195                canames,
2196            })
2197        }
2198    }
2199}
2200
2201#[derive(Debug)]
2202pub(crate) enum CertReqExtension {
2203    SignatureAlgorithms(Vec<SignatureScheme>),
2204    AuthorityNames(Vec<DistinguishedName>),
2205    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
2206    Unknown(UnknownExtension),
2207}
2208
2209impl CertReqExtension {
2210    pub(crate) fn ext_type(&self) -> ExtensionType {
2211        match *self {
2212            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
2213            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
2214            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
2215            Self::Unknown(ref r) => r.typ,
2216        }
2217    }
2218}
2219
2220impl Codec<'_> for CertReqExtension {
2221    fn encode(&self, bytes: &mut Vec<u8>) {
2222        self.ext_type().encode(bytes);
2223
2224        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2225        match *self {
2226            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
2227            Self::AuthorityNames(ref r) => r.encode(nested.buf),
2228            Self::CertificateCompressionAlgorithms(ref r) => r.encode(nested.buf),
2229            Self::Unknown(ref r) => r.encode(nested.buf),
2230        }
2231    }
2232
2233    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2234        let typ = ExtensionType::read(r)?;
2235        let len = u16::read(r)? as usize;
2236        let mut sub = r.sub(len)?;
2237
2238        let ext = match typ {
2239            ExtensionType::SignatureAlgorithms => {
2240                let schemes = Vec::read(&mut sub)?;
2241                if schemes.is_empty() {
2242                    return Err(InvalidMessage::NoSignatureSchemes);
2243                }
2244                Self::SignatureAlgorithms(schemes)
2245            }
2246            ExtensionType::CertificateAuthorities => {
2247                let cas = Vec::read(&mut sub)?;
2248                Self::AuthorityNames(cas)
2249            }
2250            ExtensionType::CompressCertificate => {
2251                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
2252            }
2253            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2254        };
2255
2256        sub.expect_empty("CertReqExtension")
2257            .map(|_| ext)
2258    }
2259}
2260
2261impl TlsListElement for CertReqExtension {
2262    const SIZE_LEN: ListLength = ListLength::U16;
2263}
2264
2265#[derive(Debug)]
2266pub struct CertificateRequestPayloadTls13 {
2267    pub(crate) context: PayloadU8,
2268    pub(crate) extensions: Vec<CertReqExtension>,
2269}
2270
2271impl Codec<'_> for CertificateRequestPayloadTls13 {
2272    fn encode(&self, bytes: &mut Vec<u8>) {
2273        self.context.encode(bytes);
2274        self.extensions.encode(bytes);
2275    }
2276
2277    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2278        let context = PayloadU8::read(r)?;
2279        let extensions = Vec::read(r)?;
2280
2281        Ok(Self {
2282            context,
2283            extensions,
2284        })
2285    }
2286}
2287
2288impl CertificateRequestPayloadTls13 {
2289    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
2290        self.extensions
2291            .iter()
2292            .find(|x| x.ext_type() == ext)
2293    }
2294
2295    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
2296        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
2297        match *ext {
2298            CertReqExtension::SignatureAlgorithms(ref sa) => Some(sa),
2299            _ => None,
2300        }
2301    }
2302
2303    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
2304        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
2305        match *ext {
2306            CertReqExtension::AuthorityNames(ref an) => Some(an),
2307            _ => None,
2308        }
2309    }
2310
2311    pub(crate) fn certificate_compression_extension(
2312        &self,
2313    ) -> Option<&[CertificateCompressionAlgorithm]> {
2314        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
2315        match *ext {
2316            CertReqExtension::CertificateCompressionAlgorithms(ref comps) => Some(comps),
2317            _ => None,
2318        }
2319    }
2320}
2321
2322// -- NewSessionTicket --
2323#[derive(Debug)]
2324pub struct NewSessionTicketPayload {
2325    pub(crate) lifetime_hint: u32,
2326    // Tickets can be large (KB), so we deserialise this straight
2327    // into an Arc, so it can be passed directly into the client's
2328    // session object without copying.
2329    pub(crate) ticket: Arc<PayloadU16>,
2330}
2331
2332impl NewSessionTicketPayload {
2333    #[cfg(feature = "tls12")]
2334    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2335        Self {
2336            lifetime_hint,
2337            ticket: Arc::new(PayloadU16::new(ticket)),
2338        }
2339    }
2340}
2341
2342impl Codec<'_> for NewSessionTicketPayload {
2343    fn encode(&self, bytes: &mut Vec<u8>) {
2344        self.lifetime_hint.encode(bytes);
2345        self.ticket.encode(bytes);
2346    }
2347
2348    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2349        let lifetime = u32::read(r)?;
2350        let ticket = Arc::new(PayloadU16::read(r)?);
2351
2352        Ok(Self {
2353            lifetime_hint: lifetime,
2354            ticket,
2355        })
2356    }
2357}
2358
2359// -- NewSessionTicket electric boogaloo --
2360#[derive(Debug)]
2361pub(crate) enum NewSessionTicketExtension {
2362    EarlyData(u32),
2363    Unknown(UnknownExtension),
2364}
2365
2366impl NewSessionTicketExtension {
2367    pub(crate) fn ext_type(&self) -> ExtensionType {
2368        match *self {
2369            Self::EarlyData(_) => ExtensionType::EarlyData,
2370            Self::Unknown(ref r) => r.typ,
2371        }
2372    }
2373}
2374
2375impl Codec<'_> for NewSessionTicketExtension {
2376    fn encode(&self, bytes: &mut Vec<u8>) {
2377        self.ext_type().encode(bytes);
2378
2379        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2380        match *self {
2381            Self::EarlyData(r) => r.encode(nested.buf),
2382            Self::Unknown(ref r) => r.encode(nested.buf),
2383        }
2384    }
2385
2386    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2387        let typ = ExtensionType::read(r)?;
2388        let len = u16::read(r)? as usize;
2389        let mut sub = r.sub(len)?;
2390
2391        let ext = match typ {
2392            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2393            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2394        };
2395
2396        sub.expect_empty("NewSessionTicketExtension")
2397            .map(|_| ext)
2398    }
2399}
2400
2401impl TlsListElement for NewSessionTicketExtension {
2402    const SIZE_LEN: ListLength = ListLength::U16;
2403}
2404
2405#[derive(Debug)]
2406pub struct NewSessionTicketPayloadTls13 {
2407    pub(crate) lifetime: u32,
2408    pub(crate) age_add: u32,
2409    pub(crate) nonce: PayloadU8,
2410    pub(crate) ticket: Arc<PayloadU16>,
2411    pub(crate) exts: Vec<NewSessionTicketExtension>,
2412}
2413
2414impl NewSessionTicketPayloadTls13 {
2415    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2416        Self {
2417            lifetime,
2418            age_add,
2419            nonce: PayloadU8::new(nonce),
2420            ticket: Arc::new(PayloadU16::new(ticket)),
2421            exts: vec![],
2422        }
2423    }
2424
2425    pub(crate) fn has_duplicate_extension(&self) -> bool {
2426        has_duplicates::<_, _, u16>(
2427            self.exts
2428                .iter()
2429                .map(|ext| ext.ext_type()),
2430        )
2431    }
2432
2433    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2434        self.exts
2435            .iter()
2436            .find(|x| x.ext_type() == ext)
2437    }
2438
2439    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2440        let ext = self.find_extension(ExtensionType::EarlyData)?;
2441        match *ext {
2442            NewSessionTicketExtension::EarlyData(ref sz) => Some(*sz),
2443            _ => None,
2444        }
2445    }
2446}
2447
2448impl Codec<'_> for NewSessionTicketPayloadTls13 {
2449    fn encode(&self, bytes: &mut Vec<u8>) {
2450        self.lifetime.encode(bytes);
2451        self.age_add.encode(bytes);
2452        self.nonce.encode(bytes);
2453        self.ticket.encode(bytes);
2454        self.exts.encode(bytes);
2455    }
2456
2457    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2458        let lifetime = u32::read(r)?;
2459        let age_add = u32::read(r)?;
2460        let nonce = PayloadU8::read(r)?;
2461        let ticket = Arc::new(PayloadU16::read(r)?);
2462        let exts = Vec::read(r)?;
2463
2464        Ok(Self {
2465            lifetime,
2466            age_add,
2467            nonce,
2468            ticket,
2469            exts,
2470        })
2471    }
2472}
2473
2474// -- RFC6066 certificate status types
2475
2476/// Only supports OCSP
2477#[derive(Debug)]
2478pub struct CertificateStatus<'a> {
2479    pub(crate) ocsp_response: PayloadU24<'a>,
2480}
2481
2482impl<'a> Codec<'a> for CertificateStatus<'a> {
2483    fn encode(&self, bytes: &mut Vec<u8>) {
2484        CertificateStatusType::OCSP.encode(bytes);
2485        self.ocsp_response.encode(bytes);
2486    }
2487
2488    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2489        let typ = CertificateStatusType::read(r)?;
2490
2491        match typ {
2492            CertificateStatusType::OCSP => Ok(Self {
2493                ocsp_response: PayloadU24::read(r)?,
2494            }),
2495            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2496        }
2497    }
2498}
2499
2500impl<'a> CertificateStatus<'a> {
2501    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2502        CertificateStatus {
2503            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2504        }
2505    }
2506
2507    #[cfg(feature = "tls12")]
2508    pub(crate) fn into_inner(self) -> Vec<u8> {
2509        self.ocsp_response.0.into_vec()
2510    }
2511
2512    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2513        CertificateStatus {
2514            ocsp_response: self.ocsp_response.into_owned(),
2515        }
2516    }
2517}
2518
2519// -- RFC8879 compressed certificates
2520
2521#[derive(Debug)]
2522pub struct CompressedCertificatePayload<'a> {
2523    pub(crate) alg: CertificateCompressionAlgorithm,
2524    pub(crate) uncompressed_len: u32,
2525    pub(crate) compressed: PayloadU24<'a>,
2526}
2527
2528impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2529    fn encode(&self, bytes: &mut Vec<u8>) {
2530        self.alg.encode(bytes);
2531        codec::u24(self.uncompressed_len).encode(bytes);
2532        self.compressed.encode(bytes);
2533    }
2534
2535    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2536        Ok(Self {
2537            alg: CertificateCompressionAlgorithm::read(r)?,
2538            uncompressed_len: codec::u24::read(r)?.0,
2539            compressed: PayloadU24::read(r)?,
2540        })
2541    }
2542}
2543
2544impl CompressedCertificatePayload<'_> {
2545    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2546        CompressedCertificatePayload {
2547            compressed: self.compressed.into_owned(),
2548            ..self
2549        }
2550    }
2551
2552    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2553        CompressedCertificatePayload {
2554            alg: self.alg,
2555            uncompressed_len: self.uncompressed_len,
2556            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2557        }
2558    }
2559}
2560
2561#[derive(Debug)]
2562pub enum HandshakePayload<'a> {
2563    HelloRequest,
2564    ClientHello(ClientHelloPayload),
2565    ServerHello(ServerHelloPayload),
2566    HelloRetryRequest(HelloRetryRequest),
2567    Certificate(CertificateChain<'a>),
2568    CertificateTls13(CertificatePayloadTls13<'a>),
2569    CompressedCertificate(CompressedCertificatePayload<'a>),
2570    ServerKeyExchange(ServerKeyExchangePayload),
2571    CertificateRequest(CertificateRequestPayload),
2572    CertificateRequestTls13(CertificateRequestPayloadTls13),
2573    CertificateVerify(DigitallySignedStruct),
2574    ServerHelloDone,
2575    EndOfEarlyData,
2576    ClientKeyExchange(Payload<'a>),
2577    NewSessionTicket(NewSessionTicketPayload),
2578    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2579    EncryptedExtensions(Vec<ServerExtension>),
2580    KeyUpdate(KeyUpdateRequest),
2581    Finished(Payload<'a>),
2582    CertificateStatus(CertificateStatus<'a>),
2583    MessageHash(Payload<'a>),
2584    Unknown(Payload<'a>),
2585}
2586
2587impl HandshakePayload<'_> {
2588    fn encode(&self, bytes: &mut Vec<u8>) {
2589        use self::HandshakePayload::*;
2590        match *self {
2591            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2592            ClientHello(ref x) => x.encode(bytes),
2593            ServerHello(ref x) => x.encode(bytes),
2594            HelloRetryRequest(ref x) => x.encode(bytes),
2595            Certificate(ref x) => x.encode(bytes),
2596            CertificateTls13(ref x) => x.encode(bytes),
2597            CompressedCertificate(ref x) => x.encode(bytes),
2598            ServerKeyExchange(ref x) => x.encode(bytes),
2599            ClientKeyExchange(ref x) => x.encode(bytes),
2600            CertificateRequest(ref x) => x.encode(bytes),
2601            CertificateRequestTls13(ref x) => x.encode(bytes),
2602            CertificateVerify(ref x) => x.encode(bytes),
2603            NewSessionTicket(ref x) => x.encode(bytes),
2604            NewSessionTicketTls13(ref x) => x.encode(bytes),
2605            EncryptedExtensions(ref x) => x.encode(bytes),
2606            KeyUpdate(ref x) => x.encode(bytes),
2607            Finished(ref x) => x.encode(bytes),
2608            CertificateStatus(ref x) => x.encode(bytes),
2609            MessageHash(ref x) => x.encode(bytes),
2610            Unknown(ref x) => x.encode(bytes),
2611        }
2612    }
2613
2614    fn into_owned(self) -> HandshakePayload<'static> {
2615        use HandshakePayload::*;
2616
2617        match self {
2618            HelloRequest => HelloRequest,
2619            ClientHello(x) => ClientHello(x),
2620            ServerHello(x) => ServerHello(x),
2621            HelloRetryRequest(x) => HelloRetryRequest(x),
2622            Certificate(x) => Certificate(x.into_owned()),
2623            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2624            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2625            ServerKeyExchange(x) => ServerKeyExchange(x),
2626            CertificateRequest(x) => CertificateRequest(x),
2627            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2628            CertificateVerify(x) => CertificateVerify(x),
2629            ServerHelloDone => ServerHelloDone,
2630            EndOfEarlyData => EndOfEarlyData,
2631            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2632            NewSessionTicket(x) => NewSessionTicket(x),
2633            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2634            EncryptedExtensions(x) => EncryptedExtensions(x),
2635            KeyUpdate(x) => KeyUpdate(x),
2636            Finished(x) => Finished(x.into_owned()),
2637            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2638            MessageHash(x) => MessageHash(x.into_owned()),
2639            Unknown(x) => Unknown(x.into_owned()),
2640        }
2641    }
2642}
2643
2644#[derive(Debug)]
2645pub struct HandshakeMessagePayload<'a> {
2646    pub typ: HandshakeType,
2647    pub payload: HandshakePayload<'a>,
2648}
2649
2650impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2651    fn encode(&self, bytes: &mut Vec<u8>) {
2652        self.payload_encode(bytes, Encoding::Standard);
2653    }
2654
2655    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2656        Self::read_version(r, ProtocolVersion::TLSv1_2)
2657    }
2658}
2659
2660impl<'a> HandshakeMessagePayload<'a> {
2661    pub(crate) fn read_version(
2662        r: &mut Reader<'a>,
2663        vers: ProtocolVersion,
2664    ) -> Result<Self, InvalidMessage> {
2665        let mut typ = HandshakeType::read(r)?;
2666        let len = codec::u24::read(r)?.0 as usize;
2667        let mut sub = r.sub(len)?;
2668
2669        let payload = match typ {
2670            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2671            HandshakeType::ClientHello => {
2672                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2673            }
2674            HandshakeType::ServerHello => {
2675                let version = ProtocolVersion::read(&mut sub)?;
2676                let random = Random::read(&mut sub)?;
2677
2678                if random == HELLO_RETRY_REQUEST_RANDOM {
2679                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2680                    hrr.legacy_version = version;
2681                    typ = HandshakeType::HelloRetryRequest;
2682                    HandshakePayload::HelloRetryRequest(hrr)
2683                } else {
2684                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2685                    shp.legacy_version = version;
2686                    shp.random = random;
2687                    HandshakePayload::ServerHello(shp)
2688                }
2689            }
2690            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2691                let p = CertificatePayloadTls13::read(&mut sub)?;
2692                HandshakePayload::CertificateTls13(p)
2693            }
2694            HandshakeType::Certificate => {
2695                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2696            }
2697            HandshakeType::ServerKeyExchange => {
2698                let p = ServerKeyExchangePayload::read(&mut sub)?;
2699                HandshakePayload::ServerKeyExchange(p)
2700            }
2701            HandshakeType::ServerHelloDone => {
2702                sub.expect_empty("ServerHelloDone")?;
2703                HandshakePayload::ServerHelloDone
2704            }
2705            HandshakeType::ClientKeyExchange => {
2706                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2707            }
2708            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2709                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2710                HandshakePayload::CertificateRequestTls13(p)
2711            }
2712            HandshakeType::CertificateRequest => {
2713                let p = CertificateRequestPayload::read(&mut sub)?;
2714                HandshakePayload::CertificateRequest(p)
2715            }
2716            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2717                CompressedCertificatePayload::read(&mut sub)?,
2718            ),
2719            HandshakeType::CertificateVerify => {
2720                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2721            }
2722            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2723                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2724                HandshakePayload::NewSessionTicketTls13(p)
2725            }
2726            HandshakeType::NewSessionTicket => {
2727                let p = NewSessionTicketPayload::read(&mut sub)?;
2728                HandshakePayload::NewSessionTicket(p)
2729            }
2730            HandshakeType::EncryptedExtensions => {
2731                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2732            }
2733            HandshakeType::KeyUpdate => {
2734                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2735            }
2736            HandshakeType::EndOfEarlyData => {
2737                sub.expect_empty("EndOfEarlyData")?;
2738                HandshakePayload::EndOfEarlyData
2739            }
2740            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2741            HandshakeType::CertificateStatus => {
2742                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2743            }
2744            HandshakeType::MessageHash => {
2745                // does not appear on the wire
2746                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2747            }
2748            HandshakeType::HelloRetryRequest => {
2749                // not legal on wire
2750                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2751            }
2752            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2753        };
2754
2755        sub.expect_empty("HandshakeMessagePayload")
2756            .map(|_| Self { typ, payload })
2757    }
2758
2759    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2760        let mut ret = self.get_encoding();
2761
2762        let binder_len = match self.payload {
2763            HandshakePayload::ClientHello(ref ch) => match ch.extensions.last() {
2764                Some(ClientExtension::PresharedKey(ref offer)) => {
2765                    let mut binders_encoding = Vec::new();
2766                    offer
2767                        .binders
2768                        .encode(&mut binders_encoding);
2769                    binders_encoding.len()
2770                }
2771                _ => 0,
2772            },
2773            _ => 0,
2774        };
2775
2776        let ret_len = ret.len() - binder_len;
2777        ret.truncate(ret_len);
2778        ret
2779    }
2780
2781    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2782        // output type, length, and encoded payload
2783        match self.typ {
2784            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2785            _ => self.typ,
2786        }
2787        .encode(bytes);
2788
2789        let nested = LengthPrefixedBuffer::new(
2790            ListLength::U24 {
2791                max: usize::MAX,
2792                error: InvalidMessage::MessageTooLarge,
2793            },
2794            bytes,
2795        );
2796
2797        match &self.payload {
2798            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2799            // differently based on the purpose of the encoding.
2800            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2801            HandshakePayload::HelloRetryRequest(payload) => {
2802                payload.payload_encode(nested.buf, encoding)
2803            }
2804
2805            // All other payload types are encoded the same regardless of purpose.
2806            _ => self.payload.encode(nested.buf),
2807        }
2808    }
2809
2810    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2811        Self {
2812            typ: HandshakeType::MessageHash,
2813            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2814        }
2815    }
2816
2817    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2818        let Self { typ, payload } = self;
2819        HandshakeMessagePayload {
2820            typ,
2821            payload: payload.into_owned(),
2822        }
2823    }
2824}
2825
2826#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2827pub struct HpkeSymmetricCipherSuite {
2828    pub kdf_id: HpkeKdf,
2829    pub aead_id: HpkeAead,
2830}
2831
2832impl Codec<'_> for HpkeSymmetricCipherSuite {
2833    fn encode(&self, bytes: &mut Vec<u8>) {
2834        self.kdf_id.encode(bytes);
2835        self.aead_id.encode(bytes);
2836    }
2837
2838    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2839        Ok(Self {
2840            kdf_id: HpkeKdf::read(r)?,
2841            aead_id: HpkeAead::read(r)?,
2842        })
2843    }
2844}
2845
2846impl TlsListElement for HpkeSymmetricCipherSuite {
2847    const SIZE_LEN: ListLength = ListLength::U16;
2848}
2849
2850#[derive(Clone, Debug, PartialEq)]
2851pub struct HpkeKeyConfig {
2852    pub config_id: u8,
2853    pub kem_id: HpkeKem,
2854    pub public_key: PayloadU16,
2855    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2856}
2857
2858impl Codec<'_> for HpkeKeyConfig {
2859    fn encode(&self, bytes: &mut Vec<u8>) {
2860        self.config_id.encode(bytes);
2861        self.kem_id.encode(bytes);
2862        self.public_key.encode(bytes);
2863        self.symmetric_cipher_suites
2864            .encode(bytes);
2865    }
2866
2867    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2868        Ok(Self {
2869            config_id: u8::read(r)?,
2870            kem_id: HpkeKem::read(r)?,
2871            public_key: PayloadU16::read(r)?,
2872            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2873        })
2874    }
2875}
2876
2877#[derive(Clone, Debug, PartialEq)]
2878pub struct EchConfigContents {
2879    pub key_config: HpkeKeyConfig,
2880    pub maximum_name_length: u8,
2881    pub public_name: DnsName<'static>,
2882    pub extensions: Vec<EchConfigExtension>,
2883}
2884
2885impl EchConfigContents {
2886    /// Returns true if there is more than one extension of a given
2887    /// type.
2888    pub(crate) fn has_duplicate_extension(&self) -> bool {
2889        has_duplicates::<_, _, u16>(
2890            self.extensions
2891                .iter()
2892                .map(|ext| ext.ext_type()),
2893        )
2894    }
2895
2896    /// Returns true if there is at least one mandatory unsupported extension.
2897    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
2898        self.extensions
2899            .iter()
2900            // An extension is considered mandatory if the high bit of its type is set.
2901            .any(|ext| {
2902                matches!(ext.ext_type(), ExtensionType::Unknown(_))
2903                    && u16::from(ext.ext_type()) & 0x8000 != 0
2904            })
2905    }
2906}
2907
2908impl Codec<'_> for EchConfigContents {
2909    fn encode(&self, bytes: &mut Vec<u8>) {
2910        self.key_config.encode(bytes);
2911        self.maximum_name_length.encode(bytes);
2912        let dns_name = &self.public_name.borrow();
2913        PayloadU8::encode_slice(dns_name.as_ref().as_ref(), bytes);
2914        self.extensions.encode(bytes);
2915    }
2916
2917    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2918        Ok(Self {
2919            key_config: HpkeKeyConfig::read(r)?,
2920            maximum_name_length: u8::read(r)?,
2921            public_name: {
2922                DnsName::try_from(PayloadU8::read(r)?.0.as_slice())
2923                    .map_err(|_| InvalidMessage::InvalidServerName)?
2924                    .to_owned()
2925            },
2926            extensions: Vec::read(r)?,
2927        })
2928    }
2929}
2930
2931/// An encrypted client hello (ECH) config.
2932#[derive(Clone, Debug, PartialEq)]
2933pub enum EchConfigPayload {
2934    /// A recognized V18 ECH configuration.
2935    V18(EchConfigContents),
2936    /// An unknown version ECH configuration.
2937    Unknown {
2938        version: EchVersion,
2939        contents: PayloadU16,
2940    },
2941}
2942
2943impl TlsListElement for EchConfigPayload {
2944    const SIZE_LEN: ListLength = ListLength::U16;
2945}
2946
2947impl Codec<'_> for EchConfigPayload {
2948    fn encode(&self, bytes: &mut Vec<u8>) {
2949        match self {
2950            Self::V18(c) => {
2951                // Write the version, the length, and the contents.
2952                EchVersion::V18.encode(bytes);
2953                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2954                c.encode(inner.buf);
2955            }
2956            Self::Unknown { version, contents } => {
2957                // Unknown configuration versions are opaque.
2958                version.encode(bytes);
2959                contents.encode(bytes);
2960            }
2961        }
2962    }
2963
2964    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2965        let version = EchVersion::read(r)?;
2966        let length = u16::read(r)?;
2967        let mut contents = r.sub(length as usize)?;
2968
2969        Ok(match version {
2970            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
2971            _ => {
2972                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
2973                let data = PayloadU16::new(contents.rest().into());
2974                Self::Unknown {
2975                    version,
2976                    contents: data,
2977                }
2978            }
2979        })
2980    }
2981}
2982
2983#[derive(Clone, Debug, PartialEq)]
2984pub enum EchConfigExtension {
2985    Unknown(UnknownExtension),
2986}
2987
2988impl EchConfigExtension {
2989    pub(crate) fn ext_type(&self) -> ExtensionType {
2990        match *self {
2991            Self::Unknown(ref r) => r.typ,
2992        }
2993    }
2994}
2995
2996impl Codec<'_> for EchConfigExtension {
2997    fn encode(&self, bytes: &mut Vec<u8>) {
2998        self.ext_type().encode(bytes);
2999
3000        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3001        match *self {
3002            Self::Unknown(ref r) => r.encode(nested.buf),
3003        }
3004    }
3005
3006    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3007        let typ = ExtensionType::read(r)?;
3008        let len = u16::read(r)? as usize;
3009        let mut sub = r.sub(len)?;
3010
3011        #[allow(clippy::match_single_binding)] // Future-proofing.
3012        let ext = match typ {
3013            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
3014        };
3015
3016        sub.expect_empty("EchConfigExtension")
3017            .map(|_| ext)
3018    }
3019}
3020
3021impl TlsListElement for EchConfigExtension {
3022    const SIZE_LEN: ListLength = ListLength::U16;
3023}
3024
3025/// Representation of the `ECHClientHello` client extension specified in
3026/// [draft-ietf-tls-esni Section 5].
3027///
3028/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3029#[derive(Clone, Debug)]
3030pub enum EncryptedClientHello {
3031    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
3032    Outer(EncryptedClientHelloOuter),
3033    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
3034    ///
3035    /// This variant has no payload.
3036    Inner,
3037}
3038
3039impl Codec<'_> for EncryptedClientHello {
3040    fn encode(&self, bytes: &mut Vec<u8>) {
3041        match self {
3042            Self::Outer(payload) => {
3043                EchClientHelloType::ClientHelloOuter.encode(bytes);
3044                payload.encode(bytes);
3045            }
3046            Self::Inner => {
3047                EchClientHelloType::ClientHelloInner.encode(bytes);
3048                // Empty payload.
3049            }
3050        }
3051    }
3052
3053    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3054        match EchClientHelloType::read(r)? {
3055            EchClientHelloType::ClientHelloOuter => {
3056                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3057            }
3058            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3059            _ => Err(InvalidMessage::InvalidContentType),
3060        }
3061    }
3062}
3063
3064/// Representation of the ECHClientHello extension with type outer specified in
3065/// [draft-ietf-tls-esni Section 5].
3066///
3067/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3068#[derive(Clone, Debug)]
3069pub struct EncryptedClientHelloOuter {
3070    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3071    /// ECHConfigContents.cipher_suites list.
3072    pub cipher_suite: HpkeSymmetricCipherSuite,
3073    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3074    pub config_id: u8,
3075    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3076    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3077    pub enc: PayloadU16,
3078    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3079    pub payload: PayloadU16,
3080}
3081
3082impl Codec<'_> for EncryptedClientHelloOuter {
3083    fn encode(&self, bytes: &mut Vec<u8>) {
3084        self.cipher_suite.encode(bytes);
3085        self.config_id.encode(bytes);
3086        self.enc.encode(bytes);
3087        self.payload.encode(bytes);
3088    }
3089
3090    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3091        Ok(Self {
3092            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3093            config_id: u8::read(r)?,
3094            enc: PayloadU16::read(r)?,
3095            payload: PayloadU16::read(r)?,
3096        })
3097    }
3098}
3099
3100/// Representation of the ECHEncryptedExtensions extension specified in
3101/// [draft-ietf-tls-esni Section 5].
3102///
3103/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3104#[derive(Clone, Debug)]
3105pub struct ServerEncryptedClientHello {
3106    pub(crate) retry_configs: Vec<EchConfigPayload>,
3107}
3108
3109impl Codec<'_> for ServerEncryptedClientHello {
3110    fn encode(&self, bytes: &mut Vec<u8>) {
3111        self.retry_configs.encode(bytes);
3112    }
3113
3114    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3115        Ok(Self {
3116            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3117        })
3118    }
3119}
3120
3121/// The method of encoding to use for a handshake message.
3122///
3123/// In some cases a handshake message may be encoded differently depending on the purpose
3124/// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded
3125/// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation.
3126pub(crate) enum Encoding {
3127    /// Standard RFC 8446 encoding.
3128    Standard,
3129    /// Encoding for ECH confirmation.
3130    EchConfirmation,
3131    /// Encoding for ECH inner client hello.
3132    EchInnerHello { to_compress: Vec<ExtensionType> },
3133}
3134
3135fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3136    let mut seen = BTreeSet::new();
3137
3138    for x in iter {
3139        if !seen.insert(x.into()) {
3140            return true;
3141        }
3142    }
3143
3144    false
3145}
3146
3147#[cfg(test)]
3148mod tests {
3149    use super::*;
3150
3151    #[test]
3152    fn test_ech_config_dupe_exts() {
3153        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3154            typ: ExtensionType::Unknown(0x42),
3155            payload: Payload::new(vec![0x42]),
3156        });
3157        let mut config = config_template();
3158        config
3159            .extensions
3160            .push(unknown_ext.clone());
3161        config.extensions.push(unknown_ext);
3162
3163        assert!(config.has_duplicate_extension());
3164        assert!(!config.has_unknown_mandatory_extension());
3165    }
3166
3167    #[test]
3168    fn test_ech_config_mandatory_exts() {
3169        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3170            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3171            payload: Payload::new(vec![0x42]),
3172        });
3173        let mut config = config_template();
3174        config
3175            .extensions
3176            .push(mandatory_unknown_ext);
3177
3178        assert!(!config.has_duplicate_extension());
3179        assert!(config.has_unknown_mandatory_extension());
3180    }
3181
3182    fn config_template() -> EchConfigContents {
3183        EchConfigContents {
3184            key_config: HpkeKeyConfig {
3185                config_id: 0,
3186                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3187                public_key: PayloadU16(b"xxx".into()),
3188                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3189                    kdf_id: HpkeKdf::HKDF_SHA256,
3190                    aead_id: HpkeAead::AES_128_GCM,
3191                }],
3192            },
3193            maximum_name_length: 0,
3194            public_name: DnsName::try_from("example.com").unwrap(),
3195            extensions: vec![],
3196        }
3197    }
3198}