rsa/algorithms/pss.rs
1//! Support for the [Probabilistic Signature Scheme] (PSS) a.k.a. RSASSA-PSS.
2//!
3//! Designed by Mihir Bellare and Phillip Rogaway. Specified in [RFC8017 § 8.1].
4//!
5//! # Usage
6//!
7//! See [code example in the toplevel rustdoc](../index.html#pss-signatures).
8//!
9//! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme
10//! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
11
12use alloc::vec::Vec;
13use digest::{Digest, DynDigest, FixedOutputReset};
14use subtle::{Choice, ConstantTimeEq};
15
16use super::mgf::{mgf1_xor, mgf1_xor_digest};
17use crate::errors::{Error, Result};
18
19pub(crate) fn emsa_pss_encode(
20 m_hash: &[u8],
21 em_bits: usize,
22 salt: &[u8],
23 hash: &mut dyn DynDigest,
24) -> Result<Vec<u8>> {
25 // See [1], section 9.1.1
26 let h_len = hash.output_size();
27 let s_len = salt.len();
28 let em_len = (em_bits + 7) / 8;
29
30 // 1. If the length of M is greater than the input limitation for the
31 // hash function (2^61 - 1 octets for SHA-1), output "message too
32 // long" and stop.
33 //
34 // 2. Let mHash = Hash(M), an octet string of length hLen.
35 if m_hash.len() != h_len {
36 return Err(Error::InputNotHashed);
37 }
38
39 // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
40 if em_len < h_len + s_len + 2 {
41 // TODO: Key size too small
42 return Err(Error::Internal);
43 }
44
45 let mut em = vec![0; em_len];
46
47 let (db, h) = em.split_at_mut(em_len - h_len - 1);
48 let h = &mut h[..(em_len - 1) - db.len()];
49
50 // 4. Generate a random octet string salt of length s_len; if s_len = 0,
51 // then salt is the empty string.
52 //
53 // 5. Let
54 // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
55 //
56 // M' is an octet string of length 8 + h_len + s_len with eight
57 // initial zero octets.
58 //
59 // 6. Let H = Hash(M'), an octet string of length h_len.
60 let prefix = [0u8; 8];
61
62 hash.update(&prefix);
63 hash.update(m_hash);
64 hash.update(salt);
65
66 let hashed = hash.finalize_reset();
67 h.copy_from_slice(&hashed);
68
69 // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2
70 // zero octets. The length of PS may be 0.
71 //
72 // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
73 // emLen - hLen - 1.
74 db[em_len - s_len - h_len - 2] = 0x01;
75 db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
76
77 // 9. Let dbMask = MGF(H, emLen - hLen - 1).
78 //
79 // 10. Let maskedDB = DB \xor dbMask.
80 mgf1_xor(db, hash, h);
81
82 // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
83 // maskedDB to zero.
84 db[0] &= 0xFF >> (8 * em_len - em_bits);
85
86 // 12. Let EM = maskedDB || H || 0xbc.
87 em[em_len - 1] = 0xBC;
88
89 Ok(em)
90}
91
92pub(crate) fn emsa_pss_encode_digest<D>(
93 m_hash: &[u8],
94 em_bits: usize,
95 salt: &[u8],
96) -> Result<Vec<u8>>
97where
98 D: Digest + FixedOutputReset,
99{
100 // See [1], section 9.1.1
101 let h_len = <D as Digest>::output_size();
102 let s_len = salt.len();
103 let em_len = (em_bits + 7) / 8;
104
105 // 1. If the length of M is greater than the input limitation for the
106 // hash function (2^61 - 1 octets for SHA-1), output "message too
107 // long" and stop.
108 //
109 // 2. Let mHash = Hash(M), an octet string of length hLen.
110 if m_hash.len() != h_len {
111 return Err(Error::InputNotHashed);
112 }
113
114 // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
115 if em_len < h_len + s_len + 2 {
116 // TODO: Key size too small
117 return Err(Error::Internal);
118 }
119
120 let mut em = vec![0; em_len];
121
122 let (db, h) = em.split_at_mut(em_len - h_len - 1);
123 let h = &mut h[..(em_len - 1) - db.len()];
124
125 // 4. Generate a random octet string salt of length s_len; if s_len = 0,
126 // then salt is the empty string.
127 //
128 // 5. Let
129 // M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
130 //
131 // M' is an octet string of length 8 + h_len + s_len with eight
132 // initial zero octets.
133 //
134 // 6. Let H = Hash(M'), an octet string of length h_len.
135 let prefix = [0u8; 8];
136
137 let mut hash = D::new();
138
139 Digest::update(&mut hash, prefix);
140 Digest::update(&mut hash, m_hash);
141 Digest::update(&mut hash, salt);
142
143 let hashed = hash.finalize_reset();
144 h.copy_from_slice(&hashed);
145
146 // 7. Generate an octet string PS consisting of em_len - s_len - h_len - 2
147 // zero octets. The length of PS may be 0.
148 //
149 // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
150 // emLen - hLen - 1.
151 db[em_len - s_len - h_len - 2] = 0x01;
152 db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
153
154 // 9. Let dbMask = MGF(H, emLen - hLen - 1).
155 //
156 // 10. Let maskedDB = DB \xor dbMask.
157 mgf1_xor_digest(db, &mut hash, h);
158
159 // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
160 // maskedDB to zero.
161 db[0] &= 0xFF >> (8 * em_len - em_bits);
162
163 // 12. Let EM = maskedDB || H || 0xbc.
164 em[em_len - 1] = 0xBC;
165
166 Ok(em)
167}
168
169fn emsa_pss_verify_pre<'a>(
170 m_hash: &[u8],
171 em: &'a mut [u8],
172 em_bits: usize,
173 s_len: usize,
174 h_len: usize,
175) -> Result<(&'a mut [u8], &'a mut [u8])> {
176 // 1. If the length of M is greater than the input limitation for the
177 // hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
178 // and stop.
179 //
180 // 2. Let mHash = Hash(M), an octet string of length hLen
181 if m_hash.len() != h_len {
182 return Err(Error::Verification);
183 }
184
185 // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
186 let em_len = em.len(); //(em_bits + 7) / 8;
187 if em_len < h_len + s_len + 2 {
188 return Err(Error::Verification);
189 }
190
191 // 4. If the rightmost octet of EM does not have hexadecimal value
192 // 0xbc, output "inconsistent" and stop.
193 if em[em.len() - 1] != 0xBC {
194 return Err(Error::Verification);
195 }
196
197 // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
198 // let H be the next hLen octets.
199 let (db, h) = em.split_at_mut(em_len - h_len - 1);
200 let h = &mut h[..h_len];
201
202 // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in
203 // maskedDB are not all equal to zero, output "inconsistent" and
204 // stop.
205 if db[0]
206 & (0xFF_u8
207 .checked_shl(8 - (8 * em_len - em_bits) as u32)
208 .unwrap_or(0))
209 != 0
210 {
211 return Err(Error::Verification);
212 }
213
214 Ok((db, h))
215}
216
217fn emsa_pss_verify_salt(db: &[u8], em_len: usize, s_len: usize, h_len: usize) -> Choice {
218 // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
219 // or if the octet at position emLen - hLen - sLen - 1 (the leftmost
220 // position is "position 1") does not have hexadecimal value 0x01,
221 // output "inconsistent" and stop.
222 let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
223 let valid: Choice = zeroes
224 .iter()
225 .fold(Choice::from(1u8), |a, e| a & e.ct_eq(&0x00));
226
227 valid & rest[0].ct_eq(&0x01)
228}
229
230pub(crate) fn emsa_pss_verify(
231 m_hash: &[u8],
232 em: &mut [u8],
233 s_len: usize,
234 hash: &mut dyn DynDigest,
235 key_bits: usize,
236) -> Result<()> {
237 let em_bits = key_bits - 1;
238 let em_len = (em_bits + 7) / 8;
239 let key_len = (key_bits + 7) / 8;
240 let h_len = hash.output_size();
241
242 let em = &mut em[key_len - em_len..];
243
244 let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
245
246 // 7. Let dbMask = MGF(H, em_len - h_len - 1)
247 //
248 // 8. Let DB = maskedDB \xor dbMask
249 mgf1_xor(db, hash, &*h);
250
251 // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
252 // to zero.
253 db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
254
255 let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
256
257 // 11. Let salt be the last s_len octets of DB.
258 let salt = &db[db.len() - s_len..];
259
260 // 12. Let
261 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
262 // M' is an octet string of length 8 + hLen + sLen with eight
263 // initial zero octets.
264 //
265 // 13. Let H' = Hash(M'), an octet string of length hLen.
266 let prefix = [0u8; 8];
267
268 hash.update(&prefix[..]);
269 hash.update(m_hash);
270 hash.update(salt);
271 let h0 = hash.finalize_reset();
272
273 // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
274 if (salt_valid & h0.ct_eq(h)).into() {
275 Ok(())
276 } else {
277 Err(Error::Verification)
278 }
279}
280
281pub(crate) fn emsa_pss_verify_digest<D>(
282 m_hash: &[u8],
283 em: &mut [u8],
284 s_len: usize,
285 key_bits: usize,
286) -> Result<()>
287where
288 D: Digest + FixedOutputReset,
289{
290 let em_bits = key_bits - 1;
291 let em_len = (em_bits + 7) / 8;
292 let key_len = (key_bits + 7) / 8;
293 let h_len = <D as Digest>::output_size();
294
295 let em = &mut em[key_len - em_len..];
296
297 let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
298
299 let mut hash = D::new();
300
301 // 7. Let dbMask = MGF(H, em_len - h_len - 1)
302 //
303 // 8. Let DB = maskedDB \xor dbMask
304 mgf1_xor_digest::<D>(db, &mut hash, &*h);
305
306 // 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
307 // to zero.
308 db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
309
310 let salt_valid = emsa_pss_verify_salt(db, em_len, s_len, h_len);
311
312 // 11. Let salt be the last s_len octets of DB.
313 let salt = &db[db.len() - s_len..];
314
315 // 12. Let
316 // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
317 // M' is an octet string of length 8 + hLen + sLen with eight
318 // initial zero octets.
319 //
320 // 13. Let H' = Hash(M'), an octet string of length hLen.
321 let prefix = [0u8; 8];
322
323 Digest::update(&mut hash, &prefix[..]);
324 Digest::update(&mut hash, m_hash);
325 Digest::update(&mut hash, salt);
326 let h0 = hash.finalize_reset();
327
328 // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
329 if (salt_valid & h0.ct_eq(h)).into() {
330 Ok(())
331 } else {
332 Err(Error::Verification)
333 }
334}