1pub(crate) fn decode_secret<'a>(input: &[u8], output: &'a mut [u8]) -> Result<&'a [u8], Error> {
17 decode(input, output, CodePoint::decode_secret)
18}
19
20pub(crate) fn decode_public<'a>(input: &[u8], output: &'a mut [u8]) -> Result<&'a [u8], Error> {
27 decode(input, output, CodePoint::decode_public)
28}
29
30pub(crate) const fn decoded_length(base64_len: usize) -> usize {
33 ((base64_len + 3) / 4) * 3
34}
35
36fn decode<'a>(
37 input: &[u8],
38 output: &'a mut [u8],
39 decode_byte: impl Fn(u8) -> CodePoint,
40) -> Result<&'a [u8], Error> {
41 let mut buffer = 0u64;
42 let mut used = 0;
43 let mut shift = SHIFT_INITIAL;
44 let mut pad_mask = 0;
45
46 let mut output_offset = 0;
47
48 const SHIFT_INITIAL: i32 = (8 - 1) * 6;
49
50 for byte in input.iter().copied() {
51 let (item, pad) = match decode_byte(byte) {
52 CodePoint::WHITESPACE => continue,
53 CodePoint::INVALID => return Err(Error::InvalidCharacter(byte)),
54 CodePoint::PAD => (0, 1),
55 CodePoint(n) => (n, 0),
56 };
57
58 if used == 8 {
61 if pad_mask != 0b0000_0000 {
62 return Err(Error::PrematurePadding);
63 }
64
65 let chunk = output
66 .get_mut(output_offset..output_offset + 6)
67 .ok_or(Error::InsufficientOutputSpace)?;
68
69 chunk[0] = (buffer >> 40) as u8;
70 chunk[1] = (buffer >> 32) as u8;
71 chunk[2] = (buffer >> 24) as u8;
72 chunk[3] = (buffer >> 16) as u8;
73 chunk[4] = (buffer >> 8) as u8;
74 chunk[5] = buffer as u8;
75
76 output_offset += 6;
77 buffer = 0;
78 used = 0;
79 pad_mask = 0;
80 shift = SHIFT_INITIAL;
81 }
82
83 buffer |= (item as u64) << shift;
84 shift -= 6;
85 pad_mask |= pad << used;
86 used += 1;
87 }
88
89 if used > 4 {
91 if pad_mask & 0b0000_1111 != 0 {
92 return Err(Error::PrematurePadding);
93 }
94 let chunk = output
95 .get_mut(output_offset..output_offset + 3)
96 .ok_or(Error::InsufficientOutputSpace)?;
97 chunk[0] = (buffer >> 40) as u8;
98 chunk[1] = (buffer >> 32) as u8;
99 chunk[2] = (buffer >> 24) as u8;
100
101 buffer <<= 24;
102 pad_mask >>= 4;
103 used -= 4;
104 output_offset += 3;
105 }
106
107 match (used, pad_mask) {
108 (0, 0b0000) => {}
110
111 (4, 0b0000) => {
113 let chunk = output
114 .get_mut(output_offset..output_offset + 3)
115 .ok_or(Error::InsufficientOutputSpace)?;
116 chunk[0] = (buffer >> 40) as u8;
117 chunk[1] = (buffer >> 32) as u8;
118 chunk[2] = (buffer >> 24) as u8;
119 output_offset += 3;
120 }
121
122 (4, 0b1000) | (3, 0b0000) => {
124 let chunk = output
125 .get_mut(output_offset..output_offset + 2)
126 .ok_or(Error::InsufficientOutputSpace)?;
127
128 chunk[0] = (buffer >> 40) as u8;
129 chunk[1] = (buffer >> 32) as u8;
130 output_offset += 2;
131 }
132
133 (4, 0b1100) | (2, 0b0000) => {
135 let chunk = output
136 .get_mut(output_offset..output_offset + 1)
137 .ok_or(Error::InsufficientOutputSpace)?;
138 chunk[0] = (buffer >> 40) as u8;
139 output_offset += 1;
140 }
141
142 _ => return Err(Error::InvalidTrailingPadding),
144 }
145
146 Ok(&output[..output_offset])
147}
148
149#[derive(Debug, PartialEq)]
150pub(crate) enum Error {
151 InvalidCharacter(u8),
153
154 PrematurePadding,
157
158 InvalidTrailingPadding,
160
161 InsufficientOutputSpace,
165}
166
167#[derive(Copy, Clone, Debug, Eq, PartialEq)]
168struct CodePoint(u8);
169
170impl CodePoint {
171 const WHITESPACE: Self = Self(0xf0);
172 const PAD: Self = Self(0xf1);
173 const INVALID: Self = Self(0xf2);
174}
175
176impl CodePoint {
177 fn decode_secret(b: u8) -> Self {
184 let is_upper = u8_in_range(b, b'A', b'Z');
185 let is_lower = u8_in_range(b, b'a', b'z');
186 let is_digit = u8_in_range(b, b'0', b'9');
187 let is_plus = u8_equals(b, b'+');
188 let is_slash = u8_equals(b, b'/');
189 let is_pad = u8_equals(b, b'=');
190 let is_space = u8_in_range(b, b'\t', b'\r') | u8_equals(b, b' ');
191
192 let is_invalid = !(is_lower | is_upper | is_digit | is_plus | is_slash | is_pad | is_space);
193
194 Self(
195 (is_upper & b.wrapping_sub(b'A'))
196 | (is_lower & (b.wrapping_sub(b'a').wrapping_add(26)))
197 | (is_digit & (b.wrapping_sub(b'0').wrapping_add(52)))
198 | (is_plus & 62)
199 | (is_slash & 63)
200 | (is_space & Self::WHITESPACE.0)
201 | (is_pad & Self::PAD.0)
202 | (is_invalid & Self::INVALID.0),
203 )
204 }
205
206 fn decode_public(a: u8) -> Self {
207 const TABLE: [CodePoint; 256] = [
208 CodePoint::INVALID,
210 CodePoint::INVALID,
211 CodePoint::INVALID,
212 CodePoint::INVALID,
213 CodePoint::INVALID,
214 CodePoint::INVALID,
215 CodePoint::INVALID,
216 CodePoint::INVALID,
217 CodePoint::INVALID,
218 CodePoint::WHITESPACE,
219 CodePoint::WHITESPACE,
220 CodePoint::WHITESPACE,
221 CodePoint::WHITESPACE,
222 CodePoint::WHITESPACE,
223 CodePoint::INVALID,
224 CodePoint::INVALID,
225 CodePoint::INVALID,
227 CodePoint::INVALID,
228 CodePoint::INVALID,
229 CodePoint::INVALID,
230 CodePoint::INVALID,
231 CodePoint::INVALID,
232 CodePoint::INVALID,
233 CodePoint::INVALID,
234 CodePoint::INVALID,
235 CodePoint::INVALID,
236 CodePoint::INVALID,
237 CodePoint::INVALID,
238 CodePoint::INVALID,
239 CodePoint::INVALID,
240 CodePoint::INVALID,
241 CodePoint::INVALID,
242 CodePoint::WHITESPACE,
244 CodePoint::INVALID,
245 CodePoint::INVALID,
246 CodePoint::INVALID,
247 CodePoint::INVALID,
248 CodePoint::INVALID,
249 CodePoint::INVALID,
250 CodePoint::INVALID,
251 CodePoint::INVALID,
252 CodePoint::INVALID,
253 CodePoint::INVALID,
254 CodePoint(62),
255 CodePoint::INVALID,
256 CodePoint::INVALID,
257 CodePoint::INVALID,
258 CodePoint(63),
259 CodePoint(52),
261 CodePoint(53),
262 CodePoint(54),
263 CodePoint(55),
264 CodePoint(56),
265 CodePoint(57),
266 CodePoint(58),
267 CodePoint(59),
268 CodePoint(60),
269 CodePoint(61),
270 CodePoint::INVALID,
271 CodePoint::INVALID,
272 CodePoint::INVALID,
273 CodePoint::PAD,
274 CodePoint::INVALID,
275 CodePoint::INVALID,
276 CodePoint::INVALID,
278 CodePoint(0),
279 CodePoint(1),
280 CodePoint(2),
281 CodePoint(3),
282 CodePoint(4),
283 CodePoint(5),
284 CodePoint(6),
285 CodePoint(7),
286 CodePoint(8),
287 CodePoint(9),
288 CodePoint(10),
289 CodePoint(11),
290 CodePoint(12),
291 CodePoint(13),
292 CodePoint(14),
293 CodePoint(15),
295 CodePoint(16),
296 CodePoint(17),
297 CodePoint(18),
298 CodePoint(19),
299 CodePoint(20),
300 CodePoint(21),
301 CodePoint(22),
302 CodePoint(23),
303 CodePoint(24),
304 CodePoint(25),
305 CodePoint::INVALID,
306 CodePoint::INVALID,
307 CodePoint::INVALID,
308 CodePoint::INVALID,
309 CodePoint::INVALID,
310 CodePoint::INVALID,
312 CodePoint(26),
313 CodePoint(27),
314 CodePoint(28),
315 CodePoint(29),
316 CodePoint(30),
317 CodePoint(31),
318 CodePoint(32),
319 CodePoint(33),
320 CodePoint(34),
321 CodePoint(35),
322 CodePoint(36),
323 CodePoint(37),
324 CodePoint(38),
325 CodePoint(39),
326 CodePoint(40),
327 CodePoint(41),
329 CodePoint(42),
330 CodePoint(43),
331 CodePoint(44),
332 CodePoint(45),
333 CodePoint(46),
334 CodePoint(47),
335 CodePoint(48),
336 CodePoint(49),
337 CodePoint(50),
338 CodePoint(51),
339 CodePoint::INVALID,
340 CodePoint::INVALID,
341 CodePoint::INVALID,
342 CodePoint::INVALID,
343 CodePoint::INVALID,
344 CodePoint::INVALID,
346 CodePoint::INVALID,
347 CodePoint::INVALID,
348 CodePoint::INVALID,
349 CodePoint::INVALID,
350 CodePoint::INVALID,
351 CodePoint::INVALID,
352 CodePoint::INVALID,
353 CodePoint::INVALID,
354 CodePoint::INVALID,
355 CodePoint::INVALID,
356 CodePoint::INVALID,
357 CodePoint::INVALID,
358 CodePoint::INVALID,
359 CodePoint::INVALID,
360 CodePoint::INVALID,
361 CodePoint::INVALID,
363 CodePoint::INVALID,
364 CodePoint::INVALID,
365 CodePoint::INVALID,
366 CodePoint::INVALID,
367 CodePoint::INVALID,
368 CodePoint::INVALID,
369 CodePoint::INVALID,
370 CodePoint::INVALID,
371 CodePoint::INVALID,
372 CodePoint::INVALID,
373 CodePoint::INVALID,
374 CodePoint::INVALID,
375 CodePoint::INVALID,
376 CodePoint::INVALID,
377 CodePoint::INVALID,
378 CodePoint::INVALID,
380 CodePoint::INVALID,
381 CodePoint::INVALID,
382 CodePoint::INVALID,
383 CodePoint::INVALID,
384 CodePoint::INVALID,
385 CodePoint::INVALID,
386 CodePoint::INVALID,
387 CodePoint::INVALID,
388 CodePoint::INVALID,
389 CodePoint::INVALID,
390 CodePoint::INVALID,
391 CodePoint::INVALID,
392 CodePoint::INVALID,
393 CodePoint::INVALID,
394 CodePoint::INVALID,
395 CodePoint::INVALID,
397 CodePoint::INVALID,
398 CodePoint::INVALID,
399 CodePoint::INVALID,
400 CodePoint::INVALID,
401 CodePoint::INVALID,
402 CodePoint::INVALID,
403 CodePoint::INVALID,
404 CodePoint::INVALID,
405 CodePoint::INVALID,
406 CodePoint::INVALID,
407 CodePoint::INVALID,
408 CodePoint::INVALID,
409 CodePoint::INVALID,
410 CodePoint::INVALID,
411 CodePoint::INVALID,
412 CodePoint::INVALID,
414 CodePoint::INVALID,
415 CodePoint::INVALID,
416 CodePoint::INVALID,
417 CodePoint::INVALID,
418 CodePoint::INVALID,
419 CodePoint::INVALID,
420 CodePoint::INVALID,
421 CodePoint::INVALID,
422 CodePoint::INVALID,
423 CodePoint::INVALID,
424 CodePoint::INVALID,
425 CodePoint::INVALID,
426 CodePoint::INVALID,
427 CodePoint::INVALID,
428 CodePoint::INVALID,
429 CodePoint::INVALID,
431 CodePoint::INVALID,
432 CodePoint::INVALID,
433 CodePoint::INVALID,
434 CodePoint::INVALID,
435 CodePoint::INVALID,
436 CodePoint::INVALID,
437 CodePoint::INVALID,
438 CodePoint::INVALID,
439 CodePoint::INVALID,
440 CodePoint::INVALID,
441 CodePoint::INVALID,
442 CodePoint::INVALID,
443 CodePoint::INVALID,
444 CodePoint::INVALID,
445 CodePoint::INVALID,
446 CodePoint::INVALID,
448 CodePoint::INVALID,
449 CodePoint::INVALID,
450 CodePoint::INVALID,
451 CodePoint::INVALID,
452 CodePoint::INVALID,
453 CodePoint::INVALID,
454 CodePoint::INVALID,
455 CodePoint::INVALID,
456 CodePoint::INVALID,
457 CodePoint::INVALID,
458 CodePoint::INVALID,
459 CodePoint::INVALID,
460 CodePoint::INVALID,
461 CodePoint::INVALID,
462 CodePoint::INVALID,
463 CodePoint::INVALID,
465 CodePoint::INVALID,
466 CodePoint::INVALID,
467 CodePoint::INVALID,
468 CodePoint::INVALID,
469 CodePoint::INVALID,
470 CodePoint::INVALID,
471 CodePoint::INVALID,
472 CodePoint::INVALID,
473 CodePoint::INVALID,
474 CodePoint::INVALID,
475 CodePoint::INVALID,
476 CodePoint::INVALID,
477 CodePoint::INVALID,
478 CodePoint::INVALID,
479 CodePoint::INVALID,
480 ];
481
482 TABLE[a as usize]
483 }
484}
485
486fn u8_in_range(a: u8, lo: u8, hi: u8) -> u8 {
491 debug_assert!(lo <= hi);
492 debug_assert!(hi - lo != 255);
493 let a = a.wrapping_sub(lo);
494 u8_less_than(a, (hi - lo).wrapping_add(1))
495}
496
497fn u8_less_than(a: u8, b: u8) -> u8 {
499 let a = u16::from(a);
500 let b = u16::from(b);
501 u8_broadcast16(a.wrapping_sub(b))
502}
503
504fn u8_equals(a: u8, b: u8) -> u8 {
506 let diff = a ^ b;
507 u8_nonzero(diff)
508}
509
510fn u8_nonzero(x: u8) -> u8 {
512 u8_broadcast8(!x & x.wrapping_sub(1))
513}
514
515fn u8_broadcast8(x: u8) -> u8 {
520 let msb = x >> 7;
521 0u8.wrapping_sub(msb)
522}
523
524fn u8_broadcast16(x: u16) -> u8 {
529 let msb = x >> 15;
530 0u8.wrapping_sub(msb as u8)
531}
532
533#[cfg(all(test, feature = "alloc"))]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn decode_test() {
539 assert_eq!(
540 decode(b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"),
541 b"\x00\x10\x83\x10\x51\x87\x20\x92\x8b\x30\xd3\x8f\x41\x14\x93\x51\x55\x97\
542 \x61\x96\x9b\x71\xd7\x9f\x82\x18\xa3\x92\x59\xa7\xa2\x9a\xab\xb2\xdb\xaf\
543 \xc3\x1c\xb3\xd3\x5d\xb7\xe3\x9e\xbb\xf3\xdf\xbf"
544 );
545 assert_eq!(decode(b"aGVsbG8="), b"hello");
546 assert_eq!(decode(b"aGVsbG8gd29ybGQ="), b"hello world");
547 assert_eq!(decode(b"aGVsbG8gd29ybGQh"), b"hello world!");
548 assert_eq!(decode(b"////"), b"\xff\xff\xff");
549 assert_eq!(decode(b"++++"), b"\xfb\xef\xbe");
550 assert_eq!(decode(b"AAAA"), b"\x00\x00\x00");
551 assert_eq!(decode(b"AAA="), b"\x00\x00");
552 assert_eq!(decode(b"AA=="), b"\x00");
553
554 assert_eq!(decode(b"AAA"), b"\x00\x00");
557 assert_eq!(decode(b"AA"), b"\x00");
558
559 assert_eq!(decode(b""), b"");
560 }
561
562 #[test]
563 fn decode_errors() {
564 let mut buf = [0u8; 6];
565
566 assert_eq!(
568 decode_both(b"A===", &mut buf),
569 Err(Error::InvalidTrailingPadding)
570 );
571 assert_eq!(
572 decode_both(b"====", &mut buf),
573 Err(Error::InvalidTrailingPadding)
574 );
575 assert_eq!(
576 decode_both(b"A==", &mut buf),
577 Err(Error::InvalidTrailingPadding)
578 );
579 assert_eq!(
580 decode_both(b"AA=", &mut buf),
581 Err(Error::InvalidTrailingPadding)
582 );
583 assert_eq!(
584 decode_both(b"A", &mut buf),
585 Err(Error::InvalidTrailingPadding)
586 );
587
588 assert_eq!(
590 decode_both(b"=AAAAA==", &mut buf),
591 Err(Error::PrematurePadding)
592 );
593 assert_eq!(
594 decode_both(b"A=AAAA==", &mut buf),
595 Err(Error::PrematurePadding)
596 );
597 assert_eq!(
598 decode_both(b"AA=AAA==", &mut buf),
599 Err(Error::PrematurePadding)
600 );
601 assert_eq!(
602 decode_both(b"AAA=AA==", &mut buf),
603 Err(Error::PrematurePadding)
604 );
605
606 assert_eq!(
608 decode_both(b"%AAA", &mut buf),
609 Err(Error::InvalidCharacter(b'%'))
610 );
611 assert_eq!(
612 decode_both(b"A%AA", &mut buf),
613 Err(Error::InvalidCharacter(b'%'))
614 );
615 assert_eq!(
616 decode_both(b"AA%A", &mut buf),
617 Err(Error::InvalidCharacter(b'%'))
618 );
619 assert_eq!(
620 decode_both(b"AAA%", &mut buf),
621 Err(Error::InvalidCharacter(b'%'))
622 );
623
624 assert_eq!(decode_both(b"am9lIGJw", &mut [0u8; 7]), Ok(&b"joe bp"[..]));
626 assert_eq!(decode_both(b"am9lIGJw", &mut [0u8; 6]), Ok(&b"joe bp"[..]));
627 assert_eq!(
628 decode_both(b"am9lIGJw", &mut [0u8; 5]),
629 Err(Error::InsufficientOutputSpace)
630 );
631 assert_eq!(
632 decode_both(b"am9lIGJw", &mut [0u8; 4]),
633 Err(Error::InsufficientOutputSpace)
634 );
635 assert_eq!(
636 decode_both(b"am9lIGJw", &mut [0u8; 3]),
637 Err(Error::InsufficientOutputSpace)
638 );
639
640 assert_eq!(decode_both(b"am9=", &mut [0u8; 2]), Ok(&b"jo"[..]));
642 assert_eq!(decode_both(b"am==", &mut [0u8; 1]), Ok(&b"j"[..]));
643 assert_eq!(decode_both(b"am9", &mut [0u8; 2]), Ok(&b"jo"[..]));
644 assert_eq!(decode_both(b"am", &mut [0u8; 1]), Ok(&b"j"[..]));
645 }
646
647 #[test]
648 fn check_models() {
649 fn u8_broadcast8_model(x: u8) -> u8 {
650 match x & 0x80 {
651 0x80 => 0xff,
652 _ => 0x00,
653 }
654 }
655
656 fn u8_broadcast16_model(x: u16) -> u8 {
657 match x & 0x8000 {
658 0x8000 => 0xff,
659 _ => 0x00,
660 }
661 }
662
663 fn u8_nonzero_model(x: u8) -> u8 {
664 match x {
665 0 => 0xff,
666 _ => 0x00,
667 }
668 }
669
670 fn u8_equals_model(x: u8, y: u8) -> u8 {
671 match x == y {
672 true => 0xff,
673 false => 0x00,
674 }
675 }
676
677 fn u8_in_range_model(x: u8, y: u8, z: u8) -> u8 {
678 match (y..=z).contains(&x) {
679 true => 0xff,
680 false => 0x00,
681 }
682 }
683
684 for x in u8::MIN..=u8::MAX {
685 assert_eq!(u8_broadcast8(x), u8_broadcast8_model(x));
686 assert_eq!(u8_nonzero(x), u8_nonzero_model(x));
687 assert_eq!(CodePoint::decode_secret(x), CodePoint::decode_public(x));
688
689 for y in u8::MIN..=u8::MAX {
690 assert_eq!(u8_equals(x, y), u8_equals_model(x, y));
691
692 let v = (x as u16) | ((y as u16) << 8);
693 assert_eq!(u8_broadcast16(v), u8_broadcast16_model(v));
694
695 for z in y..=u8::MAX {
696 if z - y == 255 {
697 continue;
698 }
699 assert_eq!(u8_in_range(x, y, z), u8_in_range_model(x, y, z));
700 }
701 }
702 }
703 }
704
705 #[cfg(all(feature = "std", target_os = "linux", target_arch = "x86_64"))]
706 #[test]
707 fn codepoint_decode_secret_does_not_branch_or_index_on_secret_input() {
708 use crabgrind as cg;
710
711 if matches!(cg::run_mode(), cg::RunMode::Native) {
712 std::println!("SKIPPED: must be run under valgrind");
713 return;
714 }
715
716 let input = [b'a'];
717 cg::monitor_command(format!(
718 "make_memory undefined {:p} {}",
719 input.as_ptr(),
720 input.len()
721 ))
722 .unwrap();
723
724 core::hint::black_box(CodePoint::decode_secret(input[0]));
725 }
726
727 #[track_caller]
728 fn decode(input: &[u8]) -> alloc::vec::Vec<u8> {
729 let length = decoded_length(input.len());
730
731 let mut v = alloc::vec![0u8; length];
732 let used = decode_both(input, &mut v).unwrap().len();
733 v.truncate(used);
734
735 v
736 }
737
738 fn decode_both<'a>(input: &'_ [u8], output: &'a mut [u8]) -> Result<&'a [u8], Error> {
739 let mut output_copy = output.to_vec();
740 let r_pub = decode_public(input, &mut output_copy);
741
742 let r_sec = decode_secret(input, output);
743
744 assert_eq!(r_pub, r_sec);
745
746 r_sec
747 }
748}