curve25519_dalek/backend/serial/scalar_mul/pippenger.rs
1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2019 Oleg Andreev
5// See LICENSE for licensing information.
6//
7// Authors:
8// - Oleg Andreev <oleganza@gmail.com>
9
10//! Implementation of a variant of Pippenger's algorithm.
11
12#![allow(non_snake_case)]
13
14use alloc::vec::Vec;
15
16use core::borrow::Borrow;
17use core::cmp::Ordering;
18
19use crate::edwards::EdwardsPoint;
20use crate::scalar::Scalar;
21use crate::traits::VartimeMultiscalarMul;
22
23/// Implements a version of Pippenger's algorithm.
24///
25/// The algorithm works as follows:
26///
27/// Let `n` be a number of point-scalar pairs.
28/// Let `w` be a window of bits (6..8, chosen based on `n`, see cost factor).
29///
30/// 1. Prepare `2^(w-1) - 1` buckets with indices `[1..2^(w-1))` initialized with identity points.
31/// Bucket 0 is not needed as it would contain points multiplied by 0.
32/// 2. Convert scalars to a radix-`2^w` representation with signed digits in `[-2^w/2, 2^w/2]`.
33/// Note: only the last digit may equal `2^w/2`.
34/// 3. Starting with the last window, for each point `i=[0..n)` add it to a a bucket indexed by
35/// the point's scalar's value in the window.
36/// 4. Once all points in a window are sorted into buckets, add buckets by multiplying each
37/// by their index. Efficient way of doing it is to start with the last bucket and compute two sums:
38/// intermediate sum from the last to the first, and the full sum made of all intermediate sums.
39/// 5. Shift the resulting sum of buckets by `w` bits by using `w` doublings.
40/// 6. Add to the return value.
41/// 7. Repeat the loop.
42///
43/// Approximate cost w/o wNAF optimizations (A = addition, D = doubling):
44///
45/// ```ascii
46/// cost = (n*A + 2*(2^w/2)*A + w*D + A)*256/w
47/// | | | | |
48/// | | | | looping over 256/w windows
49/// | | | adding to the result
50/// sorting points | shifting the sum by w bits (to the next window, starting from last window)
51/// one by one |
52/// into buckets adding/subtracting all buckets
53/// multiplied by their indexes
54/// using a sum of intermediate sums
55/// ```
56///
57/// For large `n`, dominant factor is (n*256/w) additions.
58/// However, if `w` is too big and `n` is not too big, then `(2^w/2)*A` could dominate.
59/// Therefore, the optimal choice of `w` grows slowly as `n` grows.
60///
61/// This algorithm is adapted from section 4 of <https://eprint.iacr.org/2012/549.pdf>.
62pub struct Pippenger;
63
64impl VartimeMultiscalarMul for Pippenger {
65 type Point = EdwardsPoint;
66
67 fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
68 where
69 I: IntoIterator,
70 I::Item: Borrow<Scalar>,
71 J: IntoIterator<Item = Option<EdwardsPoint>>,
72 {
73 use crate::traits::Identity;
74
75 let mut scalars = scalars.into_iter();
76 let size = scalars.by_ref().size_hint().0;
77
78 // Digit width in bits. As digit width grows,
79 // number of point additions goes down, but amount of
80 // buckets and bucket additions grows exponentially.
81 let w = if size < 500 {
82 6
83 } else if size < 800 {
84 7
85 } else {
86 8
87 };
88
89 let max_digit: usize = 1 << w;
90 let digits_count: usize = Scalar::to_radix_2w_size_hint(w);
91 let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket
92
93 // Collect optimized scalars and points in buffers for repeated access
94 // (scanning the whole set per digit position).
95 let scalars = scalars.map(|s| s.borrow().as_radix_2w(w));
96
97 let points = points
98 .into_iter()
99 .map(|p| p.map(|P| P.as_projective_niels()));
100
101 let scalars_points = scalars
102 .zip(points)
103 .map(|(s, maybe_p)| maybe_p.map(|p| (s, p)))
104 .collect::<Option<Vec<_>>>()?;
105
106 // Prepare 2^w/2 buckets.
107 // buckets[i] corresponds to a multiplication factor (i+1).
108 let mut buckets: Vec<_> = (0..buckets_count)
109 .map(|_| EdwardsPoint::identity())
110 .collect();
111
112 let mut columns = (0..digits_count).rev().map(|digit_index| {
113 // Clear the buckets when processing another digit.
114 for bucket in &mut buckets {
115 *bucket = EdwardsPoint::identity();
116 }
117
118 // Iterate over pairs of (point, scalar)
119 // and add/sub the point to the corresponding bucket.
120 // Note: if we add support for precomputed lookup tables,
121 // we'll be adding/subtracting point premultiplied by `digits[i]` to buckets[0].
122 for (digits, pt) in scalars_points.iter() {
123 // Widen digit so that we don't run into edge cases when w=8.
124 let digit = digits[digit_index] as i16;
125 match digit.cmp(&0) {
126 Ordering::Greater => {
127 let b = (digit - 1) as usize;
128 buckets[b] = (&buckets[b] + pt).as_extended();
129 }
130 Ordering::Less => {
131 let b = (-digit - 1) as usize;
132 buckets[b] = (&buckets[b] - pt).as_extended();
133 }
134 Ordering::Equal => {}
135 }
136 }
137
138 // Add the buckets applying the multiplication factor to each bucket.
139 // The most efficient way to do that is to have a single sum with two running sums:
140 // an intermediate sum from last bucket to the first, and a sum of intermediate sums.
141 //
142 // For example, to add buckets 1*A, 2*B, 3*C we need to add these points:
143 // C
144 // C B
145 // C B A Sum = C + (C+B) + (C+B+A)
146 let mut buckets_intermediate_sum = buckets[buckets_count - 1];
147 let mut buckets_sum = buckets[buckets_count - 1];
148 for i in (0..(buckets_count - 1)).rev() {
149 buckets_intermediate_sum += buckets[i];
150 buckets_sum += buckets_intermediate_sum;
151 }
152
153 buckets_sum
154 });
155
156 // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`.
157 let hi_column = columns.next().expect("should have more than zero digits");
158
159 Some(columns.fold(hi_column, |total, p| total.mul_by_pow_2(w as u32) + p))
160 }
161}
162
163#[cfg(test)]
164mod test {
165 use super::*;
166 use crate::constants;
167
168 #[test]
169 fn test_vartime_pippenger() {
170 // Reuse points across different tests
171 let mut n = 512;
172 let x = Scalar::from(2128506u64).invert();
173 let y = Scalar::from(4443282u64).invert();
174 let points: Vec<_> = (0..n)
175 .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64))
176 .collect();
177 let scalars: Vec<_> = (0..n)
178 .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars
179 .collect();
180
181 let premultiplied: Vec<EdwardsPoint> = scalars
182 .iter()
183 .zip(points.iter())
184 .map(|(sc, pt)| sc * pt)
185 .collect();
186
187 while n > 0 {
188 let scalars = &scalars[0..n].to_vec();
189 let points = &points[0..n].to_vec();
190 let control: EdwardsPoint = premultiplied[0..n].iter().sum();
191
192 let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone());
193
194 assert_eq!(subject.compress(), control.compress());
195
196 n /= 2;
197 }
198 }
199}