1use crate::{c, error, polyfill::ArrayFlatMap};
22
23#[cfg(any(test, feature = "alloc"))]
24use crate::bits;
25
26#[cfg(feature = "alloc")]
27use core::num::Wrapping;
28
29#[cfg(target_pointer_width = "64")]
31pub type Limb = u64;
32#[cfg(target_pointer_width = "32")]
33pub type Limb = u32;
34#[cfg(target_pointer_width = "64")]
35pub const LIMB_BITS: usize = 64;
36#[cfg(target_pointer_width = "32")]
37pub const LIMB_BITS: usize = 32;
38
39#[cfg(target_pointer_width = "64")]
40#[derive(Debug, PartialEq)]
41#[repr(u64)]
42pub enum LimbMask {
43 True = 0xffff_ffff_ffff_ffff,
44 False = 0,
45}
46
47#[cfg(target_pointer_width = "32")]
48#[derive(Debug, PartialEq)]
49#[repr(u32)]
50pub enum LimbMask {
51 True = 0xffff_ffff,
52 False = 0,
53}
54
55pub const LIMB_BYTES: usize = (LIMB_BITS + 7) / 8;
56
57#[inline]
58pub fn limbs_equal_limbs_consttime(a: &[Limb], b: &[Limb]) -> LimbMask {
59 prefixed_extern! {
60 fn LIMBS_equal(a: *const Limb, b: *const Limb, num_limbs: c::size_t) -> LimbMask;
61 }
62
63 assert_eq!(a.len(), b.len());
64 unsafe { LIMBS_equal(a.as_ptr(), b.as_ptr(), a.len()) }
65}
66
67#[inline]
68pub fn limbs_less_than_limbs_consttime(a: &[Limb], b: &[Limb]) -> LimbMask {
69 assert_eq!(a.len(), b.len());
70 unsafe { LIMBS_less_than(a.as_ptr(), b.as_ptr(), b.len()) }
71}
72
73#[inline]
74pub fn limbs_less_than_limbs_vartime(a: &[Limb], b: &[Limb]) -> bool {
75 limbs_less_than_limbs_consttime(a, b) == LimbMask::True
76}
77
78#[inline]
79#[cfg(feature = "alloc")]
80pub fn limbs_less_than_limb_constant_time(a: &[Limb], b: Limb) -> LimbMask {
81 unsafe { LIMBS_less_than_limb(a.as_ptr(), b, a.len()) }
82}
83
84#[inline]
85pub fn limbs_are_zero_constant_time(limbs: &[Limb]) -> LimbMask {
86 unsafe { LIMBS_are_zero(limbs.as_ptr(), limbs.len()) }
87}
88
89#[cfg(any(test, feature = "alloc"))]
90#[inline]
91pub fn limbs_are_even_constant_time(limbs: &[Limb]) -> LimbMask {
92 unsafe { LIMBS_are_even(limbs.as_ptr(), limbs.len()) }
93}
94
95#[cfg(any(test, feature = "alloc"))]
96#[inline]
97pub fn limbs_equal_limb_constant_time(a: &[Limb], b: Limb) -> LimbMask {
98 unsafe { LIMBS_equal_limb(a.as_ptr(), b, a.len()) }
99}
100
101#[cfg(any(test, feature = "alloc"))]
109pub fn limbs_minimal_bits(a: &[Limb]) -> bits::BitLength {
110 for num_limbs in (1..=a.len()).rev() {
111 let high_limb = a[num_limbs - 1];
112
113 for high_limb_num_bits in (1..=LIMB_BITS).rev() {
118 let shifted = unsafe { LIMB_shr(high_limb, high_limb_num_bits - 1) };
119 if shifted != 0 {
120 return bits::BitLength::from_usize_bits(
121 ((num_limbs - 1) * LIMB_BITS) + high_limb_num_bits,
122 );
123 }
124 }
125 }
126
127 bits::BitLength::from_usize_bits(0)
129}
130
131#[inline]
133pub fn limbs_reduce_once_constant_time(r: &mut [Limb], m: &[Limb]) {
134 assert_eq!(r.len(), m.len());
135 unsafe { LIMBS_reduce_once(r.as_mut_ptr(), m.as_ptr(), m.len()) };
136}
137
138#[derive(Clone, Copy, PartialEq)]
139pub enum AllowZero {
140 No,
141 Yes,
142}
143
144pub fn parse_big_endian_in_range_and_pad_consttime(
153 input: untrusted::Input,
154 allow_zero: AllowZero,
155 max_exclusive: &[Limb],
156 result: &mut [Limb],
157) -> Result<(), error::Unspecified> {
158 parse_big_endian_and_pad_consttime(input, result)?;
159 if limbs_less_than_limbs_consttime(result, max_exclusive) != LimbMask::True {
160 return Err(error::Unspecified);
161 }
162 if allow_zero != AllowZero::Yes {
163 if limbs_are_zero_constant_time(result) != LimbMask::False {
164 return Err(error::Unspecified);
165 }
166 }
167 Ok(())
168}
169
170pub fn parse_big_endian_and_pad_consttime(
174 input: untrusted::Input,
175 result: &mut [Limb],
176) -> Result<(), error::Unspecified> {
177 if input.is_empty() {
178 return Err(error::Unspecified);
179 }
180
181 let mut bytes_in_current_limb = input.len() % LIMB_BYTES;
185 if bytes_in_current_limb == 0 {
186 bytes_in_current_limb = LIMB_BYTES;
187 }
188
189 let num_encoded_limbs = (input.len() / LIMB_BYTES)
190 + (if bytes_in_current_limb == LIMB_BYTES {
191 0
192 } else {
193 1
194 });
195 if num_encoded_limbs > result.len() {
196 return Err(error::Unspecified);
197 }
198
199 result.fill(0);
200
201 input.read_all(error::Unspecified, |input| {
204 for i in 0..num_encoded_limbs {
205 let mut limb: Limb = 0;
206 for _ in 0..bytes_in_current_limb {
207 let b: Limb = input.read_byte()?.into();
208 limb = (limb << 8) | b;
209 }
210 result[num_encoded_limbs - i - 1] = limb;
211 bytes_in_current_limb = LIMB_BYTES;
212 }
213 Ok(())
214 })
215}
216
217pub fn big_endian_from_limbs(limbs: &[Limb], out: &mut [u8]) {
218 let be_bytes = unstripped_be_bytes(limbs);
219 assert_eq!(out.len(), be_bytes.len());
220 out.iter_mut().zip(be_bytes).for_each(|(o, i)| {
221 *o = i;
222 });
223}
224
225pub fn unstripped_be_bytes(limbs: &[Limb]) -> impl ExactSizeIterator<Item = u8> + Clone + '_ {
230 ArrayFlatMap::new(limbs.iter().rev().copied(), Limb::to_be_bytes).unwrap()
232}
233
234#[cfg(feature = "alloc")]
235pub type Window = Limb;
236
237#[cfg(feature = "alloc")]
250pub fn fold_5_bit_windows<R, I: FnOnce(Window) -> R, F: Fn(R, Window) -> R>(
251 limbs: &[Limb],
252 init: I,
253 fold: F,
254) -> R {
255 #[derive(Clone, Copy)]
256 #[repr(transparent)]
257 struct BitIndex(Wrapping<c::size_t>);
258
259 const WINDOW_BITS: Wrapping<c::size_t> = Wrapping(5);
260
261 prefixed_extern! {
262 fn LIMBS_window5_split_window(
263 lower_limb: Limb,
264 higher_limb: Limb,
265 index_within_word: BitIndex,
266 ) -> Window;
267 fn LIMBS_window5_unsplit_window(limb: Limb, index_within_word: BitIndex) -> Window;
268 }
269
270 let num_limbs = limbs.len();
271 let mut window_low_bit = {
272 let num_whole_windows = (num_limbs * LIMB_BITS) / 5;
273 let mut leading_bits = (num_limbs * LIMB_BITS) - (num_whole_windows * 5);
274 if leading_bits == 0 {
275 leading_bits = WINDOW_BITS.0;
276 }
277 BitIndex(Wrapping(LIMB_BITS - leading_bits))
278 };
279
280 let initial_value = {
281 let leading_partial_window =
282 unsafe { LIMBS_window5_split_window(*limbs.last().unwrap(), 0, window_low_bit) };
283 window_low_bit.0 -= WINDOW_BITS;
284 init(leading_partial_window)
285 };
286
287 let mut low_limb = 0;
288 limbs
289 .iter()
290 .rev()
291 .fold(initial_value, |mut acc, current_limb| {
292 let higher_limb = low_limb;
293 low_limb = *current_limb;
294
295 if window_low_bit.0 > Wrapping(LIMB_BITS) - WINDOW_BITS {
296 let window =
297 unsafe { LIMBS_window5_split_window(low_limb, higher_limb, window_low_bit) };
298 window_low_bit.0 -= WINDOW_BITS;
299 acc = fold(acc, window);
300 };
301 while window_low_bit.0 < Wrapping(LIMB_BITS) {
302 let window = unsafe { LIMBS_window5_unsplit_window(low_limb, window_low_bit) };
303 window_low_bit.0 -= WINDOW_BITS;
306 acc = fold(acc, window);
307 }
308 window_low_bit.0 += Wrapping(LIMB_BITS); acc
311 })
312}
313
314#[inline]
315pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
316 debug_assert_eq!(a.len(), m.len());
317 debug_assert_eq!(b.len(), m.len());
318 prefixed_extern! {
319 fn LIMBS_add_mod(
321 r: *mut Limb,
322 a: *const Limb,
323 b: *const Limb,
324 m: *const Limb,
325 num_limbs: c::size_t,
326 );
327 }
328 unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) }
329}
330
331pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) {
333 assert_eq!(r.len(), m.len());
334 prefixed_extern! {
335 fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
336 }
337 unsafe {
338 LIMBS_shl_mod(r.as_mut_ptr(), r.as_ptr(), m.as_ptr(), m.len());
339 }
340}
341
342pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) {
344 debug_assert_eq!(r.len(), a.len());
345 r.iter_mut().zip(a.iter()).for_each(|(r, &a)| {
348 *r = !a;
349 });
350 r[0] |= 1;
353}
354
355prefixed_extern! {
356 fn LIMBS_are_zero(a: *const Limb, num_limbs: c::size_t) -> LimbMask;
357 fn LIMBS_less_than(a: *const Limb, b: *const Limb, num_limbs: c::size_t) -> LimbMask;
358 fn LIMBS_reduce_once(r: *mut Limb, m: *const Limb, num_limbs: c::size_t);
359}
360
361#[cfg(any(test, feature = "alloc"))]
362prefixed_extern! {
363 fn LIMB_shr(a: Limb, shift: c::size_t) -> Limb;
364 fn LIMBS_are_even(a: *const Limb, num_limbs: c::size_t) -> LimbMask;
365 fn LIMBS_equal_limb(a: *const Limb, b: Limb, num_limbs: c::size_t) -> LimbMask;
366}
367
368#[cfg(feature = "alloc")]
369prefixed_extern! {
370 fn LIMBS_less_than_limb(a: *const Limb, b: Limb, num_limbs: c::size_t) -> LimbMask;
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 const MAX: Limb = LimbMask::True as Limb;
378
379 #[test]
380 fn test_limbs_are_even() {
381 static EVENS: &[&[Limb]] = &[
382 &[],
383 &[0],
384 &[2],
385 &[0, 0],
386 &[2, 0],
387 &[0, 1],
388 &[0, 2],
389 &[0, 3],
390 &[0, 0, 0, 0, MAX],
391 ];
392 for even in EVENS {
393 assert_eq!(limbs_are_even_constant_time(even), LimbMask::True);
394 }
395 static ODDS: &[&[Limb]] = &[
396 &[1],
397 &[3],
398 &[1, 0],
399 &[3, 0],
400 &[1, 1],
401 &[1, 2],
402 &[1, 3],
403 &[1, 0, 0, 0, MAX],
404 ];
405 for odd in ODDS {
406 assert_eq!(limbs_are_even_constant_time(odd), LimbMask::False);
407 }
408 }
409
410 static ZEROES: &[&[Limb]] = &[
411 &[],
412 &[0],
413 &[0, 0],
414 &[0, 0, 0],
415 &[0, 0, 0, 0],
416 &[0, 0, 0, 0, 0],
417 &[0, 0, 0, 0, 0, 0, 0],
418 &[0, 0, 0, 0, 0, 0, 0, 0],
419 &[0, 0, 0, 0, 0, 0, 0, 0, 0],
420 ];
421
422 static NONZEROES: &[&[Limb]] = &[
423 &[1],
424 &[0, 1],
425 &[1, 1],
426 &[1, 0, 0, 0],
427 &[0, 1, 0, 0],
428 &[0, 0, 1, 0],
429 &[0, 0, 0, 1],
430 ];
431
432 #[test]
433 fn test_limbs_are_zero() {
434 for zero in ZEROES {
435 assert_eq!(limbs_are_zero_constant_time(zero), LimbMask::True);
436 }
437 for nonzero in NONZEROES {
438 assert_eq!(limbs_are_zero_constant_time(nonzero), LimbMask::False);
439 }
440 }
441
442 #[test]
443 fn test_limbs_equal_limb() {
444 for zero in ZEROES {
445 assert_eq!(limbs_equal_limb_constant_time(zero, 0), LimbMask::True);
446 }
447 for nonzero in NONZEROES {
448 assert_eq!(limbs_equal_limb_constant_time(nonzero, 0), LimbMask::False);
449 }
450 static EQUAL: &[(&[Limb], Limb)] = &[
451 (&[1], 1),
452 (&[MAX], MAX),
453 (&[1, 0], 1),
454 (&[MAX, 0, 0], MAX),
455 (&[0b100], 0b100),
456 (&[0b100, 0], 0b100),
457 ];
458 for &(a, b) in EQUAL {
459 assert_eq!(limbs_equal_limb_constant_time(a, b), LimbMask::True);
460 }
461 static UNEQUAL: &[(&[Limb], Limb)] = &[
462 (&[0], 1),
463 (&[2], 1),
464 (&[3], 1),
465 (&[1, 1], 1),
466 (&[0b100, 0b100], 0b100),
467 (&[1, 0, 0b100, 0, 0, 0, 0, 0], 1),
468 (&[1, 0, 0, 0, 0, 0, 0, 0b100], 1),
469 (&[MAX, MAX], MAX),
470 (&[MAX, 1], MAX),
471 ];
472 for &(a, b) in UNEQUAL {
473 assert_eq!(limbs_equal_limb_constant_time(a, b), LimbMask::False);
474 }
475 }
476
477 #[test]
478 #[cfg(feature = "alloc")]
479 fn test_limbs_less_than_limb_constant_time() {
480 static LESSER: &[(&[Limb], Limb)] = &[
481 (&[0], 1),
482 (&[0, 0], 1),
483 (&[1, 0], 2),
484 (&[2, 0], 3),
485 (&[2, 0], 3),
486 (&[MAX - 1], MAX),
487 (&[MAX - 1, 0], MAX),
488 ];
489 for &(a, b) in LESSER {
490 assert_eq!(limbs_less_than_limb_constant_time(a, b), LimbMask::True);
491 }
492 static EQUAL: &[(&[Limb], Limb)] = &[
493 (&[0], 0),
494 (&[0, 0, 0, 0], 0),
495 (&[1], 1),
496 (&[1, 0, 0, 0, 0, 0, 0], 1),
497 (&[MAX], MAX),
498 ];
499 static GREATER: &[(&[Limb], Limb)] = &[
500 (&[1], 0),
501 (&[2, 0], 1),
502 (&[3, 0, 0, 0], 1),
503 (&[0, 1, 0, 0], 1),
504 (&[0, 0, 1, 0], 1),
505 (&[0, 0, 1, 1], 1),
506 (&[MAX], MAX - 1),
507 ];
508 for &(a, b) in EQUAL.iter().chain(GREATER.iter()) {
509 assert_eq!(limbs_less_than_limb_constant_time(a, b), LimbMask::False);
510 }
511 }
512
513 #[test]
514 fn test_parse_big_endian_and_pad_consttime() {
515 const LIMBS: usize = 4;
516
517 {
518 let inp = untrusted::Input::from(&[]);
520 let mut result = [0; LIMBS];
521 assert!(parse_big_endian_and_pad_consttime(inp, &mut result).is_err());
522 }
523
524 {
526 let inp = [1, 2, 3, 4, 5, 6, 7, 8, 9];
527 let inp = untrusted::Input::from(&inp);
528 let mut result = [0; 8 / LIMB_BYTES];
529 assert!(parse_big_endian_and_pad_consttime(inp, &mut result[..]).is_err());
530 }
531
532 {
534 let inp = [0xfe];
535 let inp = untrusted::Input::from(&inp);
536 let mut result = [0; LIMBS];
537 assert_eq!(
538 Ok(()),
539 parse_big_endian_and_pad_consttime(inp, &mut result[..])
540 );
541 assert_eq!(&[0xfe, 0, 0, 0], &result);
542 }
543
544 {
546 let inp = [0xbe, 0xef, 0xf0, 0x0d];
547 let inp = untrusted::Input::from(&inp);
548 let mut result = [0; LIMBS];
549 assert_eq!(Ok(()), parse_big_endian_and_pad_consttime(inp, &mut result));
550 assert_eq!(&[0xbeeff00d, 0, 0, 0], &result);
551 }
552
553 }
555
556 #[test]
557 fn test_big_endian_from_limbs_same_length() {
558 #[cfg(target_pointer_width = "32")]
559 let limbs = [
560 0xbccddeef, 0x89900aab, 0x45566778, 0x01122334, 0xddeeff00, 0x99aabbcc, 0x55667788,
561 0x11223344,
562 ];
563
564 #[cfg(target_pointer_width = "64")]
565 let limbs = [
566 0x8990_0aab_bccd_deef,
567 0x0112_2334_4556_6778,
568 0x99aa_bbcc_ddee_ff00,
569 0x1122_3344_5566_7788,
570 ];
571
572 let expected = [
573 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee,
574 0xff, 0x00, 0x01, 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0x90, 0x0a, 0xab,
575 0xbc, 0xcd, 0xde, 0xef,
576 ];
577
578 let mut out = [0xabu8; 32];
579 big_endian_from_limbs(&limbs[..], &mut out);
580 assert_eq!(&out[..], &expected[..]);
581 }
582
583 #[should_panic]
584 #[test]
585 fn test_big_endian_from_limbs_fewer_limbs() {
586 #[cfg(target_pointer_width = "32")]
587 let limbs = [
589 0xbccddeef, 0x89900aab, 0x45566778, 0x01122334, 0xddeeff00, 0x99aabbcc,
590 ];
591
592 #[cfg(target_pointer_width = "64")]
594 let limbs = [
595 0x8990_0aab_bccd_deef,
596 0x0112_2334_4556_6778,
597 0x99aa_bbcc_ddee_ff00,
598 ];
599
600 let mut out = [0xabu8; 32];
601
602 big_endian_from_limbs(&limbs[..], &mut out);
603 }
604
605 #[test]
606 fn test_limbs_minimal_bits() {
607 const ALL_ONES: Limb = LimbMask::True as Limb;
608 static CASES: &[(&[Limb], usize)] = &[
609 (&[], 0),
610 (&[0], 0),
611 (&[ALL_ONES], LIMB_BITS),
612 (&[ALL_ONES, 0], LIMB_BITS),
613 (&[ALL_ONES, 1], LIMB_BITS + 1),
614 (&[0, 0], 0),
615 (&[1, 0], 1),
616 (&[0, 1], LIMB_BITS + 1),
617 (&[0, ALL_ONES], 2 * LIMB_BITS),
618 (&[ALL_ONES, ALL_ONES], 2 * LIMB_BITS),
619 (&[ALL_ONES, ALL_ONES >> 1], 2 * LIMB_BITS - 1),
620 (&[ALL_ONES, 0b100_0000], LIMB_BITS + 7),
621 (&[ALL_ONES, 0b101_0000], LIMB_BITS + 7),
622 (&[ALL_ONES, ALL_ONES >> 1], LIMB_BITS + (LIMB_BITS) - 1),
623 ];
624 for (limbs, bits) in CASES {
625 assert_eq!(limbs_minimal_bits(limbs).as_bits(), *bits);
626 }
627 }
628}