rsa/algorithms/
pss.rs

1//! Support for the [Probabilistic Signature Scheme] (PSS) a.k.a. RSASSA-PSS.
2//!
3//! Designed by Mihir Bellare and Phillip Rogaway. Specified in [RFC8017 § 8.1].
4//!
5//! # Usage
6//!
7//! See [code example in the toplevel rustdoc](../index.html#pss-signatures).
8//!
9//! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme
10//! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
11
12use alloc::vec::Vec;
13use digest::{Digest, DynDigest, FixedOutputReset};
14use subtle::{Choice, ConstantTimeEq};
15
16use super::mgf::{mgf1_xor, mgf1_xor_digest};
17use crate::errors::{Error, Result};
18
19pub(crate) fn emsa_pss_encode(
20    m_hash: &[u8],
21    em_bits: usize,
22    salt: &[u8],
23    hash: &mut dyn DynDigest,
24) -> Result<Vec<u8>> {
25    // See [1], section 9.1.1
26    let h_len = hash.output_size();
27    let s_len = salt.len();
28    let em_len = (em_bits + 7) / 8;
29
30    // 1. If the length of M is greater than the input limitation for the
31    //     hash function (2^61 - 1 octets for SHA-1), output "message too
32    //     long" and stop.
33    //
34    // 2.  Let mHash = Hash(M), an octet string of length hLen.
35    if m_hash.len() != h_len {
36        return Err(Error::InputNotHashed);
37    }
38
39    // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
40    if em_len < h_len + s_len + 2 {
41        // TODO: Key size too small
42        return Err(Error::Internal);
43    }
44
45    let mut em = vec![0; em_len];
46
47    let (db, h) = em.split_at_mut(em_len - h_len - 1);
48    let h = &mut h[..(em_len - 1) - db.len()];
49
50    // 4. Generate a random octet string salt of length s_len; if s_len = 0,
51    //     then salt is the empty string.
52    //
53    // 5.  Let
54    //       M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
55    //
56    //     M' is an octet string of length 8 + h_len + s_len with eight
57    //     initial zero octets.
58    //
59    // 6.  Let H = Hash(M'), an octet string of length h_len.
60    let prefix = [0u8; 8];
61
62    hash.update(&prefix);
63    hash.update(m_hash);
64    hash.update(salt);
65
66    let hashed = hash.finalize_reset();
67    h.copy_from_slice(&hashed);
68
69    // 7.  Generate an octet string PS consisting of em_len - s_len - h_len - 2
70    //     zero octets. The length of PS may be 0.
71    //
72    // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
73    //     emLen - hLen - 1.
74    db[em_len - s_len - h_len - 2] = 0x01;
75    db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
76
77    // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
78    //
79    // 10. Let maskedDB = DB \xor dbMask.
80    mgf1_xor(db, hash, h);
81
82    // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
83    //     maskedDB to zero.
84    db[0] &= 0xFF >> (8 * em_len - em_bits);
85
86    // 12. Let EM = maskedDB || H || 0xbc.
87    em[em_len - 1] = 0xBC;
88
89    Ok(em)
90}
91
92pub(crate) fn emsa_pss_encode_digest<D>(
93    m_hash: &[u8],
94    em_bits: usize,
95    salt: &[u8],
96) -> Result<Vec<u8>>
97where
98    D: Digest + FixedOutputReset,
99{
100    // See [1], section 9.1.1
101    let h_len = <D as Digest>::output_size();
102    let s_len = salt.len();
103    let em_len = (em_bits + 7) / 8;
104
105    // 1. If the length of M is greater than the input limitation for the
106    //     hash function (2^61 - 1 octets for SHA-1), output "message too
107    //     long" and stop.
108    //
109    // 2.  Let mHash = Hash(M), an octet string of length hLen.
110    if m_hash.len() != h_len {
111        return Err(Error::InputNotHashed);
112    }
113
114    // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
115    if em_len < h_len + s_len + 2 {
116        // TODO: Key size too small
117        return Err(Error::Internal);
118    }
119
120    let mut em = vec![0; em_len];
121
122    let (db, h) = em.split_at_mut(em_len - h_len - 1);
123    let h = &mut h[..(em_len - 1) - db.len()];
124
125    // 4. Generate a random octet string salt of length s_len; if s_len = 0,
126    //     then salt is the empty string.
127    //
128    // 5.  Let
129    //       M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
130    //
131    //     M' is an octet string of length 8 + h_len + s_len with eight
132    //     initial zero octets.
133    //
134    // 6.  Let H = Hash(M'), an octet string of length h_len.
135    let prefix = [0u8; 8];
136
137    let mut hash = D::new();
138
139    Digest::update(&mut hash, prefix);
140    Digest::update(&mut hash, m_hash);
141    Digest::update(&mut hash, salt);
142
143    let hashed = hash.finalize_reset();
144    h.copy_from_slice(&hashed);
145
146    // 7.  Generate an octet string PS consisting of em_len - s_len - h_len - 2
147    //     zero octets. The length of PS may be 0.
148    //
149    // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
150    //     emLen - hLen - 1.
151    db[em_len - s_len - h_len - 2] = 0x01;
152    db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
153
154    // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
155    //
156    // 10. Let maskedDB = DB \xor dbMask.
157    mgf1_xor_digest(db, &mut hash, h);
158
159    // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
160    //     maskedDB to zero.
161    db[0] &= 0xFF >> (8 * em_len - em_bits);
162
163    // 12. Let EM = maskedDB || H || 0xbc.
164    em[em_len - 1] = 0xBC;
165
166    Ok(em)
167}
168
169fn emsa_pss_verify_pre<'a>(
170    m_hash: &[u8],
171    em: &'a mut [u8],
172    em_bits: usize,
173    s_len: usize,
174    h_len: usize,
175) -> Result<(&'a mut [u8], &'a mut [u8])> {
176    // 1. If the length of M is greater than the input limitation for the
177    //    hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
178    //    and stop.
179    //
180    // 2. Let mHash = Hash(M), an octet string of length hLen
181    if m_hash.len() != h_len {
182        return Err(Error::Verification);
183    }
184
185    // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
186    let em_len = em.len(); //(em_bits + 7) / 8;
187    if em_len < h_len + s_len + 2 {
188        return Err(Error::Verification);
189    }
190
191    // 4. If the rightmost octet of EM does not have hexadecimal value
192    //    0xbc, output "inconsistent" and stop.
193    if em[em.len() - 1] != 0xBC {
194        return Err(Error::Verification);
195    }
196
197    // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
198    //    let H be the next hLen octets.
199    let (db, h) = em.split_at_mut(em_len - h_len - 1);
200    let h = &mut h[..h_len];
201
202    // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in
203    //    maskedDB are not all equal to zero, output "inconsistent" and
204    //    stop.
205    if db[0]
206        & (0xFF_u8
207            .checked_shl(8 - (8 * em_len - em_bits) as u32)
208            .unwrap_or(0))
209        != 0
210    {
211        return Err(Error::Verification);
212    }
213
214    Ok((db, h))
215}
216
217fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice {
218    // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
219    //     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
220    //     position is "position 1") does not have hexadecimal value 0x01,
221    //     output "inconsistent" and stop.
222    let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
223    let valid: Choice = zeroes
224        .iter()
225        .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00));
226
227    valid & rest[0].ct_eq(&0x01)
228}
229
230pub(crate) fn emsa_pss_verify(
231    m_hash: &[u8],
232    em: &mut [u8],
233    s_len: usize,
234    hash: &mut dyn DynDigest,
235    key_bits: usize,
236) -> Result<()> {
237    let em_bits = key_bits - 1;
238    let em_len = (em_bits + 7) / 8;
239    let key_len = (key_bits + 7) / 8;
240    let h_len = hash.output_size();
241
242    let em = &mut em[key_len - em_len..];
243
244    let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
245
246    // 7. Let dbMask = MGF(H, em_len - h_len - 1)
247    //
248    // 8. Let DB = maskedDB \xor dbMask
249    mgf1_xor(db, hash, &*h);
250
251    // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
252    //     to zero.
253    db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
254
255    let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
256
257    // 11. Let salt be the last s_len octets of DB.
258    let salt = &db[db.len() - s_len..];
259
260    // 12. Let
261    //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
262    //     M' is an octet string of length 8 + hLen + sLen with eight
263    //     initial zero octets.
264    //
265    // 13. Let H' = Hash(M'), an octet string of length hLen.
266    let prefix = [0u8; 8];
267
268    hash.update(&prefix[..]);
269    hash.update(m_hash);
270    hash.update(salt);
271    let h0 = hash.finalize_reset();
272
273    // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
274    if (salt_valid & h0.ct_eq(h)).into() {
275        Ok(())
276    } else {
277        Err(Error::Verification)
278    }
279}
280
281pub(crate) fn emsa_pss_verify_digest<D>(
282    m_hash: &[u8],
283    em: &mut [u8],
284    s_len: usize,
285    key_bits: usize,
286) -> Result<()>
287where
288    D: Digest + FixedOutputReset,
289{
290    let em_bits = key_bits - 1;
291    let em_len = (em_bits + 7) / 8;
292    let key_len = (key_bits + 7) / 8;
293    let h_len = <D as Digest>::output_size();
294
295    let em = &mut em[key_len - em_len..];
296
297    let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
298
299    let mut hash = D::new();
300
301    // 7. Let dbMask = MGF(H, em_len - h_len - 1)
302    //
303    // 8. Let DB = maskedDB \xor dbMask
304    mgf1_xor_digest::<D>(db, &mut hash, &*h);
305
306    // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
307    //     to zero.
308    db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
309
310    let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
311
312    // 11. Let salt be the last s_len octets of DB.
313    let salt = &db[db.len() - s_len..];
314
315    // 12. Let
316    //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
317    //     M' is an octet string of length 8 + hLen + sLen with eight
318    //     initial zero octets.
319    //
320    // 13. Let H' = Hash(M'), an octet string of length hLen.
321    let prefix = [0u8; 8];
322
323    Digest::update(&mut hash, &prefix[..]);
324    Digest::update(&mut hash, m_hash);
325    Digest::update(&mut hash, salt);
326    let h0 = hash.finalize_reset();
327
328    // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
329    if (salt_valid & h0.ct_eq(h)).into() {
330        Ok(())
331    } else {
332        Err(Error::Verification)
333    }
334}