crypto_bigint/uint/mul/
karatsuba.rs

1//! Karatsuba multiplication
2//!
3//! This is a method which reduces the complexity of multiplication from O(n^2) to O(n^1.585).
4//! For smaller numbers, it is best to stick to schoolbook multiplication, taking advantage
5//! of better cache locality and avoiding recursion.
6//!
7//! In general, we consider the multiplication of two numbers of an equal size, `n` bits.
8//! Setting b = 2^(n/2), then we can decompose the values:
9//!   x•y = (x0 + x1•b)(y0 + y1•b)
10//!
11//! This equation is equivalent to a linear combination of three products of size `n/2`, which
12//! may each be reduced by applying the same optimization.
13//! Setting z0 = x0•y0, z1 = (x0-x1)(y1-y0), z2 = x1•y1:
14//!   x•y = z0 + (z0 - z1 + z2)•b + z2•b^2
15//!
16//! Considering each sub-product as a tuple of integers `(lo, hi)`, the product is calculated as
17//! follows (with appropriate carries):
18//!   [z0.0, z0.0 + z0.1 - z1.0 + z2.0, z0.1 - z1.1 + z2.0 + z2.1, z2.1]
19//!
20
21use super::{uint_mul_limbs, uint_square_limbs};
22use crate::{ConstChoice, Limb, Uint};
23
24#[cfg(feature = "alloc")]
25use super::square_limbs;
26#[cfg(feature = "alloc")]
27use crate::{WideWord, Word};
28
29#[cfg(feature = "alloc")]
30pub const KARATSUBA_MIN_STARTING_LIMBS: usize = 32;
31#[cfg(feature = "alloc")]
32pub const KARATSUBA_MAX_REDUCE_LIMBS: usize = 24;
33
34/// A helper struct for performing Karatsuba multiplication on Uints.
35pub(crate) struct UintKaratsubaMul<const LIMBS: usize>;
36
37macro_rules! impl_uint_karatsuba_multiplication {
38    // TODO: revisit when `const_mut_refs` is stable
39    (reduce $full_size:expr, $half_size:expr) => {
40        impl UintKaratsubaMul<$full_size> {
41            pub(crate) const fn multiply(
42                lhs: &[Limb],
43                rhs: &[Limb],
44            ) -> (Uint<$full_size>, Uint<$full_size>) {
45                let (x0, x1) = lhs.split_at($half_size);
46                let (y0, y1) = rhs.split_at($half_size);
47
48                // Calculate z1 = (x0 - x1)(y1 - y0)
49                let mut l0 = Uint::<$half_size>::ZERO;
50                let mut l1 = Uint::<$half_size>::ZERO;
51                let mut l0b = Limb::ZERO;
52                let mut l1b = Limb::ZERO;
53                let mut i = 0;
54                while i < $half_size {
55                    (l0.limbs[i], l0b) = x0[i].sbb(x1[i], l0b);
56                    (l1.limbs[i], l1b) = y1[i].sbb(y0[i], l1b);
57                    i += 1;
58                }
59                l0 = Uint::select(
60                    &l0,
61                    &l0.wrapping_neg(),
62                    ConstChoice::from_word_mask(l0b.0),
63                );
64                l1 = Uint::select(
65                    &l1,
66                    &l1.wrapping_neg(),
67                    ConstChoice::from_word_mask(l1b.0),
68                );
69                let z1 = UintKaratsubaMul::<$half_size>::multiply(&l0.limbs, &l1.limbs);
70                let z1_neg = ConstChoice::from_word_mask(l0b.0)
71                    .xor(ConstChoice::from_word_mask(l1b.0));
72
73                // Conditionally add or subtract z1•b depending on its sign
74                let mut res = (Uint::ZERO, z1.0, z1.1, Uint::ZERO);
75                res.0 = Uint::select(&res.0, &res.0.not(), z1_neg);
76                res.1 = Uint::select(&res.1, &res.1.not(), z1_neg);
77                res.2 = Uint::select(&res.2, &res.2.not(), z1_neg);
78                res.3 = Uint::select(&res.3, &res.3.not(), z1_neg);
79
80                // Calculate z0 = x0•y0
81                let z0 = UintKaratsubaMul::<$half_size>::multiply(&x0, &y0);
82                // Calculate z2 = x1•y1
83                let z2 = UintKaratsubaMul::<$half_size>::multiply(&x1, &y1);
84
85                // Add z0 + (z0 + z2)•b + z2•b^2
86                let mut carry = Limb::select(Limb::ZERO, Limb::ONE, z1_neg);
87                (res.0, carry) = res.0.adc(&z0.0, carry);
88                (res.1, carry) = res.1.adc(&z0.1, carry);
89                let mut carry2;
90                (res.1, carry2) = res.1.adc(&z0.0, Limb::ZERO);
91                (res.2, carry) = res.2.adc(&z0.1, carry.wrapping_add(carry2));
92                (res.1, carry2) = res.1.adc(&z2.0, Limb::ZERO);
93                (res.2, carry2) = res.2.adc(&z2.1, carry2);
94                carry = carry.wrapping_add(carry2);
95                (res.2, carry2) = res.2.adc(&z2.0, Limb::ZERO);
96                (res.3, _) = res.3.adc(&z2.1, carry.wrapping_add(carry2));
97
98                (res.0.concat(&res.1), res.2.concat(&res.3))
99            }
100        }
101    };
102    ($small_size:expr) => {
103        impl UintKaratsubaMul<$small_size> {
104            #[inline]
105            pub(crate) const fn multiply(lhs: &[Limb], rhs: &[Limb]) -> (Uint<$small_size>, Uint<$small_size>) {
106                uint_mul_limbs(lhs, rhs)
107            }
108        }
109    };
110    ($full_size:tt, $half_size:tt $(,$rest:tt)*) => {
111        impl_uint_karatsuba_multiplication!{reduce $full_size, $half_size}
112        impl_uint_karatsuba_multiplication!{$half_size $(,$rest)*}
113    }
114}
115
116macro_rules! impl_uint_karatsuba_squaring {
117    (reduce $full_size:expr, $half_size:expr) => {
118        impl UintKaratsubaMul<$full_size> {
119            pub(crate) const fn square(limbs: &[Limb]) -> (Uint<$full_size>, Uint<$full_size>) {
120                let (x0, x1) = limbs.split_at($half_size);
121                let z0 = UintKaratsubaMul::<$half_size>::square(&x0);
122                let z2 = UintKaratsubaMul::<$half_size>::square(&x1);
123
124                // Calculate z0 + (z0 + z2)•b + z2•b^2
125                let mut res = (z0.0, z0.1, Uint::<$half_size>::ZERO, Uint::<$half_size>::ZERO);
126                let mut carry;
127                (res.1, carry) = res.1.adc(&z0.0, Limb::ZERO);
128                (res.2, carry) = z0.1.adc(&z2.0, carry);
129                let mut carry2;
130                (res.1, carry2) = res.1.adc(&z2.0, Limb::ZERO);
131                (res.2, carry2) = res.2.adc(&z2.1, carry2);
132                (res.3, _) = z2.1.adc(&Uint::ZERO, carry.wrapping_add(carry2));
133
134                // Calculate z1 = (x0 - x1)^2
135                let mut l0 = Uint::<$half_size>::ZERO;
136                let mut l0b = Limb::ZERO;
137                let mut i = 0;
138                while i < $half_size {
139                    (l0.limbs[i], l0b) = x0[i].sbb(x1[i], l0b);
140                    i += 1;
141                }
142                l0 = Uint::select(
143                    &l0,
144                    &l0.wrapping_neg(),
145                    ConstChoice::from_word_mask(l0b.0),
146                );
147
148                let z1 = UintKaratsubaMul::<$half_size>::square(&l0.limbs);
149
150                // Subtract z1•b
151                carry = Limb::ZERO;
152                (res.1, carry) = res.1.sbb(&z1.0, carry);
153                (res.2, carry) = res.2.sbb(&z1.1, carry);
154                (res.3, _) = res.3.sbb(&Uint::ZERO, carry);
155
156                (res.0.concat(&res.1), res.2.concat(&res.3))
157            }
158        }
159    };
160    ($small_size:expr) => {
161        impl UintKaratsubaMul<$small_size> {
162            #[inline]
163            pub(crate) const fn square(limbs: &[Limb]) -> (Uint<$small_size>, Uint<$small_size>) {
164                uint_square_limbs(limbs)
165            }
166        }
167    };
168    ($full_size:tt, $half_size:tt $(,$rest:tt)*) => {
169        impl_uint_karatsuba_squaring!{reduce $full_size, $half_size}
170        impl_uint_karatsuba_squaring!{$half_size $(,$rest)*}
171    }
172}
173
174#[cfg(feature = "alloc")]
175#[inline(never)]
176pub(crate) fn karatsuba_mul_limbs(
177    lhs: &[Limb],
178    rhs: &[Limb],
179    out: &mut [Limb],
180    scratch: &mut [Limb],
181) {
182    let size = {
183        let overlap = lhs.len().min(rhs.len());
184        if (overlap & 1) == 1 {
185            overlap.saturating_sub(1)
186        } else {
187            overlap
188        }
189    };
190    if size <= KARATSUBA_MAX_REDUCE_LIMBS {
191        out.fill(Limb::ZERO);
192        adc_mul_limbs(lhs, rhs, out);
193        return;
194    }
195    if lhs.len() + rhs.len() != out.len() || scratch.len() < 2 * size {
196        panic!("invalid arguments to karatsuba_mul_limbs");
197    }
198    let half = size / 2;
199    let (scratch, ext_scratch) = scratch.split_at_mut(size);
200
201    let (x, xt) = lhs.split_at(size);
202    let (y, yt) = rhs.split_at(size);
203    let (x0, x1) = x.split_at(half);
204    let (y0, y1) = y.split_at(half);
205
206    // Initialize output buffer
207    out.fill(Limb::ZERO);
208
209    // Calculate abs(x0 - x1) and abs(y1 - y0)
210    let mut i = 0;
211    let mut borrow0 = Limb::ZERO;
212    let mut borrow1 = Limb::ZERO;
213    while i < half {
214        (scratch[i], borrow0) = x0[i].sbb(x1[i], borrow0);
215        (scratch[i + half], borrow1) = y1[i].sbb(y0[i], borrow1);
216        i += 1;
217    }
218    // Conditionally negate terms depending whether they borrowed
219    conditional_wrapping_neg_assign(&mut scratch[..half], ConstChoice::from_word_mask(borrow0.0));
220    conditional_wrapping_neg_assign(
221        &mut scratch[half..size],
222        ConstChoice::from_word_mask(borrow1.0),
223    );
224
225    // Calculate abs(z1) = abs(x0 - x1)•abs(y1 - y0)
226    karatsuba_mul_limbs(
227        &scratch[..half],
228        &scratch[half..size],
229        &mut out[half..size + half],
230        ext_scratch,
231    );
232    let z1_neg = ConstChoice::from_word_mask(borrow0.0).xor(ConstChoice::from_word_mask(borrow1.0));
233    // Conditionally negate the output
234    conditional_wrapping_neg_assign(&mut out[..2 * size], z1_neg);
235
236    // Calculate z0 = x0•y0 into scratch
237    karatsuba_mul_limbs(x0, y0, scratch, ext_scratch);
238    // Add z0•(1 + b) to output
239    let mut carry = Limb::ZERO;
240    let mut carry2 = Limb::ZERO;
241    i = 0;
242    while i < size {
243        (out[i], carry) = out[i].adc(scratch[i], carry); // add z0
244        i += 1;
245    }
246    i = 0;
247    while i < half {
248        (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z0.0
249        i += 1;
250    }
251    carry = carry.wrapping_add(carry2);
252    while i < size {
253        (out[i + half], carry) = out[i + half].adc(scratch[i], carry); // add z0.1
254        i += 1;
255    }
256
257    // Calculate z2 = x1•y1 into scratch
258    karatsuba_mul_limbs(x1, y1, scratch, ext_scratch);
259    // Add z2•(b + b^2) to output
260    carry2 = Limb::ZERO;
261    i = 0;
262    while i < size {
263        (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z2
264        i += 1;
265    }
266    carry = carry.wrapping_add(carry2);
267    carry2 = Limb::ZERO;
268    i = 0;
269    while i < half {
270        (out[i + size], carry2) = out[i + size].adc(scratch[i], carry2); // add z2.0
271        i += 1;
272    }
273    carry = carry.wrapping_add(carry2);
274    while i < size {
275        (out[i + size], carry) = out[i + size].adc(scratch[i], carry); // add z2.1
276        i += 1;
277    }
278
279    // Handle trailing limbs
280    if !xt.is_empty() {
281        adc_mul_limbs(xt, rhs, &mut out[size..]);
282    }
283    if !yt.is_empty() {
284        let end_pos = 2 * size + yt.len();
285        carry = adc_mul_limbs(yt, x, &mut out[size..end_pos]);
286        i = end_pos;
287        while i < out.len() {
288            (out[i], carry) = out[i].adc(Limb::ZERO, carry);
289            i += 1;
290        }
291    }
292}
293
294#[cfg(feature = "alloc")]
295#[inline(never)]
296pub(crate) fn karatsuba_square_limbs(limbs: &[Limb], out: &mut [Limb], scratch: &mut [Limb]) {
297    let size = limbs.len();
298    if size <= KARATSUBA_MAX_REDUCE_LIMBS * 2 || (size & 1) == 1 {
299        out.fill(Limb::ZERO);
300        square_limbs(limbs, out);
301        return;
302    }
303    if 2 * size != out.len() || scratch.len() < out.len() {
304        panic!("invalid arguments to karatsuba_square_limbs");
305    }
306    let half = size / 2;
307    let (scratch, ext_scratch) = scratch.split_at_mut(size);
308    let (x0, x1) = limbs.split_at(half);
309
310    // Initialize output buffer
311    out[..2 * size].fill(Limb::ZERO);
312
313    // Calculate x0 - x1
314    let mut i = 0;
315    let mut borrow = Limb::ZERO;
316    while i < half {
317        (scratch[i], borrow) = x0[i].sbb(x1[i], borrow);
318        i += 1;
319    }
320    // Conditionally negate depending whether subtraction borrowed
321    conditional_wrapping_neg_assign(&mut scratch[..half], ConstChoice::from_word_mask(borrow.0));
322    // Calculate z1 = (x0 - x1)^2 into output
323    karatsuba_square_limbs(&scratch[..half], &mut out[half..3 * half], ext_scratch);
324    // Negate the output (will add 1 to produce the wrapping negative)
325    i = 0;
326    while i < 2 * size {
327        out[i] = !out[i];
328        i += 1;
329    }
330
331    // Calculate z0 = x0^2 into scratch
332    karatsuba_square_limbs(x0, scratch, ext_scratch);
333    // Add z0•(1 + b) to output
334    let mut carry = Limb::ONE; // add 1 to complete wrapping negative
335    let mut carry2 = Limb::ZERO;
336    i = 0;
337    while i < size {
338        (out[i], carry) = out[i].adc(scratch[i], carry); // add z0
339        i += 1;
340    }
341    i = 0;
342    while i < half {
343        (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z0.0
344        i += 1;
345    }
346    carry = carry.wrapping_add(carry2);
347    while i < size {
348        (out[i + half], carry) = out[i + half].adc(scratch[i], carry); // add z0.1
349        i += 1;
350    }
351
352    // Calculate z2 = x1^2 into scratch
353    karatsuba_square_limbs(x1, scratch, ext_scratch);
354    // Add z2•(b + b^2) to output
355    carry2 = Limb::ZERO;
356    i = 0;
357    while i < size {
358        (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); // add z2
359        i += 1;
360    }
361    carry = carry.wrapping_add(carry2);
362    carry2 = Limb::ZERO;
363    i = 0;
364    while i < half {
365        (out[i + size], carry2) = out[i + size].adc(scratch[i], carry2); // add z2.0
366        i += 1;
367    }
368    carry = carry.wrapping_add(carry2);
369    while i < size {
370        (out[i + size], carry) = out[i + size].adc(scratch[i], carry); // add z2.1
371        i += 1;
372    }
373}
374
375#[cfg(feature = "alloc")]
376/// Conditionally replace the contents of a mutable limb slice with its wrapping negation.
377#[inline]
378fn conditional_wrapping_neg_assign(limbs: &mut [Limb], choice: ConstChoice) {
379    let mut carry = choice.select_word(0, 1) as WideWord;
380    let mut r;
381    let mut i = 0;
382    while i < limbs.len() {
383        r = (choice.select_word(limbs[i].0, !limbs[i].0) as WideWord) + carry;
384        limbs[i].0 = r as Word;
385        carry = r >> Word::BITS;
386        i += 1;
387    }
388}
389
390/// Add the schoolbook product of two limb slices to a limb slice, returning the carry.
391#[cfg(feature = "alloc")]
392fn adc_mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) -> Limb {
393    if lhs.len() + rhs.len() != out.len() {
394        panic!("adc_mul_limbs length mismatch");
395    }
396
397    let mut carry = Limb::ZERO;
398    let mut i = 0;
399    while i < lhs.len() {
400        let mut j = 0;
401        let mut carry2 = Limb::ZERO;
402        let xi = lhs[i];
403
404        while j < rhs.len() {
405            let k = i + j;
406            (out[k], carry2) = out[k].mac(xi, rhs[j], carry2);
407            j += 1;
408        }
409
410        carry = carry.wrapping_add(carry2);
411        (out[i + j], carry) = out[i + j].adc(Limb::ZERO, carry);
412        i += 1;
413    }
414
415    carry
416}
417
418impl_uint_karatsuba_multiplication!(128, 64, 32, 16, 8);
419impl_uint_karatsuba_squaring!(128, 64, 32);