1use crate::{
20 aead::{aes, chacha},
21 cpu, error, hkdf,
22};
23
24pub struct HeaderProtectionKey {
26 inner: KeyInner,
27 algorithm: &'static Algorithm,
28}
29
30#[allow(clippy::large_enum_variant, variant_size_differences)]
31enum KeyInner {
32 Aes(aes::Key),
33 ChaCha20(chacha::Key),
34}
35
36impl From<hkdf::Okm<'_, &'static Algorithm>> for HeaderProtectionKey {
37 fn from(okm: hkdf::Okm<&'static Algorithm>) -> Self {
38 let mut key_bytes = [0; super::MAX_KEY_LEN];
39 let algorithm = *okm.len();
40 let key_bytes = &mut key_bytes[..algorithm.key_len()];
41 okm.fill(key_bytes).unwrap();
42 Self::new(algorithm, key_bytes).unwrap()
43 }
44}
45
46impl HeaderProtectionKey {
47 pub fn new(
51 algorithm: &'static Algorithm,
52 key_bytes: &[u8],
53 ) -> Result<Self, error::Unspecified> {
54 Ok(Self {
55 inner: (algorithm.init)(key_bytes, cpu::features())?,
56 algorithm,
57 })
58 }
59
60 pub fn new_mask(&self, sample: &[u8]) -> Result<[u8; 5], error::Unspecified> {
64 let sample = <&[u8; SAMPLE_LEN]>::try_from(sample)?;
65
66 let out = (self.algorithm.new_mask)(&self.inner, *sample);
67 Ok(out)
68 }
69
70 #[inline(always)]
72 pub fn algorithm(&self) -> &'static Algorithm {
73 self.algorithm
74 }
75}
76
77const SAMPLE_LEN: usize = super::TAG_LEN;
78
79pub type Sample = [u8; SAMPLE_LEN];
81
82pub struct Algorithm {
84 init: fn(key: &[u8], cpu_features: cpu::Features) -> Result<KeyInner, error::Unspecified>,
85
86 new_mask: fn(key: &KeyInner, sample: Sample) -> [u8; 5],
87
88 key_len: usize,
89 id: AlgorithmID,
90}
91
92impl hkdf::KeyType for &'static Algorithm {
93 #[inline]
94 fn len(&self) -> usize {
95 self.key_len()
96 }
97}
98
99impl Algorithm {
100 #[inline(always)]
102 pub fn key_len(&self) -> usize {
103 self.key_len
104 }
105
106 #[inline(always)]
108 pub fn sample_len(&self) -> usize {
109 SAMPLE_LEN
110 }
111}
112
113derive_debug_via_id!(Algorithm);
114
115#[derive(Debug, Eq, PartialEq)]
116enum AlgorithmID {
117 AES_128,
118 AES_256,
119 CHACHA20,
120}
121
122impl PartialEq for Algorithm {
123 fn eq(&self, other: &Self) -> bool {
124 self.id == other.id
125 }
126}
127
128impl Eq for Algorithm {}
129
130pub static AES_128: Algorithm = Algorithm {
132 key_len: 16,
133 init: aes_init_128,
134 new_mask: aes_new_mask,
135 id: AlgorithmID::AES_128,
136};
137
138pub static AES_256: Algorithm = Algorithm {
140 key_len: 32,
141 init: aes_init_256,
142 new_mask: aes_new_mask,
143 id: AlgorithmID::AES_256,
144};
145
146fn aes_init_128(key: &[u8], cpu_features: cpu::Features) -> Result<KeyInner, error::Unspecified> {
147 let key = key.try_into().map_err(|_| error::Unspecified)?;
148 let aes_key = aes::Key::new(aes::KeyBytes::AES_128(key), cpu_features)?;
149 Ok(KeyInner::Aes(aes_key))
150}
151
152fn aes_init_256(key: &[u8], cpu_features: cpu::Features) -> Result<KeyInner, error::Unspecified> {
153 let key = key.try_into().map_err(|_| error::Unspecified)?;
154 let aes_key = aes::Key::new(aes::KeyBytes::AES_256(key), cpu_features)?;
155 Ok(KeyInner::Aes(aes_key))
156}
157
158fn aes_new_mask(key: &KeyInner, sample: Sample) -> [u8; 5] {
159 let aes_key = match key {
160 KeyInner::Aes(key) => key,
161 _ => unreachable!(),
162 };
163
164 aes_key.new_mask(sample)
165}
166
167pub static CHACHA20: Algorithm = Algorithm {
169 key_len: chacha::KEY_LEN,
170 init: chacha20_init,
171 new_mask: chacha20_new_mask,
172 id: AlgorithmID::CHACHA20,
173};
174
175fn chacha20_init(key: &[u8], _cpu_features: cpu::Features) -> Result<KeyInner, error::Unspecified> {
176 let chacha20_key: [u8; chacha::KEY_LEN] = key.try_into()?;
177 Ok(KeyInner::ChaCha20(chacha::Key::new(chacha20_key)))
178}
179
180fn chacha20_new_mask(key: &KeyInner, sample: Sample) -> [u8; 5] {
181 let chacha20_key = match key {
182 KeyInner::ChaCha20(key) => key,
183 _ => unreachable!(),
184 };
185
186 chacha20_key.new_mask(sample)
187}