curve25519_dalek/backend/vector/avx2/
edwards.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2016-2021 isis lovecruft
5// Copyright (c) 2016-2019 Henry de Valence
6// See LICENSE for licensing information.
7//
8// Authors:
9// - isis agora lovecruft <isis@patternsinthevoid.net>
10// - Henry de Valence <hdevalence@hdevalence.ca>
11
12//! Parallel Edwards Arithmetic for Curve25519.
13//!
14//! This module currently has two point types:
15//!
16//! * `ExtendedPoint`: a point stored in vector-friendly format, with
17//! vectorized doubling and addition;
18//!
19//! * `CachedPoint`: used for readdition.
20//!
21//! Details on the formulas can be found in the documentation for the
22//! parent `avx2` module.
23//!
24//! This API is designed to be safe: vectorized points can only be
25//! created from serial points (which do validation on decompression),
26//! and operations on valid points return valid points, so invalid
27//! point states should be unrepresentable.
28//!
29//! This design goal is met, with one exception: the `Neg`
30//! implementation for the `CachedPoint` performs a lazy negation, so
31//! that subtraction can be efficiently implemented as a negation and
32//! an addition.  Repeatedly negating a `CachedPoint` will cause its
33//! coefficients to grow and eventually overflow.  Repeatedly negating
34//! a point should not be necessary anyways.
35
36#![allow(non_snake_case)]
37
38use core::ops::{Add, Neg, Sub};
39
40use subtle::Choice;
41use subtle::ConditionallySelectable;
42
43use curve25519_dalek_derive::unsafe_target_feature;
44
45use crate::edwards;
46use crate::window::{LookupTable, NafLookupTable5};
47
48#[cfg(any(feature = "precomputed-tables", feature = "alloc"))]
49use crate::window::NafLookupTable8;
50
51use crate::traits::Identity;
52
53use super::constants;
54use super::field::{FieldElement2625x4, Lanes, Shuffle};
55
56/// A point on Curve25519, using parallel Edwards formulas for curve
57/// operations.
58///
59/// # Invariant
60///
61/// The coefficients of an `ExtendedPoint` are bounded with
62/// \\( b < 0.007 \\).
63#[derive(Copy, Clone, Debug)]
64pub struct ExtendedPoint(pub(super) FieldElement2625x4);
65
66#[unsafe_target_feature("avx2")]
67impl From<edwards::EdwardsPoint> for ExtendedPoint {
68    fn from(P: edwards::EdwardsPoint) -> ExtendedPoint {
69        ExtendedPoint(FieldElement2625x4::new(&P.X, &P.Y, &P.Z, &P.T))
70    }
71}
72
73#[unsafe_target_feature("avx2")]
74impl From<ExtendedPoint> for edwards::EdwardsPoint {
75    fn from(P: ExtendedPoint) -> edwards::EdwardsPoint {
76        let tmp = P.0.split();
77        edwards::EdwardsPoint {
78            X: tmp[0],
79            Y: tmp[1],
80            Z: tmp[2],
81            T: tmp[3],
82        }
83    }
84}
85
86#[unsafe_target_feature("avx2")]
87impl ConditionallySelectable for ExtendedPoint {
88    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
89        ExtendedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
90    }
91
92    fn conditional_assign(&mut self, other: &Self, choice: Choice) {
93        self.0.conditional_assign(&other.0, choice);
94    }
95}
96
97#[unsafe_target_feature("avx2")]
98impl Default for ExtendedPoint {
99    fn default() -> ExtendedPoint {
100        ExtendedPoint::identity()
101    }
102}
103
104#[unsafe_target_feature("avx2")]
105impl Identity for ExtendedPoint {
106    fn identity() -> ExtendedPoint {
107        constants::EXTENDEDPOINT_IDENTITY
108    }
109}
110
111#[unsafe_target_feature("avx2")]
112impl ExtendedPoint {
113    /// Compute the double of this point.
114    pub fn double(&self) -> ExtendedPoint {
115        // Want to compute (X1 Y1 Z1 X1+Y1).
116        // Not sure how to do this less expensively than computing
117        // (X1 Y1 Z1 T1) --(256bit shuffle)--> (X1 Y1 X1 Y1)
118        // (X1 Y1 X1 Y1) --(2x128b shuffle)--> (Y1 X1 Y1 X1)
119        // and then adding.
120
121        // Set tmp0 = (X1 Y1 X1 Y1)
122        let mut tmp0 = self.0.shuffle(Shuffle::ABAB);
123
124        // Set tmp1 = (Y1 X1 Y1 X1)
125        let mut tmp1 = tmp0.shuffle(Shuffle::BADC);
126
127        // Set tmp0 = (X1 Y1 Z1 X1+Y1)
128        tmp0 = self.0.blend(tmp0 + tmp1, Lanes::D);
129
130        // Set tmp1 = tmp0^2, negating the D values
131        tmp1 = tmp0.square_and_negate_D();
132        // Now tmp1 = (S1 S2 S3 -S4) with b < 0.007
133
134        // See discussion of bounds in the module-level documentation.
135        // We want to compute
136        //
137        //    + | S1 | S1 | S1 | S1 |
138        //    + | S2 |    |    | S2 |
139        //    + |    |    | S3 |    |
140        //    + |    |    | S3 |    |
141        //    + |    |    |    |-S4 |
142        //    + |    | 2p | 2p |    |
143        //    - |    | S2 | S2 |    |
144        //    =======================
145        //        S5   S6   S8   S9
146
147        let zero = FieldElement2625x4::ZERO;
148        let S_1 = tmp1.shuffle(Shuffle::AAAA);
149        let S_2 = tmp1.shuffle(Shuffle::BBBB);
150
151        tmp0 = zero.blend(tmp1 + tmp1, Lanes::C);
152        // tmp0 = (0, 0,  2S_3, 0)
153        tmp0 = tmp0.blend(tmp1, Lanes::D);
154        // tmp0 = (0, 0,  2S_3, -S_4)
155        tmp0 = tmp0 + S_1;
156        // tmp0 = (  S_1,   S_1, S_1 + 2S_3, S_1 - S_4)
157        tmp0 = tmp0 + zero.blend(S_2, Lanes::AD);
158        // tmp0 = (S_1 + S_2,   S_1, S_1 + 2S_3, S_1 + S_2 - S_4)
159        tmp0 = tmp0 + zero.blend(S_2.negate_lazy(), Lanes::BC);
160        // tmp0 = (S_1 + S_2, S_1 - S_2, S_1 - S_2 + 2S_3, S_1 + S_2 - S_4)
161        //    b < (     1.01,       1.6,             2.33,             1.6)
162        // Now tmp0 = (S_5, S_6, S_8, S_9)
163
164        // Set tmp1 = ( S_9,  S_6,  S_6,  S_9)
165        //        b < ( 1.6,  1.6,  1.6,  1.6)
166        tmp1 = tmp0.shuffle(Shuffle::DBBD);
167        // Set tmp0 = ( S_8,  S_5,  S_8,  S_5)
168        //        b < (2.33, 1.01, 2.33, 1.01)
169        tmp0 = tmp0.shuffle(Shuffle::CACA);
170
171        // Bounds on (tmp0, tmp1) are (2.33, 1.6) < (2.5, 1.75).
172        ExtendedPoint(&tmp0 * &tmp1)
173    }
174
175    pub fn mul_by_pow_2(&self, k: u32) -> ExtendedPoint {
176        let mut tmp: ExtendedPoint = *self;
177        for _ in 0..k {
178            tmp = tmp.double();
179        }
180        tmp
181    }
182}
183
184/// A cached point with some precomputed variables used for readdition.
185///
186/// # Warning
187///
188/// It is not safe to negate this point more than once.
189///
190/// # Invariant
191///
192/// As long as the `CachedPoint` is not repeatedly negated, its
193/// coefficients will be bounded with \\( b < 1.0 \\).
194#[derive(Copy, Clone, Debug)]
195pub struct CachedPoint(pub(super) FieldElement2625x4);
196
197#[unsafe_target_feature("avx2")]
198impl From<ExtendedPoint> for CachedPoint {
199    fn from(P: ExtendedPoint) -> CachedPoint {
200        let mut x = P.0;
201
202        x = x.blend(x.diff_sum(), Lanes::AB);
203        // x = (Y2 - X2, Y2 + X2, Z2, T2) = (S2 S3 Z2 T2)
204
205        x = x * (121666, 121666, 2 * 121666, 2 * 121665);
206        // x = (121666*S2 121666*S3 2*121666*Z2 2*121665*T2)
207
208        x = x.blend(-x, Lanes::D);
209        // x = (121666*S2 121666*S3 2*121666*Z2 -2*121665*T2)
210
211        // The coefficients of the output are bounded with b < 0.007.
212        CachedPoint(x)
213    }
214}
215
216#[unsafe_target_feature("avx2")]
217impl Default for CachedPoint {
218    fn default() -> CachedPoint {
219        CachedPoint::identity()
220    }
221}
222
223#[unsafe_target_feature("avx2")]
224impl Identity for CachedPoint {
225    fn identity() -> CachedPoint {
226        constants::CACHEDPOINT_IDENTITY
227    }
228}
229
230#[unsafe_target_feature("avx2")]
231impl ConditionallySelectable for CachedPoint {
232    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
233        CachedPoint(FieldElement2625x4::conditional_select(&a.0, &b.0, choice))
234    }
235
236    fn conditional_assign(&mut self, other: &Self, choice: Choice) {
237        self.0.conditional_assign(&other.0, choice);
238    }
239}
240
241#[unsafe_target_feature("avx2")]
242impl Neg for &CachedPoint {
243    type Output = CachedPoint;
244    /// Lazily negate the point.
245    ///
246    /// # Warning
247    ///
248    /// Because this method does not perform a reduction, it is not
249    /// safe to repeatedly negate a point.
250    fn neg(self) -> CachedPoint {
251        let swapped = self.0.shuffle(Shuffle::BACD);
252        CachedPoint(swapped.blend(swapped.negate_lazy(), Lanes::D))
253    }
254}
255
256#[unsafe_target_feature("avx2")]
257impl Add<&CachedPoint> for &ExtendedPoint {
258    type Output = ExtendedPoint;
259
260    /// Add an `ExtendedPoint` and a `CachedPoint`.
261    fn add(self, other: &CachedPoint) -> ExtendedPoint {
262        // The coefficients of an `ExtendedPoint` are reduced after
263        // every operation.  If the `CachedPoint` was negated, its
264        // coefficients grow by one bit.  So on input, `self` is
265        // bounded with `b < 0.007` and `other` is bounded with
266        // `b < 1.0`.
267
268        let mut tmp = self.0;
269
270        tmp = tmp.blend(tmp.diff_sum(), Lanes::AB);
271        // tmp = (Y1-X1 Y1+X1 Z1 T1) = (S0 S1 Z1 T1) with b < 1.6
272
273        // (tmp, other) bounded with b < (1.6, 1.0) < (2.5, 1.75).
274        tmp = &tmp * &other.0;
275        // tmp = (S0*S2' S1*S3' Z1*Z2' T1*T2') = (S8 S9 S10 S11)
276
277        tmp = tmp.shuffle(Shuffle::ABDC);
278        // tmp = (S8 S9 S11 S10)
279
280        tmp = tmp.diff_sum();
281        // tmp = (S9-S8 S9+S8 S10-S11 S10+S11) = (S12 S13 S14 S15)
282
283        let t0 = tmp.shuffle(Shuffle::ADDA);
284        // t0 = (S12 S15 S15 S12)
285        let t1 = tmp.shuffle(Shuffle::CBCB);
286        // t1 = (S14 S13 S14 S13)
287
288        // All coefficients of t0, t1 are bounded with b < 1.6.
289        // Return (S12*S14 S15*S13 S15*S14 S12*S13) = (X3 Y3 Z3 T3)
290        ExtendedPoint(&t0 * &t1)
291    }
292}
293
294#[unsafe_target_feature("avx2")]
295impl Sub<&CachedPoint> for &ExtendedPoint {
296    type Output = ExtendedPoint;
297
298    /// Implement subtraction by negating the point and adding.
299    ///
300    /// Empirically, this seems about the same cost as a custom
301    /// subtraction impl (maybe because the benefit is cancelled by
302    /// increased code size?)
303    fn sub(self, other: &CachedPoint) -> ExtendedPoint {
304        self + &(-other)
305    }
306}
307
308#[unsafe_target_feature("avx2")]
309impl From<&edwards::EdwardsPoint> for LookupTable<CachedPoint> {
310    fn from(point: &edwards::EdwardsPoint) -> Self {
311        let P = ExtendedPoint::from(*point);
312        let mut points = [CachedPoint::from(P); 8];
313        for i in 0..7 {
314            points[i + 1] = (&P + &points[i]).into();
315        }
316        LookupTable(points)
317    }
318}
319
320#[unsafe_target_feature("avx2")]
321impl From<&edwards::EdwardsPoint> for NafLookupTable5<CachedPoint> {
322    fn from(point: &edwards::EdwardsPoint) -> Self {
323        let A = ExtendedPoint::from(*point);
324        let mut Ai = [CachedPoint::from(A); 8];
325        let A2 = A.double();
326        for i in 0..7 {
327            Ai[i + 1] = (&A2 + &Ai[i]).into();
328        }
329        // Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A]
330        NafLookupTable5(Ai)
331    }
332}
333
334#[cfg(any(feature = "precomputed-tables", feature = "alloc"))]
335#[unsafe_target_feature("avx2")]
336impl From<&edwards::EdwardsPoint> for NafLookupTable8<CachedPoint> {
337    fn from(point: &edwards::EdwardsPoint) -> Self {
338        let A = ExtendedPoint::from(*point);
339        let mut Ai = [CachedPoint::from(A); 64];
340        let A2 = A.double();
341        for i in 0..63 {
342            Ai[i + 1] = (&A2 + &Ai[i]).into();
343        }
344        // Now Ai = [A, 3A, 5A, 7A, 9A, 11A, 13A, 15A, ..., 127A]
345        NafLookupTable8(Ai)
346    }
347}
348
349#[cfg(target_feature = "avx2")]
350#[cfg(test)]
351mod test {
352    use super::*;
353
354    #[rustfmt::skip] // keep alignment of some S* calculations
355    fn serial_add(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) -> edwards::EdwardsPoint {
356        use crate::backend::serial::u64::field::FieldElement51;
357
358        let (X1, Y1, Z1, T1) = (P.X, P.Y, P.Z, P.T);
359        let (X2, Y2, Z2, T2) = (Q.X, Q.Y, Q.Z, Q.T);
360
361        macro_rules! print_var {
362            ($x:ident) => {
363                println!("{} = {:?}", stringify!($x), $x.as_bytes());
364            };
365        }
366
367        let S0 = &Y1 - &X1; // R1
368        let S1 = &Y1 + &X1; // R3
369        let S2 = &Y2 - &X2; // R2
370        let S3 = &Y2 + &X2; // R4
371        print_var!(S0);
372        print_var!(S1);
373        print_var!(S2);
374        print_var!(S3);
375        println!("");
376
377        let S4 = &S0 * &S2; // R5 = R1 * R2
378        let S5 = &S1 * &S3; // R6 = R3 * R4
379        let S6 = &Z1 * &Z2; // R8
380        let S7 = &T1 * &T2; // R7
381        print_var!(S4);
382        print_var!(S5);
383        print_var!(S6);
384        print_var!(S7);
385        println!("");
386
387        let S8  =  &S4 *    &FieldElement51([  121666,0,0,0,0]);  // R5
388        let S9  =  &S5 *    &FieldElement51([  121666,0,0,0,0]);  // R6
389        let S10 =  &S6 *    &FieldElement51([2*121666,0,0,0,0]);  // R8
390        let S11 =  &S7 * &(-&FieldElement51([2*121665,0,0,0,0])); // R7
391        print_var!(S8);
392        print_var!(S9);
393        print_var!(S10);
394        print_var!(S11);
395        println!("");
396
397        let S12 =  &S9 - &S8;  // R1
398        let S13 =  &S9 + &S8;  // R4
399        let S14 = &S10 - &S11; // R2
400        let S15 = &S10 + &S11; // R3
401        print_var!(S12);
402        print_var!(S13);
403        print_var!(S14);
404        print_var!(S15);
405        println!("");
406
407        let X3 = &S12 * &S14; // R1 * R2
408        let Y3 = &S15 * &S13; // R3 * R4
409        let Z3 = &S15 * &S14; // R2 * R3
410        let T3 = &S12 * &S13; // R1 * R4
411
412        edwards::EdwardsPoint {
413            X: X3,
414            Y: Y3,
415            Z: Z3,
416            T: T3,
417        }
418    }
419
420    fn addition_test_helper(P: edwards::EdwardsPoint, Q: edwards::EdwardsPoint) {
421        // Test the serial implementation of the parallel addition formulas
422        let R_serial: edwards::EdwardsPoint = serial_add(P.into(), Q.into()).into();
423
424        // Test the vector implementation of the parallel readdition formulas
425        let cached_Q = CachedPoint::from(ExtendedPoint::from(Q));
426        let R_vector: edwards::EdwardsPoint = (&ExtendedPoint::from(P) + &cached_Q).into();
427        let S_vector: edwards::EdwardsPoint = (&ExtendedPoint::from(P) - &cached_Q).into();
428
429        println!("Testing point addition:");
430        println!("P = {:?}", P);
431        println!("Q = {:?}", Q);
432        println!("cached Q = {:?}", cached_Q);
433        println!("R = P + Q = {:?}", &P + &Q);
434        println!("R_serial = {:?}", R_serial);
435        println!("R_vector = {:?}", R_vector);
436        println!("S = P - Q = {:?}", &P - &Q);
437        println!("S_vector = {:?}", S_vector);
438        assert_eq!(R_serial.compress(), (&P + &Q).compress());
439        assert_eq!(R_vector.compress(), (&P + &Q).compress());
440        assert_eq!(S_vector.compress(), (&P - &Q).compress());
441        println!("OK!\n");
442    }
443
444    #[test]
445    fn vector_addition_vs_serial_addition_vs_edwards_extendedpoint() {
446        use crate::constants;
447        use crate::scalar::Scalar;
448
449        println!("Testing id +- id");
450        let P = edwards::EdwardsPoint::identity();
451        let Q = edwards::EdwardsPoint::identity();
452        addition_test_helper(P, Q);
453
454        println!("Testing id +- B");
455        let P = edwards::EdwardsPoint::identity();
456        let Q = constants::ED25519_BASEPOINT_POINT;
457        addition_test_helper(P, Q);
458
459        println!("Testing B +- B");
460        let P = constants::ED25519_BASEPOINT_POINT;
461        let Q = constants::ED25519_BASEPOINT_POINT;
462        addition_test_helper(P, Q);
463
464        println!("Testing B +- kB");
465        let P = constants::ED25519_BASEPOINT_POINT;
466        let Q = constants::ED25519_BASEPOINT_TABLE * &Scalar::from(8475983829u64);
467        addition_test_helper(P, Q);
468    }
469
470    fn serial_double(P: edwards::EdwardsPoint) -> edwards::EdwardsPoint {
471        let (X1, Y1, Z1, _T1) = (P.X, P.Y, P.Z, P.T);
472
473        macro_rules! print_var {
474            ($x:ident) => {
475                println!("{} = {:?}", stringify!($x), $x.as_bytes());
476            };
477        }
478
479        let S0 = &X1 + &Y1; // R1
480        print_var!(S0);
481        println!("");
482
483        let S1 = X1.square();
484        let S2 = Y1.square();
485        let S3 = Z1.square();
486        let S4 = S0.square();
487        print_var!(S1);
488        print_var!(S2);
489        print_var!(S3);
490        print_var!(S4);
491        println!("");
492
493        let S5 = &S1 + &S2;
494        let S6 = &S1 - &S2;
495        let S7 = &S3 + &S3;
496        let S8 = &S7 + &S6;
497        let S9 = &S5 - &S4;
498        print_var!(S5);
499        print_var!(S6);
500        print_var!(S7);
501        print_var!(S8);
502        print_var!(S9);
503        println!("");
504
505        let X3 = &S8 * &S9;
506        let Y3 = &S5 * &S6;
507        let Z3 = &S8 * &S6;
508        let T3 = &S5 * &S9;
509
510        edwards::EdwardsPoint {
511            X: X3,
512            Y: Y3,
513            Z: Z3,
514            T: T3,
515        }
516    }
517
518    fn doubling_test_helper(P: edwards::EdwardsPoint) {
519        let R1: edwards::EdwardsPoint = serial_double(P.into()).into();
520        let R2: edwards::EdwardsPoint = ExtendedPoint::from(P).double().into();
521        println!("Testing point doubling:");
522        println!("P = {:?}", P);
523        println!("(serial) R1 = {:?}", R1);
524        println!("(vector) R2 = {:?}", R2);
525        println!("P + P = {:?}", &P + &P);
526        assert_eq!(R1.compress(), (&P + &P).compress());
527        assert_eq!(R2.compress(), (&P + &P).compress());
528        println!("OK!\n");
529    }
530
531    #[test]
532    fn vector_doubling_vs_serial_doubling_vs_edwards_extendedpoint() {
533        use crate::constants;
534        use crate::scalar::Scalar;
535
536        println!("Testing [2]id");
537        let P = edwards::EdwardsPoint::identity();
538        doubling_test_helper(P);
539
540        println!("Testing [2]B");
541        let P = constants::ED25519_BASEPOINT_POINT;
542        doubling_test_helper(P);
543
544        println!("Testing [2]([k]B)");
545        let P = constants::ED25519_BASEPOINT_TABLE * &Scalar::from(8475983829u64);
546        doubling_test_helper(P);
547    }
548
549    #[cfg(any(feature = "precomputed-tables", feature = "alloc"))]
550    #[test]
551    fn basepoint_odd_lookup_table_verify() {
552        use crate::backend::vector::avx2::constants::BASEPOINT_ODD_LOOKUP_TABLE;
553        use crate::constants;
554
555        let basepoint_odd_table =
556            NafLookupTable8::<CachedPoint>::from(&constants::ED25519_BASEPOINT_POINT);
557        println!("basepoint_odd_lookup_table = {:?}", basepoint_odd_table);
558
559        let table_B = &BASEPOINT_ODD_LOOKUP_TABLE;
560        for (b_vec, base_vec) in table_B.0.iter().zip(basepoint_odd_table.0.iter()) {
561            let b_splits = b_vec.0.split();
562            let base_splits = base_vec.0.split();
563
564            assert_eq!(base_splits[0], b_splits[0]);
565            assert_eq!(base_splits[1], b_splits[1]);
566            assert_eq!(base_splits[2], b_splits[2]);
567            assert_eq!(base_splits[3], b_splits[3]);
568        }
569    }
570}