1use alloc::string::String;
4use alloc::vec::Vec;
5
6use digest::{Digest, DynDigest, FixedOutputReset};
7use rand_core::CryptoRngCore;
8use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
9use zeroize::Zeroizing;
10
11use super::mgf::{mgf1_xor, mgf1_xor_digest};
12use crate::errors::{Error, Result};
13
14const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951;
17
18#[inline]
19fn encrypt_internal<R: CryptoRngCore + ?Sized, MGF: FnMut(&mut [u8], &mut [u8])>(
20 rng: &mut R,
21 msg: &[u8],
22 p_hash: &[u8],
23 h_size: usize,
24 k: usize,
25 mut mgf: MGF,
26) -> Result<Zeroizing<Vec<u8>>> {
27 if msg.len() + 2 * h_size + 2 > k {
28 return Err(Error::MessageTooLong);
29 }
30
31 let mut em = Zeroizing::new(vec![0u8; k]);
32
33 let (_, payload) = em.split_at_mut(1);
34 let (seed, db) = payload.split_at_mut(h_size);
35 rng.fill_bytes(seed);
36
37 let db_len = k - h_size - 1;
39
40 db[0..h_size].copy_from_slice(p_hash);
41 db[db_len - msg.len() - 1] = 1;
42 db[db_len - msg.len()..].copy_from_slice(msg);
43
44 mgf(seed, db);
45
46 Ok(em)
47}
48
49#[inline]
57pub(crate) fn oaep_encrypt<R: CryptoRngCore + ?Sized>(
58 rng: &mut R,
59 msg: &[u8],
60 digest: &mut dyn DynDigest,
61 mgf_digest: &mut dyn DynDigest,
62 label: Option<String>,
63 k: usize,
64) -> Result<Zeroizing<Vec<u8>>> {
65 let h_size = digest.output_size();
66
67 let label = label.unwrap_or_default();
68 if label.len() as u64 > MAX_LABEL_LEN {
69 return Err(Error::LabelTooLong);
70 }
71
72 digest.update(label.as_bytes());
73 let p_hash = digest.finalize_reset();
74
75 encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
76 mgf1_xor(db, mgf_digest, seed);
77 mgf1_xor(seed, mgf_digest, db);
78 })
79}
80
81#[inline]
89pub(crate) fn oaep_encrypt_digest<
90 R: CryptoRngCore + ?Sized,
91 D: Digest,
92 MGD: Digest + FixedOutputReset,
93>(
94 rng: &mut R,
95 msg: &[u8],
96 label: Option<String>,
97 k: usize,
98) -> Result<Zeroizing<Vec<u8>>> {
99 let h_size = <D as Digest>::output_size();
100
101 let label = label.unwrap_or_default();
102 if label.len() as u64 > MAX_LABEL_LEN {
103 return Err(Error::LabelTooLong);
104 }
105
106 let p_hash = D::digest(label.as_bytes());
107
108 encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
109 let mut mgf_digest = MGD::new();
110 mgf1_xor_digest(db, &mut mgf_digest, seed);
111 mgf1_xor_digest(seed, &mut mgf_digest, db);
112 })
113}
114
115#[inline]
126pub(crate) fn oaep_decrypt(
127 em: &mut [u8],
128 digest: &mut dyn DynDigest,
129 mgf_digest: &mut dyn DynDigest,
130 label: Option<String>,
131 k: usize,
132) -> Result<Vec<u8>> {
133 let h_size = digest.output_size();
134
135 let label = label.unwrap_or_default();
136 if label.len() as u64 > MAX_LABEL_LEN {
137 return Err(Error::Decryption);
138 }
139
140 digest.update(label.as_bytes());
141
142 let expected_p_hash = digest.finalize_reset();
143
144 let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
145 mgf1_xor(seed, mgf_digest, db);
146 mgf1_xor(db, mgf_digest, seed);
147 })?;
148 if res.is_none().into() {
149 return Err(Error::Decryption);
150 }
151
152 let (out, index) = res.unwrap();
153
154 Ok(out[index as usize..].to_vec())
155}
156
157#[inline]
168pub(crate) fn oaep_decrypt_digest<D: Digest, MGD: Digest + FixedOutputReset>(
169 em: &mut [u8],
170 label: Option<String>,
171 k: usize,
172) -> Result<Vec<u8>> {
173 let h_size = <D as Digest>::output_size();
174
175 let label = label.unwrap_or_default();
176 if label.len() as u64 > MAX_LABEL_LEN {
177 return Err(Error::LabelTooLong);
178 }
179
180 let expected_p_hash = D::digest(label.as_bytes());
181
182 let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
183 let mut mgf_digest = MGD::new();
184 mgf1_xor_digest(seed, &mut mgf_digest, db);
185 mgf1_xor_digest(db, &mut mgf_digest, seed);
186 })?;
187 if res.is_none().into() {
188 return Err(Error::Decryption);
189 }
190
191 let (out, index) = res.unwrap();
192
193 Ok(out[index as usize..].to_vec())
194}
195
196#[inline]
199fn decrypt_inner<MGF: FnMut(&mut [u8], &mut [u8])>(
200 em: &mut [u8],
201 h_size: usize,
202 expected_p_hash: &[u8],
203 k: usize,
204 mut mgf: MGF,
205) -> Result<CtOption<(Vec<u8>, u32)>> {
206 if k < 11 {
207 return Err(Error::Decryption);
208 }
209
210 if k < h_size * 2 + 2 {
211 return Err(Error::Decryption);
212 }
213
214 let first_byte_is_zero = em[0].ct_eq(&0u8);
215
216 let (_, payload) = em.split_at_mut(1);
217 let (seed, db) = payload.split_at_mut(h_size);
218
219 mgf(seed, db);
220
221 let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash);
222
223 let mut looking_for_index = Choice::from(1u8);
229 let mut index = 0u32;
230 let mut nonzero_before_one = Choice::from(0u8);
231
232 for (i, el) in db.iter().skip(h_size).enumerate() {
233 let equals0 = el.ct_eq(&0u8);
234 let equals1 = el.ct_eq(&1u8);
235 index.conditional_assign(&(i as u32), looking_for_index & equals1);
236 looking_for_index &= !equals1;
237 nonzero_before_one |= looking_for_index & !equals0;
238 }
239
240 let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index;
241
242 Ok(CtOption::new(
243 (em.to_vec(), index + 2 + (h_size * 2) as u32),
244 valid,
245 ))
246}