1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2019 Oleg Andreev
5// See LICENSE for licensing information.
6//
7// Authors:
8// - Oleg Andreev <oleganza@gmail.com>
910#![allow(non_snake_case)]
1112#[curve25519_dalek_derive::unsafe_target_feature_specialize(
13"avx2",
14 conditional("avx512ifma,avx512vl", nightly)
15)]
16pub mod spec {
1718use alloc::vec::Vec;
1920use core::borrow::Borrow;
21use core::cmp::Ordering;
2223#[for_target_feature("avx2")]
24use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint};
2526#[for_target_feature("avx512ifma")]
27use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint};
2829use crate::edwards::EdwardsPoint;
30use crate::scalar::Scalar;
31use crate::traits::{Identity, VartimeMultiscalarMul};
3233/// Implements a version of Pippenger's algorithm.
34 ///
35 /// See the documentation in the serial `scalar_mul::pippenger` module for details.
36pub struct Pippenger;
3738impl VartimeMultiscalarMul for Pippenger {
39type Point = EdwardsPoint;
4041fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
42where
43I: IntoIterator,
44 I::Item: Borrow<Scalar>,
45 J: IntoIterator<Item = Option<EdwardsPoint>>,
46 {
47let mut scalars = scalars.into_iter();
48let size = scalars.by_ref().size_hint().0;
49let w = if size < 500 {
506
51} else if size < 800 {
527
53} else {
548
55};
5657let max_digit: usize = 1 << w;
58let digits_count: usize = Scalar::to_radix_2w_size_hint(w);
59let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket
6061 // Collect optimized scalars and points in a buffer for repeated access
62 // (scanning the whole collection per each digit position).
63let scalars = scalars.map(|s| s.borrow().as_radix_2w(w));
6465let points = points
66 .into_iter()
67 .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P))));
6869let scalars_points = scalars
70 .zip(points)
71 .map(|(s, maybe_p)| maybe_p.map(|p| (s, p)))
72 .collect::<Option<Vec<_>>>()?;
7374// Prepare 2^w/2 buckets.
75 // buckets[i] corresponds to a multiplication factor (i+1).
76let mut buckets: Vec<ExtendedPoint> = (0..buckets_count)
77 .map(|_| ExtendedPoint::identity())
78 .collect();
7980let mut columns = (0..digits_count).rev().map(|digit_index| {
81// Clear the buckets when processing another digit.
82for bucket in &mut buckets {
83*bucket = ExtendedPoint::identity();
84 }
8586// Iterate over pairs of (point, scalar)
87 // and add/sub the point to the corresponding bucket.
88 // Note: if we add support for precomputed lookup tables,
89 // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0].
90for (digits, pt) in scalars_points.iter() {
91// Widen digit so that we don't run into edge cases when w=8.
92let digit = digits[digit_index] as i16;
93match digit.cmp(&0) {
94 Ordering::Greater => {
95let b = (digit - 1) as usize;
96 buckets[b] = &buckets[b] + pt;
97 }
98 Ordering::Less => {
99let b = (-digit - 1) as usize;
100 buckets[b] = &buckets[b] - pt;
101 }
102 Ordering::Equal => {}
103 }
104 }
105106// Add the buckets applying the multiplication factor to each bucket.
107 // The most efficient way to do that is to have a single sum with two running sums:
108 // an intermediate sum from last bucket to the first, and a sum of intermediate sums.
109 //
110 // For example, to add buckets 1*A, 2*B, 3*C we need to add these points:
111 // C
112 // C B
113 // C B A Sum = C + (C+B) + (C+B+A)
114let mut buckets_intermediate_sum = buckets[buckets_count - 1];
115let mut buckets_sum = buckets[buckets_count - 1];
116for i in (0..(buckets_count - 1)).rev() {
117 buckets_intermediate_sum =
118&buckets_intermediate_sum + &CachedPoint::from(buckets[i]);
119 buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum);
120 }
121122 buckets_sum
123 });
124125// Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`.
126let hi_column = columns.next().expect("should have more than zero digits");
127128Some(
129 columns
130 .fold(hi_column, |total, p| {
131&total.mul_by_pow_2(w as u32) + &CachedPoint::from(p)
132 })
133 .into(),
134 )
135 }
136 }
137138#[cfg(test)]
139mod test {
140#[test]
141fn test_vartime_pippenger() {
142use super::*;
143use crate::constants;
144use crate::scalar::Scalar;
145146// Reuse points across different tests
147let mut n = 512;
148let x = Scalar::from(2128506u64).invert();
149let y = Scalar::from(4443282u64).invert();
150let points: Vec<_> = (0..n)
151 .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64))
152 .collect();
153let scalars: Vec<_> = (0..n)
154 .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars
155.collect();
156157let premultiplied: Vec<EdwardsPoint> = scalars
158 .iter()
159 .zip(points.iter())
160 .map(|(sc, pt)| sc * pt)
161 .collect();
162163while n > 0 {
164let scalars = &scalars[0..n].to_vec();
165let points = &points[0..n].to_vec();
166let control: EdwardsPoint = premultiplied[0..n].iter().sum();
167168let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone());
169170assert_eq!(subject.compress(), control.compress());
171172 n = n / 2;
173 }
174 }
175 }
176}