curve25519_dalek/backend/vector/packed_simd.rs
1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// See LICENSE for licensing information.
5
6//! This module defines wrappers over platform-specific SIMD types to make them
7//! more convenient to use.
8//!
9//! UNSAFETY: Everything in this module assumes that we're running on hardware
10//! which supports at least AVX2. This invariant *must* be enforced
11//! by the callers of this code.
12use core::ops::{Add, AddAssign, BitAnd, BitAndAssign, BitXor, BitXorAssign, Sub};
13
14use curve25519_dalek_derive::unsafe_target_feature;
15
16macro_rules! impl_shared {
17 (
18 $ty:ident,
19 $lane_ty:ident,
20 $add_intrinsic:ident,
21 $sub_intrinsic:ident,
22 $shl_intrinsic:ident,
23 $shr_intrinsic:ident,
24 $extract_intrinsic:ident
25 ) => {
26 #[allow(non_camel_case_types)]
27 #[derive(Copy, Clone, Debug)]
28 #[repr(transparent)]
29 pub struct $ty(core::arch::x86_64::__m256i);
30
31 #[unsafe_target_feature("avx2")]
32 impl From<$ty> for core::arch::x86_64::__m256i {
33 #[inline]
34 fn from(value: $ty) -> core::arch::x86_64::__m256i {
35 value.0
36 }
37 }
38
39 #[unsafe_target_feature("avx2")]
40 impl From<core::arch::x86_64::__m256i> for $ty {
41 #[inline]
42 fn from(value: core::arch::x86_64::__m256i) -> $ty {
43 $ty(value)
44 }
45 }
46
47 #[unsafe_target_feature("avx2")]
48 impl PartialEq for $ty {
49 #[inline]
50 fn eq(&self, rhs: &$ty) -> bool {
51 unsafe {
52 // This compares each pair of 8-bit packed integers and returns either 0xFF or
53 // 0x00 depending on whether they're equal.
54 //
55 // So the values are equal if (and only if) this returns a value that's filled
56 // with only 0xFF.
57 //
58 // Pseudocode of what this does:
59 // self.0
60 // .bytes()
61 // .zip(rhs.0.bytes())
62 // .map(|a, b| if a == b { 0xFF } else { 0x00 })
63 // .join();
64 let m = core::arch::x86_64::_mm256_cmpeq_epi8(self.0, rhs.0);
65
66 // Now we need to reduce the 256-bit value to something on which we can branch.
67 //
68 // This will just take the most significant bit of every 8-bit packed integer
69 // and build an `i32` out of it. If the values we previously compared were
70 // equal then all off the most significant bits will be equal to 1, which means
71 // that this will return 0xFFFFFFFF, which is equal to -1 when represented as
72 // an `i32`.
73 core::arch::x86_64::_mm256_movemask_epi8(m) == -1
74 }
75 }
76 }
77
78 impl Eq for $ty {}
79
80 #[unsafe_target_feature("avx2")]
81 impl Add for $ty {
82 type Output = Self;
83
84 #[inline]
85 fn add(self, rhs: $ty) -> Self {
86 unsafe { core::arch::x86_64::$add_intrinsic(self.0, rhs.0).into() }
87 }
88 }
89
90 #[allow(clippy::assign_op_pattern)]
91 #[unsafe_target_feature("avx2")]
92 impl AddAssign for $ty {
93 #[inline]
94 fn add_assign(&mut self, rhs: $ty) {
95 *self = *self + rhs
96 }
97 }
98
99 #[unsafe_target_feature("avx2")]
100 impl Sub for $ty {
101 type Output = Self;
102
103 #[inline]
104 fn sub(self, rhs: $ty) -> Self {
105 unsafe { core::arch::x86_64::$sub_intrinsic(self.0, rhs.0).into() }
106 }
107 }
108
109 #[unsafe_target_feature("avx2")]
110 impl BitAnd for $ty {
111 type Output = Self;
112
113 #[inline]
114 fn bitand(self, rhs: $ty) -> Self {
115 unsafe { core::arch::x86_64::_mm256_and_si256(self.0, rhs.0).into() }
116 }
117 }
118
119 #[unsafe_target_feature("avx2")]
120 impl BitXor for $ty {
121 type Output = Self;
122
123 #[inline]
124 fn bitxor(self, rhs: $ty) -> Self {
125 unsafe { core::arch::x86_64::_mm256_xor_si256(self.0, rhs.0).into() }
126 }
127 }
128
129 #[allow(clippy::assign_op_pattern)]
130 #[unsafe_target_feature("avx2")]
131 impl BitAndAssign for $ty {
132 #[inline]
133 fn bitand_assign(&mut self, rhs: $ty) {
134 *self = *self & rhs;
135 }
136 }
137
138 #[allow(clippy::assign_op_pattern)]
139 #[unsafe_target_feature("avx2")]
140 impl BitXorAssign for $ty {
141 #[inline]
142 fn bitxor_assign(&mut self, rhs: $ty) {
143 *self = *self ^ rhs;
144 }
145 }
146
147 #[unsafe_target_feature("avx2")]
148 #[allow(dead_code)]
149 impl $ty {
150 #[inline]
151 pub fn shl<const N: i32>(self) -> Self {
152 unsafe { core::arch::x86_64::$shl_intrinsic(self.0, N).into() }
153 }
154
155 #[inline]
156 pub fn shr<const N: i32>(self) -> Self {
157 unsafe { core::arch::x86_64::$shr_intrinsic(self.0, N).into() }
158 }
159
160 #[inline]
161 pub fn extract<const N: i32>(self) -> $lane_ty {
162 unsafe { core::arch::x86_64::$extract_intrinsic(self.0, N) as $lane_ty }
163 }
164 }
165 };
166}
167
168macro_rules! impl_conv {
169 ($src:ident => $($dst:ident),+) => {
170 $(
171 #[unsafe_target_feature("avx2")]
172 impl From<$src> for $dst {
173 #[inline]
174 fn from(value: $src) -> $dst {
175 $dst(value.0)
176 }
177 }
178 )+
179 }
180}
181
182// We define SIMD functionality over packed unsigned integer types. However, all the integer
183// intrinsics deal with signed integers. So we cast unsigned to signed, pack it into SIMD, do
184// add/sub/shl/shr arithmetic, and finally cast back to unsigned at the end. Why is this equivalent
185// to doing the same thing on unsigned integers? Shl/shr is clear, because casting does not change
186// the bits of the integer. But what about add/sub? This is due to the following:
187//
188// 1) Rust uses two's complement to represent signed integers. So we're assured that the values
189// we cast into SIMD and extract out at the end are two's complement.
190//
191// https://doc.rust-lang.org/reference/types/numeric.html
192//
193// 2) Wrapping add/sub is compatible between two's complement signed and unsigned integers.
194// That is, for all x,y: u64 (or any unsigned integer type),
195//
196// x.wrapping_add(y) == (x as i64).wrapping_add(y as i64) as u64, and
197// x.wrapping_sub(y) == (x as i64).wrapping_sub(y as i64) as u64
198//
199// https://julesjacobs.com/2019/03/20/why-twos-complement-works.html
200//
201// 3) The add/sub functions we use for SIMD are indeed wrapping. The docs indicate that
202// __mm256_add/sub compile to vpaddX/vpsubX instructions where X = w, d, or q depending on
203// the bitwidth. From x86 docs:
204//
205// When an individual result is too large to be represented in X bits (overflow), the
206// result is wrapped around and the low X bits are written to the destination operand
207// (that is, the carry is ignored).
208//
209// https://www.felixcloutier.com/x86/paddb:paddw:paddd:paddq
210// https://www.felixcloutier.com/x86/psubb:psubw:psubd
211// https://www.felixcloutier.com/x86/psubq
212
213impl_shared!(
214 u64x4,
215 u64,
216 _mm256_add_epi64,
217 _mm256_sub_epi64,
218 _mm256_slli_epi64,
219 _mm256_srli_epi64,
220 _mm256_extract_epi64
221);
222impl_shared!(
223 u32x8,
224 u32,
225 _mm256_add_epi32,
226 _mm256_sub_epi32,
227 _mm256_slli_epi32,
228 _mm256_srli_epi32,
229 _mm256_extract_epi32
230);
231
232impl_conv!(u64x4 => u32x8);
233
234#[allow(dead_code)]
235impl u64x4 {
236 /// A constified variant of `new`.
237 ///
238 /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
239 #[inline]
240 pub const fn new_const(x0: u64, x1: u64, x2: u64, x3: u64) -> Self {
241 // SAFETY: Transmuting between an array and a SIMD type is safe
242 // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
243 unsafe { Self(core::mem::transmute([x0, x1, x2, x3])) }
244 }
245
246 /// A constified variant of `splat`.
247 ///
248 /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
249 #[inline]
250 pub const fn splat_const<const N: u64>() -> Self {
251 Self::new_const(N, N, N, N)
252 }
253
254 /// Constructs a new instance.
255 #[unsafe_target_feature("avx2")]
256 #[inline]
257 pub fn new(x0: u64, x1: u64, x2: u64, x3: u64) -> u64x4 {
258 unsafe {
259 // _mm256_set_epi64 sets the underlying vector in reverse order of the args
260 u64x4(core::arch::x86_64::_mm256_set_epi64x(
261 x3 as i64, x2 as i64, x1 as i64, x0 as i64,
262 ))
263 }
264 }
265
266 /// Constructs a new instance with all of the elements initialized to the given value.
267 #[unsafe_target_feature("avx2")]
268 #[inline]
269 pub fn splat(x: u64) -> u64x4 {
270 unsafe { u64x4(core::arch::x86_64::_mm256_set1_epi64x(x as i64)) }
271 }
272}
273
274#[allow(dead_code)]
275impl u32x8 {
276 /// A constified variant of `new`.
277 ///
278 /// Should only be called from `const` contexts. At runtime `new` is going to be faster.
279 #[allow(clippy::too_many_arguments)]
280 #[inline]
281 pub const fn new_const(
282 x0: u32,
283 x1: u32,
284 x2: u32,
285 x3: u32,
286 x4: u32,
287 x5: u32,
288 x6: u32,
289 x7: u32,
290 ) -> Self {
291 // SAFETY: Transmuting between an array and a SIMD type is safe
292 // https://rust-lang.github.io/unsafe-code-guidelines/layout/packed-simd-vectors.html
293 unsafe { Self(core::mem::transmute([x0, x1, x2, x3, x4, x5, x6, x7])) }
294 }
295
296 /// A constified variant of `splat`.
297 ///
298 /// Should only be called from `const` contexts. At runtime `splat` is going to be faster.
299 #[inline]
300 pub const fn splat_const<const N: u32>() -> Self {
301 Self::new_const(N, N, N, N, N, N, N, N)
302 }
303
304 /// Constructs a new instance.
305 #[allow(clippy::too_many_arguments)]
306 #[unsafe_target_feature("avx2")]
307 #[inline]
308 pub fn new(x0: u32, x1: u32, x2: u32, x3: u32, x4: u32, x5: u32, x6: u32, x7: u32) -> u32x8 {
309 unsafe {
310 // _mm256_set_epi32 sets the underlying vector in reverse order of the args
311 u32x8(core::arch::x86_64::_mm256_set_epi32(
312 x7 as i32, x6 as i32, x5 as i32, x4 as i32, x3 as i32, x2 as i32, x1 as i32,
313 x0 as i32,
314 ))
315 }
316 }
317
318 /// Constructs a new instance with all of the elements initialized to the given value.
319 #[unsafe_target_feature("avx2")]
320 #[inline]
321 pub fn splat(x: u32) -> u32x8 {
322 unsafe { u32x8(core::arch::x86_64::_mm256_set1_epi32(x as i32)) }
323 }
324}
325
326#[unsafe_target_feature("avx2")]
327impl u32x8 {
328 /// Multiplies the low unsigned 32-bits from each packed 64-bit element
329 /// and returns the unsigned 64-bit results.
330 ///
331 /// (This ignores the upper 32-bits from each packed 64-bits!)
332 #[inline]
333 pub fn mul32(self, rhs: u32x8) -> u64x4 {
334 // NOTE: This ignores the upper 32-bits from each packed 64-bits.
335 unsafe { core::arch::x86_64::_mm256_mul_epu32(self.0, rhs.0).into() }
336 }
337}