memchr/vector.rs
1/// A trait for describing vector operations used by vectorized searchers.
2///
3/// The trait is highly constrained to low level vector operations needed.
4/// In general, it was invented mostly to be generic over x86's __m128i and
5/// __m256i types. At time of writing, it also supports wasm and aarch64
6/// 128-bit vector types as well.
7///
8/// # Safety
9///
10/// All methods are not safe since they are intended to be implemented using
11/// vendor intrinsics, which are also not safe. Callers must ensure that the
12/// appropriate target features are enabled in the calling function, and that
13/// the current CPU supports them. All implementations should avoid marking the
14/// routines with #[target_feature] and instead mark them as #[inline(always)]
15/// to ensure they get appropriately inlined. (inline(always) cannot be used
16/// with target_feature.)
17pub(crate) trait Vector: Copy + core::fmt::Debug {
18 /// The number of bytes in the vector. That is, this is the size of the
19 /// vector in memory.
20 const BYTES: usize;
21 /// The bits that must be zero in order for a `*const u8` pointer to be
22 /// correctly aligned to read vector values.
23 const ALIGN: usize;
24
25 /// The type of the value returned by `Vector::movemask`.
26 ///
27 /// This supports abstracting over the specific representation used in
28 /// order to accommodate different representations in different ISAs.
29 type Mask: MoveMask;
30
31 /// Create a vector with 8-bit lanes with the given byte repeated into each
32 /// lane.
33 unsafe fn splat(byte: u8) -> Self;
34
35 /// Read a vector-size number of bytes from the given pointer. The pointer
36 /// must be aligned to the size of the vector.
37 ///
38 /// # Safety
39 ///
40 /// Callers must guarantee that at least `BYTES` bytes are readable from
41 /// `data` and that `data` is aligned to a `BYTES` boundary.
42 unsafe fn load_aligned(data: *const u8) -> Self;
43
44 /// Read a vector-size number of bytes from the given pointer. The pointer
45 /// does not need to be aligned.
46 ///
47 /// # Safety
48 ///
49 /// Callers must guarantee that at least `BYTES` bytes are readable from
50 /// `data`.
51 unsafe fn load_unaligned(data: *const u8) -> Self;
52
53 /// _mm_movemask_epi8 or _mm256_movemask_epi8
54 unsafe fn movemask(self) -> Self::Mask;
55 /// _mm_cmpeq_epi8 or _mm256_cmpeq_epi8
56 unsafe fn cmpeq(self, vector2: Self) -> Self;
57 /// _mm_and_si128 or _mm256_and_si256
58 unsafe fn and(self, vector2: Self) -> Self;
59 /// _mm_or or _mm256_or_si256
60 unsafe fn or(self, vector2: Self) -> Self;
61 /// Returns true if and only if `Self::movemask` would return a mask that
62 /// contains at least one non-zero bit.
63 unsafe fn movemask_will_have_non_zero(self) -> bool {
64 self.movemask().has_non_zero()
65 }
66}
67
68/// A trait that abstracts over a vector-to-scalar operation called
69/// "move mask."
70///
71/// On x86-64, this is `_mm_movemask_epi8` for SSE2 and `_mm256_movemask_epi8`
72/// for AVX2. It takes a vector of `u8` lanes and returns a scalar where the
73/// `i`th bit is set if and only if the most significant bit in the `i`th lane
74/// of the vector is set. The simd128 ISA for wasm32 also supports this
75/// exact same operation natively.
76///
77/// ... But aarch64 doesn't. So we have to fake it with more instructions and
78/// a slightly different representation. We could do extra work to unify the
79/// representations, but then would require additional costs in the hot path
80/// for `memchr` and `packedpair`. So instead, we abstraction over the specific
81/// representation with this trait an ddefine the operations we actually need.
82pub(crate) trait MoveMask: Copy + core::fmt::Debug {
83 /// Return a mask that is all zeros except for the least significant `n`
84 /// lanes in a corresponding vector.
85 fn all_zeros_except_least_significant(n: usize) -> Self;
86
87 /// Returns true if and only if this mask has a a non-zero bit anywhere.
88 fn has_non_zero(self) -> bool;
89
90 /// Returns the number of bits set to 1 in this mask.
91 fn count_ones(self) -> usize;
92
93 /// Does a bitwise `and` operation between `self` and `other`.
94 fn and(self, other: Self) -> Self;
95
96 /// Does a bitwise `or` operation between `self` and `other`.
97 fn or(self, other: Self) -> Self;
98
99 /// Returns a mask that is equivalent to `self` but with the least
100 /// significant 1-bit set to 0.
101 fn clear_least_significant_bit(self) -> Self;
102
103 /// Returns the offset of the first non-zero lane this mask represents.
104 fn first_offset(self) -> usize;
105
106 /// Returns the offset of the last non-zero lane this mask represents.
107 fn last_offset(self) -> usize;
108}
109
110/// This is a "sensible" movemask implementation where each bit represents
111/// whether the most significant bit is set in each corresponding lane of a
112/// vector. This is used on x86-64 and wasm, but such a mask is more expensive
113/// to get on aarch64 so we use something a little different.
114///
115/// We call this "sensible" because this is what we get using native sse/avx
116/// movemask instructions. But neon has no such native equivalent.
117#[derive(Clone, Copy, Debug)]
118pub(crate) struct SensibleMoveMask(u32);
119
120impl SensibleMoveMask {
121 /// Get the mask in a form suitable for computing offsets.
122 ///
123 /// Basically, this normalizes to little endian. On big endian, this swaps
124 /// the bytes.
125 #[inline(always)]
126 fn get_for_offset(self) -> u32 {
127 #[cfg(target_endian = "big")]
128 {
129 self.0.swap_bytes()
130 }
131 #[cfg(target_endian = "little")]
132 {
133 self.0
134 }
135 }
136}
137
138impl MoveMask for SensibleMoveMask {
139 #[inline(always)]
140 fn all_zeros_except_least_significant(n: usize) -> SensibleMoveMask {
141 debug_assert!(n < 32);
142 SensibleMoveMask(!((1 << n) - 1))
143 }
144
145 #[inline(always)]
146 fn has_non_zero(self) -> bool {
147 self.0 != 0
148 }
149
150 #[inline(always)]
151 fn count_ones(self) -> usize {
152 self.0.count_ones() as usize
153 }
154
155 #[inline(always)]
156 fn and(self, other: SensibleMoveMask) -> SensibleMoveMask {
157 SensibleMoveMask(self.0 & other.0)
158 }
159
160 #[inline(always)]
161 fn or(self, other: SensibleMoveMask) -> SensibleMoveMask {
162 SensibleMoveMask(self.0 | other.0)
163 }
164
165 #[inline(always)]
166 fn clear_least_significant_bit(self) -> SensibleMoveMask {
167 SensibleMoveMask(self.0 & (self.0 - 1))
168 }
169
170 #[inline(always)]
171 fn first_offset(self) -> usize {
172 // We are dealing with little endian here (and if we aren't, we swap
173 // the bytes so we are in practice), where the most significant byte
174 // is at a higher address. That means the least significant bit that
175 // is set corresponds to the position of our first matching byte.
176 // That position corresponds to the number of zeros after the least
177 // significant bit.
178 self.get_for_offset().trailing_zeros() as usize
179 }
180
181 #[inline(always)]
182 fn last_offset(self) -> usize {
183 // We are dealing with little endian here (and if we aren't, we swap
184 // the bytes so we are in practice), where the most significant byte is
185 // at a higher address. That means the most significant bit that is set
186 // corresponds to the position of our last matching byte. The position
187 // from the end of the mask is therefore the number of leading zeros
188 // in a 32 bit integer, and the position from the start of the mask is
189 // therefore 32 - (leading zeros) - 1.
190 32 - self.get_for_offset().leading_zeros() as usize - 1
191 }
192}
193
194#[cfg(target_arch = "x86_64")]
195mod x86sse2 {
196 use core::arch::x86_64::*;
197
198 use super::{SensibleMoveMask, Vector};
199
200 impl Vector for __m128i {
201 const BYTES: usize = 16;
202 const ALIGN: usize = Self::BYTES - 1;
203
204 type Mask = SensibleMoveMask;
205
206 #[inline(always)]
207 unsafe fn splat(byte: u8) -> __m128i {
208 _mm_set1_epi8(byte as i8)
209 }
210
211 #[inline(always)]
212 unsafe fn load_aligned(data: *const u8) -> __m128i {
213 _mm_load_si128(data as *const __m128i)
214 }
215
216 #[inline(always)]
217 unsafe fn load_unaligned(data: *const u8) -> __m128i {
218 _mm_loadu_si128(data as *const __m128i)
219 }
220
221 #[inline(always)]
222 unsafe fn movemask(self) -> SensibleMoveMask {
223 SensibleMoveMask(_mm_movemask_epi8(self) as u32)
224 }
225
226 #[inline(always)]
227 unsafe fn cmpeq(self, vector2: Self) -> __m128i {
228 _mm_cmpeq_epi8(self, vector2)
229 }
230
231 #[inline(always)]
232 unsafe fn and(self, vector2: Self) -> __m128i {
233 _mm_and_si128(self, vector2)
234 }
235
236 #[inline(always)]
237 unsafe fn or(self, vector2: Self) -> __m128i {
238 _mm_or_si128(self, vector2)
239 }
240 }
241}
242
243#[cfg(target_arch = "x86_64")]
244mod x86avx2 {
245 use core::arch::x86_64::*;
246
247 use super::{SensibleMoveMask, Vector};
248
249 impl Vector for __m256i {
250 const BYTES: usize = 32;
251 const ALIGN: usize = Self::BYTES - 1;
252
253 type Mask = SensibleMoveMask;
254
255 #[inline(always)]
256 unsafe fn splat(byte: u8) -> __m256i {
257 _mm256_set1_epi8(byte as i8)
258 }
259
260 #[inline(always)]
261 unsafe fn load_aligned(data: *const u8) -> __m256i {
262 _mm256_load_si256(data as *const __m256i)
263 }
264
265 #[inline(always)]
266 unsafe fn load_unaligned(data: *const u8) -> __m256i {
267 _mm256_loadu_si256(data as *const __m256i)
268 }
269
270 #[inline(always)]
271 unsafe fn movemask(self) -> SensibleMoveMask {
272 SensibleMoveMask(_mm256_movemask_epi8(self) as u32)
273 }
274
275 #[inline(always)]
276 unsafe fn cmpeq(self, vector2: Self) -> __m256i {
277 _mm256_cmpeq_epi8(self, vector2)
278 }
279
280 #[inline(always)]
281 unsafe fn and(self, vector2: Self) -> __m256i {
282 _mm256_and_si256(self, vector2)
283 }
284
285 #[inline(always)]
286 unsafe fn or(self, vector2: Self) -> __m256i {
287 _mm256_or_si256(self, vector2)
288 }
289 }
290}
291
292#[cfg(target_arch = "aarch64")]
293mod aarch64neon {
294 use core::arch::aarch64::*;
295
296 use super::{MoveMask, Vector};
297
298 impl Vector for uint8x16_t {
299 const BYTES: usize = 16;
300 const ALIGN: usize = Self::BYTES - 1;
301
302 type Mask = NeonMoveMask;
303
304 #[inline(always)]
305 unsafe fn splat(byte: u8) -> uint8x16_t {
306 vdupq_n_u8(byte)
307 }
308
309 #[inline(always)]
310 unsafe fn load_aligned(data: *const u8) -> uint8x16_t {
311 // I've tried `data.cast::<uint8x16_t>().read()` instead, but
312 // couldn't observe any benchmark differences.
313 Self::load_unaligned(data)
314 }
315
316 #[inline(always)]
317 unsafe fn load_unaligned(data: *const u8) -> uint8x16_t {
318 vld1q_u8(data)
319 }
320
321 #[inline(always)]
322 unsafe fn movemask(self) -> NeonMoveMask {
323 let asu16s = vreinterpretq_u16_u8(self);
324 let mask = vshrn_n_u16(asu16s, 4);
325 let asu64 = vreinterpret_u64_u8(mask);
326 let scalar64 = vget_lane_u64(asu64, 0);
327 NeonMoveMask(scalar64 & 0x8888888888888888)
328 }
329
330 #[inline(always)]
331 unsafe fn cmpeq(self, vector2: Self) -> uint8x16_t {
332 vceqq_u8(self, vector2)
333 }
334
335 #[inline(always)]
336 unsafe fn and(self, vector2: Self) -> uint8x16_t {
337 vandq_u8(self, vector2)
338 }
339
340 #[inline(always)]
341 unsafe fn or(self, vector2: Self) -> uint8x16_t {
342 vorrq_u8(self, vector2)
343 }
344
345 /// This is the only interesting implementation of this routine.
346 /// Basically, instead of doing the "shift right narrow" dance, we use
347 /// adajacent folding max to determine whether there are any non-zero
348 /// bytes in our mask. If there are, *then* we'll do the "shift right
349 /// narrow" dance. In benchmarks, this does lead to slightly better
350 /// throughput, but the win doesn't appear huge.
351 #[inline(always)]
352 unsafe fn movemask_will_have_non_zero(self) -> bool {
353 let low = vreinterpretq_u64_u8(vpmaxq_u8(self, self));
354 vgetq_lane_u64(low, 0) != 0
355 }
356 }
357
358 /// Neon doesn't have a `movemask` that works like the one in x86-64, so we
359 /// wind up using a different method[1]. The different method also produces
360 /// a mask, but 4 bits are set in the neon case instead of a single bit set
361 /// in the x86-64 case. We do an extra step to zero out 3 of the 4 bits,
362 /// but we still wind up with at least 3 zeroes between each set bit. This
363 /// generally means that we need to do some division by 4 before extracting
364 /// offsets.
365 ///
366 /// In fact, the existence of this type is the entire reason that we have
367 /// the `MoveMask` trait in the first place. This basically lets us keep
368 /// the different representations of masks without being forced to unify
369 /// them into a single representation, which could result in extra and
370 /// unnecessary work.
371 ///
372 /// [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
373 #[derive(Clone, Copy, Debug)]
374 pub(crate) struct NeonMoveMask(u64);
375
376 impl NeonMoveMask {
377 /// Get the mask in a form suitable for computing offsets.
378 ///
379 /// Basically, this normalizes to little endian. On big endian, this
380 /// swaps the bytes.
381 #[inline(always)]
382 fn get_for_offset(self) -> u64 {
383 #[cfg(target_endian = "big")]
384 {
385 self.0.swap_bytes()
386 }
387 #[cfg(target_endian = "little")]
388 {
389 self.0
390 }
391 }
392 }
393
394 impl MoveMask for NeonMoveMask {
395 #[inline(always)]
396 fn all_zeros_except_least_significant(n: usize) -> NeonMoveMask {
397 debug_assert!(n < 16);
398 NeonMoveMask(!(((1 << n) << 2) - 1))
399 }
400
401 #[inline(always)]
402 fn has_non_zero(self) -> bool {
403 self.0 != 0
404 }
405
406 #[inline(always)]
407 fn count_ones(self) -> usize {
408 self.0.count_ones() as usize
409 }
410
411 #[inline(always)]
412 fn and(self, other: NeonMoveMask) -> NeonMoveMask {
413 NeonMoveMask(self.0 & other.0)
414 }
415
416 #[inline(always)]
417 fn or(self, other: NeonMoveMask) -> NeonMoveMask {
418 NeonMoveMask(self.0 | other.0)
419 }
420
421 #[inline(always)]
422 fn clear_least_significant_bit(self) -> NeonMoveMask {
423 NeonMoveMask(self.0 & (self.0 - 1))
424 }
425
426 #[inline(always)]
427 fn first_offset(self) -> usize {
428 // We are dealing with little endian here (and if we aren't,
429 // we swap the bytes so we are in practice), where the most
430 // significant byte is at a higher address. That means the least
431 // significant bit that is set corresponds to the position of our
432 // first matching byte. That position corresponds to the number of
433 // zeros after the least significant bit.
434 //
435 // Note that unlike `SensibleMoveMask`, this mask has its bits
436 // spread out over 64 bits instead of 16 bits (for a 128 bit
437 // vector). Namely, where as x86-64 will turn
438 //
439 // 0x00 0xFF 0x00 0x00 0xFF
440 //
441 // into 10010, our neon approach will turn it into
442 //
443 // 10000000000010000000
444 //
445 // And this happens because neon doesn't have a native `movemask`
446 // instruction, so we kind of fake it[1]. Thus, we divide the
447 // number of trailing zeros by 4 to get the "real" offset.
448 //
449 // [1]: https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
450 (self.get_for_offset().trailing_zeros() >> 2) as usize
451 }
452
453 #[inline(always)]
454 fn last_offset(self) -> usize {
455 // See comment in `first_offset` above. This is basically the same,
456 // but coming from the other direction.
457 16 - (self.get_for_offset().leading_zeros() >> 2) as usize - 1
458 }
459 }
460}
461
462#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
463mod wasm_simd128 {
464 use core::arch::wasm32::*;
465
466 use super::{SensibleMoveMask, Vector};
467
468 impl Vector for v128 {
469 const BYTES: usize = 16;
470 const ALIGN: usize = Self::BYTES - 1;
471
472 type Mask = SensibleMoveMask;
473
474 #[inline(always)]
475 unsafe fn splat(byte: u8) -> v128 {
476 u8x16_splat(byte)
477 }
478
479 #[inline(always)]
480 unsafe fn load_aligned(data: *const u8) -> v128 {
481 *data.cast()
482 }
483
484 #[inline(always)]
485 unsafe fn load_unaligned(data: *const u8) -> v128 {
486 v128_load(data.cast())
487 }
488
489 #[inline(always)]
490 unsafe fn movemask(self) -> SensibleMoveMask {
491 SensibleMoveMask(u8x16_bitmask(self).into())
492 }
493
494 #[inline(always)]
495 unsafe fn cmpeq(self, vector2: Self) -> v128 {
496 u8x16_eq(self, vector2)
497 }
498
499 #[inline(always)]
500 unsafe fn and(self, vector2: Self) -> v128 {
501 v128_and(self, vector2)
502 }
503
504 #[inline(always)]
505 unsafe fn or(self, vector2: Self) -> v128 {
506 v128_or(self, vector2)
507 }
508 }
509}