rsa/algorithms/
mgf.rs

1//! Mask generation function common to both PSS and OAEP padding
2
3use digest::{Digest, DynDigest, FixedOutputReset};
4
5/// Mask generation function.
6///
7/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
8pub(crate) fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) {
9    let mut counter = [0u8; 4];
10    let mut i = 0;
11
12    const MAX_LEN: u64 = core::u32::MAX as u64 + 1;
13    assert!(out.len() as u64 <= MAX_LEN);
14
15    while i < out.len() {
16        let mut digest_input = vec![0u8; seed.len() + 4];
17        digest_input[0..seed.len()].copy_from_slice(seed);
18        digest_input[seed.len()..].copy_from_slice(&counter);
19
20        digest.update(digest_input.as_slice());
21        let digest_output = &*digest.finalize_reset();
22        let mut j = 0;
23        loop {
24            if j >= digest_output.len() || i >= out.len() {
25                break;
26            }
27
28            out[i] ^= digest_output[j];
29            j += 1;
30            i += 1;
31        }
32        inc_counter(&mut counter);
33    }
34}
35
36/// Mask generation function.
37///
38/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
39pub(crate) fn mgf1_xor_digest<D>(out: &mut [u8], digest: &mut D, seed: &[u8])
40where
41    D: Digest + FixedOutputReset,
42{
43    let mut counter = [0u8; 4];
44    let mut i = 0;
45
46    const MAX_LEN: u64 = core::u32::MAX as u64 + 1;
47    assert!(out.len() as u64 <= MAX_LEN);
48
49    while i < out.len() {
50        Digest::update(digest, seed);
51        Digest::update(digest, counter);
52
53        let digest_output = digest.finalize_reset();
54        let mut j = 0;
55        loop {
56            if j >= digest_output.len() || i >= out.len() {
57                break;
58            }
59
60            out[i] ^= digest_output[j];
61            j += 1;
62            i += 1;
63        }
64        inc_counter(&mut counter);
65    }
66}
67fn inc_counter(counter: &mut [u8; 4]) {
68    for i in (0..4).rev() {
69        counter[i] = counter[i].wrapping_add(1);
70        if counter[i] != 0 {
71            // No overflow
72            return;
73        }
74    }
75}