1use alloc::borrow::Cow;
4use alloc::vec::Vec;
5use num_bigint::{BigInt, BigUint, IntoBigInt, IntoBigUint, ModInverse, RandBigInt, ToBigInt};
6use num_integer::{sqrt, Integer};
7use num_traits::{FromPrimitive, One, Pow, Signed, Zero};
8use rand_core::CryptoRngCore;
9use zeroize::{Zeroize, Zeroizing};
10
11use crate::errors::{Error, Result};
12use crate::traits::{PrivateKeyParts, PublicKeyParts};
13
14#[inline]
21pub fn rsa_encrypt<K: PublicKeyParts>(key: &K, m: &BigUint) -> Result<BigUint> {
22 Ok(m.modpow(key.e(), key.n()))
23}
24
25#[inline]
34pub fn rsa_decrypt<R: CryptoRngCore + ?Sized>(
35 mut rng: Option<&mut R>,
36 priv_key: &impl PrivateKeyParts,
37 c: &BigUint,
38) -> Result<BigUint> {
39 if c >= priv_key.n() {
40 return Err(Error::Decryption);
41 }
42
43 if priv_key.n().is_zero() {
44 return Err(Error::Decryption);
45 }
46
47 let mut ir = None;
48
49 let c = if let Some(ref mut rng) = rng {
50 let (blinded, unblinder) = blind(rng, priv_key, c);
51 ir = Some(unblinder);
52 Cow::Owned(blinded)
53 } else {
54 Cow::Borrowed(c)
55 };
56
57 let dp = priv_key.dp();
58 let dq = priv_key.dq();
59 let qinv = priv_key.qinv();
60 let crt_values = priv_key.crt_values();
61
62 let m = match (dp, dq, qinv, crt_values) {
63 (Some(dp), Some(dq), Some(qinv), Some(crt_values)) => {
64 let p = &priv_key.primes()[0];
67 let q = &priv_key.primes()[1];
68
69 let mut m = c.modpow(dp, p).into_bigint().unwrap();
70 let mut m2 = c.modpow(dq, q).into_bigint().unwrap();
71
72 m -= &m2;
73
74 let mut primes: Vec<_> = priv_key
75 .primes()
76 .iter()
77 .map(ToBigInt::to_bigint)
78 .map(Option::unwrap)
79 .collect();
80
81 while m.is_negative() {
82 m += &primes[0];
83 }
84 m *= qinv;
85 m %= &primes[0];
86 m *= &primes[1];
87 m += &m2;
88
89 let mut c = c.into_owned().into_bigint().unwrap();
90 for (i, value) in crt_values.iter().enumerate() {
91 let prime = &primes[2 + i];
92 m2 = c.modpow(&value.exp, prime);
93 m2 -= &m;
94 m2 *= &value.coeff;
95 m2 %= prime;
96 while m2.is_negative() {
97 m2 += prime;
98 }
99 m2 *= &value.r;
100 m += &m2;
101 }
102
103 for prime in primes.iter_mut() {
105 prime.zeroize();
106 }
107 primes.clear();
108 c.zeroize();
109 m2.zeroize();
110
111 m.into_biguint().expect("failed to decrypt")
112 }
113 _ => c.modpow(priv_key.d(), priv_key.n()),
114 };
115
116 match ir {
117 Some(ref ir) => {
118 Ok(unblind(priv_key, &m, ir))
120 }
121 None => Ok(m),
122 }
123}
124
125#[inline]
135pub fn rsa_decrypt_and_check<R: CryptoRngCore + ?Sized>(
136 priv_key: &impl PrivateKeyParts,
137 rng: Option<&mut R>,
138 c: &BigUint,
139) -> Result<BigUint> {
140 let m = rsa_decrypt(rng, priv_key, c)?;
141
142 let check = rsa_encrypt(priv_key, &m)?;
145
146 if c != &check {
147 return Err(Error::Internal);
148 }
149
150 Ok(m)
151}
152
153fn blind<R: CryptoRngCore, K: PublicKeyParts>(
155 rng: &mut R,
156 key: &K,
157 c: &BigUint,
158) -> (BigUint, BigUint) {
159 let mut r: BigUint;
165 let mut ir: Option<BigInt>;
166 let unblinder;
167 loop {
168 r = rng.gen_biguint_below(key.n());
169 if r.is_zero() {
170 r = BigUint::one();
171 }
172 ir = r.clone().mod_inverse(key.n());
173 if let Some(ir) = ir {
174 if let Some(ub) = ir.into_biguint() {
175 unblinder = ub;
176 break;
177 }
178 }
179 }
180
181 let c = {
182 let mut rpowe = r.modpow(key.e(), key.n()); let mut c = c * &rpowe;
184 c %= key.n();
185
186 rpowe.zeroize();
187
188 c
189 };
190
191 (c, unblinder)
192}
193
194fn unblind(key: &impl PublicKeyParts, m: &BigUint, unblinder: &BigUint) -> BigUint {
196 (m * unblinder) % key.n()
197}
198
199pub fn recover_primes(n: &BigUint, e: &BigUint, d: &BigUint) -> Result<(BigUint, BigUint)> {
203 let two = BigUint::from_u8(2).unwrap();
205 if e <= &two.pow(16u32) || e >= &two.pow(256u32) {
206 return Err(Error::InvalidArguments);
207 }
208
209 let one = BigUint::one();
211 let a = Zeroizing::new((d * e - &one) * (n - &one).gcd(&(d * e - &one)));
212
213 let m = Zeroizing::new(&*a / n);
215 let r = Zeroizing::new(&*a - &*m * n);
216
217 let modulus_check = Zeroizing::new((n - &*r) % (&*m + &one));
220 if !modulus_check.is_zero() {
221 return Err(Error::InvalidArguments);
222 }
223 let b = Zeroizing::new((n - &*r) / (&*m + &one) + one);
224
225 let four = BigUint::from_u8(4).unwrap();
226 let four_n = Zeroizing::new(n * four);
227 let b_squared = Zeroizing::new(b.pow(2u32));
228 if *b_squared <= *four_n {
229 return Err(Error::InvalidArguments);
230 }
231 let b_squared_minus_four_n = Zeroizing::new(&*b_squared - &*four_n);
232
233 let y = Zeroizing::new(sqrt((*b_squared_minus_four_n).clone()));
236
237 let y_squared = Zeroizing::new(y.pow(2u32));
238 let sqrt_is_whole_number = y_squared == b_squared_minus_four_n;
239 if !sqrt_is_whole_number {
240 return Err(Error::InvalidArguments);
241 }
242 let p = (&*b + &*y) / &two;
243 let q = (&*b - &*y) / two;
244
245 Ok((p, q))
246}
247
248pub(crate) fn compute_modulus(primes: &[BigUint]) -> BigUint {
250 primes.iter().product()
251}
252
253#[inline]
256pub(crate) fn compute_private_exponent_euler_totient(
257 primes: &[BigUint],
258 exp: &BigUint,
259) -> Result<BigUint> {
260 if primes.len() < 2 {
261 return Err(Error::InvalidPrime);
262 }
263
264 let mut totient = BigUint::one();
265
266 for prime in primes {
267 totient *= prime - BigUint::one();
268 }
269
270 if let Some(d) = exp.mod_inverse(totient) {
273 Ok(d.to_biguint().unwrap())
274 } else {
275 Err(Error::InvalidPrime)
277 }
278}
279
280#[inline]
289pub(crate) fn compute_private_exponent_carmicheal(
290 p: &BigUint,
291 q: &BigUint,
292 exp: &BigUint,
293) -> Result<BigUint> {
294 let p1 = p - BigUint::one();
295 let q1 = q - BigUint::one();
296
297 let lcm = p1.lcm(&q1);
298 if let Some(d) = exp.mod_inverse(lcm) {
299 Ok(d.to_biguint().unwrap())
300 } else {
301 Err(Error::InvalidPrime)
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use num_traits::FromPrimitive;
309
310 use super::*;
311
312 #[test]
313 fn recover_primes_works() {
314 let n = BigUint::parse_bytes(b"00d397b84d98a4c26138ed1b695a8106ead91d553bf06041b62d3fdc50a041e222b8f4529689c1b82c5e71554f5dd69fa2f4b6158cf0dbeb57811a0fc327e1f28e74fe74d3bc166c1eabdc1b8b57b934ca8be5b00b4f29975bcc99acaf415b59bb28a6782bb41a2c3c2976b3c18dbadef62f00c6bb226640095096c0cc60d22fe7ef987d75c6a81b10d96bf292028af110dc7cc1bbc43d22adab379a0cd5d8078cc780ff5cd6209dea34c922cf784f7717e428d75b5aec8ff30e5f0141510766e2e0ab8d473c84e8710b2b98227c3db095337ad3452f19e2b9bfbccdd8148abf6776fa552775e6e75956e45229ae5a9c46949bab1e622f0e48f56524a84ed3483b", 16).unwrap();
315 let e = BigUint::from_u64(65537).unwrap();
316 let d = BigUint::parse_bytes(b"00c4e70c689162c94c660828191b52b4d8392115df486a9adbe831e458d73958320dc1b755456e93701e9702d76fb0b92f90e01d1fe248153281fe79aa9763a92fae69d8d7ecd144de29fa135bd14f9573e349e45031e3b76982f583003826c552e89a397c1a06bd2163488630d92e8c2bb643d7abef700da95d685c941489a46f54b5316f62b5d2c3a7f1bbd134cb37353a44683fdc9d95d36458de22f6c44057fe74a0a436c4308f73f4da42f35c47ac16a7138d483afc91e41dc3a1127382e0c0f5119b0221b4fc639d6b9c38177a6de9b526ebd88c38d7982c07f98a0efd877d508aae275b946915c02e2e1106d175d74ec6777f5e80d12c053d9c7be1e341", 16).unwrap();
317 let p = BigUint::parse_bytes(b"00f827bbf3a41877c7cc59aebf42ed4b29c32defcb8ed96863d5b090a05a8930dd624a21c9dcf9838568fdfa0df65b8462a5f2ac913d6c56f975532bd8e78fb07bd405ca99a484bcf59f019bbddcb3933f2bce706300b4f7b110120c5df9018159067c35da3061a56c8635a52b54273b31271b4311f0795df6021e6355e1a42e61",16).unwrap();
318 let q = BigUint::parse_bytes(b"00da4817ce0089dd36f2ade6a3ff410c73ec34bf1b4f6bda38431bfede11cef1f7f6efa70e5f8063a3b1f6e17296ffb15feefa0912a0325b8d1fd65a559e717b5b961ec345072e0ec5203d03441d29af4d64054a04507410cf1da78e7b6119d909ec66e6ad625bf995b279a4b3c5be7d895cd7c5b9c4c497fde730916fcdb4e41b", 16).unwrap();
319
320 let (mut p1, mut q1) = recover_primes(&n, &e, &d).unwrap();
321
322 if p1 < q1 {
323 std::mem::swap(&mut p1, &mut q1);
324 }
325 assert_eq!(p, p1);
326 assert_eq!(q, q1);
327 }
328}