1use super::{uint_mul_limbs, uint_square_limbs};
22use crate::{ConstChoice, Limb, Uint};
23
24#[cfg(feature = "alloc")]
25use super::square_limbs;
26#[cfg(feature = "alloc")]
27use crate::{WideWord, Word};
28
29#[cfg(feature = "alloc")]
30pub const KARATSUBA_MIN_STARTING_LIMBS: usize = 32;
31#[cfg(feature = "alloc")]
32pub const KARATSUBA_MAX_REDUCE_LIMBS: usize = 24;
33
34pub(crate) struct UintKaratsubaMul<const LIMBS: usize>;
36
37macro_rules! impl_uint_karatsuba_multiplication {
38 (reduce $full_size:expr, $half_size:expr) => {
40 impl UintKaratsubaMul<$full_size> {
41 pub(crate) const fn multiply(
42 lhs: &[Limb],
43 rhs: &[Limb],
44 ) -> (Uint<$full_size>, Uint<$full_size>) {
45 let (x0, x1) = lhs.split_at($half_size);
46 let (y0, y1) = rhs.split_at($half_size);
47
48 let mut l0 = Uint::<$half_size>::ZERO;
50 let mut l1 = Uint::<$half_size>::ZERO;
51 let mut l0b = Limb::ZERO;
52 let mut l1b = Limb::ZERO;
53 let mut i = 0;
54 while i < $half_size {
55 (l0.limbs[i], l0b) = x0[i].sbb(x1[i], l0b);
56 (l1.limbs[i], l1b) = y1[i].sbb(y0[i], l1b);
57 i += 1;
58 }
59 l0 = Uint::select(
60 &l0,
61 &l0.wrapping_neg(),
62 ConstChoice::from_word_mask(l0b.0),
63 );
64 l1 = Uint::select(
65 &l1,
66 &l1.wrapping_neg(),
67 ConstChoice::from_word_mask(l1b.0),
68 );
69 let z1 = UintKaratsubaMul::<$half_size>::multiply(&l0.limbs, &l1.limbs);
70 let z1_neg = ConstChoice::from_word_mask(l0b.0)
71 .xor(ConstChoice::from_word_mask(l1b.0));
72
73 let mut res = (Uint::ZERO, z1.0, z1.1, Uint::ZERO);
75 res.0 = Uint::select(&res.0, &res.0.not(), z1_neg);
76 res.1 = Uint::select(&res.1, &res.1.not(), z1_neg);
77 res.2 = Uint::select(&res.2, &res.2.not(), z1_neg);
78 res.3 = Uint::select(&res.3, &res.3.not(), z1_neg);
79
80 let z0 = UintKaratsubaMul::<$half_size>::multiply(&x0, &y0);
82 let z2 = UintKaratsubaMul::<$half_size>::multiply(&x1, &y1);
84
85 let mut carry = Limb::select(Limb::ZERO, Limb::ONE, z1_neg);
87 (res.0, carry) = res.0.adc(&z0.0, carry);
88 (res.1, carry) = res.1.adc(&z0.1, carry);
89 let mut carry2;
90 (res.1, carry2) = res.1.adc(&z0.0, Limb::ZERO);
91 (res.2, carry) = res.2.adc(&z0.1, carry.wrapping_add(carry2));
92 (res.1, carry2) = res.1.adc(&z2.0, Limb::ZERO);
93 (res.2, carry2) = res.2.adc(&z2.1, carry2);
94 carry = carry.wrapping_add(carry2);
95 (res.2, carry2) = res.2.adc(&z2.0, Limb::ZERO);
96 (res.3, _) = res.3.adc(&z2.1, carry.wrapping_add(carry2));
97
98 (res.0.concat(&res.1), res.2.concat(&res.3))
99 }
100 }
101 };
102 ($small_size:expr) => {
103 impl UintKaratsubaMul<$small_size> {
104 #[inline]
105 pub(crate) const fn multiply(lhs: &[Limb], rhs: &[Limb]) -> (Uint<$small_size>, Uint<$small_size>) {
106 uint_mul_limbs(lhs, rhs)
107 }
108 }
109 };
110 ($full_size:tt, $half_size:tt $(,$rest:tt)*) => {
111 impl_uint_karatsuba_multiplication!{reduce $full_size, $half_size}
112 impl_uint_karatsuba_multiplication!{$half_size $(,$rest)*}
113 }
114}
115
116macro_rules! impl_uint_karatsuba_squaring {
117 (reduce $full_size:expr, $half_size:expr) => {
118 impl UintKaratsubaMul<$full_size> {
119 pub(crate) const fn square(limbs: &[Limb]) -> (Uint<$full_size>, Uint<$full_size>) {
120 let (x0, x1) = limbs.split_at($half_size);
121 let z0 = UintKaratsubaMul::<$half_size>::square(&x0);
122 let z2 = UintKaratsubaMul::<$half_size>::square(&x1);
123
124 let mut res = (z0.0, z0.1, Uint::<$half_size>::ZERO, Uint::<$half_size>::ZERO);
126 let mut carry;
127 (res.1, carry) = res.1.adc(&z0.0, Limb::ZERO);
128 (res.2, carry) = z0.1.adc(&z2.0, carry);
129 let mut carry2;
130 (res.1, carry2) = res.1.adc(&z2.0, Limb::ZERO);
131 (res.2, carry2) = res.2.adc(&z2.1, carry2);
132 (res.3, _) = z2.1.adc(&Uint::ZERO, carry.wrapping_add(carry2));
133
134 let mut l0 = Uint::<$half_size>::ZERO;
136 let mut l0b = Limb::ZERO;
137 let mut i = 0;
138 while i < $half_size {
139 (l0.limbs[i], l0b) = x0[i].sbb(x1[i], l0b);
140 i += 1;
141 }
142 l0 = Uint::select(
143 &l0,
144 &l0.wrapping_neg(),
145 ConstChoice::from_word_mask(l0b.0),
146 );
147
148 let z1 = UintKaratsubaMul::<$half_size>::square(&l0.limbs);
149
150 carry = Limb::ZERO;
152 (res.1, carry) = res.1.sbb(&z1.0, carry);
153 (res.2, carry) = res.2.sbb(&z1.1, carry);
154 (res.3, _) = res.3.sbb(&Uint::ZERO, carry);
155
156 (res.0.concat(&res.1), res.2.concat(&res.3))
157 }
158 }
159 };
160 ($small_size:expr) => {
161 impl UintKaratsubaMul<$small_size> {
162 #[inline]
163 pub(crate) const fn square(limbs: &[Limb]) -> (Uint<$small_size>, Uint<$small_size>) {
164 uint_square_limbs(limbs)
165 }
166 }
167 };
168 ($full_size:tt, $half_size:tt $(,$rest:tt)*) => {
169 impl_uint_karatsuba_squaring!{reduce $full_size, $half_size}
170 impl_uint_karatsuba_squaring!{$half_size $(,$rest)*}
171 }
172}
173
174#[cfg(feature = "alloc")]
175#[inline(never)]
176pub(crate) fn karatsuba_mul_limbs(
177 lhs: &[Limb],
178 rhs: &[Limb],
179 out: &mut [Limb],
180 scratch: &mut [Limb],
181) {
182 let size = {
183 let overlap = lhs.len().min(rhs.len());
184 if (overlap & 1) == 1 {
185 overlap.saturating_sub(1)
186 } else {
187 overlap
188 }
189 };
190 if size <= KARATSUBA_MAX_REDUCE_LIMBS {
191 out.fill(Limb::ZERO);
192 adc_mul_limbs(lhs, rhs, out);
193 return;
194 }
195 if lhs.len() + rhs.len() != out.len() || scratch.len() < 2 * size {
196 panic!("invalid arguments to karatsuba_mul_limbs");
197 }
198 let half = size / 2;
199 let (scratch, ext_scratch) = scratch.split_at_mut(size);
200
201 let (x, xt) = lhs.split_at(size);
202 let (y, yt) = rhs.split_at(size);
203 let (x0, x1) = x.split_at(half);
204 let (y0, y1) = y.split_at(half);
205
206 out.fill(Limb::ZERO);
208
209 let mut i = 0;
211 let mut borrow0 = Limb::ZERO;
212 let mut borrow1 = Limb::ZERO;
213 while i < half {
214 (scratch[i], borrow0) = x0[i].sbb(x1[i], borrow0);
215 (scratch[i + half], borrow1) = y1[i].sbb(y0[i], borrow1);
216 i += 1;
217 }
218 conditional_wrapping_neg_assign(&mut scratch[..half], ConstChoice::from_word_mask(borrow0.0));
220 conditional_wrapping_neg_assign(
221 &mut scratch[half..size],
222 ConstChoice::from_word_mask(borrow1.0),
223 );
224
225 karatsuba_mul_limbs(
227 &scratch[..half],
228 &scratch[half..size],
229 &mut out[half..size + half],
230 ext_scratch,
231 );
232 let z1_neg = ConstChoice::from_word_mask(borrow0.0).xor(ConstChoice::from_word_mask(borrow1.0));
233 conditional_wrapping_neg_assign(&mut out[..2 * size], z1_neg);
235
236 karatsuba_mul_limbs(x0, y0, scratch, ext_scratch);
238 let mut carry = Limb::ZERO;
240 let mut carry2 = Limb::ZERO;
241 i = 0;
242 while i < size {
243 (out[i], carry) = out[i].adc(scratch[i], carry); i += 1;
245 }
246 i = 0;
247 while i < half {
248 (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); i += 1;
250 }
251 carry = carry.wrapping_add(carry2);
252 while i < size {
253 (out[i + half], carry) = out[i + half].adc(scratch[i], carry); i += 1;
255 }
256
257 karatsuba_mul_limbs(x1, y1, scratch, ext_scratch);
259 carry2 = Limb::ZERO;
261 i = 0;
262 while i < size {
263 (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); i += 1;
265 }
266 carry = carry.wrapping_add(carry2);
267 carry2 = Limb::ZERO;
268 i = 0;
269 while i < half {
270 (out[i + size], carry2) = out[i + size].adc(scratch[i], carry2); i += 1;
272 }
273 carry = carry.wrapping_add(carry2);
274 while i < size {
275 (out[i + size], carry) = out[i + size].adc(scratch[i], carry); i += 1;
277 }
278
279 if !xt.is_empty() {
281 adc_mul_limbs(xt, rhs, &mut out[size..]);
282 }
283 if !yt.is_empty() {
284 let end_pos = 2 * size + yt.len();
285 carry = adc_mul_limbs(yt, x, &mut out[size..end_pos]);
286 i = end_pos;
287 while i < out.len() {
288 (out[i], carry) = out[i].adc(Limb::ZERO, carry);
289 i += 1;
290 }
291 }
292}
293
294#[cfg(feature = "alloc")]
295#[inline(never)]
296pub(crate) fn karatsuba_square_limbs(limbs: &[Limb], out: &mut [Limb], scratch: &mut [Limb]) {
297 let size = limbs.len();
298 if size <= KARATSUBA_MAX_REDUCE_LIMBS * 2 || (size & 1) == 1 {
299 out.fill(Limb::ZERO);
300 square_limbs(limbs, out);
301 return;
302 }
303 if 2 * size != out.len() || scratch.len() < out.len() {
304 panic!("invalid arguments to karatsuba_square_limbs");
305 }
306 let half = size / 2;
307 let (scratch, ext_scratch) = scratch.split_at_mut(size);
308 let (x0, x1) = limbs.split_at(half);
309
310 out[..2 * size].fill(Limb::ZERO);
312
313 let mut i = 0;
315 let mut borrow = Limb::ZERO;
316 while i < half {
317 (scratch[i], borrow) = x0[i].sbb(x1[i], borrow);
318 i += 1;
319 }
320 conditional_wrapping_neg_assign(&mut scratch[..half], ConstChoice::from_word_mask(borrow.0));
322 karatsuba_square_limbs(&scratch[..half], &mut out[half..3 * half], ext_scratch);
324 i = 0;
326 while i < 2 * size {
327 out[i] = !out[i];
328 i += 1;
329 }
330
331 karatsuba_square_limbs(x0, scratch, ext_scratch);
333 let mut carry = Limb::ONE; let mut carry2 = Limb::ZERO;
336 i = 0;
337 while i < size {
338 (out[i], carry) = out[i].adc(scratch[i], carry); i += 1;
340 }
341 i = 0;
342 while i < half {
343 (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); i += 1;
345 }
346 carry = carry.wrapping_add(carry2);
347 while i < size {
348 (out[i + half], carry) = out[i + half].adc(scratch[i], carry); i += 1;
350 }
351
352 karatsuba_square_limbs(x1, scratch, ext_scratch);
354 carry2 = Limb::ZERO;
356 i = 0;
357 while i < size {
358 (out[i + half], carry2) = out[i + half].adc(scratch[i], carry2); i += 1;
360 }
361 carry = carry.wrapping_add(carry2);
362 carry2 = Limb::ZERO;
363 i = 0;
364 while i < half {
365 (out[i + size], carry2) = out[i + size].adc(scratch[i], carry2); i += 1;
367 }
368 carry = carry.wrapping_add(carry2);
369 while i < size {
370 (out[i + size], carry) = out[i + size].adc(scratch[i], carry); i += 1;
372 }
373}
374
375#[cfg(feature = "alloc")]
376#[inline]
378fn conditional_wrapping_neg_assign(limbs: &mut [Limb], choice: ConstChoice) {
379 let mut carry = choice.select_word(0, 1) as WideWord;
380 let mut r;
381 let mut i = 0;
382 while i < limbs.len() {
383 r = (choice.select_word(limbs[i].0, !limbs[i].0) as WideWord) + carry;
384 limbs[i].0 = r as Word;
385 carry = r >> Word::BITS;
386 i += 1;
387 }
388}
389
390#[cfg(feature = "alloc")]
392fn adc_mul_limbs(lhs: &[Limb], rhs: &[Limb], out: &mut [Limb]) -> Limb {
393 if lhs.len() + rhs.len() != out.len() {
394 panic!("adc_mul_limbs length mismatch");
395 }
396
397 let mut carry = Limb::ZERO;
398 let mut i = 0;
399 while i < lhs.len() {
400 let mut j = 0;
401 let mut carry2 = Limb::ZERO;
402 let xi = lhs[i];
403
404 while j < rhs.len() {
405 let k = i + j;
406 (out[k], carry2) = out[k].mac(xi, rhs[j], carry2);
407 j += 1;
408 }
409
410 carry = carry.wrapping_add(carry2);
411 (out[i + j], carry) = out[i + j].adc(Limb::ZERO, carry);
412 i += 1;
413 }
414
415 carry
416}
417
418impl_uint_karatsuba_multiplication!(128, 64, 32, 16, 8);
419impl_uint_karatsuba_squaring!(128, 64, 32);