crypto_bigint/uint/
mul.rs

1//! [`Uint`] addition operations.
2
3use crate::{Checked, CheckedMul, Concat, ConcatMixed, Limb, Uint, WideWord, Word, Wrapping, Zero};
4use core::ops::{Mul, MulAssign};
5use subtle::CtOption;
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8    /// Multiply `self` by `rhs`, returning a concatenated "wide" result.
9    pub fn mul<const HLIMBS: usize>(
10        &self,
11        rhs: &Uint<HLIMBS>,
12    ) -> <Uint<HLIMBS> as ConcatMixed<Self>>::MixedOutput
13    where
14        Uint<HLIMBS>: ConcatMixed<Self>,
15    {
16        let (lo, hi) = self.mul_wide(rhs);
17        hi.concat_mixed(&lo)
18    }
19
20    /// Compute "wide" multiplication, with a product twice the size of the input.
21    ///
22    /// Returns a tuple containing the `(lo, hi)` components of the product.
23    ///
24    /// # Ordering note
25    ///
26    /// Releases of `crypto-bigint` prior to v0.3 used `(hi, lo)` ordering
27    /// instead. This has been changed for better consistency with the rest of
28    /// the APIs in this crate.
29    ///
30    /// For more info see: <https://github.com/RustCrypto/crypto-bigint/issues/4>
31    pub const fn mul_wide<const HLIMBS: usize>(&self, rhs: &Uint<HLIMBS>) -> (Self, Uint<HLIMBS>) {
32        let mut i = 0;
33        let mut lo = Self::ZERO;
34        let mut hi = Uint::<HLIMBS>::ZERO;
35
36        // Schoolbook multiplication.
37        // TODO(tarcieri): use Karatsuba for better performance?
38        while i < LIMBS {
39            let mut j = 0;
40            let mut carry = Limb::ZERO;
41
42            while j < HLIMBS {
43                let k = i + j;
44
45                if k >= LIMBS {
46                    let (n, c) = hi.limbs[k - LIMBS].mac(self.limbs[i], rhs.limbs[j], carry);
47                    hi.limbs[k - LIMBS] = n;
48                    carry = c;
49                } else {
50                    let (n, c) = lo.limbs[k].mac(self.limbs[i], rhs.limbs[j], carry);
51                    lo.limbs[k] = n;
52                    carry = c;
53                }
54
55                j += 1;
56            }
57
58            if i + j >= LIMBS {
59                hi.limbs[i + j - LIMBS] = carry;
60            } else {
61                lo.limbs[i + j] = carry;
62            }
63            i += 1;
64        }
65
66        (lo, hi)
67    }
68
69    /// Perform saturating multiplication, returning `MAX` on overflow.
70    pub const fn saturating_mul<const HLIMBS: usize>(&self, rhs: &Uint<HLIMBS>) -> Self {
71        let (res, overflow) = self.mul_wide(rhs);
72        Self::ct_select(&res, &Self::MAX, overflow.ct_is_nonzero())
73    }
74
75    /// Perform wrapping multiplication, discarding overflow.
76    pub const fn wrapping_mul<const H: usize>(&self, rhs: &Uint<H>) -> Self {
77        self.mul_wide(rhs).0
78    }
79
80    /// Square self, returning a concatenated "wide" result.
81    pub fn square(&self) -> <Self as Concat>::Output
82    where
83        Self: Concat,
84    {
85        let (lo, hi) = self.square_wide();
86        hi.concat(&lo)
87    }
88
89    /// Square self, returning a "wide" result in two parts as (lo, hi).
90    pub const fn square_wide(&self) -> (Self, Self) {
91        // Translated from https://github.com/ucbrise/jedi-pairing/blob/c4bf151/include/core/bigint.hpp#L410
92        //
93        // Permission to relicense the resulting translation as Apache 2.0 + MIT was given
94        // by the original author Sam Kumar: https://github.com/RustCrypto/crypto-bigint/pull/133#discussion_r1056870411
95        let mut lo = Self::ZERO;
96        let mut hi = Self::ZERO;
97
98        // Schoolbook multiplication, but only considering half of the multiplication grid
99        let mut i = 1;
100        while i < LIMBS {
101            let mut j = 0;
102            let mut carry = Limb::ZERO;
103
104            while j < i {
105                let k = i + j;
106
107                if k >= LIMBS {
108                    let (n, c) = hi.limbs[k - LIMBS].mac(self.limbs[i], self.limbs[j], carry);
109                    hi.limbs[k - LIMBS] = n;
110                    carry = c;
111                } else {
112                    let (n, c) = lo.limbs[k].mac(self.limbs[i], self.limbs[j], carry);
113                    lo.limbs[k] = n;
114                    carry = c;
115                }
116
117                j += 1;
118            }
119
120            if (2 * i) < LIMBS {
121                lo.limbs[2 * i] = carry;
122            } else {
123                hi.limbs[2 * i - LIMBS] = carry;
124            }
125
126            i += 1;
127        }
128
129        // Double the current result, this accounts for the other half of the multiplication grid.
130        // TODO: The top word is empty so we can also use a special purpose shl.
131        (lo, hi) = Self::shl_vartime_wide((lo, hi), 1);
132
133        // Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
134        let mut carry = Limb::ZERO;
135        let mut i = 0;
136        while i < LIMBS {
137            if (i * 2) < LIMBS {
138                let (n, c) = lo.limbs[i * 2].mac(self.limbs[i], self.limbs[i], carry);
139                lo.limbs[i * 2] = n;
140                carry = c;
141            } else {
142                let (n, c) = hi.limbs[i * 2 - LIMBS].mac(self.limbs[i], self.limbs[i], carry);
143                hi.limbs[i * 2 - LIMBS] = n;
144                carry = c;
145            }
146
147            if (i * 2 + 1) < LIMBS {
148                let n = lo.limbs[i * 2 + 1].0 as WideWord + carry.0 as WideWord;
149                lo.limbs[i * 2 + 1] = Limb(n as Word);
150                carry = Limb((n >> Word::BITS) as Word);
151            } else {
152                let n = hi.limbs[i * 2 + 1 - LIMBS].0 as WideWord + carry.0 as WideWord;
153                hi.limbs[i * 2 + 1 - LIMBS] = Limb(n as Word);
154                carry = Limb((n >> Word::BITS) as Word);
155            }
156
157            i += 1;
158        }
159
160        (lo, hi)
161    }
162}
163
164impl<const LIMBS: usize, const HLIMBS: usize> CheckedMul<&Uint<HLIMBS>> for Uint<LIMBS> {
165    type Output = Self;
166
167    fn checked_mul(&self, rhs: &Uint<HLIMBS>) -> CtOption<Self> {
168        let (lo, hi) = self.mul_wide(rhs);
169        CtOption::new(lo, hi.is_zero())
170    }
171}
172
173impl<const LIMBS: usize, const HLIMBS: usize> Mul<Wrapping<Uint<HLIMBS>>>
174    for Wrapping<Uint<LIMBS>>
175{
176    type Output = Self;
177
178    fn mul(self, rhs: Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
179        Wrapping(self.0.wrapping_mul(&rhs.0))
180    }
181}
182
183impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Wrapping<Uint<HLIMBS>>>
184    for Wrapping<Uint<LIMBS>>
185{
186    type Output = Self;
187
188    fn mul(self, rhs: &Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
189        Wrapping(self.0.wrapping_mul(&rhs.0))
190    }
191}
192
193impl<const LIMBS: usize, const HLIMBS: usize> Mul<Wrapping<Uint<HLIMBS>>>
194    for &Wrapping<Uint<LIMBS>>
195{
196    type Output = Wrapping<Uint<LIMBS>>;
197
198    fn mul(self, rhs: Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
199        Wrapping(self.0.wrapping_mul(&rhs.0))
200    }
201}
202
203impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Wrapping<Uint<HLIMBS>>>
204    for &Wrapping<Uint<LIMBS>>
205{
206    type Output = Wrapping<Uint<LIMBS>>;
207
208    fn mul(self, rhs: &Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
209        Wrapping(self.0.wrapping_mul(&rhs.0))
210    }
211}
212
213impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<Wrapping<Uint<HLIMBS>>>
214    for Wrapping<Uint<LIMBS>>
215{
216    fn mul_assign(&mut self, other: Wrapping<Uint<HLIMBS>>) {
217        *self = *self * other;
218    }
219}
220
221impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<&Wrapping<Uint<HLIMBS>>>
222    for Wrapping<Uint<LIMBS>>
223{
224    fn mul_assign(&mut self, other: &Wrapping<Uint<HLIMBS>>) {
225        *self = *self * other;
226    }
227}
228
229impl<const LIMBS: usize, const HLIMBS: usize> Mul<Checked<Uint<HLIMBS>>> for Checked<Uint<LIMBS>> {
230    type Output = Self;
231
232    fn mul(self, rhs: Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
233        Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
234    }
235}
236
237impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Checked<Uint<HLIMBS>>> for Checked<Uint<LIMBS>> {
238    type Output = Checked<Uint<LIMBS>>;
239
240    fn mul(self, rhs: &Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
241        Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
242    }
243}
244
245impl<const LIMBS: usize, const HLIMBS: usize> Mul<Checked<Uint<HLIMBS>>> for &Checked<Uint<LIMBS>> {
246    type Output = Checked<Uint<LIMBS>>;
247
248    fn mul(self, rhs: Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
249        Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
250    }
251}
252
253impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Checked<Uint<HLIMBS>>>
254    for &Checked<Uint<LIMBS>>
255{
256    type Output = Checked<Uint<LIMBS>>;
257
258    fn mul(self, rhs: &Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
259        Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
260    }
261}
262
263impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<Checked<Uint<HLIMBS>>>
264    for Checked<Uint<LIMBS>>
265{
266    fn mul_assign(&mut self, other: Checked<Uint<HLIMBS>>) {
267        *self = *self * other;
268    }
269}
270
271impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<&Checked<Uint<HLIMBS>>>
272    for Checked<Uint<LIMBS>>
273{
274    fn mul_assign(&mut self, other: &Checked<Uint<HLIMBS>>) {
275        *self = *self * other;
276    }
277}
278
279impl<const LIMBS: usize, const HLIMBS: usize> Mul<Uint<HLIMBS>> for Uint<LIMBS>
280where
281    Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
282{
283    type Output = <Uint<HLIMBS> as ConcatMixed<Self>>::MixedOutput;
284
285    fn mul(self, other: Uint<HLIMBS>) -> Self::Output {
286        Uint::mul(&self, &other)
287    }
288}
289
290impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Uint<HLIMBS>> for Uint<LIMBS>
291where
292    Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
293{
294    type Output = <Uint<HLIMBS> as ConcatMixed<Self>>::MixedOutput;
295
296    fn mul(self, other: &Uint<HLIMBS>) -> Self::Output {
297        Uint::mul(&self, other)
298    }
299}
300
301impl<const LIMBS: usize, const HLIMBS: usize> Mul<Uint<HLIMBS>> for &Uint<LIMBS>
302where
303    Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
304{
305    type Output = <Uint<HLIMBS> as ConcatMixed<Uint<LIMBS>>>::MixedOutput;
306
307    fn mul(self, other: Uint<HLIMBS>) -> Self::Output {
308        Uint::mul(self, &other)
309    }
310}
311
312impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Uint<HLIMBS>> for &Uint<LIMBS>
313where
314    Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
315{
316    type Output = <Uint<HLIMBS> as ConcatMixed<Uint<LIMBS>>>::MixedOutput;
317
318    fn mul(self, other: &Uint<HLIMBS>) -> Self::Output {
319        Uint::mul(self, other)
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use crate::{CheckedMul, Zero, U128, U192, U256, U64};
326
327    #[test]
328    fn mul_wide_zero_and_one() {
329        assert_eq!(U64::ZERO.mul_wide(&U64::ZERO), (U64::ZERO, U64::ZERO));
330        assert_eq!(U64::ZERO.mul_wide(&U64::ONE), (U64::ZERO, U64::ZERO));
331        assert_eq!(U64::ONE.mul_wide(&U64::ZERO), (U64::ZERO, U64::ZERO));
332        assert_eq!(U64::ONE.mul_wide(&U64::ONE), (U64::ONE, U64::ZERO));
333    }
334
335    #[test]
336    fn mul_wide_lo_only() {
337        let primes: &[u32] = &[3, 5, 17, 257, 65537];
338
339        for &a_int in primes {
340            for &b_int in primes {
341                let (lo, hi) = U64::from_u32(a_int).mul_wide(&U64::from_u32(b_int));
342                let expected = U64::from_u64(a_int as u64 * b_int as u64);
343                assert_eq!(lo, expected);
344                assert!(bool::from(hi.is_zero()));
345            }
346        }
347    }
348
349    #[test]
350    fn mul_concat_even() {
351        assert_eq!(U64::ZERO * U64::MAX, U128::ZERO);
352        assert_eq!(U64::MAX * U64::ZERO, U128::ZERO);
353        assert_eq!(
354            U64::MAX * U64::MAX,
355            U128::from_u128(0xfffffffffffffffe_0000000000000001)
356        );
357        assert_eq!(
358            U64::ONE * U64::MAX,
359            U128::from_u128(0x0000000000000000_ffffffffffffffff)
360        );
361    }
362
363    #[test]
364    fn mul_concat_mixed() {
365        let a = U64::from_u64(0x0011223344556677);
366        let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff);
367        assert_eq!(a * b, U192::from(&a).saturating_mul(&b));
368        assert_eq!(b * a, U192::from(&b).saturating_mul(&a));
369    }
370
371    #[test]
372    fn checked_mul_ok() {
373        let n = U64::from_u32(0xffff_ffff);
374        assert_eq!(
375            n.checked_mul(&n).unwrap(),
376            U64::from_u64(0xffff_fffe_0000_0001)
377        );
378    }
379
380    #[test]
381    fn checked_mul_overflow() {
382        let n = U64::from_u64(0xffff_ffff_ffff_ffff);
383        assert!(bool::from(n.checked_mul(&n).is_none()));
384    }
385
386    #[test]
387    fn saturating_mul_no_overflow() {
388        let n = U64::from_u8(8);
389        assert_eq!(n.saturating_mul(&n), U64::from_u8(64));
390    }
391
392    #[test]
393    fn saturating_mul_overflow() {
394        let a = U64::from(0xffff_ffff_ffff_ffffu64);
395        let b = U64::from(2u8);
396        assert_eq!(a.saturating_mul(&b), U64::MAX);
397    }
398
399    #[test]
400    fn square() {
401        let n = U64::from_u64(0xffff_ffff_ffff_ffff);
402        let (hi, lo) = n.square().split();
403        assert_eq!(lo, U64::from_u64(1));
404        assert_eq!(hi, U64::from_u64(0xffff_ffff_ffff_fffe));
405    }
406
407    #[test]
408    fn square_larger() {
409        let n = U256::MAX;
410        let (hi, lo) = n.square().split();
411        assert_eq!(lo, U256::ONE);
412        assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE));
413    }
414}