curve25519_dalek/backend/vector/avx2/
field.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2016-2021 isis lovecruft
5// Copyright (c) 2016-2019 Henry de Valence
6// See LICENSE for licensing information.
7//
8// Authors:
9// - isis agora lovecruft <isis@patternsinthevoid.net>
10// - Henry de Valence <hdevalence@hdevalence.ca>
11
12//! An implementation of 4-way vectorized 32bit field arithmetic using
13//! AVX2.
14//!
15//! The `FieldElement2625x4` struct provides a vector of four field
16//! elements, implemented using AVX2 operations.  Its API is designed
17//! to abstract away the platform-dependent details, so that point
18//! arithmetic can be implemented only in terms of a vector of field
19//! elements.
20//!
21//! At this level, the API is optimized for speed and not safety.  The
22//! `FieldElement2625x4` does not always perform reductions.  The pre-
23//! and post-conditions on the bounds of the coefficients are
24//! documented for each method, but it is the caller's responsibility
25//! to ensure that there are no overflows.
26
27#![allow(non_snake_case)]
28
29const A_LANES: u8 = 0b0000_0101;
30const B_LANES: u8 = 0b0000_1010;
31const C_LANES: u8 = 0b0101_0000;
32const D_LANES: u8 = 0b1010_0000;
33
34#[allow(unused)]
35const A_LANES64: u8 = 0b00_00_00_11;
36#[allow(unused)]
37const B_LANES64: u8 = 0b00_00_11_00;
38#[allow(unused)]
39const C_LANES64: u8 = 0b00_11_00_00;
40#[allow(unused)]
41const D_LANES64: u8 = 0b11_00_00_00;
42
43use crate::backend::vector::packed_simd::{u32x8, u64x4};
44use core::ops::{Add, Mul, Neg};
45
46use crate::backend::serial::u64::field::FieldElement51;
47use crate::backend::vector::avx2::constants::{
48    P_TIMES_16_HI, P_TIMES_16_LO, P_TIMES_2_HI, P_TIMES_2_LO,
49};
50
51use curve25519_dalek_derive::unsafe_target_feature;
52
53/// Unpack 32-bit lanes into 64-bit lanes:
54/// ```ascii,no_run
55/// (a0, b0, a1, b1, c0, d0, c1, d1)
56/// ```
57/// into
58/// ```ascii,no_run
59/// (a0, 0, b0, 0, c0, 0, d0, 0)
60/// (a1, 0, b1, 0, c1, 0, d1, 0)
61/// ```
62#[unsafe_target_feature("avx2")]
63#[inline(always)]
64fn unpack_pair(src: u32x8) -> (u32x8, u32x8) {
65    let a: u32x8;
66    let b: u32x8;
67    let zero = u32x8::splat(0);
68    unsafe {
69        use core::arch::x86_64::_mm256_unpackhi_epi32;
70        use core::arch::x86_64::_mm256_unpacklo_epi32;
71        a = _mm256_unpacklo_epi32(src.into(), zero.into()).into();
72        b = _mm256_unpackhi_epi32(src.into(), zero.into()).into();
73    }
74    (a, b)
75}
76
77/// Repack 64-bit lanes into 32-bit lanes:
78/// ```ascii,no_run
79/// (a0, 0, b0, 0, c0, 0, d0, 0)
80/// (a1, 0, b1, 0, c1, 0, d1, 0)
81/// ```
82/// into
83/// ```ascii,no_run
84/// (a0, b0, a1, b1, c0, d0, c1, d1)
85/// ```
86#[unsafe_target_feature("avx2")]
87#[inline(always)]
88fn repack_pair(x: u32x8, y: u32x8) -> u32x8 {
89    unsafe {
90        use core::arch::x86_64::_mm256_blend_epi32;
91        use core::arch::x86_64::_mm256_shuffle_epi32;
92
93        // Input: x = (a0, 0, b0, 0, c0, 0, d0, 0)
94        // Input: y = (a1, 0, b1, 0, c1, 0, d1, 0)
95
96        let x_shuffled = _mm256_shuffle_epi32(x.into(), 0b11_01_10_00);
97        let y_shuffled = _mm256_shuffle_epi32(y.into(), 0b10_00_11_01);
98
99        // x' = (a0, b0,  0,  0, c0, d0,  0,  0)
100        // y' = ( 0,  0, a1, b1,  0,  0, c1, d1)
101
102        _mm256_blend_epi32(x_shuffled, y_shuffled, 0b11001100).into()
103    }
104}
105
106/// The `Lanes` enum represents a subset of the lanes `A,B,C,D` of a
107/// `FieldElement2625x4`.
108///
109/// It's used to specify blend operations without
110/// having to know details about the data layout of the
111/// `FieldElement2625x4`.
112#[allow(clippy::upper_case_acronyms)]
113#[derive(Copy, Clone, Debug)]
114pub enum Lanes {
115    C,
116    D,
117    AB,
118    AC,
119    CD,
120    AD,
121    BC,
122    ABCD,
123}
124
125/// The `Shuffle` enum represents a shuffle of a `FieldElement2625x4`.
126///
127/// The enum variants are named by what they do to a vector \\(
128/// (A,B,C,D) \\); for instance, `Shuffle::BADC` turns \\( (A, B, C,
129/// D) \\) into \\( (B, A, D, C) \\).
130#[allow(clippy::upper_case_acronyms)]
131#[derive(Copy, Clone, Debug)]
132pub enum Shuffle {
133    AAAA,
134    BBBB,
135    CACA,
136    DBBD,
137    ADDA,
138    CBCB,
139    ABAB,
140    BADC,
141    BACD,
142    ABDC,
143}
144
145/// A vector of four field elements.
146///
147/// Each operation on a `FieldElement2625x4` has documented effects on
148/// the bounds of the coefficients.  This API is designed for speed
149/// and not safety; it is the caller's responsibility to ensure that
150/// the post-conditions of one operation are compatible with the
151/// pre-conditions of the next.
152#[derive(Clone, Copy, Debug)]
153pub struct FieldElement2625x4(pub(crate) [u32x8; 5]);
154
155use subtle::Choice;
156use subtle::ConditionallySelectable;
157
158#[unsafe_target_feature("avx2")]
159impl ConditionallySelectable for FieldElement2625x4 {
160    fn conditional_select(
161        a: &FieldElement2625x4,
162        b: &FieldElement2625x4,
163        choice: Choice,
164    ) -> FieldElement2625x4 {
165        let mask = (-(choice.unwrap_u8() as i32)) as u32;
166        let mask_vec = u32x8::splat(mask);
167        FieldElement2625x4([
168            a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])),
169            a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])),
170            a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])),
171            a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])),
172            a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])),
173        ])
174    }
175
176    fn conditional_assign(&mut self, other: &FieldElement2625x4, choice: Choice) {
177        let mask = (-(choice.unwrap_u8() as i32)) as u32;
178        let mask_vec = u32x8::splat(mask);
179        self.0[0] ^= mask_vec & (self.0[0] ^ other.0[0]);
180        self.0[1] ^= mask_vec & (self.0[1] ^ other.0[1]);
181        self.0[2] ^= mask_vec & (self.0[2] ^ other.0[2]);
182        self.0[3] ^= mask_vec & (self.0[3] ^ other.0[3]);
183        self.0[4] ^= mask_vec & (self.0[4] ^ other.0[4]);
184    }
185}
186
187#[unsafe_target_feature("avx2")]
188impl FieldElement2625x4 {
189    pub const ZERO: FieldElement2625x4 = FieldElement2625x4([u32x8::splat_const::<0>(); 5]);
190
191    /// Split this vector into an array of four (serial) field
192    /// elements.
193    #[rustfmt::skip] // keep alignment of extracted lanes
194    pub fn split(&self) -> [FieldElement51; 4] {
195        let mut out = [FieldElement51::ZERO; 4];
196        for i in 0..5 {
197            let a_2i   = self.0[i].extract::<0>() as u64; //
198            let b_2i   = self.0[i].extract::<1>() as u64; //
199            let a_2i_1 = self.0[i].extract::<2>() as u64; // `.
200            let b_2i_1 = self.0[i].extract::<3>() as u64; //  | pre-swapped to avoid
201            let c_2i   = self.0[i].extract::<4>() as u64; //  | a cross lane shuffle
202            let d_2i   = self.0[i].extract::<5>() as u64; // .'
203            let c_2i_1 = self.0[i].extract::<6>() as u64; //
204            let d_2i_1 = self.0[i].extract::<7>() as u64; //
205
206            out[0].0[i] = a_2i + (a_2i_1 << 26);
207            out[1].0[i] = b_2i + (b_2i_1 << 26);
208            out[2].0[i] = c_2i + (c_2i_1 << 26);
209            out[3].0[i] = d_2i + (d_2i_1 << 26);
210        }
211
212        out
213    }
214
215    /// Rearrange the elements of this vector according to `control`.
216    ///
217    /// The `control` parameter should be a compile-time constant, so
218    /// that when this function is inlined, LLVM is able to lower the
219    /// shuffle using an immediate.
220    #[inline]
221    pub fn shuffle(&self, control: Shuffle) -> FieldElement2625x4 {
222        #[inline(always)]
223        fn shuffle_lanes(x: u32x8, control: Shuffle) -> u32x8 {
224            unsafe {
225                use core::arch::x86_64::_mm256_permutevar8x32_epi32;
226
227                let c: u32x8 = match control {
228                    Shuffle::AAAA => u32x8::new(0, 0, 2, 2, 0, 0, 2, 2),
229                    Shuffle::BBBB => u32x8::new(1, 1, 3, 3, 1, 1, 3, 3),
230                    Shuffle::CACA => u32x8::new(4, 0, 6, 2, 4, 0, 6, 2),
231                    Shuffle::DBBD => u32x8::new(5, 1, 7, 3, 1, 5, 3, 7),
232                    Shuffle::ADDA => u32x8::new(0, 5, 2, 7, 5, 0, 7, 2),
233                    Shuffle::CBCB => u32x8::new(4, 1, 6, 3, 4, 1, 6, 3),
234                    Shuffle::ABAB => u32x8::new(0, 1, 2, 3, 0, 1, 2, 3),
235                    Shuffle::BADC => u32x8::new(1, 0, 3, 2, 5, 4, 7, 6),
236                    Shuffle::BACD => u32x8::new(1, 0, 3, 2, 4, 5, 6, 7),
237                    Shuffle::ABDC => u32x8::new(0, 1, 2, 3, 5, 4, 7, 6),
238                };
239                // Note that this gets turned into a generic LLVM
240                // shuffle-by-constants, which can be lowered to a simpler
241                // instruction than a generic permute.
242                _mm256_permutevar8x32_epi32(x.into(), c.into()).into()
243            }
244        }
245
246        FieldElement2625x4([
247            shuffle_lanes(self.0[0], control),
248            shuffle_lanes(self.0[1], control),
249            shuffle_lanes(self.0[2], control),
250            shuffle_lanes(self.0[3], control),
251            shuffle_lanes(self.0[4], control),
252        ])
253    }
254
255    /// Blend `self` with `other`, taking lanes specified in `control` from `other`.
256    ///
257    /// The `control` parameter should be a compile-time constant, so
258    /// that this function can be inlined and LLVM can lower it to a
259    /// blend instruction using an immediate.
260    #[inline]
261    pub fn blend(&self, other: FieldElement2625x4, control: Lanes) -> FieldElement2625x4 {
262        #[inline(always)]
263        fn blend_lanes(x: u32x8, y: u32x8, control: Lanes) -> u32x8 {
264            unsafe {
265                use core::arch::x86_64::_mm256_blend_epi32;
266
267                // This would be much cleaner if we could factor out the match
268                // statement on the control. Unfortunately, rustc forgets
269                // constant-info very quickly, so we can't even write
270                // ```
271                // match control {
272                //     Lanes::C => {
273                //         let imm = C_LANES as i32;
274                //         _mm256_blend_epi32(..., imm)
275                // ```
276                // let alone
277                // ```
278                // let imm = match control {
279                //     Lanes::C => C_LANES as i32,
280                // }
281                // _mm256_blend_epi32(..., imm)
282                // ```
283                // even though both of these would be constant-folded by LLVM
284                // at a lower level (as happens in the shuffle implementation,
285                // which does not require a shuffle immediate but *is* lowered
286                // to immediate shuffles anyways).
287                match control {
288                    Lanes::C => _mm256_blend_epi32(x.into(), y.into(), C_LANES as i32).into(),
289                    Lanes::D => _mm256_blend_epi32(x.into(), y.into(), D_LANES as i32).into(),
290                    Lanes::AD => {
291                        _mm256_blend_epi32(x.into(), y.into(), (A_LANES | D_LANES) as i32).into()
292                    }
293                    Lanes::AB => {
294                        _mm256_blend_epi32(x.into(), y.into(), (A_LANES | B_LANES) as i32).into()
295                    }
296                    Lanes::AC => {
297                        _mm256_blend_epi32(x.into(), y.into(), (A_LANES | C_LANES) as i32).into()
298                    }
299                    Lanes::CD => {
300                        _mm256_blend_epi32(x.into(), y.into(), (C_LANES | D_LANES) as i32).into()
301                    }
302                    Lanes::BC => {
303                        _mm256_blend_epi32(x.into(), y.into(), (B_LANES | C_LANES) as i32).into()
304                    }
305                    Lanes::ABCD => _mm256_blend_epi32(
306                        x.into(),
307                        y.into(),
308                        (A_LANES | B_LANES | C_LANES | D_LANES) as i32,
309                    )
310                    .into(),
311                }
312            }
313        }
314
315        FieldElement2625x4([
316            blend_lanes(self.0[0], other.0[0], control),
317            blend_lanes(self.0[1], other.0[1], control),
318            blend_lanes(self.0[2], other.0[2], control),
319            blend_lanes(self.0[3], other.0[3], control),
320            blend_lanes(self.0[4], other.0[4], control),
321        ])
322    }
323
324    /// Convenience wrapper around `new(x,x,x,x)`.
325    pub fn splat(x: &FieldElement51) -> FieldElement2625x4 {
326        FieldElement2625x4::new(x, x, x, x)
327    }
328
329    /// Create a `FieldElement2625x4` from four `FieldElement51`s.
330    ///
331    /// # Postconditions
332    ///
333    /// The resulting `FieldElement2625x4` is bounded with \\( b < 0.0002 \\).
334    #[rustfmt::skip] // keep alignment of computed lanes
335    pub fn new(
336        x0: &FieldElement51,
337        x1: &FieldElement51,
338        x2: &FieldElement51,
339        x3: &FieldElement51,
340    ) -> FieldElement2625x4 {
341        let mut buf = [u32x8::splat(0); 5];
342        let low_26_bits = (1 << 26) - 1;
343        #[allow(clippy::needless_range_loop)]
344        for i in 0..5 {
345            let a_2i   = (x0.0[i] & low_26_bits) as u32;
346            let a_2i_1 = (x0.0[i] >> 26) as u32;
347            let b_2i   = (x1.0[i] & low_26_bits) as u32;
348            let b_2i_1 = (x1.0[i] >> 26) as u32;
349            let c_2i   = (x2.0[i] & low_26_bits) as u32;
350            let c_2i_1 = (x2.0[i] >> 26) as u32;
351            let d_2i   = (x3.0[i] & low_26_bits) as u32;
352            let d_2i_1 = (x3.0[i] >> 26) as u32;
353
354            buf[i] = u32x8::new(a_2i, b_2i, a_2i_1, b_2i_1, c_2i, d_2i, c_2i_1, d_2i_1);
355        }
356
357        // We don't know that the original `FieldElement51`s were
358        // fully reduced, so the odd limbs may exceed 2^25.
359        // Reduce them to be sure.
360        FieldElement2625x4(buf).reduce()
361    }
362
363    /// Given \\((A,B,C,D)\\), compute \\((-A,-B,-C,-D)\\), without
364    /// performing a reduction.
365    ///
366    /// # Preconditions
367    ///
368    /// The coefficients of `self` must be bounded with \\( b < 0.999 \\).
369    ///
370    /// # Postconditions
371    ///
372    /// The coefficients of the result are bounded with \\( b < 1 \\).
373    #[inline]
374    pub fn negate_lazy(&self) -> FieldElement2625x4 {
375        // The limbs of self are bounded with b < 0.999, while the
376        // smallest limb of 2*p is 67108845 > 2^{26+0.9999}, so
377        // underflows are not possible.
378        FieldElement2625x4([
379            P_TIMES_2_LO - self.0[0],
380            P_TIMES_2_HI - self.0[1],
381            P_TIMES_2_HI - self.0[2],
382            P_TIMES_2_HI - self.0[3],
383            P_TIMES_2_HI - self.0[4],
384        ])
385    }
386
387    /// Given `self = (A,B,C,D)`, compute `(B - A, B + A, D - C, D + C)`.
388    ///
389    /// # Preconditions
390    ///
391    /// The coefficients of `self` must be bounded with \\( b < 0.01 \\).
392    ///
393    /// # Postconditions
394    ///
395    /// The coefficients of the result are bounded with \\( b < 1.6 \\).
396    #[inline]
397    pub fn diff_sum(&self) -> FieldElement2625x4 {
398        // tmp1 = (B, A, D, C)
399        let tmp1 = self.shuffle(Shuffle::BADC);
400        // tmp2 = (-A, B, -C, D)
401        let tmp2 = self.blend(self.negate_lazy(), Lanes::AC);
402        // (B - A, B + A, D - C, D + C) bounded with b < 1.6
403        tmp1 + tmp2
404    }
405
406    /// Reduce this vector of field elements \\(\mathrm{mod} p\\).
407    ///
408    /// # Postconditions
409    ///
410    /// The coefficients of the result are bounded with \\( b < 0.0002 \\).
411    #[inline]
412    pub fn reduce(&self) -> FieldElement2625x4 {
413        let shifts = u32x8::new(26, 26, 25, 25, 26, 26, 25, 25);
414        let masks = u32x8::new(
415            (1 << 26) - 1,
416            (1 << 26) - 1,
417            (1 << 25) - 1,
418            (1 << 25) - 1,
419            (1 << 26) - 1,
420            (1 << 26) - 1,
421            (1 << 25) - 1,
422            (1 << 25) - 1,
423        );
424
425        // Let c(x) denote the carryout of the coefficient x.
426        //
427        // Given    (   x0,    y0,    x1,    y1,    z0,    w0,    z1,    w1),
428        // compute  (c(x1), c(y1), c(x0), c(y0), c(z1), c(w1), c(z0), c(w0)).
429        //
430        // The carryouts are bounded by 2^(32 - 25) = 2^7.
431        let rotated_carryout = |v: u32x8| -> u32x8 {
432            unsafe {
433                use core::arch::x86_64::_mm256_shuffle_epi32;
434                use core::arch::x86_64::_mm256_srlv_epi32;
435
436                let c = _mm256_srlv_epi32(v.into(), shifts.into());
437                _mm256_shuffle_epi32(c, 0b01_00_11_10).into()
438            }
439        };
440
441        // Combine (lo, lo, lo, lo, lo, lo, lo, lo)
442        //    with (hi, hi, hi, hi, hi, hi, hi, hi)
443        //      to (lo, lo, hi, hi, lo, lo, hi, hi)
444        //
445        // This allows combining carryouts, e.g.,
446        //
447        // lo  (c(x1), c(y1), c(x0), c(y0), c(z1), c(w1), c(z0), c(w0))
448        // hi  (c(x3), c(y3), c(x2), c(y2), c(z3), c(w3), c(z2), c(w2))
449        // ->  (c(x1), c(y1), c(x2), c(y2), c(z1), c(w1), c(z2), c(w2))
450        //
451        // which is exactly the vector of carryins for
452        //
453        //     (   x2,    y2,    x3,    y3,    z2,    w2,    z3,    w3).
454        //
455        let combine = |v_lo: u32x8, v_hi: u32x8| -> u32x8 {
456            unsafe {
457                use core::arch::x86_64::_mm256_blend_epi32;
458                _mm256_blend_epi32(v_lo.into(), v_hi.into(), 0b11_00_11_00).into()
459            }
460        };
461
462        let mut v = self.0;
463
464        let c10 = rotated_carryout(v[0]);
465        v[0] = (v[0] & masks) + combine(u32x8::splat(0), c10);
466
467        let c32 = rotated_carryout(v[1]);
468        v[1] = (v[1] & masks) + combine(c10, c32);
469
470        let c54 = rotated_carryout(v[2]);
471        v[2] = (v[2] & masks) + combine(c32, c54);
472
473        let c76 = rotated_carryout(v[3]);
474        v[3] = (v[3] & masks) + combine(c54, c76);
475
476        let c98 = rotated_carryout(v[4]);
477        v[4] = (v[4] & masks) + combine(c76, c98);
478
479        let c9_19: u32x8 = unsafe {
480            use core::arch::x86_64::_mm256_mul_epu32;
481            use core::arch::x86_64::_mm256_shuffle_epi32;
482
483            // Need to rearrange c98, since vpmuludq uses the low
484            // 32-bits of each 64-bit lane to compute the product:
485            //
486            // c98       = (c(x9), c(y9), c(x8), c(y8), c(z9), c(w9), c(z8), c(w8));
487            // c9_spread = (c(x9), c(x8), c(y9), c(y8), c(z9), c(z8), c(w9), c(w8)).
488            let c9_spread = _mm256_shuffle_epi32(c98.into(), 0b11_01_10_00);
489
490            // Since the carryouts are bounded by 2^7, their products with 19
491            // are bounded by 2^11.25.  This means that
492            //
493            // c9_19_spread = (19*c(x9), 0, 19*c(y9), 0, 19*c(z9), 0, 19*c(w9), 0).
494            let c9_19_spread = _mm256_mul_epu32(c9_spread, u64x4::splat(19).into());
495
496            // Unshuffle:
497            // c9_19 = (19*c(x9), 19*c(y9), 0, 0, 19*c(z9), 19*c(w9), 0, 0).
498            _mm256_shuffle_epi32(c9_19_spread, 0b11_01_10_00).into()
499        };
500
501        // Add the final carryin.
502        v[0] += c9_19;
503
504        // Each output coefficient has exactly one carryin, which is
505        // bounded by 2^11.25, so they are bounded as
506        //
507        // c_even < 2^26 + 2^11.25 < 26.00006 < 2^{26+b}
508        // c_odd  < 2^25 + 2^11.25 < 25.0001  < 2^{25+b}
509        //
510        // where b = 0.0002.
511        FieldElement2625x4(v)
512    }
513
514    /// Given an array of wide coefficients, reduce them to a `FieldElement2625x4`.
515    ///
516    /// # Postconditions
517    ///
518    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
519    #[inline]
520    #[rustfmt::skip] // keep alignment of carry chain
521    fn reduce64(mut z: [u64x4; 10]) -> FieldElement2625x4 {
522        // These aren't const because splat isn't a const fn
523        let LOW_25_BITS: u64x4 = u64x4::splat((1 << 25) - 1);
524        let LOW_26_BITS: u64x4 = u64x4::splat((1 << 26) - 1);
525
526        // Carry the value from limb i = 0..8 to limb i+1
527        let carry = |z: &mut [u64x4; 10], i: usize| {
528            debug_assert!(i < 9);
529            if i % 2 == 0 {
530                // Even limbs have 26 bits
531                z[i + 1] += z[i].shr::<26>();
532                z[i] &= LOW_26_BITS;
533            } else {
534                // Odd limbs have 25 bits
535                z[i + 1] += z[i].shr::<25>();
536                z[i] &= LOW_25_BITS;
537            }
538        };
539
540        // Perform two halves of the carry chain in parallel.
541        carry(&mut z, 0); carry(&mut z, 4);
542        carry(&mut z, 1); carry(&mut z, 5);
543        carry(&mut z, 2); carry(&mut z, 6);
544        carry(&mut z, 3); carry(&mut z, 7);
545        // Since z[3] < 2^64, c < 2^(64-25) = 2^39,
546        // so    z[4] < 2^26 + 2^39 < 2^39.0002
547        carry(&mut z, 4); carry(&mut z, 8);
548        // Now z[4] < 2^26
549        // and z[5] < 2^25 + 2^13.0002 < 2^25.0004 (good enough)
550
551        // Last carry has a multiplication by 19.  In the serial case we
552        // do a 64-bit multiplication by 19, but here we want to do a
553        // 32-bit multiplication.  However, if we only know z[9] < 2^64,
554        // the carry is bounded as c < 2^(64-25) = 2^39, which is too
555        // big.  To ensure c < 2^32, we would need z[9] < 2^57.
556        // Instead, we split the carry in two, with c = c_0 + c_1*2^26.
557
558        let c = z[9].shr::<25>();
559        z[9] &= LOW_25_BITS;
560        let mut c0: u64x4 = c & LOW_26_BITS; // c0 < 2^26;
561        let mut c1: u64x4 = c.shr::<26>();         // c1 < 2^(39-26) = 2^13;
562
563        let x19 = u64x4::splat(19);
564        c0 = u32x8::from(c0).mul32(u32x8::from(x19));
565        c1 = u32x8::from(c1).mul32(u32x8::from(x19));
566
567        z[0] += c0; // z0 < 2^26 + 2^30.25 < 2^30.33
568        z[1] += c1; // z1 < 2^25 + 2^17.25 < 2^25.0067
569        carry(&mut z, 0); // z0 < 2^26, z1 < 2^25.0067 + 2^4.33 = 2^25.007
570
571        // The output coefficients are bounded with
572        //
573        // b = 0.007  for z[1]
574        // b = 0.0004 for z[5]
575        // b = 0      for other z[i].
576        //
577        // So the packed result is bounded with b = 0.007.
578        FieldElement2625x4([
579            repack_pair(z[0].into(), z[1].into()),
580            repack_pair(z[2].into(), z[3].into()),
581            repack_pair(z[4].into(), z[5].into()),
582            repack_pair(z[6].into(), z[7].into()),
583            repack_pair(z[8].into(), z[9].into()),
584        ])
585    }
586
587    /// Square this field element, and negate the result's \\(D\\) value.
588    ///
589    /// # Preconditions
590    ///
591    /// The coefficients of `self` must be bounded with \\( b < 1.5 \\).
592    ///
593    /// # Postconditions
594    ///
595    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
596    #[rustfmt::skip] // keep alignment of z* calculations
597    pub fn square_and_negate_D(&self) -> FieldElement2625x4 {
598        #[inline(always)]
599        fn m(x: u32x8, y: u32x8) -> u64x4 {
600            x.mul32(y)
601        }
602
603        #[inline(always)]
604        fn m_lo(x: u32x8, y: u32x8) -> u32x8 {
605            x.mul32(y).into()
606        }
607
608        let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0);
609
610        let (x0, x1) = unpack_pair(self.0[0]);
611        let (x2, x3) = unpack_pair(self.0[1]);
612        let (x4, x5) = unpack_pair(self.0[2]);
613        let (x6, x7) = unpack_pair(self.0[3]);
614        let (x8, x9) = unpack_pair(self.0[4]);
615
616        let x0_2 = x0.shl::<1>();
617        let x1_2 = x1.shl::<1>();
618        let x2_2 = x2.shl::<1>();
619        let x3_2 = x3.shl::<1>();
620        let x4_2 = x4.shl::<1>();
621        let x5_2 = x5.shl::<1>();
622        let x6_2 = x6.shl::<1>();
623        let x7_2 = x7.shl::<1>();
624
625        let x5_19 = m_lo(v19, x5);
626        let x6_19 = m_lo(v19, x6);
627        let x7_19 = m_lo(v19, x7);
628        let x8_19 = m_lo(v19, x8);
629        let x9_19 = m_lo(v19, x9);
630
631        let mut z0 = m(x0,   x0) + m(x2_2, x8_19) + m(x4_2, x6_19) + ((m(x1_2, x9_19) +   m(x3_2, x7_19) +    m(x5,   x5_19)).shl::<1>());
632        let mut z1 = m(x0_2, x1) + m(x3_2, x8_19) + m(x5_2, x6_19) +                    ((m(x2,   x9_19) +    m(x4,   x7_19)).shl::<1>());
633        let mut z2 = m(x0_2, x2) + m(x1_2,    x1) + m(x4_2, x8_19) +   m(x6,   x6_19) + ((m(x3_2, x9_19) +    m(x5_2, x7_19)).shl::<1>());
634        let mut z3 = m(x0_2, x3) + m(x1_2,    x2) + m(x5_2, x8_19) +                    ((m(x4,   x9_19) +    m(x6,   x7_19)).shl::<1>());
635        let mut z4 = m(x0_2, x4) + m(x1_2,  x3_2) + m(x2,      x2) +   m(x6_2, x8_19) + ((m(x5_2, x9_19) +    m(x7,   x7_19)).shl::<1>());
636        let mut z5 = m(x0_2, x5) + m(x1_2,    x4) + m(x2_2,    x3) +   m(x7_2, x8_19)                    +  ((m(x6,   x9_19)).shl::<1>());
637        let mut z6 = m(x0_2, x6) + m(x1_2,  x5_2) + m(x2_2,    x4) +   m(x3_2,    x3) +   m(x8,   x8_19) +  ((m(x7_2, x9_19)).shl::<1>());
638        let mut z7 = m(x0_2, x7) + m(x1_2,    x6) + m(x2_2,    x5) +   m(x3_2,    x4)                    +  ((m(x8,   x9_19)).shl::<1>());
639        let mut z8 = m(x0_2, x8) + m(x1_2,  x7_2) + m(x2_2,    x6) +   m(x3_2,  x5_2) +   m(x4,      x4) +  ((m(x9,   x9_19)).shl::<1>());
640        let mut z9 = m(x0_2, x9) + m(x1_2,    x8) + m(x2_2,    x7) +   m(x3_2,    x6) +   m(x4_2,    x5)                                 ;
641
642        // The biggest z_i is bounded as z_i < 249*2^(51 + 2*b);
643        // if b < 1.5 we get z_i < 4485585228861014016.
644        //
645        // The limbs of the multiples of p are bounded above by
646        //
647        // 0x3fffffff << 37 = 9223371899415822336 < 2^63
648        //
649        // and below by
650        //
651        // 0x1fffffff << 37 = 4611685880988434432
652        //                  > 4485585228861014016
653        //
654        // So these multiples of p are big enough to avoid underflow
655        // in subtraction, and small enough to fit within u64
656        // with room for a carry.
657
658        let low__p37 = u64x4::splat(0x3ffffed << 37);
659        let even_p37 = u64x4::splat(0x3ffffff << 37);
660        let odd__p37 = u64x4::splat(0x1ffffff << 37);
661
662        let negate_D = |x: u64x4, p: u64x4| -> u64x4 {
663            unsafe {
664                use core::arch::x86_64::_mm256_blend_epi32;
665                _mm256_blend_epi32(x.into(), (p - x).into(), D_LANES64 as i32).into()
666            }
667        };
668
669        z0 = negate_D(z0, low__p37);
670        z1 = negate_D(z1, odd__p37);
671        z2 = negate_D(z2, even_p37);
672        z3 = negate_D(z3, odd__p37);
673        z4 = negate_D(z4, even_p37);
674        z5 = negate_D(z5, odd__p37);
675        z6 = negate_D(z6, even_p37);
676        z7 = negate_D(z7, odd__p37);
677        z8 = negate_D(z8, even_p37);
678        z9 = negate_D(z9, odd__p37);
679
680        FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9])
681    }
682}
683
684#[unsafe_target_feature("avx2")]
685impl Neg for FieldElement2625x4 {
686    type Output = FieldElement2625x4;
687
688    /// Negate this field element, performing a reduction.
689    ///
690    /// If the coefficients are known to be small, use `negate_lazy`
691    /// to avoid performing a reduction.
692    ///
693    /// # Preconditions
694    ///
695    /// The coefficients of `self` must be bounded with \\( b < 4.0 \\).
696    ///
697    /// # Postconditions
698    ///
699    /// The coefficients of the result are bounded with \\( b < 0.0002 \\).
700    #[inline]
701    fn neg(self) -> FieldElement2625x4 {
702        FieldElement2625x4([
703            P_TIMES_16_LO - self.0[0],
704            P_TIMES_16_HI - self.0[1],
705            P_TIMES_16_HI - self.0[2],
706            P_TIMES_16_HI - self.0[3],
707            P_TIMES_16_HI - self.0[4],
708        ])
709        .reduce()
710    }
711}
712
713#[unsafe_target_feature("avx2")]
714impl Add<FieldElement2625x4> for FieldElement2625x4 {
715    type Output = FieldElement2625x4;
716    /// Add two `FieldElement2625x4`s, without performing a reduction.
717    #[inline]
718    fn add(self, rhs: FieldElement2625x4) -> FieldElement2625x4 {
719        FieldElement2625x4([
720            self.0[0] + rhs.0[0],
721            self.0[1] + rhs.0[1],
722            self.0[2] + rhs.0[2],
723            self.0[3] + rhs.0[3],
724            self.0[4] + rhs.0[4],
725        ])
726    }
727}
728
729#[unsafe_target_feature("avx2")]
730impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 {
731    type Output = FieldElement2625x4;
732    /// Perform a multiplication by a vector of small constants.
733    ///
734    /// # Postconditions
735    ///
736    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
737    #[inline]
738    fn mul(self, scalars: (u32, u32, u32, u32)) -> FieldElement2625x4 {
739        let consts = u32x8::new(scalars.0, 0, scalars.1, 0, scalars.2, 0, scalars.3, 0);
740
741        let (b0, b1) = unpack_pair(self.0[0]);
742        let (b2, b3) = unpack_pair(self.0[1]);
743        let (b4, b5) = unpack_pair(self.0[2]);
744        let (b6, b7) = unpack_pair(self.0[3]);
745        let (b8, b9) = unpack_pair(self.0[4]);
746
747        FieldElement2625x4::reduce64([
748            b0.mul32(consts),
749            b1.mul32(consts),
750            b2.mul32(consts),
751            b3.mul32(consts),
752            b4.mul32(consts),
753            b5.mul32(consts),
754            b6.mul32(consts),
755            b7.mul32(consts),
756            b8.mul32(consts),
757            b9.mul32(consts),
758        ])
759    }
760}
761
762#[unsafe_target_feature("avx2")]
763impl Mul<&FieldElement2625x4> for &FieldElement2625x4 {
764    type Output = FieldElement2625x4;
765    /// Multiply `self` by `rhs`.
766    ///
767    /// # Preconditions
768    ///
769    /// The coefficients of `self` must be bounded with \\( b < 2.5 \\).
770    ///
771    /// The coefficients of `rhs` must be bounded with \\( b < 1.75 \\).
772    ///
773    /// # Postconditions
774    ///
775    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
776    ///
777    #[rustfmt::skip] // keep alignment of z* calculations
778    #[inline]
779    fn mul(self, rhs: &FieldElement2625x4) -> FieldElement2625x4 {
780        #[inline(always)]
781        fn m(x: u32x8, y: u32x8) -> u64x4 {
782            x.mul32(y)
783        }
784
785        #[inline(always)]
786        fn m_lo(x: u32x8, y: u32x8) -> u32x8 {
787            x.mul32(y).into()
788        }
789
790        let (x0, x1) = unpack_pair(self.0[0]);
791        let (x2, x3) = unpack_pair(self.0[1]);
792        let (x4, x5) = unpack_pair(self.0[2]);
793        let (x6, x7) = unpack_pair(self.0[3]);
794        let (x8, x9) = unpack_pair(self.0[4]);
795
796        let (y0, y1) = unpack_pair(rhs.0[0]);
797        let (y2, y3) = unpack_pair(rhs.0[1]);
798        let (y4, y5) = unpack_pair(rhs.0[2]);
799        let (y6, y7) = unpack_pair(rhs.0[3]);
800        let (y8, y9) = unpack_pair(rhs.0[4]);
801
802        let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0);
803
804        let y1_19 = m_lo(v19, y1); // This fits in a u32
805        let y2_19 = m_lo(v19, y2); // iff 26 + b + lg(19) < 32
806        let y3_19 = m_lo(v19, y3); // if  b < 32 - 26 - 4.248 = 1.752
807        let y4_19 = m_lo(v19, y4);
808        let y5_19 = m_lo(v19, y5);
809        let y6_19 = m_lo(v19, y6);
810        let y7_19 = m_lo(v19, y7);
811        let y8_19 = m_lo(v19, y8);
812        let y9_19 = m_lo(v19, y9);
813
814        let x1_2 = x1 + x1; // This fits in a u32 iff 25 + b + 1 < 32
815        let x3_2 = x3 + x3; //                    iff b < 6
816        let x5_2 = x5 + x5;
817        let x7_2 = x7 + x7;
818        let x9_2 = x9 + x9;
819
820        let z0 = m(x0, y0) + m(x1_2, y9_19) + m(x2, y8_19) + m(x3_2, y7_19) + m(x4, y6_19) + m(x5_2, y5_19) + m(x6, y4_19) + m(x7_2, y3_19) + m(x8, y2_19) + m(x9_2, y1_19);
821        let z1 = m(x0, y1) + m(x1,      y0) + m(x2, y9_19) + m(x3,   y8_19) + m(x4, y7_19) + m(x5,   y6_19) + m(x6, y5_19) + m(x7,   y4_19) + m(x8, y3_19) + m(x9,   y2_19);
822        let z2 = m(x0, y2) + m(x1_2,    y1) + m(x2,    y0) + m(x3_2, y9_19) + m(x4, y8_19) + m(x5_2, y7_19) + m(x6, y6_19) + m(x7_2, y5_19) + m(x8, y4_19) + m(x9_2, y3_19);
823        let z3 = m(x0, y3) + m(x1,      y2) + m(x2,    y1) + m(x3,      y0) + m(x4, y9_19) + m(x5,   y8_19) + m(x6, y7_19) + m(x7,   y6_19) + m(x8, y5_19) + m(x9,   y4_19);
824        let z4 = m(x0, y4) + m(x1_2,    y3) + m(x2,    y2) + m(x3_2,    y1) + m(x4,    y0) + m(x5_2, y9_19) + m(x6, y8_19) + m(x7_2, y7_19) + m(x8, y6_19) + m(x9_2, y5_19);
825        let z5 = m(x0, y5) + m(x1,      y4) + m(x2,    y3) + m(x3,      y2) + m(x4,    y1) + m(x5,      y0) + m(x6, y9_19) + m(x7,   y8_19) + m(x8, y7_19) + m(x9,   y6_19);
826        let z6 = m(x0, y6) + m(x1_2,    y5) + m(x2,    y4) + m(x3_2,    y3) + m(x4,    y2) + m(x5_2,    y1) + m(x6,    y0) + m(x7_2, y9_19) + m(x8, y8_19) + m(x9_2, y7_19);
827        let z7 = m(x0, y7) + m(x1,      y6) + m(x2,    y5) + m(x3,      y4) + m(x4,    y3) + m(x5,      y2) + m(x6,    y1) + m(x7,      y0) + m(x8, y9_19) + m(x9,   y8_19);
828        let z8 = m(x0, y8) + m(x1_2,    y7) + m(x2,    y6) + m(x3_2,    y5) + m(x4,    y4) + m(x5_2,    y3) + m(x6,    y2) + m(x7_2,    y1) + m(x8,    y0) + m(x9_2, y9_19);
829        let z9 = m(x0, y9) + m(x1,      y8) + m(x2,    y7) + m(x3,      y6) + m(x4,    y5) + m(x5,      y4) + m(x6,    y3) + m(x7,      y2) + m(x8,    y1) + m(x9,      y0);
830
831        // The bounds on z[i] are the same as in the serial 32-bit code
832        // and the comment below is copied from there:
833
834        // How big is the contribution to z[i+j] from x[i], y[j]?
835        //
836        // Using the bounds above, we get:
837        //
838        // i even, j even:   x[i]*y[j] <   2^(26+b)*2^(26+b) = 2*2^(51+2*b)
839        // i  odd, j even:   x[i]*y[j] <   2^(25+b)*2^(26+b) = 1*2^(51+2*b)
840        // i even, j  odd:   x[i]*y[j] <   2^(26+b)*2^(25+b) = 1*2^(51+2*b)
841        // i  odd, j  odd: 2*x[i]*y[j] < 2*2^(25+b)*2^(25+b) = 1*2^(51+2*b)
842        //
843        // We perform inline reduction mod p by replacing 2^255 by 19
844        // (since 2^255 - 19 = 0 mod p).  This adds a factor of 19, so
845        // we get the bounds (z0 is the biggest one, but calculated for
846        // posterity here in case finer estimation is needed later):
847        //
848        //  z0 < ( 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 249*2^(51 + 2*b)
849        //  z1 < ( 1 +  1   + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 154*2^(51 + 2*b)
850        //  z2 < ( 2 +  1   +  2   + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 195*2^(51 + 2*b)
851        //  z3 < ( 1 +  1   +  1   +  1   + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 118*2^(51 + 2*b)
852        //  z4 < ( 2 +  1   +  2   +  1   +  2   + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 141*2^(51 + 2*b)
853        //  z5 < ( 1 +  1   +  1   +  1   +  1   +  1   + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) =  82*2^(51 + 2*b)
854        //  z6 < ( 2 +  1   +  2   +  1   +  2   +  1   +  2   + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) =  87*2^(51 + 2*b)
855        //  z7 < ( 1 +  1   +  1   +  1   +  1   +  1   +  1   +  1   + 1*19 + 1*19 )*2^(51 + 2b) =  46*2^(51 + 2*b)
856        //  z8 < ( 2 +  1   +  2   +  1   +  2   +  1   +  2   +  1   +  2   + 1*19 )*2^(51 + 2b) =  33*2^(51 + 2*b)
857        //  z9 < ( 1 +  1   +  1   +  1   +  1   +  1   +  1   +  1   +  1   +  1   )*2^(51 + 2b) =  10*2^(51 + 2*b)
858        //
859        // So z[0] fits into a u64 if 51 + 2*b + lg(249) < 64
860        //                         if b < 2.5.
861
862        // In fact this bound is slightly sloppy, since it treats both
863        // inputs x and y as being bounded by the same parameter b,
864        // while they are in fact bounded by b_x and b_y, and we
865        // already require that b_y < 1.75 in order to fit the
866        // multiplications by 19 into a u32.  The tighter bound on b_y
867        // means we could get a tighter bound on the outputs, or a
868        // looser bound on b_x.
869        FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9])
870    }
871}
872
873#[cfg(target_feature = "avx2")]
874#[cfg(test)]
875mod test {
876    use super::*;
877
878    #[test]
879    fn scale_by_curve_constants() {
880        let mut x = FieldElement2625x4::splat(&FieldElement51::ONE);
881
882        x = x * (121666, 121666, 2 * 121666, 2 * 121665);
883
884        let xs = x.split();
885        assert_eq!(xs[0], FieldElement51([121666, 0, 0, 0, 0]));
886        assert_eq!(xs[1], FieldElement51([121666, 0, 0, 0, 0]));
887        assert_eq!(xs[2], FieldElement51([2 * 121666, 0, 0, 0, 0]));
888        assert_eq!(xs[3], FieldElement51([2 * 121665, 0, 0, 0, 0]));
889    }
890
891    #[test]
892    fn diff_sum_vs_serial() {
893        let x0 = FieldElement51([10000, 10001, 10002, 10003, 10004]);
894        let x1 = FieldElement51([10100, 10101, 10102, 10103, 10104]);
895        let x2 = FieldElement51([10200, 10201, 10202, 10203, 10204]);
896        let x3 = FieldElement51([10300, 10301, 10302, 10303, 10304]);
897
898        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3).diff_sum();
899
900        let result = vec.split();
901
902        assert_eq!(result[0], &x1 - &x0);
903        assert_eq!(result[1], &x1 + &x0);
904        assert_eq!(result[2], &x3 - &x2);
905        assert_eq!(result[3], &x3 + &x2);
906    }
907
908    #[test]
909    fn square_vs_serial() {
910        let x0 = FieldElement51([10000, 10001, 10002, 10003, 10004]);
911        let x1 = FieldElement51([10100, 10101, 10102, 10103, 10104]);
912        let x2 = FieldElement51([10200, 10201, 10202, 10203, 10204]);
913        let x3 = FieldElement51([10300, 10301, 10302, 10303, 10304]);
914
915        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
916
917        let result = vec.square_and_negate_D().split();
918
919        assert_eq!(result[0], &x0 * &x0);
920        assert_eq!(result[1], &x1 * &x1);
921        assert_eq!(result[2], &x2 * &x2);
922        assert_eq!(result[3], -&(&x3 * &x3));
923    }
924
925    #[test]
926    fn multiply_vs_serial() {
927        let x0 = FieldElement51([10000, 10001, 10002, 10003, 10004]);
928        let x1 = FieldElement51([10100, 10101, 10102, 10103, 10104]);
929        let x2 = FieldElement51([10200, 10201, 10202, 10203, 10204]);
930        let x3 = FieldElement51([10300, 10301, 10302, 10303, 10304]);
931
932        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
933        let vecprime = vec.clone();
934
935        let result = (&vec * &vecprime).split();
936
937        assert_eq!(result[0], &x0 * &x0);
938        assert_eq!(result[1], &x1 * &x1);
939        assert_eq!(result[2], &x2 * &x2);
940        assert_eq!(result[3], &x3 * &x3);
941    }
942
943    #[test]
944    fn test_unpack_repack_pair() {
945        let x0 = FieldElement51([10000 + (10001 << 26), 0, 0, 0, 0]);
946        let x1 = FieldElement51([10100 + (10101 << 26), 0, 0, 0, 0]);
947        let x2 = FieldElement51([10200 + (10201 << 26), 0, 0, 0, 0]);
948        let x3 = FieldElement51([10300 + (10301 << 26), 0, 0, 0, 0]);
949
950        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
951
952        let src = vec.0[0];
953
954        let (a, b) = unpack_pair(src);
955
956        let expected_a = u32x8::new(10000, 0, 10100, 0, 10200, 0, 10300, 0);
957        let expected_b = u32x8::new(10001, 0, 10101, 0, 10201, 0, 10301, 0);
958
959        assert_eq!(a, expected_a);
960        assert_eq!(b, expected_b);
961
962        let expected_src = repack_pair(a, b);
963
964        assert_eq!(src, expected_src);
965    }
966
967    #[test]
968    fn new_split_roundtrips() {
969        let x0 = FieldElement51::from_bytes(&[0x10; 32]);
970        let x1 = FieldElement51::from_bytes(&[0x11; 32]);
971        let x2 = FieldElement51::from_bytes(&[0x12; 32]);
972        let x3 = FieldElement51::from_bytes(&[0x13; 32]);
973
974        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
975
976        let splits = vec.split();
977
978        assert_eq!(x0, splits[0]);
979        assert_eq!(x1, splits[1]);
980        assert_eq!(x2, splits[2]);
981        assert_eq!(x3, splits[3]);
982    }
983}