rustls/msgs/
handshake.rs

1use alloc::collections::BTreeSet;
2#[cfg(feature = "logging")]
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6use core::ops::Deref;
7use core::{fmt, iter};
8
9use pki_types::{CertificateDer, DnsName};
10
11#[cfg(feature = "tls12")]
12use crate::crypto::ActiveKeyExchange;
13use crate::crypto::SecureRandom;
14use crate::enums::{
15    CertificateCompressionAlgorithm, CipherSuite, EchClientHelloType, HandshakeType,
16    ProtocolVersion, SignatureScheme,
17};
18use crate::error::InvalidMessage;
19#[cfg(feature = "tls12")]
20use crate::ffdhe_groups::FfdheGroup;
21use crate::log::warn;
22use crate::msgs::base::{MaybeEmpty, NonEmpty, Payload, PayloadU8, PayloadU16, PayloadU24};
23use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
24use crate::msgs::enums::{
25    CertificateStatusType, CertificateType, ClientCertificateType, Compression, ECCurveType,
26    ECPointFormat, EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest,
27    NamedGroup, PskKeyExchangeMode, ServerNameType,
28};
29use crate::rand;
30use crate::sync::Arc;
31use crate::verify::DigitallySignedStruct;
32use crate::x509::wrap_in_sequence;
33
34/// Create a newtype wrapper around a given type.
35///
36/// This is used to create newtypes for the various TLS message types which is used to wrap
37/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
38/// anything other than access to the underlying bytes.
39macro_rules! wrapped_payload(
40  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident$(<$inner_ty:ty>)?,) => {
41    $(#[$comment])*
42    #[derive(Clone, Debug)]
43    $vis struct $name($inner$(<$inner_ty>)?);
44
45    impl From<Vec<u8>> for $name {
46        fn from(v: Vec<u8>) -> Self {
47            Self($inner::new(v))
48        }
49    }
50
51    impl AsRef<[u8]> for $name {
52        fn as_ref(&self) -> &[u8] {
53            self.0.0.as_slice()
54        }
55    }
56
57    impl Codec<'_> for $name {
58        fn encode(&self, bytes: &mut Vec<u8>) {
59            self.0.encode(bytes);
60        }
61
62        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
63            Ok(Self($inner::read(r)?))
64        }
65    }
66  }
67);
68
69#[derive(Clone, Copy, Eq, PartialEq)]
70pub struct Random(pub(crate) [u8; 32]);
71
72impl fmt::Debug for Random {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        super::base::hex(f, &self.0)
75    }
76}
77
78static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
79    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
80    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
81]);
82
83static ZERO_RANDOM: Random = Random([0u8; 32]);
84
85impl Codec<'_> for Random {
86    fn encode(&self, bytes: &mut Vec<u8>) {
87        bytes.extend_from_slice(&self.0);
88    }
89
90    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
91        let Some(bytes) = r.take(32) else {
92            return Err(InvalidMessage::MissingData("Random"));
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(self.as_ref());
148    }
149
150    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let Some(bytes) = r.take(len) else {
157            return Err(InvalidMessage::MissingData("SessionID"));
158        };
159
160        let mut out = [0u8; 32];
161        out[..len].clone_from_slice(&bytes[..len]);
162        Ok(Self { data: out, len })
163    }
164}
165
166impl SessionId {
167    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
168        let mut data = [0u8; 32];
169        secure_random.fill(&mut data)?;
170        Ok(Self { data, len: 32 })
171    }
172
173    pub(crate) fn empty() -> Self {
174        Self {
175            data: [0u8; 32],
176            len: 0,
177        }
178    }
179
180    #[cfg(feature = "tls12")]
181    pub(crate) fn is_empty(&self) -> bool {
182        self.len == 0
183    }
184}
185
186impl AsRef<[u8]> for SessionId {
187    fn as_ref(&self) -> &[u8] {
188        &self.data[..self.len]
189    }
190}
191
192#[derive(Clone, Debug, PartialEq)]
193pub struct UnknownExtension {
194    pub(crate) typ: ExtensionType,
195    pub(crate) payload: Payload<'static>,
196}
197
198impl UnknownExtension {
199    fn encode(&self, bytes: &mut Vec<u8>) {
200        self.payload.encode(bytes);
201    }
202
203    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
204        let payload = Payload::read(r).into_owned();
205        Self { typ, payload }
206    }
207}
208
209/// RFC8422: `ECPointFormat ec_point_format_list<1..2^8-1>`
210impl TlsListElement for ECPointFormat {
211    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
212        empty_error: InvalidMessage::IllegalEmptyList("ECPointFormats"),
213    };
214}
215
216/// RFC8422: `NamedCurve named_curve_list<2..2^16-1>`
217impl TlsListElement for NamedGroup {
218    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
219        empty_error: InvalidMessage::IllegalEmptyList("NamedGroups"),
220    };
221}
222
223/// RFC8446: `SignatureScheme supported_signature_algorithms<2..2^16-2>;`
224impl TlsListElement for SignatureScheme {
225    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
226        empty_error: InvalidMessage::NoSignatureSchemes,
227    };
228}
229
230#[derive(Clone, Debug)]
231pub enum ServerNamePayload<'a> {
232    /// A successfully decoded value:
233    SingleDnsName(DnsName<'a>),
234
235    /// A DNS name which was actually an IP address
236    IpAddress,
237
238    /// A successfully decoded, but syntactically-invalid value.
239    Invalid,
240}
241
242impl ServerNamePayload<'_> {
243    fn into_owned(self) -> ServerNamePayload<'static> {
244        match self {
245            Self::SingleDnsName(d) => ServerNamePayload::SingleDnsName(d.to_owned()),
246            Self::IpAddress => ServerNamePayload::IpAddress,
247            Self::Invalid => ServerNamePayload::Invalid,
248        }
249    }
250
251    /// RFC6066: `ServerName server_name_list<1..2^16-1>`
252    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
253        empty_error: InvalidMessage::IllegalEmptyList("ServerNames"),
254    };
255}
256
257/// Simplified encoding/decoding for a `ServerName` extension payload to/from `DnsName`
258///
259/// This is possible because:
260///
261/// - the spec (RFC6066) disallows multiple names for a given name type
262/// - name types other than ServerNameType::HostName are not defined, and they and
263///   any data that follows them cannot be skipped over.
264impl<'a> Codec<'a> for ServerNamePayload<'a> {
265    fn encode(&self, bytes: &mut Vec<u8>) {
266        let server_name_list = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
267
268        let ServerNamePayload::SingleDnsName(dns_name) = self else {
269            return;
270        };
271
272        ServerNameType::HostName.encode(server_name_list.buf);
273        let name_slice = dns_name.as_ref().as_bytes();
274        (name_slice.len() as u16).encode(server_name_list.buf);
275        server_name_list
276            .buf
277            .extend_from_slice(name_slice);
278    }
279
280    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
281        let mut found = None;
282
283        let len = Self::SIZE_LEN.read(r)?;
284        let mut sub = r.sub(len)?;
285
286        while sub.any_left() {
287            let typ = ServerNameType::read(&mut sub)?;
288
289            let payload = match typ {
290                ServerNameType::HostName => HostNamePayload::read(&mut sub)?,
291                _ => {
292                    // Consume remainder of extension bytes.  Since the length of the item
293                    // is an unknown encoding, we cannot continue.
294                    sub.rest();
295                    break;
296                }
297            };
298
299            // "The ServerNameList MUST NOT contain more than one name of
300            // the same name_type." - RFC6066
301            if found.is_some() {
302                warn!("Illegal SNI extension: duplicate host_name received");
303                return Err(InvalidMessage::InvalidServerName);
304            }
305
306            found = match payload {
307                HostNamePayload::HostName(dns_name) => {
308                    Some(Self::SingleDnsName(dns_name.to_owned()))
309                }
310
311                HostNamePayload::IpAddress(_invalid) => {
312                    warn!(
313                        "Illegal SNI extension: ignoring IP address presented as hostname ({:?})",
314                        _invalid
315                    );
316                    Some(Self::IpAddress)
317                }
318
319                HostNamePayload::Invalid(_invalid) => {
320                    warn!(
321                        "Illegal SNI hostname received {:?}",
322                        String::from_utf8_lossy(&_invalid.0)
323                    );
324                    Some(Self::Invalid)
325                }
326            };
327        }
328
329        Ok(found.unwrap_or(Self::Invalid))
330    }
331}
332
333impl<'a> From<&DnsName<'a>> for ServerNamePayload<'static> {
334    fn from(value: &DnsName<'a>) -> Self {
335        Self::SingleDnsName(trim_hostname_trailing_dot_for_sni(value))
336    }
337}
338
339#[derive(Clone, Debug)]
340pub(crate) enum HostNamePayload {
341    HostName(DnsName<'static>),
342    IpAddress(PayloadU16<NonEmpty>),
343    Invalid(PayloadU16<NonEmpty>),
344}
345
346impl HostNamePayload {
347    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
348        use pki_types::ServerName;
349        let raw = PayloadU16::<NonEmpty>::read(r)?;
350
351        match ServerName::try_from(raw.0.as_slice()) {
352            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
353            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
354            Ok(_) | Err(_) => Ok(Self::Invalid(raw)),
355        }
356    }
357}
358
359wrapped_payload!(
360    /// RFC7301: `opaque ProtocolName<1..2^8-1>;`
361    pub struct ProtocolName, PayloadU8<NonEmpty>,
362);
363
364/// RFC7301: `ProtocolName protocol_name_list<2..2^16-1>`
365impl TlsListElement for ProtocolName {
366    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
367        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
368    };
369}
370
371/// RFC7301 encodes a single protocol name as `Vec<ProtocolName>`
372#[derive(Clone, Debug)]
373pub struct SingleProtocolName(ProtocolName);
374
375impl SingleProtocolName {
376    pub(crate) fn new(bytes: Vec<u8>) -> Self {
377        Self(ProtocolName::from(bytes))
378    }
379
380    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
381        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
382    };
383}
384
385impl Codec<'_> for SingleProtocolName {
386    fn encode(&self, bytes: &mut Vec<u8>) {
387        let body = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
388        self.0.encode(body.buf);
389    }
390
391    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
392        let len = Self::SIZE_LEN.read(reader)?;
393        let mut sub = reader.sub(len)?;
394
395        let item = ProtocolName::read(&mut sub)?;
396
397        if sub.any_left() {
398            Err(InvalidMessage::TrailingData("SingleProtocolName"))
399        } else {
400            Ok(Self(item))
401        }
402    }
403}
404
405impl AsRef<[u8]> for SingleProtocolName {
406    fn as_ref(&self) -> &[u8] {
407        self.0.as_ref()
408    }
409}
410
411// --- TLS 1.3 Key shares ---
412#[derive(Clone, Debug)]
413pub struct KeyShareEntry {
414    pub(crate) group: NamedGroup,
415    /// RFC8446: `opaque key_exchange<1..2^16-1>;`
416    pub(crate) payload: PayloadU16<NonEmpty>,
417}
418
419impl KeyShareEntry {
420    pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
421        Self {
422            group,
423            payload: PayloadU16::new(payload.into()),
424        }
425    }
426
427    pub fn group(&self) -> NamedGroup {
428        self.group
429    }
430}
431
432impl Codec<'_> for KeyShareEntry {
433    fn encode(&self, bytes: &mut Vec<u8>) {
434        self.group.encode(bytes);
435        self.payload.encode(bytes);
436    }
437
438    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
439        let group = NamedGroup::read(r)?;
440        let payload = PayloadU16::read(r)?;
441
442        Ok(Self { group, payload })
443    }
444}
445
446// --- TLS 1.3 PresharedKey offers ---
447#[derive(Clone, Debug)]
448pub(crate) struct PresharedKeyIdentity {
449    /// RFC8446: `opaque identity<1..2^16-1>;`
450    pub(crate) identity: PayloadU16<NonEmpty>,
451    pub(crate) obfuscated_ticket_age: u32,
452}
453
454impl PresharedKeyIdentity {
455    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
456        Self {
457            identity: PayloadU16::new(id),
458            obfuscated_ticket_age: age,
459        }
460    }
461}
462
463impl Codec<'_> for PresharedKeyIdentity {
464    fn encode(&self, bytes: &mut Vec<u8>) {
465        self.identity.encode(bytes);
466        self.obfuscated_ticket_age.encode(bytes);
467    }
468
469    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
470        Ok(Self {
471            identity: PayloadU16::read(r)?,
472            obfuscated_ticket_age: u32::read(r)?,
473        })
474    }
475}
476
477/// RFC8446: `PskIdentity identities<7..2^16-1>;`
478impl TlsListElement for PresharedKeyIdentity {
479    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
480        empty_error: InvalidMessage::IllegalEmptyList("PskIdentities"),
481    };
482}
483
484wrapped_payload!(
485    /// RFC8446: `opaque PskBinderEntry<32..255>;`
486    pub(crate) struct PresharedKeyBinder, PayloadU8<NonEmpty>,
487);
488
489/// RFC8446: `PskBinderEntry binders<33..2^16-1>;`
490impl TlsListElement for PresharedKeyBinder {
491    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
492        empty_error: InvalidMessage::IllegalEmptyList("PskBinders"),
493    };
494}
495
496#[derive(Clone, Debug)]
497pub struct PresharedKeyOffer {
498    pub(crate) identities: Vec<PresharedKeyIdentity>,
499    pub(crate) binders: Vec<PresharedKeyBinder>,
500}
501
502impl PresharedKeyOffer {
503    /// Make a new one with one entry.
504    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
505        Self {
506            identities: vec![id],
507            binders: vec![PresharedKeyBinder::from(binder)],
508        }
509    }
510}
511
512impl Codec<'_> for PresharedKeyOffer {
513    fn encode(&self, bytes: &mut Vec<u8>) {
514        self.identities.encode(bytes);
515        self.binders.encode(bytes);
516    }
517
518    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
519        Ok(Self {
520            identities: Vec::read(r)?,
521            binders: Vec::read(r)?,
522        })
523    }
524}
525
526// --- RFC6066 certificate status request ---
527wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
528
529/// RFC6066: `ResponderID responder_id_list<0..2^16-1>;`
530impl TlsListElement for ResponderId {
531    const SIZE_LEN: ListLength = ListLength::U16;
532}
533
534#[derive(Clone, Debug)]
535pub struct OcspCertificateStatusRequest {
536    pub(crate) responder_ids: Vec<ResponderId>,
537    pub(crate) extensions: PayloadU16,
538}
539
540impl Codec<'_> for OcspCertificateStatusRequest {
541    fn encode(&self, bytes: &mut Vec<u8>) {
542        CertificateStatusType::OCSP.encode(bytes);
543        self.responder_ids.encode(bytes);
544        self.extensions.encode(bytes);
545    }
546
547    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
548        Ok(Self {
549            responder_ids: Vec::read(r)?,
550            extensions: PayloadU16::read(r)?,
551        })
552    }
553}
554
555#[derive(Clone, Debug)]
556pub enum CertificateStatusRequest {
557    Ocsp(OcspCertificateStatusRequest),
558    Unknown((CertificateStatusType, Payload<'static>)),
559}
560
561impl Codec<'_> for CertificateStatusRequest {
562    fn encode(&self, bytes: &mut Vec<u8>) {
563        match self {
564            Self::Ocsp(r) => r.encode(bytes),
565            Self::Unknown((typ, payload)) => {
566                typ.encode(bytes);
567                payload.encode(bytes);
568            }
569        }
570    }
571
572    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
573        let typ = CertificateStatusType::read(r)?;
574
575        match typ {
576            CertificateStatusType::OCSP => {
577                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
578                Ok(Self::Ocsp(ocsp_req))
579            }
580            _ => {
581                let data = Payload::read(r).into_owned();
582                Ok(Self::Unknown((typ, data)))
583            }
584        }
585    }
586}
587
588impl CertificateStatusRequest {
589    pub(crate) fn build_ocsp() -> Self {
590        let ocsp = OcspCertificateStatusRequest {
591            responder_ids: Vec::new(),
592            extensions: PayloadU16::empty(),
593        };
594        Self::Ocsp(ocsp)
595    }
596}
597
598// ---
599
600/// RFC8446: `PskKeyExchangeMode ke_modes<1..255>;`
601impl TlsListElement for PskKeyExchangeMode {
602    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
603        empty_error: InvalidMessage::IllegalEmptyList("PskKeyExchangeModes"),
604    };
605}
606
607/// RFC8446: `KeyShareEntry client_shares<0..2^16-1>;`
608impl TlsListElement for KeyShareEntry {
609    const SIZE_LEN: ListLength = ListLength::U16;
610}
611
612/// The body of the `SupportedVersions` extension when it appears in a
613/// `ClientHello`
614///
615/// This is documented as a preference-order vector, but we (as a server)
616/// ignore the preference of the client.
617///
618/// RFC8446: `ProtocolVersion versions<2..254>;`
619#[derive(Clone, Copy, Debug, Default)]
620pub struct SupportedProtocolVersions {
621    pub(crate) tls13: bool,
622    pub(crate) tls12: bool,
623}
624
625impl SupportedProtocolVersions {
626    /// Return true if `filter` returns true for any enabled version.
627    pub(crate) fn any(&self, filter: impl Fn(ProtocolVersion) -> bool) -> bool {
628        if self.tls13 && filter(ProtocolVersion::TLSv1_3) {
629            return true;
630        }
631        if self.tls12 && filter(ProtocolVersion::TLSv1_2) {
632            return true;
633        }
634        false
635    }
636
637    const LIST_LENGTH: ListLength = ListLength::NonZeroU8 {
638        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
639    };
640}
641
642impl Codec<'_> for SupportedProtocolVersions {
643    fn encode(&self, bytes: &mut Vec<u8>) {
644        let inner = LengthPrefixedBuffer::new(Self::LIST_LENGTH, bytes);
645        if self.tls13 {
646            ProtocolVersion::TLSv1_3.encode(inner.buf);
647        }
648        if self.tls12 {
649            ProtocolVersion::TLSv1_2.encode(inner.buf);
650        }
651    }
652
653    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
654        let len = Self::LIST_LENGTH.read(reader)?;
655        let mut sub = reader.sub(len)?;
656
657        let mut tls12 = false;
658        let mut tls13 = false;
659
660        while sub.any_left() {
661            match ProtocolVersion::read(&mut sub)? {
662                ProtocolVersion::TLSv1_3 => tls13 = true,
663                ProtocolVersion::TLSv1_2 => tls12 = true,
664                _ => continue,
665            };
666        }
667
668        Ok(Self { tls13, tls12 })
669    }
670}
671
672impl TlsListElement for ProtocolVersion {
673    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
674        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
675    };
676}
677
678/// RFC7250: `CertificateType client_certificate_types<1..2^8-1>;`
679///
680/// Ditto `CertificateType server_certificate_types<1..2^8-1>;`
681impl TlsListElement for CertificateType {
682    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
683        empty_error: InvalidMessage::IllegalEmptyList("CertificateTypes"),
684    };
685}
686
687/// RFC8879: `CertificateCompressionAlgorithm algorithms<2..2^8-2>;`
688impl TlsListElement for CertificateCompressionAlgorithm {
689    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
690        empty_error: InvalidMessage::IllegalEmptyList("CertificateCompressionAlgorithms"),
691    };
692}
693
694#[derive(Clone, Debug)]
695pub enum ClientExtension {
696    EcPointFormats(Vec<ECPointFormat>),
697    NamedGroups(Vec<NamedGroup>),
698    SignatureAlgorithms(Vec<SignatureScheme>),
699    ServerName(ServerNamePayload<'static>),
700    SessionTicket(ClientSessionTicket),
701    Protocols(Vec<ProtocolName>),
702    SupportedVersions(SupportedProtocolVersions),
703    KeyShare(Vec<KeyShareEntry>),
704    PresharedKeyModes(Vec<PskKeyExchangeMode>),
705    PresharedKey(PresharedKeyOffer),
706    Cookie(PayloadU16<NonEmpty>),
707    ExtendedMasterSecretRequest,
708    CertificateStatusRequest(CertificateStatusRequest),
709    ServerCertTypes(Vec<CertificateType>),
710    ClientCertTypes(Vec<CertificateType>),
711    TransportParameters(Vec<u8>),
712    TransportParametersDraft(Vec<u8>),
713    EarlyData,
714    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
715    EncryptedClientHello(EncryptedClientHello),
716    EncryptedClientHelloOuterExtensions(Vec<ExtensionType>),
717    AuthorityNames(Vec<DistinguishedName>),
718    Unknown(UnknownExtension),
719}
720
721impl ClientExtension {
722    pub(crate) fn ext_type(&self) -> ExtensionType {
723        match self {
724            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
725            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
726            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
727            Self::ServerName(_) => ExtensionType::ServerName,
728            Self::SessionTicket(_) => ExtensionType::SessionTicket,
729            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
730            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
731            Self::KeyShare(_) => ExtensionType::KeyShare,
732            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
733            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
734            Self::Cookie(_) => ExtensionType::Cookie,
735            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
736            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
737            Self::ClientCertTypes(_) => ExtensionType::ClientCertificateType,
738            Self::ServerCertTypes(_) => ExtensionType::ServerCertificateType,
739            Self::TransportParameters(_) => ExtensionType::TransportParameters,
740            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
741            Self::EarlyData => ExtensionType::EarlyData,
742            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
743            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
744            Self::EncryptedClientHelloOuterExtensions(_) => {
745                ExtensionType::EncryptedClientHelloOuterExtensions
746            }
747            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
748            Self::Unknown(r) => r.typ,
749        }
750    }
751}
752
753impl Codec<'_> for ClientExtension {
754    fn encode(&self, bytes: &mut Vec<u8>) {
755        self.ext_type().encode(bytes);
756
757        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
758        match self {
759            Self::EcPointFormats(r) => r.encode(nested.buf),
760            Self::NamedGroups(r) => r.encode(nested.buf),
761            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
762            Self::ServerName(r) => r.encode(nested.buf),
763            Self::SessionTicket(ClientSessionTicket::Request)
764            | Self::ExtendedMasterSecretRequest
765            | Self::EarlyData => {}
766            Self::SessionTicket(ClientSessionTicket::Offer(r)) => r.encode(nested.buf),
767            Self::Protocols(r) => r.encode(nested.buf),
768            Self::SupportedVersions(r) => r.encode(nested.buf),
769            Self::KeyShare(r) => r.encode(nested.buf),
770            Self::PresharedKeyModes(r) => r.encode(nested.buf),
771            Self::PresharedKey(r) => r.encode(nested.buf),
772            Self::Cookie(r) => r.encode(nested.buf),
773            Self::CertificateStatusRequest(r) => r.encode(nested.buf),
774            Self::ClientCertTypes(r) => r.encode(nested.buf),
775            Self::ServerCertTypes(r) => r.encode(nested.buf),
776            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
777                nested.buf.extend_from_slice(r);
778            }
779            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
780            Self::EncryptedClientHello(r) => r.encode(nested.buf),
781            Self::EncryptedClientHelloOuterExtensions(r) => r.encode(nested.buf),
782            Self::AuthorityNames(r) => r.encode(nested.buf),
783            Self::Unknown(r) => r.encode(nested.buf),
784        }
785    }
786
787    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
788        let typ = ExtensionType::read(r)?;
789        let len = u16::read(r)? as usize;
790        let mut sub = r.sub(len)?;
791
792        let ext = match typ {
793            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
794            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
795            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
796            ExtensionType::ServerName => {
797                Self::ServerName(ServerNamePayload::read(&mut sub)?.into_owned())
798            }
799            ExtensionType::SessionTicket => {
800                if sub.any_left() {
801                    let contents = Payload::read(&mut sub).into_owned();
802                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
803                } else {
804                    Self::SessionTicket(ClientSessionTicket::Request)
805                }
806            }
807            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
808            ExtensionType::SupportedVersions => {
809                Self::SupportedVersions(SupportedProtocolVersions::read(&mut sub)?)
810            }
811            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
812            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
813            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
814            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
815            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
816                Self::ExtendedMasterSecretRequest
817            }
818            ExtensionType::ClientCertificateType => Self::ClientCertTypes(Vec::read(&mut sub)?),
819            ExtensionType::ServerCertificateType => Self::ServerCertTypes(Vec::read(&mut sub)?),
820            ExtensionType::StatusRequest => {
821                let csr = CertificateStatusRequest::read(&mut sub)?;
822                Self::CertificateStatusRequest(csr)
823            }
824            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
825            ExtensionType::TransportParametersDraft => {
826                Self::TransportParametersDraft(sub.rest().to_vec())
827            }
828            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
829            ExtensionType::CompressCertificate => {
830                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
831            }
832            ExtensionType::EncryptedClientHelloOuterExtensions => {
833                Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?)
834            }
835            ExtensionType::CertificateAuthorities => Self::AuthorityNames({
836                let items = Vec::read(&mut sub)?;
837                if items.is_empty() {
838                    return Err(InvalidMessage::IllegalEmptyList("DistinguishedNames"));
839                }
840                items
841            }),
842            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
843        };
844
845        sub.expect_empty("ClientExtension")
846            .map(|_| ext)
847    }
848}
849
850fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
851    let dns_name_str = dns_name.as_ref();
852
853    // RFC6066: "The hostname is represented as a byte string using
854    // ASCII encoding without a trailing dot"
855    if dns_name_str.ends_with('.') {
856        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
857        DnsName::try_from(trimmed)
858            .unwrap()
859            .to_owned()
860    } else {
861        dns_name.to_owned()
862    }
863}
864
865#[derive(Clone, Debug)]
866pub enum ClientSessionTicket {
867    Request,
868    Offer(Payload<'static>),
869}
870
871#[derive(Clone, Debug)]
872pub enum ServerExtension {
873    EcPointFormats(Vec<ECPointFormat>),
874    ServerNameAck,
875    SessionTicketAck,
876    RenegotiationInfo(PayloadU8),
877    Protocols(SingleProtocolName),
878    KeyShare(KeyShareEntry),
879    PresharedKey(u16),
880    ExtendedMasterSecretAck,
881    CertificateStatusAck,
882    ServerCertType(CertificateType),
883    ClientCertType(CertificateType),
884    SupportedVersions(ProtocolVersion),
885    TransportParameters(Vec<u8>),
886    TransportParametersDraft(Vec<u8>),
887    EarlyData,
888    EncryptedClientHello(ServerEncryptedClientHello),
889    Unknown(UnknownExtension),
890}
891
892impl ServerExtension {
893    pub(crate) fn ext_type(&self) -> ExtensionType {
894        match self {
895            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
896            Self::ServerNameAck => ExtensionType::ServerName,
897            Self::SessionTicketAck => ExtensionType::SessionTicket,
898            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
899            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
900            Self::KeyShare(_) => ExtensionType::KeyShare,
901            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
902            Self::ClientCertType(_) => ExtensionType::ClientCertificateType,
903            Self::ServerCertType(_) => ExtensionType::ServerCertificateType,
904            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
905            Self::CertificateStatusAck => ExtensionType::StatusRequest,
906            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
907            Self::TransportParameters(_) => ExtensionType::TransportParameters,
908            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
909            Self::EarlyData => ExtensionType::EarlyData,
910            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
911            Self::Unknown(r) => r.typ,
912        }
913    }
914}
915
916impl Codec<'_> for ServerExtension {
917    fn encode(&self, bytes: &mut Vec<u8>) {
918        self.ext_type().encode(bytes);
919
920        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
921        match self {
922            Self::EcPointFormats(r) => r.encode(nested.buf),
923            Self::ServerNameAck
924            | Self::SessionTicketAck
925            | Self::ExtendedMasterSecretAck
926            | Self::CertificateStatusAck
927            | Self::EarlyData => {}
928            Self::RenegotiationInfo(r) => r.encode(nested.buf),
929            Self::Protocols(r) => r.encode(nested.buf),
930            Self::KeyShare(r) => r.encode(nested.buf),
931            Self::PresharedKey(r) => r.encode(nested.buf),
932            Self::ClientCertType(r) => r.encode(nested.buf),
933            Self::ServerCertType(r) => r.encode(nested.buf),
934            Self::SupportedVersions(r) => r.encode(nested.buf),
935            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
936                nested.buf.extend_from_slice(r);
937            }
938            Self::EncryptedClientHello(r) => r.encode(nested.buf),
939            Self::Unknown(r) => r.encode(nested.buf),
940        }
941    }
942
943    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
944        let typ = ExtensionType::read(r)?;
945        let len = u16::read(r)? as usize;
946        let mut sub = r.sub(len)?;
947
948        let ext = match typ {
949            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
950            ExtensionType::ServerName => Self::ServerNameAck,
951            ExtensionType::SessionTicket => Self::SessionTicketAck,
952            ExtensionType::StatusRequest => Self::CertificateStatusAck,
953            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
954            ExtensionType::ALProtocolNegotiation => {
955                Self::Protocols(SingleProtocolName::read(&mut sub)?)
956            }
957            ExtensionType::ClientCertificateType => {
958                Self::ClientCertType(CertificateType::read(&mut sub)?)
959            }
960            ExtensionType::ServerCertificateType => {
961                Self::ServerCertType(CertificateType::read(&mut sub)?)
962            }
963            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
964            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
965            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
966            ExtensionType::SupportedVersions => {
967                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
968            }
969            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
970            ExtensionType::TransportParametersDraft => {
971                Self::TransportParametersDraft(sub.rest().to_vec())
972            }
973            ExtensionType::EarlyData => Self::EarlyData,
974            ExtensionType::EncryptedClientHello => {
975                Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?)
976            }
977            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
978        };
979
980        sub.expect_empty("ServerExtension")
981            .map(|_| ext)
982    }
983}
984
985impl ServerExtension {
986    #[cfg(feature = "tls12")]
987    pub(crate) fn make_empty_renegotiation_info() -> Self {
988        let empty = Vec::new();
989        Self::RenegotiationInfo(PayloadU8::new(empty))
990    }
991}
992
993#[derive(Clone, Debug)]
994pub struct ClientHelloPayload {
995    pub client_version: ProtocolVersion,
996    pub random: Random,
997    pub session_id: SessionId,
998    pub cipher_suites: Vec<CipherSuite>,
999    pub compression_methods: Vec<Compression>,
1000    pub extensions: Vec<ClientExtension>,
1001}
1002
1003impl Codec<'_> for ClientHelloPayload {
1004    fn encode(&self, bytes: &mut Vec<u8>) {
1005        self.payload_encode(bytes, Encoding::Standard)
1006    }
1007
1008    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1009        let mut ret = Self {
1010            client_version: ProtocolVersion::read(r)?,
1011            random: Random::read(r)?,
1012            session_id: SessionId::read(r)?,
1013            cipher_suites: Vec::read(r)?,
1014            compression_methods: Vec::read(r)?,
1015            extensions: Vec::new(),
1016        };
1017
1018        if r.any_left() {
1019            ret.extensions = Vec::read(r)?;
1020        }
1021
1022        match (r.any_left(), ret.extensions.is_empty()) {
1023            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
1024            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
1025            _ => Ok(ret),
1026        }
1027    }
1028}
1029
1030/// RFC8446: `CipherSuite cipher_suites<2..2^16-2>;`
1031impl TlsListElement for CipherSuite {
1032    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
1033        empty_error: InvalidMessage::IllegalEmptyList("CipherSuites"),
1034    };
1035}
1036
1037/// RFC5246: `CompressionMethod compression_methods<1..2^8-1>;`
1038impl TlsListElement for Compression {
1039    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1040        empty_error: InvalidMessage::IllegalEmptyList("Compressions"),
1041    };
1042}
1043
1044impl TlsListElement for ClientExtension {
1045    const SIZE_LEN: ListLength = ListLength::U16;
1046}
1047
1048/// draft-ietf-tls-esni-17: `ExtensionType OuterExtensions<2..254>;`
1049impl TlsListElement for ExtensionType {
1050    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1051        empty_error: InvalidMessage::IllegalEmptyList("ExtensionTypes"),
1052    };
1053}
1054
1055impl ClientHelloPayload {
1056    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
1057        let mut bytes = Vec::new();
1058        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
1059        bytes
1060    }
1061
1062    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1063        self.client_version.encode(bytes);
1064        self.random.encode(bytes);
1065
1066        match purpose {
1067            // SessionID is required to be empty in the encoded inner client hello.
1068            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
1069            _ => self.session_id.encode(bytes),
1070        }
1071
1072        self.cipher_suites.encode(bytes);
1073        self.compression_methods.encode(bytes);
1074
1075        let to_compress = match purpose {
1076            // Compressed extensions must be replaced in the encoded inner client hello.
1077            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
1078            _ => {
1079                if !self.extensions.is_empty() {
1080                    self.extensions.encode(bytes);
1081                }
1082                return;
1083            }
1084        };
1085
1086        // Safety: not empty check in match guard.
1087        let first_compressed_type = *to_compress.first().unwrap();
1088
1089        // Compressed extensions are in a contiguous range and must be replaced
1090        // with a marker extension.
1091        let compressed_start_idx = self
1092            .extensions
1093            .iter()
1094            .position(|ext| ext.ext_type() == first_compressed_type);
1095        let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len());
1096        let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress);
1097
1098        let exts = self
1099            .extensions
1100            .iter()
1101            .enumerate()
1102            .filter_map(|(i, ext)| {
1103                if Some(i) == compressed_start_idx {
1104                    Some(&marker_ext)
1105                } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx {
1106                    None
1107                } else {
1108                    Some(ext)
1109                }
1110            });
1111
1112        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1113        for ext in exts {
1114            ext.encode(nested.buf);
1115        }
1116    }
1117
1118    /// Returns true if there is more than one extension of a given
1119    /// type.
1120    pub(crate) fn has_duplicate_extension(&self) -> bool {
1121        has_duplicates::<_, _, u16>(
1122            self.extensions
1123                .iter()
1124                .map(|ext| ext.ext_type()),
1125        )
1126    }
1127
1128    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
1129        self.extensions
1130            .iter()
1131            .find(|x| x.ext_type() == ext)
1132    }
1133
1134    pub(crate) fn sni_extension(&self) -> Option<&ServerNamePayload<'_>> {
1135        let ext = self.find_extension(ExtensionType::ServerName)?;
1136        match ext {
1137            ClientExtension::ServerName(req) => Some(req),
1138            _ => None,
1139        }
1140    }
1141
1142    pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1143        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1144        match ext {
1145            ClientExtension::SignatureAlgorithms(req) => Some(req),
1146            _ => None,
1147        }
1148    }
1149
1150    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
1151        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
1152        match ext {
1153            ClientExtension::NamedGroups(req) => Some(req),
1154            _ => None,
1155        }
1156    }
1157
1158    #[cfg(feature = "tls12")]
1159    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1160        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1161        match ext {
1162            ClientExtension::EcPointFormats(req) => Some(req),
1163            _ => None,
1164        }
1165    }
1166
1167    pub(crate) fn server_certificate_extension(&self) -> Option<&[CertificateType]> {
1168        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
1169        match ext {
1170            ClientExtension::ServerCertTypes(req) => Some(req),
1171            _ => None,
1172        }
1173    }
1174
1175    pub(crate) fn client_certificate_extension(&self) -> Option<&[CertificateType]> {
1176        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
1177        match ext {
1178            ClientExtension::ClientCertTypes(req) => Some(req),
1179            _ => None,
1180        }
1181    }
1182
1183    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
1184        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1185        match ext {
1186            ClientExtension::Protocols(req) => Some(req),
1187            _ => None,
1188        }
1189    }
1190
1191    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
1192        let ext = self
1193            .find_extension(ExtensionType::TransportParameters)
1194            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1195        match ext {
1196            ClientExtension::TransportParameters(bytes)
1197            | ClientExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
1198            _ => None,
1199        }
1200    }
1201
1202    #[cfg(feature = "tls12")]
1203    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
1204        self.find_extension(ExtensionType::SessionTicket)
1205    }
1206
1207    pub(crate) fn versions_extension(&self) -> Option<SupportedProtocolVersions> {
1208        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1209        match ext {
1210            ClientExtension::SupportedVersions(vers) => Some(*vers),
1211            _ => None,
1212        }
1213    }
1214
1215    pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
1216        let ext = self.find_extension(ExtensionType::KeyShare)?;
1217        match ext {
1218            ClientExtension::KeyShare(shares) => Some(shares),
1219            _ => None,
1220        }
1221    }
1222
1223    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1224        self.keyshare_extension()
1225            .map(|entries| {
1226                has_duplicates::<_, _, u16>(
1227                    entries
1228                        .iter()
1229                        .map(|kse| u16::from(kse.group)),
1230                )
1231            })
1232            .unwrap_or_default()
1233    }
1234
1235    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
1236        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1237        match ext {
1238            ClientExtension::PresharedKey(psk) => Some(psk),
1239            _ => None,
1240        }
1241    }
1242
1243    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
1244        self.extensions
1245            .last()
1246            .is_some_and(|ext| ext.ext_type() == ExtensionType::PreSharedKey)
1247    }
1248
1249    pub(crate) fn psk_modes(&self) -> Option<&[PskKeyExchangeMode]> {
1250        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1251        match ext {
1252            ClientExtension::PresharedKeyModes(psk_modes) => Some(psk_modes),
1253            _ => None,
1254        }
1255    }
1256
1257    pub(crate) fn psk_mode_offered(&self, mode: PskKeyExchangeMode) -> bool {
1258        self.psk_modes()
1259            .map(|modes| modes.contains(&mode))
1260            .unwrap_or(false)
1261    }
1262
1263    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1264        let last_extension = self.extensions.last_mut();
1265        if let Some(ClientExtension::PresharedKey(offer)) = last_extension {
1266            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1267        }
1268    }
1269
1270    #[cfg(feature = "tls12")]
1271    pub(crate) fn ems_support_offered(&self) -> bool {
1272        self.find_extension(ExtensionType::ExtendedMasterSecret)
1273            .is_some()
1274    }
1275
1276    pub(crate) fn early_data_extension_offered(&self) -> bool {
1277        self.find_extension(ExtensionType::EarlyData)
1278            .is_some()
1279    }
1280
1281    pub(crate) fn certificate_compression_extension(
1282        &self,
1283    ) -> Option<&[CertificateCompressionAlgorithm]> {
1284        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
1285        match ext {
1286            ClientExtension::CertificateCompressionAlgorithms(algs) => Some(algs),
1287            _ => None,
1288        }
1289    }
1290
1291    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1292        if let Some(algs) = self.certificate_compression_extension() {
1293            has_duplicates::<_, _, u16>(algs.iter().cloned())
1294        } else {
1295            false
1296        }
1297    }
1298
1299    pub(crate) fn certificate_authorities_extension(&self) -> Option<&[DistinguishedName]> {
1300        match self.find_extension(ExtensionType::CertificateAuthorities)? {
1301            ClientExtension::AuthorityNames(ext) => Some(ext),
1302            _ => unreachable!("extension type checked"),
1303        }
1304    }
1305}
1306
1307#[derive(Clone, Debug)]
1308pub(crate) enum HelloRetryExtension {
1309    KeyShare(NamedGroup),
1310    Cookie(PayloadU16<NonEmpty>),
1311    SupportedVersions(ProtocolVersion),
1312    EchHelloRetryRequest(Vec<u8>),
1313    Unknown(UnknownExtension),
1314}
1315
1316impl HelloRetryExtension {
1317    pub(crate) fn ext_type(&self) -> ExtensionType {
1318        match self {
1319            Self::KeyShare(_) => ExtensionType::KeyShare,
1320            Self::Cookie(_) => ExtensionType::Cookie,
1321            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1322            Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello,
1323            Self::Unknown(r) => r.typ,
1324        }
1325    }
1326}
1327
1328impl Codec<'_> for HelloRetryExtension {
1329    fn encode(&self, bytes: &mut Vec<u8>) {
1330        self.ext_type().encode(bytes);
1331
1332        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1333        match self {
1334            Self::KeyShare(r) => r.encode(nested.buf),
1335            Self::Cookie(r) => r.encode(nested.buf),
1336            Self::SupportedVersions(r) => r.encode(nested.buf),
1337            Self::EchHelloRetryRequest(r) => {
1338                nested.buf.extend_from_slice(r);
1339            }
1340            Self::Unknown(r) => r.encode(nested.buf),
1341        }
1342    }
1343
1344    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1345        let typ = ExtensionType::read(r)?;
1346        let len = u16::read(r)? as usize;
1347        let mut sub = r.sub(len)?;
1348
1349        let ext = match typ {
1350            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1351            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1352            ExtensionType::SupportedVersions => {
1353                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1354            }
1355            ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()),
1356            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1357        };
1358
1359        sub.expect_empty("HelloRetryExtension")
1360            .map(|_| ext)
1361    }
1362}
1363
1364impl TlsListElement for HelloRetryExtension {
1365    const SIZE_LEN: ListLength = ListLength::U16;
1366}
1367
1368#[derive(Clone, Debug)]
1369pub struct HelloRetryRequest {
1370    pub(crate) legacy_version: ProtocolVersion,
1371    pub session_id: SessionId,
1372    pub(crate) cipher_suite: CipherSuite,
1373    pub(crate) extensions: Vec<HelloRetryExtension>,
1374}
1375
1376impl Codec<'_> for HelloRetryRequest {
1377    fn encode(&self, bytes: &mut Vec<u8>) {
1378        self.payload_encode(bytes, Encoding::Standard)
1379    }
1380
1381    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1382        let session_id = SessionId::read(r)?;
1383        let cipher_suite = CipherSuite::read(r)?;
1384        let compression = Compression::read(r)?;
1385
1386        if compression != Compression::Null {
1387            return Err(InvalidMessage::UnsupportedCompression);
1388        }
1389
1390        Ok(Self {
1391            legacy_version: ProtocolVersion::Unknown(0),
1392            session_id,
1393            cipher_suite,
1394            extensions: Vec::read(r)?,
1395        })
1396    }
1397}
1398
1399impl HelloRetryRequest {
1400    /// Returns true if there is more than one extension of a given
1401    /// type.
1402    pub(crate) fn has_duplicate_extension(&self) -> bool {
1403        has_duplicates::<_, _, u16>(
1404            self.extensions
1405                .iter()
1406                .map(|ext| ext.ext_type()),
1407        )
1408    }
1409
1410    pub(crate) fn has_unknown_extension(&self) -> bool {
1411        self.extensions.iter().any(|ext| {
1412            ext.ext_type() != ExtensionType::KeyShare
1413                && ext.ext_type() != ExtensionType::SupportedVersions
1414                && ext.ext_type() != ExtensionType::Cookie
1415                && ext.ext_type() != ExtensionType::EncryptedClientHello
1416        })
1417    }
1418
1419    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1420        self.extensions
1421            .iter()
1422            .find(|x| x.ext_type() == ext)
1423    }
1424
1425    pub fn requested_key_share_group(&self) -> Option<NamedGroup> {
1426        let ext = self.find_extension(ExtensionType::KeyShare)?;
1427        match ext {
1428            HelloRetryExtension::KeyShare(grp) => Some(*grp),
1429            _ => None,
1430        }
1431    }
1432
1433    pub(crate) fn cookie(&self) -> Option<&PayloadU16<NonEmpty>> {
1434        let ext = self.find_extension(ExtensionType::Cookie)?;
1435        match ext {
1436            HelloRetryExtension::Cookie(ck) => Some(ck),
1437            _ => None,
1438        }
1439    }
1440
1441    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1442        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1443        match ext {
1444            HelloRetryExtension::SupportedVersions(ver) => Some(*ver),
1445            _ => None,
1446        }
1447    }
1448
1449    pub(crate) fn ech(&self) -> Option<&Vec<u8>> {
1450        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
1451        match ext {
1452            HelloRetryExtension::EchHelloRetryRequest(ech) => Some(ech),
1453            _ => None,
1454        }
1455    }
1456
1457    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1458        self.legacy_version.encode(bytes);
1459        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1460        self.session_id.encode(bytes);
1461        self.cipher_suite.encode(bytes);
1462        Compression::Null.encode(bytes);
1463
1464        match purpose {
1465            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1466            // must have its payload replaced by 8 zero bytes.
1467            //
1468            // See draft-ietf-tls-esni-18 7.2.1:
1469            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1470            Encoding::EchConfirmation => {
1471                let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1472                for ext in &self.extensions {
1473                    match ext.ext_type() {
1474                        ExtensionType::EncryptedClientHello => {
1475                            HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8])
1476                                .encode(extensions.buf);
1477                        }
1478                        _ => {
1479                            ext.encode(extensions.buf);
1480                        }
1481                    }
1482                }
1483            }
1484            _ => {
1485                self.extensions.encode(bytes);
1486            }
1487        }
1488    }
1489}
1490
1491#[derive(Clone, Debug)]
1492pub struct ServerHelloPayload {
1493    pub extensions: Vec<ServerExtension>,
1494    pub(crate) legacy_version: ProtocolVersion,
1495    pub(crate) random: Random,
1496    pub(crate) session_id: SessionId,
1497    pub(crate) cipher_suite: CipherSuite,
1498    pub(crate) compression_method: Compression,
1499}
1500
1501impl Codec<'_> for ServerHelloPayload {
1502    fn encode(&self, bytes: &mut Vec<u8>) {
1503        self.payload_encode(bytes, Encoding::Standard)
1504    }
1505
1506    // minus version and random, which have already been read.
1507    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1508        let session_id = SessionId::read(r)?;
1509        let suite = CipherSuite::read(r)?;
1510        let compression = Compression::read(r)?;
1511
1512        // RFC5246:
1513        // "The presence of extensions can be detected by determining whether
1514        //  there are bytes following the compression_method field at the end of
1515        //  the ServerHello."
1516        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1517
1518        let ret = Self {
1519            legacy_version: ProtocolVersion::Unknown(0),
1520            random: ZERO_RANDOM,
1521            session_id,
1522            cipher_suite: suite,
1523            compression_method: compression,
1524            extensions,
1525        };
1526
1527        r.expect_empty("ServerHelloPayload")
1528            .map(|_| ret)
1529    }
1530}
1531
1532impl HasServerExtensions for ServerHelloPayload {
1533    fn extensions(&self) -> &[ServerExtension] {
1534        &self.extensions
1535    }
1536}
1537
1538impl ServerHelloPayload {
1539    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1540        let ext = self.find_extension(ExtensionType::KeyShare)?;
1541        match ext {
1542            ServerExtension::KeyShare(share) => Some(share),
1543            _ => None,
1544        }
1545    }
1546
1547    pub(crate) fn psk_index(&self) -> Option<u16> {
1548        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1549        match ext {
1550            ServerExtension::PresharedKey(index) => Some(*index),
1551            _ => None,
1552        }
1553    }
1554
1555    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1556        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1557        match ext {
1558            ServerExtension::EcPointFormats(fmts) => Some(fmts),
1559            _ => None,
1560        }
1561    }
1562
1563    #[cfg(feature = "tls12")]
1564    pub(crate) fn ems_support_acked(&self) -> bool {
1565        self.find_extension(ExtensionType::ExtendedMasterSecret)
1566            .is_some()
1567    }
1568
1569    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1570        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1571        match ext {
1572            ServerExtension::SupportedVersions(vers) => Some(*vers),
1573            _ => None,
1574        }
1575    }
1576
1577    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1578        self.legacy_version.encode(bytes);
1579
1580        match encoding {
1581            // When encoding a ServerHello for ECH confirmation, the random value
1582            // has the last 8 bytes zeroed out.
1583            Encoding::EchConfirmation => {
1584                // Indexing safety: self.random is 32 bytes long by definition.
1585                let rand_vec = self.random.get_encoding();
1586                bytes.extend_from_slice(&rand_vec.as_slice()[..24]);
1587                bytes.extend_from_slice(&[0u8; 8]);
1588            }
1589            _ => self.random.encode(bytes),
1590        }
1591
1592        self.session_id.encode(bytes);
1593        self.cipher_suite.encode(bytes);
1594        self.compression_method.encode(bytes);
1595
1596        if !self.extensions.is_empty() {
1597            self.extensions.encode(bytes);
1598        }
1599    }
1600}
1601
1602#[derive(Clone, Default, Debug)]
1603pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>);
1604
1605impl CertificateChain<'_> {
1606    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1607        CertificateChain(
1608            self.0
1609                .into_iter()
1610                .map(|c| c.into_owned())
1611                .collect(),
1612        )
1613    }
1614}
1615
1616impl<'a> Codec<'a> for CertificateChain<'a> {
1617    fn encode(&self, bytes: &mut Vec<u8>) {
1618        Vec::encode(&self.0, bytes)
1619    }
1620
1621    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1622        Vec::read(r).map(Self)
1623    }
1624}
1625
1626impl<'a> Deref for CertificateChain<'a> {
1627    type Target = [CertificateDer<'a>];
1628
1629    fn deref(&self) -> &[CertificateDer<'a>] {
1630        &self.0
1631    }
1632}
1633
1634impl TlsListElement for CertificateDer<'_> {
1635    const SIZE_LEN: ListLength = ListLength::U24 {
1636        max: CERTIFICATE_MAX_SIZE_LIMIT,
1637        error: InvalidMessage::CertificatePayloadTooLarge,
1638    };
1639}
1640
1641/// TLS has a 16MB size limit on any handshake message,
1642/// plus a 16MB limit on any given certificate.
1643///
1644/// We contract that to 64KB to limit the amount of memory allocation
1645/// that is directly controllable by the peer.
1646pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1647
1648#[derive(Debug)]
1649pub(crate) enum CertificateExtension<'a> {
1650    CertificateStatus(CertificateStatus<'a>),
1651    Unknown(UnknownExtension),
1652}
1653
1654impl CertificateExtension<'_> {
1655    pub(crate) fn ext_type(&self) -> ExtensionType {
1656        match self {
1657            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1658            Self::Unknown(r) => r.typ,
1659        }
1660    }
1661
1662    pub(crate) fn cert_status(&self) -> Option<&[u8]> {
1663        match self {
1664            Self::CertificateStatus(cs) => Some(cs.ocsp_response.0.bytes()),
1665            _ => None,
1666        }
1667    }
1668
1669    pub(crate) fn into_owned(self) -> CertificateExtension<'static> {
1670        match self {
1671            Self::CertificateStatus(st) => CertificateExtension::CertificateStatus(st.into_owned()),
1672            Self::Unknown(unk) => CertificateExtension::Unknown(unk),
1673        }
1674    }
1675}
1676
1677impl<'a> Codec<'a> for CertificateExtension<'a> {
1678    fn encode(&self, bytes: &mut Vec<u8>) {
1679        self.ext_type().encode(bytes);
1680
1681        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1682        match self {
1683            Self::CertificateStatus(r) => r.encode(nested.buf),
1684            Self::Unknown(r) => r.encode(nested.buf),
1685        }
1686    }
1687
1688    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1689        let typ = ExtensionType::read(r)?;
1690        let len = u16::read(r)? as usize;
1691        let mut sub = r.sub(len)?;
1692
1693        let ext = match typ {
1694            ExtensionType::StatusRequest => {
1695                let st = CertificateStatus::read(&mut sub)?;
1696                Self::CertificateStatus(st)
1697            }
1698            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1699        };
1700
1701        sub.expect_empty("CertificateExtension")
1702            .map(|_| ext)
1703    }
1704}
1705
1706impl TlsListElement for CertificateExtension<'_> {
1707    const SIZE_LEN: ListLength = ListLength::U16;
1708}
1709
1710#[derive(Debug)]
1711pub(crate) struct CertificateEntry<'a> {
1712    pub(crate) cert: CertificateDer<'a>,
1713    pub(crate) exts: Vec<CertificateExtension<'a>>,
1714}
1715
1716impl<'a> Codec<'a> for CertificateEntry<'a> {
1717    fn encode(&self, bytes: &mut Vec<u8>) {
1718        self.cert.encode(bytes);
1719        self.exts.encode(bytes);
1720    }
1721
1722    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1723        Ok(Self {
1724            cert: CertificateDer::read(r)?,
1725            exts: Vec::read(r)?,
1726        })
1727    }
1728}
1729
1730impl<'a> CertificateEntry<'a> {
1731    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1732        Self {
1733            cert,
1734            exts: Vec::new(),
1735        }
1736    }
1737
1738    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1739        CertificateEntry {
1740            cert: self.cert.into_owned(),
1741            exts: self
1742                .exts
1743                .into_iter()
1744                .map(CertificateExtension::into_owned)
1745                .collect(),
1746        }
1747    }
1748
1749    pub(crate) fn has_duplicate_extension(&self) -> bool {
1750        has_duplicates::<_, _, u16>(
1751            self.exts
1752                .iter()
1753                .map(|ext| ext.ext_type()),
1754        )
1755    }
1756
1757    pub(crate) fn has_unknown_extension(&self) -> bool {
1758        self.exts
1759            .iter()
1760            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1761    }
1762
1763    pub(crate) fn ocsp_response(&self) -> Option<&[u8]> {
1764        self.exts
1765            .iter()
1766            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1767            .and_then(CertificateExtension::cert_status)
1768    }
1769}
1770
1771impl TlsListElement for CertificateEntry<'_> {
1772    const SIZE_LEN: ListLength = ListLength::U24 {
1773        max: CERTIFICATE_MAX_SIZE_LIMIT,
1774        error: InvalidMessage::CertificatePayloadTooLarge,
1775    };
1776}
1777
1778#[derive(Debug)]
1779pub struct CertificatePayloadTls13<'a> {
1780    pub(crate) context: PayloadU8,
1781    pub(crate) entries: Vec<CertificateEntry<'a>>,
1782}
1783
1784impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1785    fn encode(&self, bytes: &mut Vec<u8>) {
1786        self.context.encode(bytes);
1787        self.entries.encode(bytes);
1788    }
1789
1790    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1791        Ok(Self {
1792            context: PayloadU8::read(r)?,
1793            entries: Vec::read(r)?,
1794        })
1795    }
1796}
1797
1798impl<'a> CertificatePayloadTls13<'a> {
1799    pub(crate) fn new(
1800        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1801        ocsp_response: Option<&'a [u8]>,
1802    ) -> Self {
1803        Self {
1804            context: PayloadU8::empty(),
1805            entries: certs
1806                // zip certificate iterator with `ocsp_response` followed by
1807                // an infinite-length iterator of `None`.
1808                .zip(
1809                    ocsp_response
1810                        .into_iter()
1811                        .map(Some)
1812                        .chain(iter::repeat(None)),
1813                )
1814                .map(|(cert, ocsp)| {
1815                    let mut e = CertificateEntry::new(cert.clone());
1816                    if let Some(ocsp) = ocsp {
1817                        e.exts
1818                            .push(CertificateExtension::CertificateStatus(
1819                                CertificateStatus::new(ocsp),
1820                            ));
1821                    }
1822                    e
1823                })
1824                .collect(),
1825        }
1826    }
1827
1828    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1829        CertificatePayloadTls13 {
1830            context: self.context,
1831            entries: self
1832                .entries
1833                .into_iter()
1834                .map(CertificateEntry::into_owned)
1835                .collect(),
1836        }
1837    }
1838
1839    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1840        for entry in &self.entries {
1841            if entry.has_duplicate_extension() {
1842                return true;
1843            }
1844        }
1845
1846        false
1847    }
1848
1849    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1850        for entry in &self.entries {
1851            if entry.has_unknown_extension() {
1852                return true;
1853            }
1854        }
1855
1856        false
1857    }
1858
1859    pub(crate) fn any_entry_has_extension(&self) -> bool {
1860        for entry in &self.entries {
1861            if !entry.exts.is_empty() {
1862                return true;
1863            }
1864        }
1865
1866        false
1867    }
1868
1869    pub(crate) fn end_entity_ocsp(&self) -> &[u8] {
1870        self.entries
1871            .first()
1872            .and_then(CertificateEntry::ocsp_response)
1873            .unwrap_or_default()
1874    }
1875
1876    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1877        CertificateChain(
1878            self.entries
1879                .into_iter()
1880                .map(|e| e.cert)
1881                .collect(),
1882        )
1883    }
1884}
1885
1886/// Describes supported key exchange mechanisms.
1887#[derive(Clone, Copy, Debug, PartialEq)]
1888#[non_exhaustive]
1889pub enum KeyExchangeAlgorithm {
1890    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1891    ///
1892    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1893    DHE,
1894    /// Key exchange performed via elliptic curve Diffie-Hellman.
1895    ECDHE,
1896}
1897
1898pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1899    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1900
1901// We don't support arbitrary curves.  It's a terrible
1902// idea and unnecessary attack surface.  Please,
1903// get a grip.
1904#[derive(Debug)]
1905pub(crate) struct EcParameters {
1906    pub(crate) curve_type: ECCurveType,
1907    pub(crate) named_group: NamedGroup,
1908}
1909
1910impl Codec<'_> for EcParameters {
1911    fn encode(&self, bytes: &mut Vec<u8>) {
1912        self.curve_type.encode(bytes);
1913        self.named_group.encode(bytes);
1914    }
1915
1916    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1917        let ct = ECCurveType::read(r)?;
1918        if ct != ECCurveType::NamedCurve {
1919            return Err(InvalidMessage::UnsupportedCurveType);
1920        }
1921
1922        let grp = NamedGroup::read(r)?;
1923
1924        Ok(Self {
1925            curve_type: ct,
1926            named_group: grp,
1927        })
1928    }
1929}
1930
1931#[cfg(feature = "tls12")]
1932pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1933    /// Decode a key exchange message given the key_exchange `algo`
1934    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1935}
1936
1937#[cfg(feature = "tls12")]
1938#[derive(Debug)]
1939pub(crate) enum ClientKeyExchangeParams {
1940    Ecdh(ClientEcdhParams),
1941    Dh(ClientDhParams),
1942}
1943
1944#[cfg(feature = "tls12")]
1945impl ClientKeyExchangeParams {
1946    pub(crate) fn pub_key(&self) -> &[u8] {
1947        match self {
1948            Self::Ecdh(ecdh) => &ecdh.public.0,
1949            Self::Dh(dh) => &dh.public.0,
1950        }
1951    }
1952
1953    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1954        match self {
1955            Self::Ecdh(ecdh) => ecdh.encode(buf),
1956            Self::Dh(dh) => dh.encode(buf),
1957        }
1958    }
1959}
1960
1961#[cfg(feature = "tls12")]
1962impl KxDecode<'_> for ClientKeyExchangeParams {
1963    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1964        use KeyExchangeAlgorithm::*;
1965        Ok(match algo {
1966            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1967            DHE => Self::Dh(ClientDhParams::read(r)?),
1968        })
1969    }
1970}
1971
1972#[cfg(feature = "tls12")]
1973#[derive(Debug)]
1974pub(crate) struct ClientEcdhParams {
1975    /// RFC4492: `opaque point <1..2^8-1>;`
1976    pub(crate) public: PayloadU8<NonEmpty>,
1977}
1978
1979#[cfg(feature = "tls12")]
1980impl Codec<'_> for ClientEcdhParams {
1981    fn encode(&self, bytes: &mut Vec<u8>) {
1982        self.public.encode(bytes);
1983    }
1984
1985    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1986        let pb = PayloadU8::read(r)?;
1987        Ok(Self { public: pb })
1988    }
1989}
1990
1991#[cfg(feature = "tls12")]
1992#[derive(Debug)]
1993pub(crate) struct ClientDhParams {
1994    /// RFC5246: `opaque dh_Yc<1..2^16-1>;`
1995    pub(crate) public: PayloadU16<NonEmpty>,
1996}
1997
1998#[cfg(feature = "tls12")]
1999impl Codec<'_> for ClientDhParams {
2000    fn encode(&self, bytes: &mut Vec<u8>) {
2001        self.public.encode(bytes);
2002    }
2003
2004    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2005        Ok(Self {
2006            public: PayloadU16::read(r)?,
2007        })
2008    }
2009}
2010
2011#[derive(Debug)]
2012pub(crate) struct ServerEcdhParams {
2013    pub(crate) curve_params: EcParameters,
2014    /// RFC4492: `opaque point <1..2^8-1>;`
2015    pub(crate) public: PayloadU8<NonEmpty>,
2016}
2017
2018impl ServerEcdhParams {
2019    #[cfg(feature = "tls12")]
2020    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2021        Self {
2022            curve_params: EcParameters {
2023                curve_type: ECCurveType::NamedCurve,
2024                named_group: kx.group(),
2025            },
2026            public: PayloadU8::new(kx.pub_key().to_vec()),
2027        }
2028    }
2029}
2030
2031impl Codec<'_> for ServerEcdhParams {
2032    fn encode(&self, bytes: &mut Vec<u8>) {
2033        self.curve_params.encode(bytes);
2034        self.public.encode(bytes);
2035    }
2036
2037    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2038        let cp = EcParameters::read(r)?;
2039        let pb = PayloadU8::read(r)?;
2040
2041        Ok(Self {
2042            curve_params: cp,
2043            public: pb,
2044        })
2045    }
2046}
2047
2048#[derive(Debug)]
2049#[allow(non_snake_case)]
2050pub(crate) struct ServerDhParams {
2051    /// RFC5246: `opaque dh_p<1..2^16-1>;`
2052    pub(crate) dh_p: PayloadU16<NonEmpty>,
2053    /// RFC5246: `opaque dh_g<1..2^16-1>;`
2054    pub(crate) dh_g: PayloadU16<NonEmpty>,
2055    /// RFC5246: `opaque dh_Ys<1..2^16-1>;`
2056    pub(crate) dh_Ys: PayloadU16<NonEmpty>,
2057}
2058
2059impl ServerDhParams {
2060    #[cfg(feature = "tls12")]
2061    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2062        let Some(params) = kx.ffdhe_group() else {
2063            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
2064        };
2065
2066        Self {
2067            dh_p: PayloadU16::new(params.p.to_vec()),
2068            dh_g: PayloadU16::new(params.g.to_vec()),
2069            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
2070        }
2071    }
2072
2073    #[cfg(feature = "tls12")]
2074    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
2075        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
2076    }
2077}
2078
2079impl Codec<'_> for ServerDhParams {
2080    fn encode(&self, bytes: &mut Vec<u8>) {
2081        self.dh_p.encode(bytes);
2082        self.dh_g.encode(bytes);
2083        self.dh_Ys.encode(bytes);
2084    }
2085
2086    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2087        Ok(Self {
2088            dh_p: PayloadU16::read(r)?,
2089            dh_g: PayloadU16::read(r)?,
2090            dh_Ys: PayloadU16::read(r)?,
2091        })
2092    }
2093}
2094
2095#[allow(dead_code)]
2096#[derive(Debug)]
2097pub(crate) enum ServerKeyExchangeParams {
2098    Ecdh(ServerEcdhParams),
2099    Dh(ServerDhParams),
2100}
2101
2102impl ServerKeyExchangeParams {
2103    #[cfg(feature = "tls12")]
2104    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2105        match kx.group().key_exchange_algorithm() {
2106            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
2107            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
2108        }
2109    }
2110
2111    #[cfg(feature = "tls12")]
2112    pub(crate) fn pub_key(&self) -> &[u8] {
2113        match self {
2114            Self::Ecdh(ecdh) => &ecdh.public.0,
2115            Self::Dh(dh) => &dh.dh_Ys.0,
2116        }
2117    }
2118
2119    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2120        match self {
2121            Self::Ecdh(ecdh) => ecdh.encode(buf),
2122            Self::Dh(dh) => dh.encode(buf),
2123        }
2124    }
2125}
2126
2127#[cfg(feature = "tls12")]
2128impl KxDecode<'_> for ServerKeyExchangeParams {
2129    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
2130        use KeyExchangeAlgorithm::*;
2131        Ok(match algo {
2132            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
2133            DHE => Self::Dh(ServerDhParams::read(r)?),
2134        })
2135    }
2136}
2137
2138#[derive(Debug)]
2139pub struct ServerKeyExchange {
2140    pub(crate) params: ServerKeyExchangeParams,
2141    pub(crate) dss: DigitallySignedStruct,
2142}
2143
2144impl ServerKeyExchange {
2145    pub fn encode(&self, buf: &mut Vec<u8>) {
2146        self.params.encode(buf);
2147        self.dss.encode(buf);
2148    }
2149}
2150
2151#[derive(Debug)]
2152pub enum ServerKeyExchangePayload {
2153    Known(ServerKeyExchange),
2154    Unknown(Payload<'static>),
2155}
2156
2157impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2158    fn from(value: ServerKeyExchange) -> Self {
2159        Self::Known(value)
2160    }
2161}
2162
2163impl Codec<'_> for ServerKeyExchangePayload {
2164    fn encode(&self, bytes: &mut Vec<u8>) {
2165        match self {
2166            Self::Known(x) => x.encode(bytes),
2167            Self::Unknown(x) => x.encode(bytes),
2168        }
2169    }
2170
2171    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2172        // read as Unknown, fully parse when we know the
2173        // KeyExchangeAlgorithm
2174        Ok(Self::Unknown(Payload::read(r).into_owned()))
2175    }
2176}
2177
2178impl ServerKeyExchangePayload {
2179    #[cfg(feature = "tls12")]
2180    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2181        if let Self::Unknown(unk) = self {
2182            let mut rd = Reader::init(unk.bytes());
2183
2184            let result = ServerKeyExchange {
2185                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2186                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2187            };
2188
2189            if !rd.any_left() {
2190                return Some(result);
2191            };
2192        }
2193
2194        None
2195    }
2196}
2197
2198// -- EncryptedExtensions (TLS1.3 only) --
2199
2200impl TlsListElement for ServerExtension {
2201    const SIZE_LEN: ListLength = ListLength::U16;
2202}
2203
2204pub(crate) trait HasServerExtensions {
2205    fn extensions(&self) -> &[ServerExtension];
2206
2207    /// Returns true if there is more than one extension of a given
2208    /// type.
2209    fn has_duplicate_extension(&self) -> bool {
2210        has_duplicates::<_, _, u16>(
2211            self.extensions()
2212                .iter()
2213                .map(|ext| ext.ext_type()),
2214        )
2215    }
2216
2217    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
2218        self.extensions()
2219            .iter()
2220            .find(|x| x.ext_type() == ext)
2221    }
2222
2223    fn alpn_protocol(&self) -> Option<&[u8]> {
2224        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
2225        match ext {
2226            ServerExtension::Protocols(protos) => Some(protos.as_ref()),
2227            _ => None,
2228        }
2229    }
2230
2231    fn server_cert_type(&self) -> Option<&CertificateType> {
2232        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
2233        match ext {
2234            ServerExtension::ServerCertType(req) => Some(req),
2235            _ => None,
2236        }
2237    }
2238
2239    fn client_cert_type(&self) -> Option<&CertificateType> {
2240        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
2241        match ext {
2242            ServerExtension::ClientCertType(req) => Some(req),
2243            _ => None,
2244        }
2245    }
2246
2247    fn quic_params_extension(&self) -> Option<Vec<u8>> {
2248        let ext = self
2249            .find_extension(ExtensionType::TransportParameters)
2250            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
2251        match ext {
2252            ServerExtension::TransportParameters(bytes)
2253            | ServerExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
2254            _ => None,
2255        }
2256    }
2257
2258    fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> {
2259        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
2260        match ext {
2261            ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()),
2262            _ => None,
2263        }
2264    }
2265
2266    fn early_data_extension_offered(&self) -> bool {
2267        self.find_extension(ExtensionType::EarlyData)
2268            .is_some()
2269    }
2270}
2271
2272impl HasServerExtensions for Vec<ServerExtension> {
2273    fn extensions(&self) -> &[ServerExtension] {
2274        self
2275    }
2276}
2277
2278/// RFC5246: `ClientCertificateType certificate_types<1..2^8-1>;`
2279impl TlsListElement for ClientCertificateType {
2280    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
2281        empty_error: InvalidMessage::IllegalEmptyList("ClientCertificateTypes"),
2282    };
2283}
2284
2285wrapped_payload!(
2286    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2287    ///
2288    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2289    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2290    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2291    ///
2292    /// ```ignore
2293    /// for name in distinguished_names {
2294    ///     use x509_parser::prelude::FromDer;
2295    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2296    /// }
2297    /// ```
2298    ///
2299    /// The TLS encoding is defined in RFC5246: `opaque DistinguishedName<1..2^16-1>;`
2300    pub struct DistinguishedName,
2301    PayloadU16<NonEmpty>,
2302);
2303
2304impl DistinguishedName {
2305    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2306    ///
2307    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2308    ///
2309    /// ```ignore
2310    /// use x509_parser::prelude::FromDer;
2311    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2312    /// ```
2313    pub fn in_sequence(bytes: &[u8]) -> Self {
2314        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2315    }
2316}
2317
2318/// RFC8446: `DistinguishedName authorities<3..2^16-1>;` however,
2319/// RFC5246: `DistinguishedName certificate_authorities<0..2^16-1>;`
2320impl TlsListElement for DistinguishedName {
2321    const SIZE_LEN: ListLength = ListLength::U16;
2322}
2323
2324#[derive(Debug)]
2325pub struct CertificateRequestPayload {
2326    pub(crate) certtypes: Vec<ClientCertificateType>,
2327    pub(crate) sigschemes: Vec<SignatureScheme>,
2328    pub(crate) canames: Vec<DistinguishedName>,
2329}
2330
2331impl Codec<'_> for CertificateRequestPayload {
2332    fn encode(&self, bytes: &mut Vec<u8>) {
2333        self.certtypes.encode(bytes);
2334        self.sigschemes.encode(bytes);
2335        self.canames.encode(bytes);
2336    }
2337
2338    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2339        let certtypes = Vec::read(r)?;
2340        let sigschemes = Vec::read(r)?;
2341        let canames = Vec::read(r)?;
2342
2343        if sigschemes.is_empty() {
2344            warn!("meaningless CertificateRequest message");
2345            Err(InvalidMessage::NoSignatureSchemes)
2346        } else {
2347            Ok(Self {
2348                certtypes,
2349                sigschemes,
2350                canames,
2351            })
2352        }
2353    }
2354}
2355
2356#[derive(Debug)]
2357pub(crate) enum CertReqExtension {
2358    SignatureAlgorithms(Vec<SignatureScheme>),
2359    AuthorityNames(Vec<DistinguishedName>),
2360    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
2361    Unknown(UnknownExtension),
2362}
2363
2364impl CertReqExtension {
2365    pub(crate) fn ext_type(&self) -> ExtensionType {
2366        match self {
2367            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
2368            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
2369            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
2370            Self::Unknown(r) => r.typ,
2371        }
2372    }
2373}
2374
2375impl Codec<'_> for CertReqExtension {
2376    fn encode(&self, bytes: &mut Vec<u8>) {
2377        self.ext_type().encode(bytes);
2378
2379        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2380        match self {
2381            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
2382            Self::AuthorityNames(r) => r.encode(nested.buf),
2383            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
2384            Self::Unknown(r) => r.encode(nested.buf),
2385        }
2386    }
2387
2388    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2389        let typ = ExtensionType::read(r)?;
2390        let len = u16::read(r)? as usize;
2391        let mut sub = r.sub(len)?;
2392
2393        let ext = match typ {
2394            ExtensionType::SignatureAlgorithms => {
2395                let schemes = Vec::read(&mut sub)?;
2396                if schemes.is_empty() {
2397                    return Err(InvalidMessage::NoSignatureSchemes);
2398                }
2399                Self::SignatureAlgorithms(schemes)
2400            }
2401            ExtensionType::CertificateAuthorities => {
2402                let cas = Vec::read(&mut sub)?;
2403                if cas.is_empty() {
2404                    return Err(InvalidMessage::IllegalEmptyList("DistinguishedNames"));
2405                }
2406                Self::AuthorityNames(cas)
2407            }
2408            ExtensionType::CompressCertificate => {
2409                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
2410            }
2411            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2412        };
2413
2414        sub.expect_empty("CertReqExtension")
2415            .map(|_| ext)
2416    }
2417}
2418
2419impl TlsListElement for CertReqExtension {
2420    const SIZE_LEN: ListLength = ListLength::U16;
2421}
2422
2423#[derive(Debug)]
2424pub struct CertificateRequestPayloadTls13 {
2425    pub(crate) context: PayloadU8,
2426    pub(crate) extensions: Vec<CertReqExtension>,
2427}
2428
2429impl Codec<'_> for CertificateRequestPayloadTls13 {
2430    fn encode(&self, bytes: &mut Vec<u8>) {
2431        self.context.encode(bytes);
2432        self.extensions.encode(bytes);
2433    }
2434
2435    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2436        let context = PayloadU8::read(r)?;
2437        let extensions = Vec::read(r)?;
2438
2439        Ok(Self {
2440            context,
2441            extensions,
2442        })
2443    }
2444}
2445
2446impl CertificateRequestPayloadTls13 {
2447    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
2448        self.extensions
2449            .iter()
2450            .find(|x| x.ext_type() == ext)
2451    }
2452
2453    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
2454        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
2455        match ext {
2456            CertReqExtension::SignatureAlgorithms(sa) => Some(sa),
2457            _ => None,
2458        }
2459    }
2460
2461    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
2462        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
2463        match ext {
2464            CertReqExtension::AuthorityNames(an) => Some(an),
2465            _ => None,
2466        }
2467    }
2468
2469    pub(crate) fn certificate_compression_extension(
2470        &self,
2471    ) -> Option<&[CertificateCompressionAlgorithm]> {
2472        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
2473        match ext {
2474            CertReqExtension::CertificateCompressionAlgorithms(comps) => Some(comps),
2475            _ => None,
2476        }
2477    }
2478}
2479
2480// -- NewSessionTicket --
2481#[derive(Debug)]
2482pub struct NewSessionTicketPayload {
2483    pub(crate) lifetime_hint: u32,
2484    // Tickets can be large (KB), so we deserialise this straight
2485    // into an Arc, so it can be passed directly into the client's
2486    // session object without copying.
2487    pub(crate) ticket: Arc<PayloadU16>,
2488}
2489
2490impl NewSessionTicketPayload {
2491    #[cfg(feature = "tls12")]
2492    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2493        Self {
2494            lifetime_hint,
2495            ticket: Arc::new(PayloadU16::new(ticket)),
2496        }
2497    }
2498}
2499
2500impl Codec<'_> for NewSessionTicketPayload {
2501    fn encode(&self, bytes: &mut Vec<u8>) {
2502        self.lifetime_hint.encode(bytes);
2503        self.ticket.encode(bytes);
2504    }
2505
2506    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2507        let lifetime = u32::read(r)?;
2508        let ticket = Arc::new(PayloadU16::read(r)?);
2509
2510        Ok(Self {
2511            lifetime_hint: lifetime,
2512            ticket,
2513        })
2514    }
2515}
2516
2517// -- NewSessionTicket electric boogaloo --
2518#[derive(Debug)]
2519pub(crate) enum NewSessionTicketExtension {
2520    EarlyData(u32),
2521    Unknown(UnknownExtension),
2522}
2523
2524impl NewSessionTicketExtension {
2525    pub(crate) fn ext_type(&self) -> ExtensionType {
2526        match self {
2527            Self::EarlyData(_) => ExtensionType::EarlyData,
2528            Self::Unknown(r) => r.typ,
2529        }
2530    }
2531}
2532
2533impl Codec<'_> for NewSessionTicketExtension {
2534    fn encode(&self, bytes: &mut Vec<u8>) {
2535        self.ext_type().encode(bytes);
2536
2537        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2538        match self {
2539            Self::EarlyData(r) => r.encode(nested.buf),
2540            Self::Unknown(r) => r.encode(nested.buf),
2541        }
2542    }
2543
2544    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2545        let typ = ExtensionType::read(r)?;
2546        let len = u16::read(r)? as usize;
2547        let mut sub = r.sub(len)?;
2548
2549        let ext = match typ {
2550            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2551            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2552        };
2553
2554        sub.expect_empty("NewSessionTicketExtension")
2555            .map(|_| ext)
2556    }
2557}
2558
2559impl TlsListElement for NewSessionTicketExtension {
2560    const SIZE_LEN: ListLength = ListLength::U16;
2561}
2562
2563#[derive(Debug)]
2564pub struct NewSessionTicketPayloadTls13 {
2565    pub(crate) lifetime: u32,
2566    pub(crate) age_add: u32,
2567    pub(crate) nonce: PayloadU8,
2568    pub(crate) ticket: Arc<PayloadU16>,
2569    pub(crate) exts: Vec<NewSessionTicketExtension>,
2570}
2571
2572impl NewSessionTicketPayloadTls13 {
2573    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2574        Self {
2575            lifetime,
2576            age_add,
2577            nonce: PayloadU8::new(nonce),
2578            ticket: Arc::new(PayloadU16::new(ticket)),
2579            exts: vec![],
2580        }
2581    }
2582
2583    pub(crate) fn has_duplicate_extension(&self) -> bool {
2584        has_duplicates::<_, _, u16>(
2585            self.exts
2586                .iter()
2587                .map(|ext| ext.ext_type()),
2588        )
2589    }
2590
2591    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2592        self.exts
2593            .iter()
2594            .find(|x| x.ext_type() == ext)
2595    }
2596
2597    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2598        let ext = self.find_extension(ExtensionType::EarlyData)?;
2599        match ext {
2600            NewSessionTicketExtension::EarlyData(sz) => Some(*sz),
2601            _ => None,
2602        }
2603    }
2604}
2605
2606impl Codec<'_> for NewSessionTicketPayloadTls13 {
2607    fn encode(&self, bytes: &mut Vec<u8>) {
2608        self.lifetime.encode(bytes);
2609        self.age_add.encode(bytes);
2610        self.nonce.encode(bytes);
2611        self.ticket.encode(bytes);
2612        self.exts.encode(bytes);
2613    }
2614
2615    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2616        let lifetime = u32::read(r)?;
2617        let age_add = u32::read(r)?;
2618        let nonce = PayloadU8::read(r)?;
2619        // nb. RFC8446: `opaque ticket<1..2^16-1>;`
2620        let ticket = Arc::new(match PayloadU16::<NonEmpty>::read(r) {
2621            Err(InvalidMessage::IllegalEmptyValue) => Err(InvalidMessage::EmptyTicketValue),
2622            Err(err) => Err(err),
2623            Ok(pl) => Ok(PayloadU16::new(pl.0)),
2624        }?);
2625        let exts = Vec::read(r)?;
2626
2627        Ok(Self {
2628            lifetime,
2629            age_add,
2630            nonce,
2631            ticket,
2632            exts,
2633        })
2634    }
2635}
2636
2637// -- RFC6066 certificate status types
2638
2639/// Only supports OCSP
2640#[derive(Debug)]
2641pub struct CertificateStatus<'a> {
2642    pub(crate) ocsp_response: PayloadU24<'a>,
2643}
2644
2645impl<'a> Codec<'a> for CertificateStatus<'a> {
2646    fn encode(&self, bytes: &mut Vec<u8>) {
2647        CertificateStatusType::OCSP.encode(bytes);
2648        self.ocsp_response.encode(bytes);
2649    }
2650
2651    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2652        let typ = CertificateStatusType::read(r)?;
2653
2654        match typ {
2655            CertificateStatusType::OCSP => Ok(Self {
2656                ocsp_response: PayloadU24::read(r)?,
2657            }),
2658            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2659        }
2660    }
2661}
2662
2663impl<'a> CertificateStatus<'a> {
2664    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2665        CertificateStatus {
2666            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2667        }
2668    }
2669
2670    #[cfg(feature = "tls12")]
2671    pub(crate) fn into_inner(self) -> Vec<u8> {
2672        self.ocsp_response.0.into_vec()
2673    }
2674
2675    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2676        CertificateStatus {
2677            ocsp_response: self.ocsp_response.into_owned(),
2678        }
2679    }
2680}
2681
2682// -- RFC8879 compressed certificates
2683
2684#[derive(Debug)]
2685pub struct CompressedCertificatePayload<'a> {
2686    pub(crate) alg: CertificateCompressionAlgorithm,
2687    pub(crate) uncompressed_len: u32,
2688    pub(crate) compressed: PayloadU24<'a>,
2689}
2690
2691impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2692    fn encode(&self, bytes: &mut Vec<u8>) {
2693        self.alg.encode(bytes);
2694        codec::u24(self.uncompressed_len).encode(bytes);
2695        self.compressed.encode(bytes);
2696    }
2697
2698    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2699        Ok(Self {
2700            alg: CertificateCompressionAlgorithm::read(r)?,
2701            uncompressed_len: codec::u24::read(r)?.0,
2702            compressed: PayloadU24::read(r)?,
2703        })
2704    }
2705}
2706
2707impl CompressedCertificatePayload<'_> {
2708    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2709        CompressedCertificatePayload {
2710            compressed: self.compressed.into_owned(),
2711            ..self
2712        }
2713    }
2714
2715    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2716        CompressedCertificatePayload {
2717            alg: self.alg,
2718            uncompressed_len: self.uncompressed_len,
2719            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2720        }
2721    }
2722}
2723
2724#[derive(Debug)]
2725pub enum HandshakePayload<'a> {
2726    HelloRequest,
2727    ClientHello(ClientHelloPayload),
2728    ServerHello(ServerHelloPayload),
2729    HelloRetryRequest(HelloRetryRequest),
2730    Certificate(CertificateChain<'a>),
2731    CertificateTls13(CertificatePayloadTls13<'a>),
2732    CompressedCertificate(CompressedCertificatePayload<'a>),
2733    ServerKeyExchange(ServerKeyExchangePayload),
2734    CertificateRequest(CertificateRequestPayload),
2735    CertificateRequestTls13(CertificateRequestPayloadTls13),
2736    CertificateVerify(DigitallySignedStruct),
2737    ServerHelloDone,
2738    EndOfEarlyData,
2739    ClientKeyExchange(Payload<'a>),
2740    NewSessionTicket(NewSessionTicketPayload),
2741    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2742    EncryptedExtensions(Vec<ServerExtension>),
2743    KeyUpdate(KeyUpdateRequest),
2744    Finished(Payload<'a>),
2745    CertificateStatus(CertificateStatus<'a>),
2746    MessageHash(Payload<'a>),
2747    Unknown(Payload<'a>),
2748}
2749
2750impl HandshakePayload<'_> {
2751    fn encode(&self, bytes: &mut Vec<u8>) {
2752        use self::HandshakePayload::*;
2753        match self {
2754            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2755            ClientHello(x) => x.encode(bytes),
2756            ServerHello(x) => x.encode(bytes),
2757            HelloRetryRequest(x) => x.encode(bytes),
2758            Certificate(x) => x.encode(bytes),
2759            CertificateTls13(x) => x.encode(bytes),
2760            CompressedCertificate(x) => x.encode(bytes),
2761            ServerKeyExchange(x) => x.encode(bytes),
2762            ClientKeyExchange(x) => x.encode(bytes),
2763            CertificateRequest(x) => x.encode(bytes),
2764            CertificateRequestTls13(x) => x.encode(bytes),
2765            CertificateVerify(x) => x.encode(bytes),
2766            NewSessionTicket(x) => x.encode(bytes),
2767            NewSessionTicketTls13(x) => x.encode(bytes),
2768            EncryptedExtensions(x) => x.encode(bytes),
2769            KeyUpdate(x) => x.encode(bytes),
2770            Finished(x) => x.encode(bytes),
2771            CertificateStatus(x) => x.encode(bytes),
2772            MessageHash(x) => x.encode(bytes),
2773            Unknown(x) => x.encode(bytes),
2774        }
2775    }
2776
2777    fn into_owned(self) -> HandshakePayload<'static> {
2778        use HandshakePayload::*;
2779
2780        match self {
2781            HelloRequest => HelloRequest,
2782            ClientHello(x) => ClientHello(x),
2783            ServerHello(x) => ServerHello(x),
2784            HelloRetryRequest(x) => HelloRetryRequest(x),
2785            Certificate(x) => Certificate(x.into_owned()),
2786            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2787            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2788            ServerKeyExchange(x) => ServerKeyExchange(x),
2789            CertificateRequest(x) => CertificateRequest(x),
2790            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2791            CertificateVerify(x) => CertificateVerify(x),
2792            ServerHelloDone => ServerHelloDone,
2793            EndOfEarlyData => EndOfEarlyData,
2794            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2795            NewSessionTicket(x) => NewSessionTicket(x),
2796            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2797            EncryptedExtensions(x) => EncryptedExtensions(x),
2798            KeyUpdate(x) => KeyUpdate(x),
2799            Finished(x) => Finished(x.into_owned()),
2800            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2801            MessageHash(x) => MessageHash(x.into_owned()),
2802            Unknown(x) => Unknown(x.into_owned()),
2803        }
2804    }
2805}
2806
2807#[derive(Debug)]
2808pub struct HandshakeMessagePayload<'a> {
2809    pub typ: HandshakeType,
2810    pub payload: HandshakePayload<'a>,
2811}
2812
2813impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2814    fn encode(&self, bytes: &mut Vec<u8>) {
2815        self.payload_encode(bytes, Encoding::Standard);
2816    }
2817
2818    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2819        Self::read_version(r, ProtocolVersion::TLSv1_2)
2820    }
2821}
2822
2823impl<'a> HandshakeMessagePayload<'a> {
2824    pub(crate) fn read_version(
2825        r: &mut Reader<'a>,
2826        vers: ProtocolVersion,
2827    ) -> Result<Self, InvalidMessage> {
2828        let mut typ = HandshakeType::read(r)?;
2829        let len = codec::u24::read(r)?.0 as usize;
2830        let mut sub = r.sub(len)?;
2831
2832        let payload = match typ {
2833            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2834            HandshakeType::ClientHello => {
2835                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2836            }
2837            HandshakeType::ServerHello => {
2838                let version = ProtocolVersion::read(&mut sub)?;
2839                let random = Random::read(&mut sub)?;
2840
2841                if random == HELLO_RETRY_REQUEST_RANDOM {
2842                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2843                    hrr.legacy_version = version;
2844                    typ = HandshakeType::HelloRetryRequest;
2845                    HandshakePayload::HelloRetryRequest(hrr)
2846                } else {
2847                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2848                    shp.legacy_version = version;
2849                    shp.random = random;
2850                    HandshakePayload::ServerHello(shp)
2851                }
2852            }
2853            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2854                let p = CertificatePayloadTls13::read(&mut sub)?;
2855                HandshakePayload::CertificateTls13(p)
2856            }
2857            HandshakeType::Certificate => {
2858                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2859            }
2860            HandshakeType::ServerKeyExchange => {
2861                let p = ServerKeyExchangePayload::read(&mut sub)?;
2862                HandshakePayload::ServerKeyExchange(p)
2863            }
2864            HandshakeType::ServerHelloDone => {
2865                sub.expect_empty("ServerHelloDone")?;
2866                HandshakePayload::ServerHelloDone
2867            }
2868            HandshakeType::ClientKeyExchange => {
2869                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2870            }
2871            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2872                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2873                HandshakePayload::CertificateRequestTls13(p)
2874            }
2875            HandshakeType::CertificateRequest => {
2876                let p = CertificateRequestPayload::read(&mut sub)?;
2877                HandshakePayload::CertificateRequest(p)
2878            }
2879            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2880                CompressedCertificatePayload::read(&mut sub)?,
2881            ),
2882            HandshakeType::CertificateVerify => {
2883                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2884            }
2885            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2886                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2887                HandshakePayload::NewSessionTicketTls13(p)
2888            }
2889            HandshakeType::NewSessionTicket => {
2890                let p = NewSessionTicketPayload::read(&mut sub)?;
2891                HandshakePayload::NewSessionTicket(p)
2892            }
2893            HandshakeType::EncryptedExtensions => {
2894                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2895            }
2896            HandshakeType::KeyUpdate => {
2897                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2898            }
2899            HandshakeType::EndOfEarlyData => {
2900                sub.expect_empty("EndOfEarlyData")?;
2901                HandshakePayload::EndOfEarlyData
2902            }
2903            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2904            HandshakeType::CertificateStatus => {
2905                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2906            }
2907            HandshakeType::MessageHash => {
2908                // does not appear on the wire
2909                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2910            }
2911            HandshakeType::HelloRetryRequest => {
2912                // not legal on wire
2913                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2914            }
2915            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2916        };
2917
2918        sub.expect_empty("HandshakeMessagePayload")
2919            .map(|_| Self { typ, payload })
2920    }
2921
2922    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2923        let mut ret = self.get_encoding();
2924        let ret_len = ret.len() - self.total_binder_length();
2925        ret.truncate(ret_len);
2926        ret
2927    }
2928
2929    pub(crate) fn total_binder_length(&self) -> usize {
2930        match &self.payload {
2931            HandshakePayload::ClientHello(ch) => match ch.extensions.last() {
2932                Some(ClientExtension::PresharedKey(offer)) => {
2933                    let mut binders_encoding = Vec::new();
2934                    offer
2935                        .binders
2936                        .encode(&mut binders_encoding);
2937                    binders_encoding.len()
2938                }
2939                _ => 0,
2940            },
2941            _ => 0,
2942        }
2943    }
2944
2945    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2946        // output type, length, and encoded payload
2947        match self.typ {
2948            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2949            _ => self.typ,
2950        }
2951        .encode(bytes);
2952
2953        let nested = LengthPrefixedBuffer::new(
2954            ListLength::U24 {
2955                max: usize::MAX,
2956                error: InvalidMessage::MessageTooLarge,
2957            },
2958            bytes,
2959        );
2960
2961        match &self.payload {
2962            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2963            // differently based on the purpose of the encoding.
2964            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2965            HandshakePayload::HelloRetryRequest(payload) => {
2966                payload.payload_encode(nested.buf, encoding)
2967            }
2968
2969            // All other payload types are encoded the same regardless of purpose.
2970            _ => self.payload.encode(nested.buf),
2971        }
2972    }
2973
2974    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2975        Self {
2976            typ: HandshakeType::MessageHash,
2977            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2978        }
2979    }
2980
2981    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2982        let Self { typ, payload } = self;
2983        HandshakeMessagePayload {
2984            typ,
2985            payload: payload.into_owned(),
2986        }
2987    }
2988}
2989
2990#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2991pub struct HpkeSymmetricCipherSuite {
2992    pub kdf_id: HpkeKdf,
2993    pub aead_id: HpkeAead,
2994}
2995
2996impl Codec<'_> for HpkeSymmetricCipherSuite {
2997    fn encode(&self, bytes: &mut Vec<u8>) {
2998        self.kdf_id.encode(bytes);
2999        self.aead_id.encode(bytes);
3000    }
3001
3002    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3003        Ok(Self {
3004            kdf_id: HpkeKdf::read(r)?,
3005            aead_id: HpkeAead::read(r)?,
3006        })
3007    }
3008}
3009
3010/// draft-ietf-tls-esni-24: `HpkeSymmetricCipherSuite cipher_suites<4..2^16-4>;`
3011impl TlsListElement for HpkeSymmetricCipherSuite {
3012    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
3013        empty_error: InvalidMessage::IllegalEmptyList("HpkeSymmetricCipherSuites"),
3014    };
3015}
3016
3017#[derive(Clone, Debug, PartialEq)]
3018pub struct HpkeKeyConfig {
3019    pub config_id: u8,
3020    pub kem_id: HpkeKem,
3021    /// draft-ietf-tls-esni-24: `opaque HpkePublicKey<1..2^16-1>;`
3022    pub public_key: PayloadU16<NonEmpty>,
3023    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
3024}
3025
3026impl Codec<'_> for HpkeKeyConfig {
3027    fn encode(&self, bytes: &mut Vec<u8>) {
3028        self.config_id.encode(bytes);
3029        self.kem_id.encode(bytes);
3030        self.public_key.encode(bytes);
3031        self.symmetric_cipher_suites
3032            .encode(bytes);
3033    }
3034
3035    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3036        Ok(Self {
3037            config_id: u8::read(r)?,
3038            kem_id: HpkeKem::read(r)?,
3039            public_key: PayloadU16::read(r)?,
3040            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
3041        })
3042    }
3043}
3044
3045#[derive(Clone, Debug, PartialEq)]
3046pub struct EchConfigContents {
3047    pub key_config: HpkeKeyConfig,
3048    pub maximum_name_length: u8,
3049    pub public_name: DnsName<'static>,
3050    pub extensions: Vec<EchConfigExtension>,
3051}
3052
3053impl EchConfigContents {
3054    /// Returns true if there is more than one extension of a given
3055    /// type.
3056    pub(crate) fn has_duplicate_extension(&self) -> bool {
3057        has_duplicates::<_, _, u16>(
3058            self.extensions
3059                .iter()
3060                .map(|ext| ext.ext_type()),
3061        )
3062    }
3063
3064    /// Returns true if there is at least one mandatory unsupported extension.
3065    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
3066        self.extensions
3067            .iter()
3068            // An extension is considered mandatory if the high bit of its type is set.
3069            .any(|ext| {
3070                matches!(ext.ext_type(), ExtensionType::Unknown(_))
3071                    && u16::from(ext.ext_type()) & 0x8000 != 0
3072            })
3073    }
3074}
3075
3076impl Codec<'_> for EchConfigContents {
3077    fn encode(&self, bytes: &mut Vec<u8>) {
3078        self.key_config.encode(bytes);
3079        self.maximum_name_length.encode(bytes);
3080        let dns_name = &self.public_name.borrow();
3081        PayloadU8::<MaybeEmpty>::encode_slice(dns_name.as_ref().as_ref(), bytes);
3082        self.extensions.encode(bytes);
3083    }
3084
3085    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3086        Ok(Self {
3087            key_config: HpkeKeyConfig::read(r)?,
3088            maximum_name_length: u8::read(r)?,
3089            public_name: {
3090                DnsName::try_from(
3091                    PayloadU8::<MaybeEmpty>::read(r)?
3092                        .0
3093                        .as_slice(),
3094                )
3095                .map_err(|_| InvalidMessage::InvalidServerName)?
3096                .to_owned()
3097            },
3098            extensions: Vec::read(r)?,
3099        })
3100    }
3101}
3102
3103/// An encrypted client hello (ECH) config.
3104#[derive(Clone, Debug, PartialEq)]
3105pub enum EchConfigPayload {
3106    /// A recognized V18 ECH configuration.
3107    V18(EchConfigContents),
3108    /// An unknown version ECH configuration.
3109    Unknown {
3110        version: EchVersion,
3111        contents: PayloadU16,
3112    },
3113}
3114
3115impl TlsListElement for EchConfigPayload {
3116    const SIZE_LEN: ListLength = ListLength::U16;
3117}
3118
3119impl Codec<'_> for EchConfigPayload {
3120    fn encode(&self, bytes: &mut Vec<u8>) {
3121        match self {
3122            Self::V18(c) => {
3123                // Write the version, the length, and the contents.
3124                EchVersion::V18.encode(bytes);
3125                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3126                c.encode(inner.buf);
3127            }
3128            Self::Unknown { version, contents } => {
3129                // Unknown configuration versions are opaque.
3130                version.encode(bytes);
3131                contents.encode(bytes);
3132            }
3133        }
3134    }
3135
3136    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3137        let version = EchVersion::read(r)?;
3138        let length = u16::read(r)?;
3139        let mut contents = r.sub(length as usize)?;
3140
3141        Ok(match version {
3142            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
3143            _ => {
3144                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
3145                let data = PayloadU16::new(contents.rest().into());
3146                Self::Unknown {
3147                    version,
3148                    contents: data,
3149                }
3150            }
3151        })
3152    }
3153}
3154
3155#[derive(Clone, Debug, PartialEq)]
3156pub enum EchConfigExtension {
3157    Unknown(UnknownExtension),
3158}
3159
3160impl EchConfigExtension {
3161    pub(crate) fn ext_type(&self) -> ExtensionType {
3162        match self {
3163            Self::Unknown(r) => r.typ,
3164        }
3165    }
3166}
3167
3168impl Codec<'_> for EchConfigExtension {
3169    fn encode(&self, bytes: &mut Vec<u8>) {
3170        self.ext_type().encode(bytes);
3171
3172        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3173        match self {
3174            Self::Unknown(r) => r.encode(nested.buf),
3175        }
3176    }
3177
3178    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3179        let typ = ExtensionType::read(r)?;
3180        let len = u16::read(r)? as usize;
3181        let mut sub = r.sub(len)?;
3182
3183        #[allow(clippy::match_single_binding)] // Future-proofing.
3184        let ext = match typ {
3185            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
3186        };
3187
3188        sub.expect_empty("EchConfigExtension")
3189            .map(|_| ext)
3190    }
3191}
3192
3193impl TlsListElement for EchConfigExtension {
3194    const SIZE_LEN: ListLength = ListLength::U16;
3195}
3196
3197/// Representation of the `ECHClientHello` client extension specified in
3198/// [draft-ietf-tls-esni Section 5].
3199///
3200/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3201#[derive(Clone, Debug)]
3202pub enum EncryptedClientHello {
3203    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
3204    Outer(EncryptedClientHelloOuter),
3205    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
3206    ///
3207    /// This variant has no payload.
3208    Inner,
3209}
3210
3211impl Codec<'_> for EncryptedClientHello {
3212    fn encode(&self, bytes: &mut Vec<u8>) {
3213        match self {
3214            Self::Outer(payload) => {
3215                EchClientHelloType::ClientHelloOuter.encode(bytes);
3216                payload.encode(bytes);
3217            }
3218            Self::Inner => {
3219                EchClientHelloType::ClientHelloInner.encode(bytes);
3220                // Empty payload.
3221            }
3222        }
3223    }
3224
3225    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3226        match EchClientHelloType::read(r)? {
3227            EchClientHelloType::ClientHelloOuter => {
3228                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3229            }
3230            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3231            _ => Err(InvalidMessage::InvalidContentType),
3232        }
3233    }
3234}
3235
3236/// Representation of the ECHClientHello extension with type outer specified in
3237/// [draft-ietf-tls-esni Section 5].
3238///
3239/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3240#[derive(Clone, Debug)]
3241pub struct EncryptedClientHelloOuter {
3242    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3243    /// ECHConfigContents.cipher_suites list.
3244    pub cipher_suite: HpkeSymmetricCipherSuite,
3245    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3246    pub config_id: u8,
3247    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3248    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3249    pub enc: PayloadU16,
3250    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3251    pub payload: PayloadU16<NonEmpty>,
3252}
3253
3254impl Codec<'_> for EncryptedClientHelloOuter {
3255    fn encode(&self, bytes: &mut Vec<u8>) {
3256        self.cipher_suite.encode(bytes);
3257        self.config_id.encode(bytes);
3258        self.enc.encode(bytes);
3259        self.payload.encode(bytes);
3260    }
3261
3262    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3263        Ok(Self {
3264            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3265            config_id: u8::read(r)?,
3266            enc: PayloadU16::read(r)?,
3267            payload: PayloadU16::read(r)?,
3268        })
3269    }
3270}
3271
3272/// Representation of the ECHEncryptedExtensions extension specified in
3273/// [draft-ietf-tls-esni Section 5].
3274///
3275/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3276#[derive(Clone, Debug)]
3277pub struct ServerEncryptedClientHello {
3278    pub(crate) retry_configs: Vec<EchConfigPayload>,
3279}
3280
3281impl Codec<'_> for ServerEncryptedClientHello {
3282    fn encode(&self, bytes: &mut Vec<u8>) {
3283        self.retry_configs.encode(bytes);
3284    }
3285
3286    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3287        Ok(Self {
3288            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3289        })
3290    }
3291}
3292
3293/// The method of encoding to use for a handshake message.
3294///
3295/// In some cases a handshake message may be encoded differently depending on the purpose
3296/// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded
3297/// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation.
3298pub(crate) enum Encoding {
3299    /// Standard RFC 8446 encoding.
3300    Standard,
3301    /// Encoding for ECH confirmation.
3302    EchConfirmation,
3303    /// Encoding for ECH inner client hello.
3304    EchInnerHello { to_compress: Vec<ExtensionType> },
3305}
3306
3307fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3308    let mut seen = BTreeSet::new();
3309
3310    for x in iter {
3311        if !seen.insert(x.into()) {
3312            return true;
3313        }
3314    }
3315
3316    false
3317}
3318
3319#[cfg(test)]
3320mod tests {
3321    use super::*;
3322
3323    #[test]
3324    fn test_ech_config_dupe_exts() {
3325        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3326            typ: ExtensionType::Unknown(0x42),
3327            payload: Payload::new(vec![0x42]),
3328        });
3329        let mut config = config_template();
3330        config
3331            .extensions
3332            .push(unknown_ext.clone());
3333        config.extensions.push(unknown_ext);
3334
3335        assert!(config.has_duplicate_extension());
3336        assert!(!config.has_unknown_mandatory_extension());
3337    }
3338
3339    #[test]
3340    fn test_ech_config_mandatory_exts() {
3341        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3342            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3343            payload: Payload::new(vec![0x42]),
3344        });
3345        let mut config = config_template();
3346        config
3347            .extensions
3348            .push(mandatory_unknown_ext);
3349
3350        assert!(!config.has_duplicate_extension());
3351        assert!(config.has_unknown_mandatory_extension());
3352    }
3353
3354    fn config_template() -> EchConfigContents {
3355        EchConfigContents {
3356            key_config: HpkeKeyConfig {
3357                config_id: 0,
3358                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3359                public_key: PayloadU16::new(b"xxx".into()),
3360                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3361                    kdf_id: HpkeKdf::HKDF_SHA256,
3362                    aead_id: HpkeAead::AES_128_GCM,
3363                }],
3364            },
3365            maximum_name_length: 0,
3366            public_name: DnsName::try_from("example.com").unwrap(),
3367            extensions: vec![],
3368        }
3369    }
3370}