1use crate::{Checked, CheckedMul, Concat, ConcatMixed, Limb, Uint, WideWord, Word, Wrapping, Zero};
4use core::ops::{Mul, MulAssign};
5use subtle::CtOption;
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8 pub fn mul<const HLIMBS: usize>(
10 &self,
11 rhs: &Uint<HLIMBS>,
12 ) -> <Uint<HLIMBS> as ConcatMixed<Self>>::MixedOutput
13 where
14 Uint<HLIMBS>: ConcatMixed<Self>,
15 {
16 let (lo, hi) = self.mul_wide(rhs);
17 hi.concat_mixed(&lo)
18 }
19
20 pub const fn mul_wide<const HLIMBS: usize>(&self, rhs: &Uint<HLIMBS>) -> (Self, Uint<HLIMBS>) {
32 let mut i = 0;
33 let mut lo = Self::ZERO;
34 let mut hi = Uint::<HLIMBS>::ZERO;
35
36 while i < LIMBS {
39 let mut j = 0;
40 let mut carry = Limb::ZERO;
41
42 while j < HLIMBS {
43 let k = i + j;
44
45 if k >= LIMBS {
46 let (n, c) = hi.limbs[k - LIMBS].mac(self.limbs[i], rhs.limbs[j], carry);
47 hi.limbs[k - LIMBS] = n;
48 carry = c;
49 } else {
50 let (n, c) = lo.limbs[k].mac(self.limbs[i], rhs.limbs[j], carry);
51 lo.limbs[k] = n;
52 carry = c;
53 }
54
55 j += 1;
56 }
57
58 if i + j >= LIMBS {
59 hi.limbs[i + j - LIMBS] = carry;
60 } else {
61 lo.limbs[i + j] = carry;
62 }
63 i += 1;
64 }
65
66 (lo, hi)
67 }
68
69 pub const fn saturating_mul<const HLIMBS: usize>(&self, rhs: &Uint<HLIMBS>) -> Self {
71 let (res, overflow) = self.mul_wide(rhs);
72 Self::ct_select(&res, &Self::MAX, overflow.ct_is_nonzero())
73 }
74
75 pub const fn wrapping_mul<const H: usize>(&self, rhs: &Uint<H>) -> Self {
77 self.mul_wide(rhs).0
78 }
79
80 pub fn square(&self) -> <Self as Concat>::Output
82 where
83 Self: Concat,
84 {
85 let (lo, hi) = self.square_wide();
86 hi.concat(&lo)
87 }
88
89 pub const fn square_wide(&self) -> (Self, Self) {
91 let mut lo = Self::ZERO;
96 let mut hi = Self::ZERO;
97
98 let mut i = 1;
100 while i < LIMBS {
101 let mut j = 0;
102 let mut carry = Limb::ZERO;
103
104 while j < i {
105 let k = i + j;
106
107 if k >= LIMBS {
108 let (n, c) = hi.limbs[k - LIMBS].mac(self.limbs[i], self.limbs[j], carry);
109 hi.limbs[k - LIMBS] = n;
110 carry = c;
111 } else {
112 let (n, c) = lo.limbs[k].mac(self.limbs[i], self.limbs[j], carry);
113 lo.limbs[k] = n;
114 carry = c;
115 }
116
117 j += 1;
118 }
119
120 if (2 * i) < LIMBS {
121 lo.limbs[2 * i] = carry;
122 } else {
123 hi.limbs[2 * i - LIMBS] = carry;
124 }
125
126 i += 1;
127 }
128
129 (lo, hi) = Self::shl_vartime_wide((lo, hi), 1);
132
133 let mut carry = Limb::ZERO;
135 let mut i = 0;
136 while i < LIMBS {
137 if (i * 2) < LIMBS {
138 let (n, c) = lo.limbs[i * 2].mac(self.limbs[i], self.limbs[i], carry);
139 lo.limbs[i * 2] = n;
140 carry = c;
141 } else {
142 let (n, c) = hi.limbs[i * 2 - LIMBS].mac(self.limbs[i], self.limbs[i], carry);
143 hi.limbs[i * 2 - LIMBS] = n;
144 carry = c;
145 }
146
147 if (i * 2 + 1) < LIMBS {
148 let n = lo.limbs[i * 2 + 1].0 as WideWord + carry.0 as WideWord;
149 lo.limbs[i * 2 + 1] = Limb(n as Word);
150 carry = Limb((n >> Word::BITS) as Word);
151 } else {
152 let n = hi.limbs[i * 2 + 1 - LIMBS].0 as WideWord + carry.0 as WideWord;
153 hi.limbs[i * 2 + 1 - LIMBS] = Limb(n as Word);
154 carry = Limb((n >> Word::BITS) as Word);
155 }
156
157 i += 1;
158 }
159
160 (lo, hi)
161 }
162}
163
164impl<const LIMBS: usize, const HLIMBS: usize> CheckedMul<&Uint<HLIMBS>> for Uint<LIMBS> {
165 type Output = Self;
166
167 fn checked_mul(&self, rhs: &Uint<HLIMBS>) -> CtOption<Self> {
168 let (lo, hi) = self.mul_wide(rhs);
169 CtOption::new(lo, hi.is_zero())
170 }
171}
172
173impl<const LIMBS: usize, const HLIMBS: usize> Mul<Wrapping<Uint<HLIMBS>>>
174 for Wrapping<Uint<LIMBS>>
175{
176 type Output = Self;
177
178 fn mul(self, rhs: Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
179 Wrapping(self.0.wrapping_mul(&rhs.0))
180 }
181}
182
183impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Wrapping<Uint<HLIMBS>>>
184 for Wrapping<Uint<LIMBS>>
185{
186 type Output = Self;
187
188 fn mul(self, rhs: &Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
189 Wrapping(self.0.wrapping_mul(&rhs.0))
190 }
191}
192
193impl<const LIMBS: usize, const HLIMBS: usize> Mul<Wrapping<Uint<HLIMBS>>>
194 for &Wrapping<Uint<LIMBS>>
195{
196 type Output = Wrapping<Uint<LIMBS>>;
197
198 fn mul(self, rhs: Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
199 Wrapping(self.0.wrapping_mul(&rhs.0))
200 }
201}
202
203impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Wrapping<Uint<HLIMBS>>>
204 for &Wrapping<Uint<LIMBS>>
205{
206 type Output = Wrapping<Uint<LIMBS>>;
207
208 fn mul(self, rhs: &Wrapping<Uint<HLIMBS>>) -> Wrapping<Uint<LIMBS>> {
209 Wrapping(self.0.wrapping_mul(&rhs.0))
210 }
211}
212
213impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<Wrapping<Uint<HLIMBS>>>
214 for Wrapping<Uint<LIMBS>>
215{
216 fn mul_assign(&mut self, other: Wrapping<Uint<HLIMBS>>) {
217 *self = *self * other;
218 }
219}
220
221impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<&Wrapping<Uint<HLIMBS>>>
222 for Wrapping<Uint<LIMBS>>
223{
224 fn mul_assign(&mut self, other: &Wrapping<Uint<HLIMBS>>) {
225 *self = *self * other;
226 }
227}
228
229impl<const LIMBS: usize, const HLIMBS: usize> Mul<Checked<Uint<HLIMBS>>> for Checked<Uint<LIMBS>> {
230 type Output = Self;
231
232 fn mul(self, rhs: Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
233 Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
234 }
235}
236
237impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Checked<Uint<HLIMBS>>> for Checked<Uint<LIMBS>> {
238 type Output = Checked<Uint<LIMBS>>;
239
240 fn mul(self, rhs: &Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
241 Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
242 }
243}
244
245impl<const LIMBS: usize, const HLIMBS: usize> Mul<Checked<Uint<HLIMBS>>> for &Checked<Uint<LIMBS>> {
246 type Output = Checked<Uint<LIMBS>>;
247
248 fn mul(self, rhs: Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
249 Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
250 }
251}
252
253impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Checked<Uint<HLIMBS>>>
254 for &Checked<Uint<LIMBS>>
255{
256 type Output = Checked<Uint<LIMBS>>;
257
258 fn mul(self, rhs: &Checked<Uint<HLIMBS>>) -> Checked<Uint<LIMBS>> {
259 Checked(self.0.and_then(|a| rhs.0.and_then(|b| a.checked_mul(&b))))
260 }
261}
262
263impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<Checked<Uint<HLIMBS>>>
264 for Checked<Uint<LIMBS>>
265{
266 fn mul_assign(&mut self, other: Checked<Uint<HLIMBS>>) {
267 *self = *self * other;
268 }
269}
270
271impl<const LIMBS: usize, const HLIMBS: usize> MulAssign<&Checked<Uint<HLIMBS>>>
272 for Checked<Uint<LIMBS>>
273{
274 fn mul_assign(&mut self, other: &Checked<Uint<HLIMBS>>) {
275 *self = *self * other;
276 }
277}
278
279impl<const LIMBS: usize, const HLIMBS: usize> Mul<Uint<HLIMBS>> for Uint<LIMBS>
280where
281 Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
282{
283 type Output = <Uint<HLIMBS> as ConcatMixed<Self>>::MixedOutput;
284
285 fn mul(self, other: Uint<HLIMBS>) -> Self::Output {
286 Uint::mul(&self, &other)
287 }
288}
289
290impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Uint<HLIMBS>> for Uint<LIMBS>
291where
292 Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
293{
294 type Output = <Uint<HLIMBS> as ConcatMixed<Self>>::MixedOutput;
295
296 fn mul(self, other: &Uint<HLIMBS>) -> Self::Output {
297 Uint::mul(&self, other)
298 }
299}
300
301impl<const LIMBS: usize, const HLIMBS: usize> Mul<Uint<HLIMBS>> for &Uint<LIMBS>
302where
303 Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
304{
305 type Output = <Uint<HLIMBS> as ConcatMixed<Uint<LIMBS>>>::MixedOutput;
306
307 fn mul(self, other: Uint<HLIMBS>) -> Self::Output {
308 Uint::mul(self, &other)
309 }
310}
311
312impl<const LIMBS: usize, const HLIMBS: usize> Mul<&Uint<HLIMBS>> for &Uint<LIMBS>
313where
314 Uint<HLIMBS>: ConcatMixed<Uint<LIMBS>>,
315{
316 type Output = <Uint<HLIMBS> as ConcatMixed<Uint<LIMBS>>>::MixedOutput;
317
318 fn mul(self, other: &Uint<HLIMBS>) -> Self::Output {
319 Uint::mul(self, other)
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use crate::{CheckedMul, Zero, U128, U192, U256, U64};
326
327 #[test]
328 fn mul_wide_zero_and_one() {
329 assert_eq!(U64::ZERO.mul_wide(&U64::ZERO), (U64::ZERO, U64::ZERO));
330 assert_eq!(U64::ZERO.mul_wide(&U64::ONE), (U64::ZERO, U64::ZERO));
331 assert_eq!(U64::ONE.mul_wide(&U64::ZERO), (U64::ZERO, U64::ZERO));
332 assert_eq!(U64::ONE.mul_wide(&U64::ONE), (U64::ONE, U64::ZERO));
333 }
334
335 #[test]
336 fn mul_wide_lo_only() {
337 let primes: &[u32] = &[3, 5, 17, 257, 65537];
338
339 for &a_int in primes {
340 for &b_int in primes {
341 let (lo, hi) = U64::from_u32(a_int).mul_wide(&U64::from_u32(b_int));
342 let expected = U64::from_u64(a_int as u64 * b_int as u64);
343 assert_eq!(lo, expected);
344 assert!(bool::from(hi.is_zero()));
345 }
346 }
347 }
348
349 #[test]
350 fn mul_concat_even() {
351 assert_eq!(U64::ZERO * U64::MAX, U128::ZERO);
352 assert_eq!(U64::MAX * U64::ZERO, U128::ZERO);
353 assert_eq!(
354 U64::MAX * U64::MAX,
355 U128::from_u128(0xfffffffffffffffe_0000000000000001)
356 );
357 assert_eq!(
358 U64::ONE * U64::MAX,
359 U128::from_u128(0x0000000000000000_ffffffffffffffff)
360 );
361 }
362
363 #[test]
364 fn mul_concat_mixed() {
365 let a = U64::from_u64(0x0011223344556677);
366 let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff);
367 assert_eq!(a * b, U192::from(&a).saturating_mul(&b));
368 assert_eq!(b * a, U192::from(&b).saturating_mul(&a));
369 }
370
371 #[test]
372 fn checked_mul_ok() {
373 let n = U64::from_u32(0xffff_ffff);
374 assert_eq!(
375 n.checked_mul(&n).unwrap(),
376 U64::from_u64(0xffff_fffe_0000_0001)
377 );
378 }
379
380 #[test]
381 fn checked_mul_overflow() {
382 let n = U64::from_u64(0xffff_ffff_ffff_ffff);
383 assert!(bool::from(n.checked_mul(&n).is_none()));
384 }
385
386 #[test]
387 fn saturating_mul_no_overflow() {
388 let n = U64::from_u8(8);
389 assert_eq!(n.saturating_mul(&n), U64::from_u8(64));
390 }
391
392 #[test]
393 fn saturating_mul_overflow() {
394 let a = U64::from(0xffff_ffff_ffff_ffffu64);
395 let b = U64::from(2u8);
396 assert_eq!(a.saturating_mul(&b), U64::MAX);
397 }
398
399 #[test]
400 fn square() {
401 let n = U64::from_u64(0xffff_ffff_ffff_ffff);
402 let (hi, lo) = n.square().split();
403 assert_eq!(lo, U64::from_u64(1));
404 assert_eq!(hi, U64::from_u64(0xffff_ffff_ffff_fffe));
405 }
406
407 #[test]
408 fn square_larger() {
409 let n = U256::MAX;
410 let (hi, lo) = n.square().split();
411 assert_eq!(lo, U256::ONE);
412 assert_eq!(hi, U256::MAX.wrapping_sub(&U256::ONE));
413 }
414}