curve25519_dalek/backend/vector/
packed_simd.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// See LICENSE for licensing information.
5
6//! This module defines wrappers over platform-specific SIMD types to make them
7//! more convenient to use.
8//!
9//! UNSAFETY: Everything in this module assumes that we're running on hardware
10//!           which supports at least AVX2. This invariant *must* be enforced
11//!           by the callers of this code.
12use core::ops::{Add, AddAssign, BitAnd, BitAndAssign, BitXor, BitXorAssign, Sub};
13
14use curve25519_dalek_derive::unsafe_target_feature;
15
16macro_rules! impl_shared {
17    (
18        $ty:ident,
19        $lane_ty:ident,
20        $add_intrinsic:ident,
21        $sub_intrinsic:ident,
22        $shl_intrinsic:ident,
23        $shr_intrinsic:ident,
24        $extract_intrinsic:ident
25    ) => {
26        #[allow(non_camel_case_types)]
27        #[derive(Copy, Clone, Debug)]
28        #[repr(transparent)]
29        pub struct $ty(core::arch::x86_64::__m256i);
30
31        #[unsafe_target_feature("avx2")]
32        impl From<$ty> for core::arch::x86_64::__m256i {
33            #[inline]
34            fn from(value: $ty) -> core::arch::x86_64::__m256i {
35                value.0
36            }
37        }
38
39        #[unsafe_target_feature("avx2")]
40        impl From<core::arch::x86_64::__m256i> for $ty {
41            #[inline]
42            fn from(value: core::arch::x86_64::__m256i) -> $ty {
43                $ty(value)
44            }
45        }
46
47        #[unsafe_target_feature("avx2")]
48        impl PartialEq for $ty {
49            #[inline]
50            fn eq(&self, rhs: &$ty) -> bool {
51                unsafe {
52                    // This compares each pair of 8-bit packed integers and returns either 0xFF or
53                    // 0x00 depending on whether they're equal.
54                    //
55                    // So the values are equal if (and only if) this returns a value that's filled
56                    // with only 0xFF.
57                    //
58                    // Pseudocode of what this does:
59                    //     self.0
60                    //         .bytes()
61                    //         .zip(rhs.0.bytes())
62                    //         .map(|a, b| if a == b { 0xFF } else { 0x00 })
63                    //         .join();
64                    let m = core::arch::x86_64::_mm256_cmpeq_epi8(self.0, rhs.0);
65
66                    // Now we need to reduce the 256-bit value to something on which we can branch.
67                    //
68                    // This will just take the most significant bit of every 8-bit packed integer
69                    // and build an `i32` out of it. If the values we previously compared were
70                    // equal then all off the most significant bits will be equal to 1, which means
71                    // that this will return 0xFFFFFFFF, which is equal to -1 when represented as
72                    // an `i32`.
73                    core::arch::x86_64::_mm256_movemask_epi8(m) == -1
74                }
75            }
76        }
77
78        impl Eq for $ty {}
79
80        #[unsafe_target_feature("avx2")]
81        impl Add for $ty {
82            type Output = Self;
83
84            #[inline]
85            fn add(self, rhs: $ty) -> Self {
86                unsafe { core::arch::x86_64::$add_intrinsic(self.0, rhs.0).into() }
87            }
88        }
89
90        #[allow(clippy::assign_op_pattern)]
91        #[unsafe_target_feature("avx2")]
92        impl AddAssign for $ty {
93            #[inline]
94            fn add_assign(&mut self, rhs: $ty) {
95                *self = *self + rhs
96            }
97        }
98
99        #[unsafe_target_feature("avx2")]
100        impl Sub for $ty {
101            type Output = Self;
102
103            #[inline]
104            fn sub(self, rhs: $ty) -> Self {
105                unsafe { core::arch::x86_64::$sub_intrinsic(self.0, rhs.0).into() }
106            }
107        }
108
109        #[unsafe_target_feature("avx2")]
110        impl BitAnd for $ty {
111            type Output = Self;
112
113            #[inline]
114            fn bitand(self, rhs: $ty) -> Self {
115                unsafe { core::arch::x86_64::_mm256_and_si256(self.0, rhs.0).into() }
116            }
117        }
118
119        #[unsafe_target_feature("avx2")]
120        impl BitXor for $ty {
121            type Output = Self;
122
123            #[inline]
124            fn bitxor(self, rhs: $ty) -> Self {
125                unsafe { core::arch::x86_64::_mm256_xor_si256(self.0, rhs.0).into() }
126            }
127        }
128
129        #[allow(clippy::assign_op_pattern)]
130        #[unsafe_target_feature("avx2")]
131        impl BitAndAssign for $ty {
132            #[inline]
133            fn bitand_assign(&mut self, rhs: $ty) {
134                *self = *self & rhs;
135            }
136        }
137
138        #[allow(clippy::assign_op_pattern)]
139        #[unsafe_target_feature("avx2")]
140        impl BitXorAssign for $ty {
141            #[inline]
142            fn bitxor_assign(&mut self, rhs: $ty) {
143                *self = *self ^ rhs;
144            }
145        }
146
147        #[unsafe_target_feature("avx2")]
148        #[allow(dead_code)]
149        impl $ty {
150            #[inline]
151            pub fn shl<const N: i32>(self) -> Self {
152                unsafe { core::arch::x86_64::$shl_intrinsic(self.0, N).into() }
153            }
154
155            #[inline]
156            pub fn shr<const N: i32>(self) -> Self {
157                unsafe { core::arch::x86_64::$shr_intrinsic(self.0, N).into() }
158            }
159
160            #[inline]
161            pub fn extract<const N: i32>(self) -> $lane_ty {
162                unsafe { core::arch::x86_64::$extract_intrinsic(self.0, N) as $lane_ty }
163            }
164        }
165    };
166}
167
168macro_rules! impl_conv {
169    ($src:ident => $($dst:ident),+) => {
170        $(
171            #[unsafe_target_feature("avx2")]
172            impl From<$src> for $dst {
173                #[inline]
174                fn from(value: $src) -> $dst {
175                    $dst(value.0)
176                }
177            }
178        )+
179    }
180}
181
182// We define SIMD functionality over packed unsigned integer types. However, all the integer
183// intrinsics deal with signed integers. So we cast unsigned to signed, pack it into SIMD, do
184// add/sub/shl/shr arithmetic, and finally cast back to unsigned at the end. Why is this equivalent
185// to doing the same thing on unsigned integers? Shl/shr is clear, because casting does not change
186// the bits of the integer. But what about add/sub? This is due to the following:
187//
188//     1) Rust uses two's complement to represent signed integers. So we're assured that the values
189//        we cast into SIMD and extract out at the end are two's complement.
190//
191//        https://doc.rust-lang.org/reference/types/numeric.html
192//
193//     2) Wrapping add/sub is compatible between two's complement signed and unsigned integers.
194//        That is, for all x,y: u64 (or any unsigned integer type),
195//
196//            x.wrapping_add(y) == (x as i64).wrapping_add(y as i64) as u64, and
197//            x.wrapping_sub(y) == (x as i64).wrapping_sub(y as i64) as u64
198//
199//        https://julesjacobs.com/2019/03/20/why-twos-complement-works.html
200//
201//     3) The add/sub functions we use for SIMD are indeed wrapping. The docs indicate that
202//        __mm256_add/sub compile to vpaddX/vpsubX instructions where X = w, d, or q depending on
203//        the bitwidth. From x86 docs:
204//
205//            When an individual result is too large to be represented in X bits (overflow), the
206//            result is wrapped around and the low X bits are written to the destination operand
207//            (that is, the carry is ignored).
208//
209//        https://www.felixcloutier.com/x86/paddb:paddw:paddd:paddq
210//        https://www.felixcloutier.com/x86/psubb:psubw:psubd
211//        https://www.felixcloutier.com/x86/psubq
212
213impl_shared!(
214    u64x4,
215    u64,
216    _mm256_add_epi64,
217    _mm256_sub_epi64,
218    _mm256_slli_epi64,
219    _mm256_srli_epi64,
220    _mm256_extract_epi64
221);
222impl_shared!(
223    u32x8,
224    u32,
225    _mm256_add_epi32,
226    _mm256_sub_epi32,
227    _mm256_slli_epi32,
228    _mm256_srli_epi32,
229    _mm256_extract_epi32
230);
231
232impl_conv!(u64x4 => u32x8);
233
234#[allow(dead_code)]
235impl u64x4 {
236    /// A constified variant of `new`.
237    ///
238    /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
239    #[inline]
240    pub const fn new_const(x0: u64, x1: u64, x2: u64, x3: u64) -> Self {
241        // SAFETY: Transmuting between an array and a SIMD type is safe
242        // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
243        unsafe { Self(core::mem::transmute([x0, x1, x2, x3])) }
244    }
245
246    /// A constified variant of `splat`.
247    ///
248    /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
249    #[inline]
250    pub const fn splat_const<const N: u64>() -> Self {
251        Self::new_const(N, N, N, N)
252    }
253
254    /// Constructs a new instance.
255    #[unsafe_target_feature("avx2")]
256    #[inline]
257    pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 {
258        unsafe {
259            // _mm256_set_epi64 sets the underlying vector in reverse order of the args
260            u64x4(core::arch::x86_64::_mm256_set_epi64x(
261                x3 as i64, x2 as i64, x1 as i64, x0 as i64,
262            ))
263        }
264    }
265
266    /// Constructs a new instance with all of the elements initialized to the given value.
267    #[unsafe_target_feature("avx2")]
268    #[inline]
269    pub fn splat(x: u64) -> u64x4 {
270        unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) }
271    }
272}
273
274#[allow(dead_code)]
275impl u32x8 {
276    /// A constified variant of `new`.
277    ///
278    /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
279    #[allow(clippy::too_many_arguments)]
280    #[inline]
281    pub const fn new_const(
282        x0: u32,
283        x1: u32,
284        x2: u32,
285        x3: u32,
286        x4: u32,
287        x5: u32,
288        x6: u32,
289        x7: u32,
290    ) -> Self {
291        // SAFETY: Transmuting between an array and a SIMD type is safe
292        // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
293        unsafe { Self(core::mem::transmute([x0, x1, x2, x3, x4, x5, x6, x7])) }
294    }
295
296    /// A constified variant of `splat`.
297    ///
298    /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
299    #[inline]
300    pub const fn splat_const<const N: u32>() -> Self {
301        Self::new_const(N, N, N, N, N, N, N, N)
302    }
303
304    /// Constructs a new instance.
305    #[allow(clippy::too_many_arguments)]
306    #[unsafe_target_feature("avx2")]
307    #[inline]
308    pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 {
309        unsafe {
310            // _mm256_set_epi32 sets the underlying vector in reverse order of the args
311            u32x8(core::arch::x86_64::_mm256_set_epi32(
312                x7 as i32, x6 as i32, x5 as i32, x4 as i32, x3 as i32, x2 as i32, x1 as i32,
313                x0 as i32,
314            ))
315        }
316    }
317
318    /// Constructs a new instance with all of the elements initialized to the given value.
319    #[unsafe_target_feature("avx2")]
320    #[inline]
321    pub fn splat(x: u32) -> u32x8 {
322        unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) }
323    }
324}
325
326#[unsafe_target_feature("avx2")]
327impl u32x8 {
328    /// Multiplies the low unsigned 32-bits from each packed 64-bit element
329    /// and returns the unsigned 64-bit results.
330    ///
331    /// (This ignores the upper 32-bits from each packed 64-bits!)
332    #[inline]
333    pub fn mul32(self, rhs: u32x8) -> u64x4 {
334        // NOTE: This ignores the upper 32-bits from each packed 64-bits.
335        unsafe { core::arch::x86_64::_mm256_mul_epu32(self.0, rhs.0).into() }
336    }
337}