rsa/algorithms/
oaep.rs

1//! Encryption and Decryption using [OAEP padding](https://datatracker.ietf.org/doc/html/rfc8017#section-7.1).
2//!
3use alloc::string::String;
4use alloc::vec::Vec;
5
6use digest::{Digest, DynDigest, FixedOutputReset};
7use rand_core::CryptoRngCore;
8use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
9use zeroize::Zeroizing;
10
11use super::mgf::{mgf1_xor, mgf1_xor_digest};
12use crate::errors::{Error, Result};
13
14// 2**61 -1 (pow is not const yet)
15// TODO: This is the maximum for SHA-1, unclear from the RFC what the values are for other hashing functions.
16const MAX_LABEL_LEN: u64 = 2_305_843_009_213_693_951;
17
18#[inline]
19fn encrypt_internal<R: CryptoRngCore + ?Sized, MGF: FnMut(&mut [u8], &mut [u8])>(
20    rng: &mut R,
21    msg: &[u8],
22    p_hash: &[u8],
23    h_size: usize,
24    k: usize,
25    mut mgf: MGF,
26) -> Result<Zeroizing<Vec<u8>>> {
27    if msg.len() + 2 * h_size + 2 > k {
28        return Err(Error::MessageTooLong);
29    }
30
31    let mut em = Zeroizing::new(vec![0u8; k]);
32
33    let (_, payload) = em.split_at_mut(1);
34    let (seed, db) = payload.split_at_mut(h_size);
35    rng.fill_bytes(seed);
36
37    // Data block DB =  pHash || PS || 01 || M
38    let db_len = k - h_size - 1;
39
40    db[0..h_size].copy_from_slice(p_hash);
41    db[db_len - msg.len() - 1] = 1;
42    db[db_len - msg.len()..].copy_from_slice(msg);
43
44    mgf(seed, db);
45
46    Ok(em)
47}
48
49/// Encrypts the given message with RSA and the padding scheme from
50/// [PKCS#1 OAEP].
51///
52/// The message must be no longer than the length of the public modulus minus
53/// `2 + (2 * hash.size())`.
54///
55/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
56#[inline]
57pub(crate) fn oaep_encrypt<R: CryptoRngCore + ?Sized>(
58    rng: &mut R,
59    msg: &[u8],
60    digest: &mut dyn DynDigest,
61    mgf_digest: &mut dyn DynDigest,
62    label: Option<String>,
63    k: usize,
64) -> Result<Zeroizing<Vec<u8>>> {
65    let h_size = digest.output_size();
66
67    let label = label.unwrap_or_default();
68    if label.len() as u64 > MAX_LABEL_LEN {
69        return Err(Error::LabelTooLong);
70    }
71
72    digest.update(label.as_bytes());
73    let p_hash = digest.finalize_reset();
74
75    encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
76        mgf1_xor(db, mgf_digest, seed);
77        mgf1_xor(seed, mgf_digest, db);
78    })
79}
80
81/// Encrypts the given message with RSA and the padding scheme from
82/// [PKCS#1 OAEP].
83///
84/// The message must be no longer than the length of the public modulus minus
85/// `2 + (2 * hash.size())`.
86///
87/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
88#[inline]
89pub(crate) fn oaep_encrypt_digest<
90    R: CryptoRngCore + ?Sized,
91    D: Digest,
92    MGD: Digest + FixedOutputReset,
93>(
94    rng: &mut R,
95    msg: &[u8],
96    label: Option<String>,
97    k: usize,
98) -> Result<Zeroizing<Vec<u8>>> {
99    let h_size = <D as Digest>::output_size();
100
101    let label = label.unwrap_or_default();
102    if label.len() as u64 > MAX_LABEL_LEN {
103        return Err(Error::LabelTooLong);
104    }
105
106    let p_hash = D::digest(label.as_bytes());
107
108    encrypt_internal(rng, msg, &p_hash, h_size, k, |seed, db| {
109        let mut mgf_digest = MGD::new();
110        mgf1_xor_digest(db, &mut mgf_digest, seed);
111        mgf1_xor_digest(seed, &mut mgf_digest, db);
112    })
113}
114
115///Decrypts OAEP padding.
116///
117/// Note that whether this function returns an error or not discloses secret
118/// information. If an attacker can cause this function to run repeatedly and
119/// learn whether each instance returned an error then they can decrypt and
120/// forge signatures as if they had the private key.
121///
122/// See `decrypt_session_key` for a way of solving this problem.
123///
124/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
125#[inline]
126pub(crate) fn oaep_decrypt(
127    em: &mut [u8],
128    digest: &mut dyn DynDigest,
129    mgf_digest: &mut dyn DynDigest,
130    label: Option<String>,
131    k: usize,
132) -> Result<Vec<u8>> {
133    let h_size = digest.output_size();
134
135    let label = label.unwrap_or_default();
136    if label.len() as u64 > MAX_LABEL_LEN {
137        return Err(Error::Decryption);
138    }
139
140    digest.update(label.as_bytes());
141
142    let expected_p_hash = digest.finalize_reset();
143
144    let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
145        mgf1_xor(seed, mgf_digest, db);
146        mgf1_xor(db, mgf_digest, seed);
147    })?;
148    if res.is_none().into() {
149        return Err(Error::Decryption);
150    }
151
152    let (out, index) = res.unwrap();
153
154    Ok(out[index as usize..].to_vec())
155}
156
157///Decrypts OAEP padding.
158///
159/// Note that whether this function returns an error or not discloses secret
160/// information. If an attacker can cause this function to run repeatedly and
161/// learn whether each instance returned an error then they can decrypt and
162/// forge signatures as if they had the private key.
163///
164/// See `decrypt_session_key` for a way of solving this problem.
165///
166/// [PKCS#1 OAEP]: https://datatracker.ietf.org/doc/html/rfc8017#section-7.1
167#[inline]
168pub(crate) fn oaep_decrypt_digest<D: Digest, MGD: Digest + FixedOutputReset>(
169    em: &mut [u8],
170    label: Option<String>,
171    k: usize,
172) -> Result<Vec<u8>> {
173    let h_size = <D as Digest>::output_size();
174
175    let label = label.unwrap_or_default();
176    if label.len() as u64 > MAX_LABEL_LEN {
177        return Err(Error::LabelTooLong);
178    }
179
180    let expected_p_hash = D::digest(label.as_bytes());
181
182    let res = decrypt_inner(em, h_size, &expected_p_hash, k, |seed, db| {
183        let mut mgf_digest = MGD::new();
184        mgf1_xor_digest(seed, &mut mgf_digest, db);
185        mgf1_xor_digest(db, &mut mgf_digest, seed);
186    })?;
187    if res.is_none().into() {
188        return Err(Error::Decryption);
189    }
190
191    let (out, index) = res.unwrap();
192
193    Ok(out[index as usize..].to_vec())
194}
195
196/// Decrypts OAEP padding. It returns one or zero in valid that indicates whether the
197/// plaintext was correctly structured.
198#[inline]
199fn decrypt_inner<MGF: FnMut(&mut [u8], &mut [u8])>(
200    em: &mut [u8],
201    h_size: usize,
202    expected_p_hash: &[u8],
203    k: usize,
204    mut mgf: MGF,
205) -> Result<CtOption<(Vec<u8>, u32)>> {
206    if k < 11 {
207        return Err(Error::Decryption);
208    }
209
210    if k < h_size * 2 + 2 {
211        return Err(Error::Decryption);
212    }
213
214    let first_byte_is_zero = em[0].ct_eq(&0u8);
215
216    let (_, payload) = em.split_at_mut(1);
217    let (seed, db) = payload.split_at_mut(h_size);
218
219    mgf(seed, db);
220
221    let hash_are_equal = db[0..h_size].ct_eq(expected_p_hash);
222
223    // The remainder of the plaintext must be zero or more 0x00, followed
224    // by 0x01, followed by the message.
225    //   looking_for_index: 1 if we are still looking for the 0x01
226    //   index: the offset of the first 0x01 byte
227    //   zero_before_one: 1 if we saw a non-zero byte before the 1
228    let mut looking_for_index = Choice::from(1u8);
229    let mut index = 0u32;
230    let mut nonzero_before_one = Choice::from(0u8);
231
232    for (i, el) in db.iter().skip(h_size).enumerate() {
233        let equals0 = el.ct_eq(&0u8);
234        let equals1 = el.ct_eq(&1u8);
235        index.conditional_assign(&(i as u32), looking_for_index & equals1);
236        looking_for_index &= !equals1;
237        nonzero_before_one |= looking_for_index & !equals0;
238    }
239
240    let valid = first_byte_is_zero & hash_are_equal & !nonzero_before_one & !looking_for_index;
241
242    Ok(CtOption::new(
243        (em.to_vec(), index + 2 + (h_size * 2) as u32),
244        valid,
245    ))
246}