httparse/simd/
swar.rs

1/// SWAR: SIMD Within A Register
2/// SIMD validator backend that validates register-sized chunks of data at a time.
3use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes};
4
5// Adapt block-size to match native register size, i.e: 32bit => 4, 64bit => 8
6const BLOCK_SIZE: usize = core::mem::size_of::<usize>();
7type ByteBlock = [u8; BLOCK_SIZE];
8
9#[inline]
10pub fn match_uri_vectored(bytes: &mut Bytes) {
11    loop {
12        if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
13            let n = match_uri_char_8_swar(bytes8);
14            // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
15            // in `bytes`, so calling `advance(n)` is safe.
16            unsafe {
17                bytes.advance(n);
18            }
19            if n == BLOCK_SIZE {
20                continue;
21            }
22        }
23        if let Some(b) = bytes.peek() {
24            if is_uri_token(b) {
25                // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte
26                // in bytes, so calling advance is safe.
27                unsafe {
28                    bytes.advance(1);
29                }
30                continue;
31            }
32        }
33        break;
34    }
35}
36
37#[inline]
38pub fn match_header_value_vectored(bytes: &mut Bytes) {
39    loop {
40        if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
41            let n = match_header_value_char_8_swar(bytes8);
42            // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
43            // in `bytes`, so calling `advance(n)` is safe.
44            unsafe {
45                bytes.advance(n);
46            }
47            if n == BLOCK_SIZE {
48                continue;
49            }
50        }
51        if let Some(b) = bytes.peek() {
52            if is_header_value_token(b) {
53                // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte
54                // in bytes, so calling advance is safe.
55                unsafe {
56                    bytes.advance(1);
57                }
58                continue;
59            }
60        }
61        break;
62    }
63}
64
65#[inline]
66pub fn match_header_name_vectored(bytes: &mut Bytes) {
67    while let Some(block) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
68        let n = match_block(is_header_name_token, block);
69        // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
70        // in `bytes`, so calling `advance(n)` is safe.
71        unsafe {
72            bytes.advance(n);
73        }
74        if n != BLOCK_SIZE {
75            return;
76        }
77    }
78    // SAFETY: match_tail processes at most the remaining data in `bytes`. advances `bytes` to the
79    // end, but no further.
80    unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) };
81}
82
83// Matches "tail", i.e: when we have <BLOCK_SIZE bytes in the buffer, should be uncommon
84#[cold]
85#[inline]
86fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize {
87    for (i, &b) in bytes.iter().enumerate() {
88        if !f(b) {
89            return i;
90        }
91    }
92    bytes.len()
93}
94
95// Naive fallback block matcher
96#[inline(always)]
97fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize {
98    for (i, &b) in block.iter().enumerate() {
99        if !f(b) {
100            return i;
101        }
102    }
103    BLOCK_SIZE
104}
105
106// A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
107// creates a u64 whose bytes are each equal to b
108const fn uniform_block(b: u8) -> usize {
109    (b as u64 *  0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize 
110}
111
112// A byte-wise range-check on an enire word/block,
113// ensuring all bytes in the word satisfy
114// `33 <= x <= 126 && x != '>' && x != '<'`
115// IMPORTANT: it false negatives if the block contains '?'
116#[inline]
117fn match_uri_char_8_swar(block: ByteBlock) -> usize {
118    // 33 <= x <= 126
119    const M: u8 = 0x21;
120    const N: u8 = 0x7E;
121    const BM: usize = uniform_block(M);
122    const BN: usize = uniform_block(127 - N);
123    const M128: usize = uniform_block(128);
124
125    let x = usize::from_ne_bytes(block); // Really just a transmute
126    let lt = x.wrapping_sub(BM) & !x; // <= m
127    let gt = x.wrapping_add(BN) | x; // >= n
128
129    // XOR checks to catch '<' & '>' for correctness
130    //
131    // XOR can be thought of as a "distance function"
132    // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0`
133    // (each u8 "xor key" providing a unique total ordering of u8)
134    // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`)
135    // xor(x, '>') <= 2 => {'>', '?', '<'}
136    // xor(x, '<') <= 2 => {'<', '=', '>'}
137    //
138    // We assume P('=') > P('?'),
139    // given well/commonly-formatted URLs with querystrings contain
140    // a single '?' but possibly many '='
141    //
142    // Thus it's preferable/near-optimal to "xor distance" on '>',
143    // since we'll slowpath at most one block per URL
144    //
145    // Some rust code to sanity check this yourself:
146    // ```rs
147    // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> {
148    //     (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect()
149    // }
150    // (xordist(b'<', 2), xordist(b'>', 2))
151    // ```
152    const B3: usize = uniform_block(3); // (dist <= 2) + 1 to wrap
153    const BGT: usize = uniform_block(b'>');
154
155    let xgt = x ^ BGT;
156    let ltgtq = xgt.wrapping_sub(B3) & !xgt;
157
158    offsetnz((ltgtq | lt | gt) & M128)
159}
160
161// A byte-wise range-check on an entire word/block,
162// ensuring all bytes in the word satisfy `32 <= x <= 126`
163// IMPORTANT: false negatives if obs-text is present (0x80..=0xFF)
164#[inline]
165fn match_header_value_char_8_swar(block: ByteBlock) -> usize {
166    // 32 <= x <= 126
167    const M: u8 = 0x20;
168    const N: u8 = 0x7E;
169    const BM: usize = uniform_block(M);
170    const BN: usize = uniform_block(127 - N);
171    const M128: usize = uniform_block(128);
172
173    let x = usize::from_ne_bytes(block); // Really just a transmute
174    let lt = x.wrapping_sub(BM) & !x; // <= m
175    let gt = x.wrapping_add(BN) | x; // >= n
176    offsetnz((lt | gt) & M128)
177}
178
179/// Check block to find offset of first non-zero byte
180// NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit
181#[inline]
182fn offsetnz(block: usize) -> usize {
183    // fast path optimistic case (common for long valid sequences)
184    if block == 0 {
185        return BLOCK_SIZE;
186    }
187
188    // perf: rust will unroll this loop
189    for (i, b) in block.to_ne_bytes().iter().copied().enumerate() {
190        if b != 0 {
191            return i;
192        }
193    }
194    unreachable!()
195}
196
197#[test]
198fn test_is_header_value_block() {
199    let is_header_value_block = |b| match_header_value_char_8_swar(b) == BLOCK_SIZE;
200
201    // 0..32 => false
202    for b in 0..32_u8 {
203        assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
204    }
205    // 32..127 => true
206    for b in 32..127_u8 {
207        assert!(is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
208    }
209    // 127..=255 => false
210    for b in 127..=255_u8 {
211        assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
212    }
213
214
215    #[cfg(target_pointer_width = "64")]
216    {
217        // A few sanity checks on non-uniform bytes for safe-measure
218        assert!(!is_header_value_block(*b"foo.com\n"));
219        assert!(!is_header_value_block(*b"o.com\r\nU"));
220    }
221}
222
223#[test]
224fn test_is_uri_block() {
225    let is_uri_block = |b| match_uri_char_8_swar(b) == BLOCK_SIZE;
226
227    // 0..33 => false
228    for b in 0..33_u8 {
229        assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b);
230    }
231    // 33..127 => true if b not in { '<', '?', '>' }
232    let falsy = |b| b"<?>".contains(&b);
233    for b in 33..127_u8 {
234        assert_eq!(is_uri_block([b; BLOCK_SIZE]), !falsy(b), "b={}", b);
235    }
236    // 127..=255 => false
237    for b in 127..=255_u8 {
238        assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b);
239    }
240}
241
242#[test]
243fn test_offsetnz() {
244    let seq = [0_u8; BLOCK_SIZE];
245    for i in 0..BLOCK_SIZE {
246        let mut seq = seq;
247        seq[i] = 1;
248        let x = usize::from_ne_bytes(seq);
249        assert_eq!(offsetnz(x), i);
250    }
251}