crypto_bigint/uint/
inv_mod.rs

1use super::Uint;
2use crate::CtChoice;
3
4impl<const LIMBS: usize> Uint<LIMBS> {
5    /// Computes 1/`self` mod `2^k`.
6    /// This method is constant-time w.r.t. `self` but not `k`.
7    ///
8    /// Conditions: `self` < 2^k and `self` must be odd
9    pub const fn inv_mod2k_vartime(&self, k: usize) -> Self {
10        // Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k"
11        // by Sadiel de la Fe and Carles Ferrer.
12        // See <https://www.mdpi.com/2410-387X/2/3/23>.
13
14        // Note that we are not using Alrgorithm 4, since we have a different approach
15        // of enforcing constant-timeness w.r.t. `self`.
16
17        let mut x = Self::ZERO; // keeps `x` during iterations
18        let mut b = Self::ONE; // keeps `b_i` during iterations
19        let mut i = 0;
20
21        while i < k {
22            // X_i = b_i mod 2
23            let x_i = b.limbs[0].0 & 1;
24            let x_i_choice = CtChoice::from_lsb(x_i);
25            // b_{i+1} = (b_i - a * X_i) / 2
26            b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);
27            // Store the X_i bit in the result (x = x | (1 << X_i))
28            x = x.bitor(&Uint::from_word(x_i).shl_vartime(i));
29
30            i += 1;
31        }
32
33        x
34    }
35
36    /// Computes 1/`self` mod `2^k`.
37    ///
38    /// Conditions: `self` < 2^k and `self` must be odd
39    pub const fn inv_mod2k(&self, k: usize) -> Self {
40        // This is the same algorithm as in `inv_mod2k_vartime()`,
41        // but made constant-time w.r.t `k` as well.
42
43        let mut x = Self::ZERO; // keeps `x` during iterations
44        let mut b = Self::ONE; // keeps `b_i` during iterations
45        let mut i = 0;
46
47        while i < Self::BITS {
48            // Only iterations for i = 0..k need to change `x`,
49            // the rest are dummy ones performed for the sake of constant-timeness.
50            let within_range = CtChoice::from_usize_lt(i, k);
51
52            // X_i = b_i mod 2
53            let x_i = b.limbs[0].0 & 1;
54            let x_i_choice = CtChoice::from_lsb(x_i);
55            // b_{i+1} = (b_i - a * X_i) / 2
56            b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);
57
58            // Store the X_i bit in the result (x = x | (1 << X_i))
59            // Don't change the result in dummy iterations.
60            let x_i_choice = x_i_choice.and(within_range);
61            x = x.set_bit(i, x_i_choice);
62
63            i += 1;
64        }
65
66        x
67    }
68
69    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
70    /// In other words `self^-1 mod modulus`.
71    /// `bits` and `modulus_bits` are the bounds on the bit size
72    /// of `self` and `modulus`, respectively
73    /// (the inversion speed will be proportional to `bits + modulus_bits`).
74    /// The second element of the tuple is the truthy value if an inverse exists,
75    /// otherwise it is a falsy value.
76    ///
77    /// **Note:** variable time in `bits` and `modulus_bits`.
78    ///
79    /// The algorithm is the same as in GMP 6.2.1's `mpn_sec_invert`.
80    pub const fn inv_odd_mod_bounded(
81        &self,
82        modulus: &Self,
83        bits: usize,
84        modulus_bits: usize,
85    ) -> (Self, CtChoice) {
86        debug_assert!(modulus.ct_is_odd().is_true_vartime());
87
88        let mut a = *self;
89
90        let mut u = Uint::ONE;
91        let mut v = Uint::ZERO;
92
93        let mut b = *modulus;
94
95        // `bit_size` can be anything >= `self.bits()` + `modulus.bits()`, setting to the minimum.
96        let bit_size = bits + modulus_bits;
97
98        let mut m1hp = *modulus;
99        let (m1hp_new, carry) = m1hp.shr_1();
100        debug_assert!(carry.is_true_vartime());
101        m1hp = m1hp_new.wrapping_add(&Uint::ONE);
102
103        let mut i = 0;
104        while i < bit_size {
105            debug_assert!(b.ct_is_odd().is_true_vartime());
106
107            let self_odd = a.ct_is_odd();
108
109            // Set `self -= b` if `self` is odd.
110            let (new_a, swap) = a.conditional_wrapping_sub(&b, self_odd);
111            // Set `b += self` if `swap` is true.
112            b = Uint::ct_select(&b, &b.wrapping_add(&new_a), swap);
113            // Negate `self` if `swap` is true.
114            a = new_a.conditional_wrapping_neg(swap);
115
116            let (new_u, new_v) = Uint::ct_swap(&u, &v, swap);
117            let (new_u, cy) = new_u.conditional_wrapping_sub(&new_v, self_odd);
118            let (new_u, cyy) = new_u.conditional_wrapping_add(modulus, cy);
119            debug_assert!(cy.is_true_vartime() == cyy.is_true_vartime());
120
121            let (new_a, overflow) = a.shr_1();
122            debug_assert!(!overflow.is_true_vartime());
123            let (new_u, cy) = new_u.shr_1();
124            let (new_u, cy) = new_u.conditional_wrapping_add(&m1hp, cy);
125            debug_assert!(!cy.is_true_vartime());
126
127            a = new_a;
128            u = new_u;
129            v = new_v;
130
131            i += 1;
132        }
133
134        debug_assert!(!a.ct_is_nonzero().is_true_vartime());
135
136        (v, Uint::ct_eq(&b, &Uint::ONE))
137    }
138
139    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
140    /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
141    /// otherwise `(undefined, CtChoice::FALSE)`.
142    pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, CtChoice) {
143        self.inv_odd_mod_bounded(modulus, Uint::<LIMBS>::BITS, Uint::<LIMBS>::BITS)
144    }
145
146    /// Computes the multiplicative inverse of `self` mod `modulus`.
147    /// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
148    /// otherwise `(undefined, CtChoice::FALSE)`.
149    pub const fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) {
150        // Decompose `modulus = s * 2^k` where `s` is odd
151        let k = modulus.trailing_zeros();
152        let s = modulus.shr(k);
153
154        // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
155        // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
156        let (a, a_is_some) = self.inv_odd_mod(&s);
157        let b = self.inv_mod2k(k);
158        // inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
159        let b_is_some = CtChoice::from_usize_being_nonzero(k)
160            .not()
161            .or(self.ct_is_odd());
162
163        // Restore from RNS:
164        // self^{-1} = a mod s = b mod 2^k
165        // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
166        // (essentially one step of the Garner's algorithm for recovery from RNS).
167
168        let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists
169
170        // This part is mod 2^k
171        let mask = Uint::ONE.shl(k).wrapping_sub(&Uint::ONE);
172        let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask);
173
174        // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
175        // so `a + s * t <= s * 2^k - 1 == modulus - 1`.
176        let result = a.wrapping_add(&s.wrapping_mul(&t));
177        (result, a_is_some.and(b_is_some))
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use crate::{U1024, U256, U64};
184
185    #[test]
186    fn inv_mod2k() {
187        let v =
188            U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f");
189        let e =
190            U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf");
191        let a = v.inv_mod2k(256);
192        assert_eq!(e, a);
193
194        let v =
195            U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
196        let e =
197            U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1");
198        let a = v.inv_mod2k(256);
199        assert_eq!(e, a);
200    }
201
202    #[test]
203    fn test_invert_odd() {
204        let a = U1024::from_be_hex(concat![
205            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
206            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
207            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
208            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
209        ]);
210        let m = U1024::from_be_hex(concat![
211            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
212            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
213            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
214            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
215        ]);
216        let expected = U1024::from_be_hex(concat![
217            "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
218            "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
219            "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
220            "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
221        ]);
222
223        let (res, is_some) = a.inv_odd_mod(&m);
224        assert!(is_some.is_true_vartime());
225        assert_eq!(res, expected);
226
227        // Even though it is less efficient, it still works
228        let (res, is_some) = a.inv_mod(&m);
229        assert!(is_some.is_true_vartime());
230        assert_eq!(res, expected);
231    }
232
233    #[test]
234    fn test_invert_even() {
235        let a = U1024::from_be_hex(concat![
236            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
237            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
238            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
239            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
240        ]);
241        let m = U1024::from_be_hex(concat![
242            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
243            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
244            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
245            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
246        ]);
247        let expected = U1024::from_be_hex(concat![
248            "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
249            "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
250            "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
251            "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
252        ]);
253
254        let (res, is_some) = a.inv_mod(&m);
255        assert!(is_some.is_true_vartime());
256        assert_eq!(res, expected);
257    }
258
259    #[test]
260    fn test_invert_bounded() {
261        let a = U1024::from_be_hex(concat![
262            "0000000000000000000000000000000000000000000000000000000000000000",
263            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
264            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
265            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
266        ]);
267        let m = U1024::from_be_hex(concat![
268            "0000000000000000000000000000000000000000000000000000000000000000",
269            "0000000000000000000000000000000000000000000000000000000000000000",
270            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
271            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
272        ]);
273
274        let (res, is_some) = a.inv_odd_mod_bounded(&m, 768, 512);
275
276        let expected = U1024::from_be_hex(concat![
277            "0000000000000000000000000000000000000000000000000000000000000000",
278            "0000000000000000000000000000000000000000000000000000000000000000",
279            "0DCC94E2FE509E6EBBA0825645A38E73EF85D5927C79C1AD8FFE7C8DF9A822FA",
280            "09EB396A21B1EF05CBE51E1A8EF284EF01EBDD36A9A4EA17039D8EEFDD934768"
281        ]);
282        assert!(is_some.is_true_vartime());
283        assert_eq!(res, expected);
284    }
285
286    #[test]
287    fn test_invert_small() {
288        let a = U64::from(3u64);
289        let m = U64::from(13u64);
290
291        let (res, is_some) = a.inv_odd_mod(&m);
292
293        assert!(is_some.is_true_vartime());
294        assert_eq!(U64::from(9u64), res);
295    }
296
297    #[test]
298    fn test_no_inverse_small() {
299        let a = U64::from(14u64);
300        let m = U64::from(49u64);
301
302        let (_res, is_some) = a.inv_odd_mod(&m);
303
304        assert!(!is_some.is_true_vartime());
305    }
306}